Embedding¶
- class Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=None, initial_weight=None, freeze=False, **kwargs)[源代码]¶
一个简单的查询表,存储具有固定大小的词向量(embedding)于固定的词典中。
该模块通常用于存储词向量(word embeddings),并使用索引来检索。输入索引列表到模块中,则输出对应的词向量。索引值应小于num_embeddings。
- 参数
实际案例
>>> import numpy as np >>> weight = mge.tensor(np.array([(1.2,2.3,3.4,4.5,5.6)], dtype=np.float32)) >>> data = mge.tensor(np.array([(0,0)], dtype=np.int32)) >>> embedding = M.Embedding(1, 5, initial_weight=weight) >>> output = embedding(data) >>> with np.printoptions(precision=6): ... print(output.numpy()) [[[1.2 2.3 3.4 4.5 5.6] [1.2 2.3 3.4 4.5 5.6]]]
- classmethod from_pretrained(embeddings, freeze=True, padding_idx=None, max_norm=None, norm_type=None)[源代码]¶
从给定的2维FloatTensor创建词向量实例。
- 参数
实际案例
>>> import numpy as np >>> weight = mge.tensor(np.array([(1.2,2.3,3.4,4.5,5.6)], dtype=np.float32)) >>> data = mge.tensor(np.array([(0,0)], dtype=np.int32)) >>> embedding = M.Embedding.from_pretrained(weight, freeze=False) >>> output = embedding(data) >>> output.numpy() array([[[1.2, 2.3, 3.4, 4.5, 5.6], [1.2, 2.3, 3.4, 4.5, 5.6]]], dtype=float32)