From 024b52d6e6cad49e5297549a330c200d1790ccca Mon Sep 17 00:00:00 2001 From: huangdongrun Date: Wed, 24 Jun 2020 10:51:52 +0800 Subject: [PATCH] add inputtoattr prim to white list --- mindspore/ccsrc/operator/ops.cc | 1 + mindspore/ccsrc/operator/ops.h | 1 + .../ccsrc/optimizer/irpass/branch_culling.cc | 36 ++++++++++++------- 3 files changed, 26 insertions(+), 12 deletions(-) diff --git a/mindspore/ccsrc/operator/ops.cc b/mindspore/ccsrc/operator/ops.cc index e6545d311c..88001bf63f 100755 --- a/mindspore/ccsrc/operator/ops.cc +++ b/mindspore/ccsrc/operator/ops.cc @@ -231,6 +231,7 @@ const PrimitivePtr kPrimFakeBprop = std::make_shared("fake_bprop"); const PrimitivePtr kPrimBpropCut = std::make_shared("bprop_cut"); const PrimitivePtr kPrimFakeQuantPerLayer = std::make_shared("FakeQuantPerLayer"); const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared("FakeQuantPerChannel"); +const PrimitivePtr kPrimApplyRMSProp = std::make_shared("ApplyRMSProp"); // Other miscellaneous const PrimitivePtr kPrimIdentity = std::make_shared("identity"); diff --git a/mindspore/ccsrc/operator/ops.h b/mindspore/ccsrc/operator/ops.h index 01812a5529..ac35a8e2bd 100755 --- a/mindspore/ccsrc/operator/ops.h +++ b/mindspore/ccsrc/operator/ops.h @@ -242,6 +242,7 @@ extern const PrimitivePtr kPrimFakeBprop; extern const PrimitivePtr kPrimBpropCut; extern const PrimitivePtr kPrimFakeQuantPerLayer; extern const PrimitivePtr kPrimFakeQuantPerChannel; +extern const PrimitivePtr kPrimApplyRMSProp; // Other Miscellaneous extern const PrimitivePtr kPrimIdentity; diff --git a/mindspore/ccsrc/optimizer/irpass/branch_culling.cc b/mindspore/ccsrc/optimizer/irpass/branch_culling.cc index 0253cd2b39..a7254c6e32 100644 --- a/mindspore/ccsrc/optimizer/irpass/branch_culling.cc +++ b/mindspore/ccsrc/optimizer/irpass/branch_culling.cc @@ -51,18 +51,30 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) { // node because it is attribute or ge specific reason. // Example : when convert CNode(kPrimReduceSum, x, axis), node of index 2 in CNode->inputs is axis which should not be // converted to switch guarded. - std::vector>> white_list( - {{prim::kPrimApplyMomentum, {1, 2}}, {prim::kPrimMomentum, {2, 3}}, - {prim::kPrimStateSetItem, {1}}, {prim::kPrimTupleGetItem, {2}}, - {prim::kPrimEnvGetItem, {1}}, {prim::kPrimEnvSetItem, {1}}, - {prim::kPrimReduceSum, {2}}, {prim::kPrimReduceMean, {2}}, - {prim::kPrimReduceAll, {2}}, {prim::kPrimCast, {2}}, - {prim::kPrimTranspose, {2}}, {prim::kPrimOneHot, {2}}, - {prim::kPrimGatherV2, {3}}, {prim::kPrimReshape, {2}}, - {prim::kPrimAssign, {1}}, {prim::kPrimAssignAdd, {1}}, - {prim::kPrimAssignSub, {1}}, {prim::kPrimTensorSummary, {1}}, - {prim::kPrimImageSummary, {1}}, {prim::kPrimScalarSummary, {1}}, - {prim::kPrimHistogramSummary, {1}}}); + std::vector>> white_list({{prim::kPrimApplyMomentum, {1, 2}}, + {prim::kPrimMomentum, {2, 3}}, + {prim::kPrimStateSetItem, {1}}, + {prim::kPrimTupleGetItem, {2}}, + {prim::kPrimEnvGetItem, {1}}, + {prim::kPrimEnvSetItem, {1}}, + {prim::kPrimReduceSum, {2}}, + {prim::kPrimReduceMean, {2}}, + {prim::kPrimReduceAll, {2}}, + {prim::kPrimCast, {2}}, + {prim::kPrimTranspose, {2}}, + {prim::kPrimOneHot, {2}}, + {prim::kPrimGatherV2, {3}}, + {prim::kPrimReshape, {2}}, + {prim::kPrimAssign, {1}}, + {prim::kPrimAssignAdd, {1}}, + {prim::kPrimAssignSub, {1}}, + {prim::kPrimTensorSummary, {1}}, + {prim::kPrimImageSummary, {1}}, + {prim::kPrimScalarSummary, {1}}, + {prim::kPrimApplyRMSProp, {6, 7, 8}}, + {prim::kPrimCumSum, {2}}, + {prim::kPrimTile, {2}}, + {prim::kPrimHistogramSummary, {1}}}); for (auto &item : white_list) { auto matched = std::any_of(item.second.begin(), item.second.end(), [&item, &node, &index](size_t idx) { return IsPrimitiveCNode(node, item.first) && idx == index;