当前位置: 首页> 文旅> 文化 > 动漫游戏制作专业学什么_邢台同城交友_百度网盘客服电话人工服务_网络推广公司运作

动漫游戏制作专业学什么_邢台同城交友_百度网盘客服电话人工服务_网络推广公司运作

时间:2025/7/29 22:18:58来源:https://blog.csdn.net/taoqick/article/details/147144149 浏览次数:0次
动漫游戏制作专业学什么_邢台同城交友_百度网盘客服电话人工服务_网络推广公司运作

GQA

import torch
import torch.nn as nn
import math
from einops import rearrangeclass MyGQA(nn.Module):def __init__(self, nheads, dim, ngroups):super().__init__()self.head_dim = dim // nheadsself.nheads = nheadsself.dim = dimself.ngroups = ngroupsself.heads_per_group = nheads // ngroups# dim = self.head_dim * nheads# dim = self.head_dim * self.heads_per_group * ngroupsself.q_proj = nn.Linear(dim, dim)self.k_proj = nn.Linear(dim, dim // self.heads_per_group)self.v_proj = nn.Linear(dim, dim // self.heads_per_group)self.o_proj = nn.Linear(dim, dim)self.ln = nn.LayerNorm(dim)def forward(self, query, key, value, attn_mask = None):bs,q_len,dim = query.shape# q = self.q_proj(query).reshape(bs, q_len, self.nheads, self.head_dim).transpose(1,2).reshape(bs, self.nheads, q_len, self.head_dim)# k = self.k_proj(key).repeat_interleave(self.heads_per_group, dim=0).reshape(bs, self.heads_per_group, q_len, self.ngroups, self.head_dim).transpose(2,3).reshape(bs, self.nheads, q_len, self.head_dim)# v = self.v_proj(value).repeat_interleave(self.heads_per_group, dim=0).reshape(bs, self.heads_per_group, q_len, self.ngroups, self.head_dim).transpose(2,3).reshape(bs, self.nheads, q_len, self.head_dim)q = rearrange(self.q_proj(query), 'b l (head k) -> b head l k', head=self.nheads)k = rearrange(self.k_proj(key).repeat_interleave(self.heads_per_group, dim=0), '(b heads_per_group) l (ngroups k) -> b (heads_per_group ngroups) l k', heads_per_group=self.heads_per_group, ngroups=self.ngroups)v = rearrange(self.v_proj(value).repeat_interleave(self.heads_per_group, dim=0), '(b heads_per_group) l (ngroups k) -> b (heads_per_group ngroups) l k', heads_per_group=self.heads_per_group, ngroups=self.ngroups)attn = torch.matmul(q, k.transpose(-1,-2)) / math.sqrt(self.head_dim)if attn_mask is not None:attn = attn.masked_fill(attn_mask == 0, float('-inf'))attn = attn.softmax(dim=-1)output = torch.matmul(attn, v) # bs,nheads,q_len,head_dimoutput = self.o_proj(rearrange(output, 'b head l k -> b l (head k)'))# output = self.o_proj(output.transpose(1,2).reshape(bs, q_len, self.nheads*self.head_dim))return output, attnclass MyGQA2(nn.Module):def __init__(self, nheads, dim, ngroups):super().__init__()self.head_dim = dim // nheadsself.nheads = nheadsself.dim = dimself.ngroups = ngroupsself.heads_per_group = nheads // ngroups# dim = self.head_dim * nheads# dim = self.head_dim * self.heads_per_group * ngroupsself.q_proj = nn.Linear(dim, dim)self.k_proj = nn.Linear(dim, dim // self.heads_per_group)self.v_proj = nn.Linear(dim, dim // self.heads_per_group)self.o_proj = nn.Linear(dim, dim)def forward(self, query, key, value, attn_mask = None):bs,q_len,dim = query.shapeq = self.q_proj(query).reshape(bs, q_len, self.nheads, self.head_dim).transpose(1,2).reshape(bs, self.nheads, q_len, self.head_dim)k = self.k_proj(key).repeat_interleave(self.heads_per_group, dim=0).reshape(bs, self.heads_per_group, q_len, self.ngroups, self.head_dim).transpose(2,3).reshape(bs, self.nheads, q_len, self.head_dim)v = self.v_proj(value).repeat_interleave(self.heads_per_group, dim=0).reshape(bs, self.heads_per_group, q_len, self.ngroups, self.head_dim).transpose(2,3).reshape(bs, self.nheads, q_len, self.head_dim)attn = torch.matmul(q, k.transpose(-1,-2)) / math.sqrt(self.head_dim)if attn_mask is not None:attn = attn.masked_fill(attn_mask == 0, float('-inf'))attn = attn.softmax(dim=-1)output = torch.matmul(attn, v) # bs,nheads,q_len,head_dimoutput = self.o_proj(output.transpose(1,2).reshape(bs, q_len, self.nheads*self.head_dim))return output, attnif __name__ == '__main__':embed_dim,num_heads,num_groups=256,8,4q_len,bs = 2,3query = torch.randn(bs, q_len, embed_dim)key = torch.randn(bs, q_len, embed_dim)value = torch.randn(bs, q_len, embed_dim)my_multihead_attn = MyGQA(num_heads, embed_dim, num_groups)for param in my_multihead_attn.parameters():param.data.fill_(0.1)my_attn_output, my_attn_output_weights = my_multihead_attn(query, key, value)print('my_attn_output={}'.format(my_attn_output))my_multihead_attn2 = MyGQA2(num_heads, embed_dim, num_groups)for param in my_multihead_attn2.parameters():param.data.fill_(0.1)my_attn_output2, my_attn_output_weights2 = my_multihead_attn2(query, key, value)print('my_attn_output2={}'.format(my_attn_output2))max_diff = torch.max(torch.abs(my_attn_output - my_attn_output2)).item()print(torch.equal(my_attn_output_weights, my_attn_output_weights2))print('max_diff={}'.format(max_diff))

参考文档

  • https://einops.rocks/pytorch-examples.html
关键字:动漫游戏制作专业学什么_邢台同城交友_百度网盘客服电话人工服务_网络推广公司运作

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

责任编辑: