Browse Source

!4280 adapt input to attr between cpu and aicpu embeddinglookup

Merge pull request !4280 from wuxuejian/embedding_input_adapt
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
196fdbe16e
2 changed files with 7 additions and 1 deletions
  1. +6
    -0
      mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc
  2. +1
    -1
      tests/st/ops/cpu/test_embedding_look_up_op.py

+ 6
- 0
mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc View File

@@ -50,6 +50,12 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An
if (!ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(AnfAlgo::GetCNodeName(cnode), &reg)) { if (!ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(AnfAlgo::GetCNodeName(cnode), &reg)) {
continue; continue;
} }
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimEmbeddingLookup->name() ||
AnfAlgo::GetCNodeName(cnode) == prim::kPrimEmbeddingLookupCommGrad->name()) {
if (!AnfAlgo::HasNodeAttr(kAttrPrimitiveTarget, cnode)) {
continue;
}
}
ConstInputToAttr(cnode, reg.GetConstInputAttrInfo()); ConstInputToAttr(cnode, reg.GetConstInputAttrInfo());
} }
return node; return node;


+ 1
- 1
tests/st/ops/cpu/test_embedding_look_up_op.py View File

@@ -26,7 +26,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self, offset): def __init__(self, offset):
super(Net, self).__init__() super(Net, self).__init__()
self.embedding = P.EmbeddingLookup()
self.embedding = P.EmbeddingLookup().add_prim_attr("primitive_target", "CPU")
self.offset = offset self.offset = offset


def construct(self, param, index): def construct(self, param, index):


Loading…
Cancel
Save