当前位置: 首页> 教育> 高考 > 【HuggingFace Transformers】input_ids与inputs_embeds的区别

【HuggingFace Transformers】input_ids与inputs_embeds的区别

时间:2025/7/8 20:34:46来源:https://blog.csdn.net/weixin_47936614/article/details/140356422 浏览次数:0次

input_ids与inputs_embeds的区别

  • input_ids
  • inputs_embeds
  • 代码示例
  • 总结

在使用 BERT 模型时,input_idsinputs_embeds 都是用于表示输入数据的,但它们有不同的用途和数据格式。以下是它们的区别和详细解释:

input_ids

定义input_ids 是词的索引序列。

类型torch.Tensor,包含整数索引。

用途:直接将文本分词后得到的词索引作为模型的输入。

示例

from transformers import BertTokenizer, BertModel# 加载预训练的 BERT 分词器
model_name = 'bert-base-chinese'
tokenizer = BertTokenizer.from_pretrained(model_name)# 示例文本
text = "我爱学习"
inputs = tokenizer(text, truncation=True, padding=True, return_tensors='pt')
input_ids = inputs['input_ids']
print(input_ids)  # shape为(1, 序列长度)的张量,包含每个词在词表中的索引

输出:

tensor([[ 101, 2769, 4263, 2110,  739,  102]])

inputs_embeds

定义inputs_embeds 是词的嵌入表示。

类型torch.Tensor,包含浮点数的向量。

用途:当已经有了词的嵌入表示时,可以直接将这些嵌入作为模型的输入,而不是使用 input_ids 由模型内部的嵌入层进行转换。

示例

# -*- coding: utf-8 -*-
import torch
from transformers import BertTokenizer, BertModel# 加载预训练的 BERT 模型和分词器
model_name = 'bert-base-chinese'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)# 示例文本
text = "我爱学习"
inputs = tokenizer(text, truncation=True, padding=True, return_tensors='pt')
input_ids = inputs['input_ids']# 获取输入的嵌入表示
with torch.no_grad():inputs_embeds = model.embeddings(input_ids)print(inputs_embeds)
print(inputs_embeds.shape)  # shape为(1, 序列长度, 隐藏维度) 的张量,包含每个词的嵌入表示

输出:

tensor([[[ 0.0588,  0.0704, -0.2139,  ..., -0.0237, -0.2234, -0.1116],[ 0.4997, -0.2951,  1.0791,  ..., -0.9039,  0.2250, -0.4692],[ 0.4052, -0.3516, -0.7161,  ..., -0.6516,  0.4880, -0.0765],[-0.7654, -0.7289,  0.5831,  ..., -0.4602,  0.4385,  1.4989],[ 0.4207, -0.5057,  0.6421,  ..., -0.7269,  0.1094,  0.0440],[-0.1846,  0.3785,  0.4089,  ..., -0.7294,  0.0194,  0.5945]]])
torch.Size([1, 6, 768])

代码示例

以下是一个示例,展示如何分别使用 input_idsinputs_embeds 作为 BERT 模型的输入:

# -*- coding: utf-8 -*-
import torch
from transformers import BertTokenizer, BertModel# 加载预训练的 BERT 模型和分词器
model_name = 'bert-base-chinese'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)# 示例文本
text = "我爱学习"
inputs = tokenizer(text, truncation=True, padding=True, return_tensors='pt')
input_ids = inputs['input_ids']# 使用 input_ids 作为输入
with torch.no_grad():outputs_using_ids = model(input_ids=input_ids)
sequence_output_using_ids = outputs_using_ids.last_hidden_state# 使用 inputs_embeds 作为输入
with torch.no_grad():inputs_embeds = model.embeddings(input_ids)outputs_using_embeds = model(inputs_embeds=inputs_embeds)
sequence_output_using_embeds = outputs_using_embeds.last_hidden_state# 打印输出形状
print("使用 input_ids 的序列输出shape:", sequence_output_using_ids.shape)
print("使用 inputs_embeds 的序列输出shape:", sequence_output_using_embeds.shape)

输出:

使用 input_ids 的序列输出shape:torch.Size([1, 6, 768])
使用 inputs_embeds 的序列输出shape:torch.Size([1, 6, 768])

注意sequence_output_using_ids sequence_output_using_embedsshape是相同的,但是它们的tensor值并不相同,打印如下:

sequence_output_using_ids为:

tensor([[[ 0.0707,  0.1193, -0.1171,  ...,  0.8360,  0.1781, -0.3330],[ 0.3077,  0.0733, -0.0056,  ..., -0.9890, -0.0871, -0.1517],[ 1.0312, -0.3254, -1.0671,  ...,  0.1756,  0.2145, -0.1055],[ 0.1788, -0.1532, -0.9967,  ...,  0.4146,  0.1664, -0.4200],[ 0.8302, -0.5292, -0.7375,  ..., -0.1464,  0.2384,  0.2083],[ 0.3939,  0.1830, -0.2468,  ...,  0.8480,  0.1541,  0.1683]]])

sequence_output_using_embeds为:

tensor([[[ 0.0083,  0.1537, -0.1859,  ...,  0.8392,  0.2139, -0.3466],[ 0.2784,  0.1942, -0.0893,  ..., -1.0140, -0.0617, -0.2342],[ 1.0391, -0.3018, -1.0898,  ...,  0.0209,  0.2510, -0.1496],[ 0.1980, -0.2273, -1.1935,  ...,  0.2233,  0.2545, -0.4895],[ 0.7549, -0.4358, -0.4421,  ..., -0.0880,  0.3952,  0.3079],[ 0.5254,  0.2023, -0.1211,  ...,  0.8243,  0.0460,  0.2076]]])

总结

  • input_ids 是词的索引序列,用于标准的文本输入。
  • inputs_embeds 是词的嵌入表示,用于自定义嵌入或高级处理场景。
  • 两者都可以作为 BERT 模型的输入,但不能同时使用。
关键字:【HuggingFace Transformers】input_ids与inputs_embeds的区别

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

责任编辑: