Merge pull request !3068 from lichen/fix_embeddinglookuptags/v0.6.0-beta
| @@ -611,6 +611,12 @@ void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) { | |||||
| ScopePtr scope = node->scope(); | ScopePtr scope = node->scope(); | ||||
| MS_EXCEPTION_IF_NULL(scope); | MS_EXCEPTION_IF_NULL(scope); | ||||
| replace_node->set_scope(scope); | replace_node->set_scope(scope); | ||||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(replace_node->input(0)); | |||||
| if (prim->name() == EMBEDDING_LOOKUP) { | |||||
| auto attrs = prim->attrs(); | |||||
| attrs[TARGET] = MakeValue(CPU); | |||||
| (void)prim->SetAttrs(attrs); | |||||
| } | |||||
| if (index == replace_op.size() - 1) { | if (index == replace_op.size() - 1) { | ||||
| (void)replace_node->set_operator_info(node->operator_info()); | (void)replace_node->set_operator_info(node->operator_info()); | ||||
| } | } | ||||
| @@ -21,7 +21,7 @@ from mindspore.common.initializer import initializer | |||||
| from ..cell import Cell | from ..cell import Cell | ||||
| from ..._checkparam import Validator as validator | from ..._checkparam import Validator as validator | ||||
| __all__ = ['Embedding'] | |||||
| __all__ = ['Embedding', 'EmbeddingLookup'] | |||||
| class Embedding(Cell): | class Embedding(Cell): | ||||
| r""" | r""" | ||||
| @@ -147,7 +147,7 @@ class EmbeddingLookup(Cell): | |||||
| def construct(self, params, indices): | def construct(self, params, indices): | ||||
| if self.target == "CPU": | if self.target == "CPU": | ||||
| out = self.embeddinglookup(params, ids, 0) | |||||
| out = self.embeddinglookup(params, indices, 0) | |||||
| else: | else: | ||||
| out = self.gatherv2(param, ids, 0) | |||||
| out = self.gatherv2(params, indices, 0) | |||||
| return out | return out | ||||