英文:
What does the embedding elements stand for in huggingFace bert model?
问题
在BERT模型中,在将我的标记通过编码器之前,我想对它们的嵌入进行一些处理。我使用以下代码提取了嵌入权重:
from transformers import TFBertModel
# 载入预训练的BERT模型
model = TFBertModel.from_pretrained('bert-base-uncased')
# 获取模型的嵌入层
embedding_layer = model.get_layer('bert').get_input_embeddings()
# 提取嵌入权重
embedding_weights = embedding_layer.get_weights()
我发现它包含如图所示的5个元素。
点击查看图片
根据我的理解,前三个元素分别是词嵌入权重、标记类型嵌入权重和位置嵌入权重。我的问题是最后两个元素代表什么?
我深入研究了BERT模型的源代码。但我无法弄清楚最后两个元素的含义。
英文:
Prior to passing my tokens through encoder in BERT model, I would like to perform some processing on their embeddings. I extracted the embedding weight using:
from transformers import TFBertModel
# Load a pre-trained BERT model
model = TFBertModel.from_pretrained('bert-base-uncased')
# Get the embedding layer of the model
embedding_layer = model.get_layer('bert').get_input_embeddings()
# Extract the embedding weights
embedding_weights = embedding_layer.get_weights()
I found it contains 5 elements as shown in Figure.
enter image description here
In my understanding, the first three elements are the word embedding weights, token type embedding weights, and positional embedding weights. My question is what does the last two elements stand for?
I dive deep into the source code of bert model. But I cannot figure out the meaning of the last two elements.
答案1
得分: 0
在bert模型中,有一个嵌入张量的后处理步骤,使用层归一化后跟辍学,
https://github.com/google-research/bert/blob/eedf5716ce1268e56f0a50264a88cafad334ac61/modeling.py#L362
我认为这两个数组是规范化层的γ和β参数,https://www.tensorflow.org/api_docs/python/tf/keras/layers/LayerNormalization
它们是可学习参数,并将跨越在参数“axis”中指定的输入轴,该参数默认为-1(对应于嵌入张量中的768)。
英文:
In bert model, there is a post-processing of the embedding tensor that uses layer normalization followed by dropout ,
https://github.com/google-research/bert/blob/eedf5716ce1268e56f0a50264a88cafad334ac61/modeling.py#L362
I think that those two arrays are the gamma and beta of the normalization layer, https://www.tensorflow.org/api_docs/python/tf/keras/layers/LayerNormalization
They are learned parameters, and will span the axes of inputs specified in param "axis" which defaults to -1 (corresponding to 768 in embedding tensor).
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论