Browse Source

add inputtoattr prim to white list

tags/v0.6.0-beta
huangdongrun 5 years ago
parent
commit
024b52d6e6
3 changed files with 26 additions and 12 deletions
  1. +1
    -0
      mindspore/ccsrc/operator/ops.cc
  2. +1
    -0
      mindspore/ccsrc/operator/ops.h
  3. +24
    -12
      mindspore/ccsrc/optimizer/irpass/branch_culling.cc

+ 1
- 0
mindspore/ccsrc/operator/ops.cc View File

@@ -231,6 +231,7 @@ const PrimitivePtr kPrimFakeBprop = std::make_shared<Primitive>("fake_bprop");
const PrimitivePtr kPrimBpropCut = std::make_shared<Primitive>("bprop_cut"); const PrimitivePtr kPrimBpropCut = std::make_shared<Primitive>("bprop_cut");
const PrimitivePtr kPrimFakeQuantPerLayer = std::make_shared<Primitive>("FakeQuantPerLayer"); const PrimitivePtr kPrimFakeQuantPerLayer = std::make_shared<Primitive>("FakeQuantPerLayer");
const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared<Primitive>("FakeQuantPerChannel"); const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared<Primitive>("FakeQuantPerChannel");
const PrimitivePtr kPrimApplyRMSProp = std::make_shared<Primitive>("ApplyRMSProp");


// Other miscellaneous // Other miscellaneous
const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("identity"); const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("identity");


+ 1
- 0
mindspore/ccsrc/operator/ops.h View File

@@ -242,6 +242,7 @@ extern const PrimitivePtr kPrimFakeBprop;
extern const PrimitivePtr kPrimBpropCut; extern const PrimitivePtr kPrimBpropCut;
extern const PrimitivePtr kPrimFakeQuantPerLayer; extern const PrimitivePtr kPrimFakeQuantPerLayer;
extern const PrimitivePtr kPrimFakeQuantPerChannel; extern const PrimitivePtr kPrimFakeQuantPerChannel;
extern const PrimitivePtr kPrimApplyRMSProp;


// Other Miscellaneous // Other Miscellaneous
extern const PrimitivePtr kPrimIdentity; extern const PrimitivePtr kPrimIdentity;


+ 24
- 12
mindspore/ccsrc/optimizer/irpass/branch_culling.cc View File

@@ -51,18 +51,30 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) {
// node because it is attribute or ge specific reason. // node because it is attribute or ge specific reason.
// Example : when convert CNode(kPrimReduceSum, x, axis), node of index 2 in CNode->inputs is axis which should not be // Example : when convert CNode(kPrimReduceSum, x, axis), node of index 2 in CNode->inputs is axis which should not be
// converted to switch guarded. // converted to switch guarded.
std::vector<std::pair<PrimitivePtr, std::vector<size_t>>> white_list(
{{prim::kPrimApplyMomentum, {1, 2}}, {prim::kPrimMomentum, {2, 3}},
{prim::kPrimStateSetItem, {1}}, {prim::kPrimTupleGetItem, {2}},
{prim::kPrimEnvGetItem, {1}}, {prim::kPrimEnvSetItem, {1}},
{prim::kPrimReduceSum, {2}}, {prim::kPrimReduceMean, {2}},
{prim::kPrimReduceAll, {2}}, {prim::kPrimCast, {2}},
{prim::kPrimTranspose, {2}}, {prim::kPrimOneHot, {2}},
{prim::kPrimGatherV2, {3}}, {prim::kPrimReshape, {2}},
{prim::kPrimAssign, {1}}, {prim::kPrimAssignAdd, {1}},
{prim::kPrimAssignSub, {1}}, {prim::kPrimTensorSummary, {1}},
{prim::kPrimImageSummary, {1}}, {prim::kPrimScalarSummary, {1}},
{prim::kPrimHistogramSummary, {1}}});
std::vector<std::pair<PrimitivePtr, std::vector<size_t>>> white_list({{prim::kPrimApplyMomentum, {1, 2}},
{prim::kPrimMomentum, {2, 3}},
{prim::kPrimStateSetItem, {1}},
{prim::kPrimTupleGetItem, {2}},
{prim::kPrimEnvGetItem, {1}},
{prim::kPrimEnvSetItem, {1}},
{prim::kPrimReduceSum, {2}},
{prim::kPrimReduceMean, {2}},
{prim::kPrimReduceAll, {2}},
{prim::kPrimCast, {2}},
{prim::kPrimTranspose, {2}},
{prim::kPrimOneHot, {2}},
{prim::kPrimGatherV2, {3}},
{prim::kPrimReshape, {2}},
{prim::kPrimAssign, {1}},
{prim::kPrimAssignAdd, {1}},
{prim::kPrimAssignSub, {1}},
{prim::kPrimTensorSummary, {1}},
{prim::kPrimImageSummary, {1}},
{prim::kPrimScalarSummary, {1}},
{prim::kPrimApplyRMSProp, {6, 7, 8}},
{prim::kPrimCumSum, {2}},
{prim::kPrimTile, {2}},
{prim::kPrimHistogramSummary, {1}}});
for (auto &item : white_list) { for (auto &item : white_list) {
auto matched = std::any_of(item.second.begin(), item.second.end(), [&item, &node, &index](size_t idx) { auto matched = std::any_of(item.second.begin(), item.second.end(), [&item, &node, &index](size_t idx) {
return IsPrimitiveCNode(node, item.first) && idx == index; return IsPrimitiveCNode(node, item.first) && idx == index;


Loading…
Cancel
Save