|
|
|
@@ -38,10 +38,10 @@ class Embedding(Module): |
|
|
|
import numpy as np |
|
|
|
import megengine as mge |
|
|
|
import megengine.module as M |
|
|
|
weight = mge.tensor(np.array([(1.2,2.3,3.4,4.5,5.6),(0.1,1.1,2.1,3.1,4.1)], dtype=np.float32)) |
|
|
|
data = mge.tensor(np.array([(0,1,1),(1,0,1),(0,0,1)], dtype=np.int32)) |
|
|
|
weight = mge.tensor(np.array([(1.2,2.3,3.4,4.5,5.6)], dtype=np.float32)) |
|
|
|
data = mge.tensor(np.array([(0,0)], dtype=np.int32)) |
|
|
|
|
|
|
|
embedding = M.Embedding(2, 5, initial_weight=weight) |
|
|
|
embedding = M.Embedding(1, 5, initial_weight=weight) |
|
|
|
output = embedding(data) |
|
|
|
with np.printoptions(precision=6): |
|
|
|
print(output.numpy()) |
|
|
|
@@ -51,16 +51,7 @@ class Embedding(Module): |
|
|
|
.. testoutput:: |
|
|
|
|
|
|
|
[[[1.2 2.3 3.4 4.5 5.6] |
|
|
|
[0.1 1.1 2.1 3.1 4.1] |
|
|
|
[0.1 1.1 2.1 3.1 4.1]] |
|
|
|
|
|
|
|
[[0.1 1.1 2.1 3.1 4.1] |
|
|
|
[1.2 2.3 3.4 4.5 5.6] |
|
|
|
[0.1 1.1 2.1 3.1 4.1]] |
|
|
|
|
|
|
|
[[1.2 2.3 3.4 4.5 5.6] |
|
|
|
[1.2 2.3 3.4 4.5 5.6] |
|
|
|
[0.1 1.1 2.1 3.1 4.1]]] |
|
|
|
[1.2 2.3 3.4 4.5 5.6]]] |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
@@ -134,8 +125,8 @@ class Embedding(Module): |
|
|
|
import numpy as np |
|
|
|
import megengine as mge |
|
|
|
import megengine.module as M |
|
|
|
weight = mge.tensor(np.array([(1.2,2.3,3.4,4.5,5.6),(0.1,1.1,2.1,3.1,4.1)], dtype=np.float32)) |
|
|
|
data = mge.tensor(np.array([(0,1,1),(1,0,1),(0,0,1)], dtype=np.int32)) |
|
|
|
weight = mge.tensor(np.array([(1.2,2.3,3.4,4.5,5.6)], dtype=np.float32)) |
|
|
|
data = mge.tensor(np.array([(0,0)], dtype=np.int32)) |
|
|
|
|
|
|
|
embedding = M.Embedding.from_pretrained(weight, freeze=False) |
|
|
|
output = embedding(data) |
|
|
|
@@ -146,17 +137,7 @@ class Embedding(Module): |
|
|
|
.. testoutput:: |
|
|
|
|
|
|
|
[[[1.2 2.3 3.4 4.5 5.6] |
|
|
|
[0.1 1.1 2.1 3.1 4.1] |
|
|
|
[0.1 1.1 2.1 3.1 4.1]] |
|
|
|
|
|
|
|
[[0.1 1.1 2.1 3.1 4.1] |
|
|
|
[1.2 2.3 3.4 4.5 5.6] |
|
|
|
[0.1 1.1 2.1 3.1 4.1]] |
|
|
|
|
|
|
|
[[1.2 2.3 3.4 4.5 5.6] |
|
|
|
[1.2 2.3 3.4 4.5 5.6] |
|
|
|
[0.1 1.1 2.1 3.1 4.1]]] |
|
|
|
|
|
|
|
[1.2 2.3 3.4 4.5 5.6]]] |
|
|
|
|
|
|
|
""" |
|
|
|
embeddings_shape = embeddings.shape |
|
|
|
|