Browse Source

!3068 [AutoParallel]Fix EmbeddingLookup bug

Merge pull request !3068 from lichen/fix_embeddinglookup
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
7e5e868d97
2 changed files with 9 additions and 3 deletions
  1. +6
    -0
      mindspore/ccsrc/frontend/parallel/step_parallel.cc
  2. +3
    -3
      mindspore/nn/layer/embedding.py

+ 6
- 0
mindspore/ccsrc/frontend/parallel/step_parallel.cc View File

@@ -611,6 +611,12 @@ void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) {
ScopePtr scope = node->scope(); ScopePtr scope = node->scope();
MS_EXCEPTION_IF_NULL(scope); MS_EXCEPTION_IF_NULL(scope);
replace_node->set_scope(scope); replace_node->set_scope(scope);
PrimitivePtr prim = GetValueNode<PrimitivePtr>(replace_node->input(0));
if (prim->name() == EMBEDDING_LOOKUP) {
auto attrs = prim->attrs();
attrs[TARGET] = MakeValue(CPU);
(void)prim->SetAttrs(attrs);
}
if (index == replace_op.size() - 1) { if (index == replace_op.size() - 1) {
(void)replace_node->set_operator_info(node->operator_info()); (void)replace_node->set_operator_info(node->operator_info());
} }


+ 3
- 3
mindspore/nn/layer/embedding.py View File

@@ -21,7 +21,7 @@ from mindspore.common.initializer import initializer
from ..cell import Cell from ..cell import Cell
from ..._checkparam import Validator as validator from ..._checkparam import Validator as validator


__all__ = ['Embedding']
__all__ = ['Embedding', 'EmbeddingLookup']


class Embedding(Cell): class Embedding(Cell):
r""" r"""
@@ -147,7 +147,7 @@ class EmbeddingLookup(Cell):


def construct(self, params, indices): def construct(self, params, indices):
if self.target == "CPU": if self.target == "CPU":
out = self.embeddinglookup(params, ids, 0)
out = self.embeddinglookup(params, indices, 0)
else: else:
out = self.gatherv2(param, ids, 0)
out = self.gatherv2(params, indices, 0)
return out return out

Loading…
Cancel
Save