diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc index 2c24687c9e..153faea927 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc @@ -50,6 +50,12 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An if (!ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(AnfAlgo::GetCNodeName(cnode), ®)) { continue; } + if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimEmbeddingLookup->name() || + AnfAlgo::GetCNodeName(cnode) == prim::kPrimEmbeddingLookupCommGrad->name()) { + if (!AnfAlgo::HasNodeAttr(kAttrPrimitiveTarget, cnode)) { + continue; + } + } ConstInputToAttr(cnode, reg.GetConstInputAttrInfo()); } return node; diff --git a/tests/st/ops/cpu/test_embedding_look_up_op.py b/tests/st/ops/cpu/test_embedding_look_up_op.py index e7fb713fe5..911c0654d8 100644 --- a/tests/st/ops/cpu/test_embedding_look_up_op.py +++ b/tests/st/ops/cpu/test_embedding_look_up_op.py @@ -26,7 +26,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target="CPU") class Net(nn.Cell): def __init__(self, offset): super(Net, self).__init__() - self.embedding = P.EmbeddingLookup() + self.embedding = P.EmbeddingLookup().add_prim_attr("primitive_target", "CPU") self.offset = offset def construct(self, param, index):