XATTN 是 “Cross Attention” 的缩写,表示交叉注意力机制。这是一种在多模态模型中常用的机制,用于在不同模态(例如,视觉和文本)之间建立联系和融合信息。
交叉注意力机制(Cross Attention)
交叉注意力机制是 Transformer 中的一种变体,通常用于多模态任务,例如视觉问答、图像字幕生成等。它的主要作用是让一个模态(如文本)关注并融合另一个模态(如图像)的信息,从而实现更好的理解和生成。
基本概念
-
Query、Key、Value:
- Query(查询):来自一个模态的输入向量。
- Key(键)和 Value(值):来自另一个模态的输入向量。
-
计算注意力权重:
- 使用 Query 和 Key 计算注意力权重,表示 Query 对每个 Key 的相关性。
- 常用的注意力函数是点积注意力(Scaled Dot-Product Attention):
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dkQKT)V
其中, Q Q Q 是 Query, K K K 是 Key, V V V 是 Value, d k d_k dk 是 Key 的维度。
-
加权求和:
- 使用计算出的注意力权重对 Value 进行加权求和,得到融合后的表示。
交叉注意力的应用
在多模态任务中,交叉注意力机制允许模型在处理文本时参考图像信息,或者在处理图像时参考文本信息。例如:
-
图像字幕生成:
- 图像特征作为 Key 和 Value,文本特征作为 Query,通过交叉注意力机制生成描述图像的文本。
-
视觉问答:
- 问题文本特征作为 Query,图像特征作为 Key 和 Value,通过交叉注意力机制生成答案。
代码实现
import torch
import torch.nn.functional as F# 假设文本特征 T 和图像特征 I
T = torch.randn(32, 10, 512) # (batch_size, text_seq_len, feature_dim)
I = torch.randn(32, 20, 512) # (batch_size, image_seq_len, feature_dim)# 计算 Query, Key, Value
Q = T # Query 来自文本特征,形状 (batch_size, text_seq_len, d_k)
K = I # Key 来自图像特征,形状 (batch_size, image_seq_len, d_k)
V = I # Value 来自图像特征,形状 (batch_size, image_seq_len, d_v)# 获取特征维度
d_k = Q.size(-1) # d_k 是 Query 和 Key 的特征维度# 计算注意力得分
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))# 计算注意力权重
attention_weights = F.softmax(scores, dim=-1)# 加权求和值
cross_attention_output = torch.matmul(attention_weights, V)# 输出形状
print(cross_attention_output.shape) # 输出形状为 (batch_size, text_seq_len, d_v)
文本特征作为 Query,图像特征作为 Key 和 Value,通过交叉注意力机制计算得到融合后的表示。
Query来自文本特征,Key和Value来自图像特征。让我们逐步分析为什么输出的形状是 (batch_size, text_seq_len, d_v)
。
代码分析
-
输入张量:
T
是文本特征,形状为(batch_size, text_seq_len, feature_dim)
。I
是图像特征,形状为(batch_size, image_seq_len, feature_dim)
。
-
Query, Key, Value 的选择:
Q = T
:Query来自文本特征,其形状为(batch_size, text_seq_len, d_k)
。K = I
:Key来自图像特征,其形状为(batch_size, image_seq_len, d_k)
。V = I
:Value来自图像特征,其形状为(batch_size, image_seq_len, d_v)
。
-
计算注意力得分:
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
scores
的形状为(batch_size, text_seq_len, image_seq_len)
,因为它是通过将(batch_size, text_seq_len, d_k)
的Q
与(batch_size, d_k, image_seq_len)
的K
转置相乘得到的。
-
计算注意力权重:
attention_weights = F.softmax(scores, dim=-1)
attention_weights
的形状为(batch_size, text_seq_len, image_seq_len)
,因为对image_seq_len
维度进行了 softmax 计算。
-
加权求和值(输出):
cross_attention_output = torch.matmul(attention_weights, V)
- 这里,
attention_weights
的形状是(batch_size, text_seq_len, image_seq_len)
,V
的形状是(batch_size, image_seq_len, d_v)
。 - 矩阵乘法后,
cross_attention_output
的形状是(batch_size, text_seq_len, d_v)
。 - 这意味着对于每个文本序列中的每个词,您都计算了来自图像序列中所有元素的加权和,因而输出的序列长度是
text_seq_len
。
- 这里,
总结
输出的形状是 (batch_size, text_seq_len, d_v)
是因为在跨模态注意力机制中,文本特征的每个词(Query)通过注意力机制与图像特征(Key和Value)进行交互,得到加权求和的结果,因此输出的序列长度保持为 text_seq_len
。