|
|
|
@@ -21,7 +21,7 @@ from mindspore.common.initializer import initializer |
|
|
|
from ..cell import Cell |
|
|
|
from ..._checkparam import Validator as validator |
|
|
|
|
|
|
|
__all__ = ['Embedding'] |
|
|
|
__all__ = ['Embedding', 'EmbeddingLookup'] |
|
|
|
|
|
|
|
class Embedding(Cell): |
|
|
|
r""" |
|
|
|
@@ -147,7 +147,7 @@ class EmbeddingLookup(Cell): |
|
|
|
|
|
|
|
def construct(self, params, indices): |
|
|
|
if self.target == "CPU": |
|
|
|
out = self.embeddinglookup(params, ids, 0) |
|
|
|
out = self.embeddinglookup(params, indices, 0) |
|
|
|
else: |
|
|
|
out = self.gatherv2(param, ids, 0) |
|
|
|
out = self.gatherv2(params, indices, 0) |
|
|
|
return out |