|
|
@@ -51,18 +51,30 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) { |
|
|
// node because it is attribute or ge specific reason. |
|
|
// node because it is attribute or ge specific reason. |
|
|
// Example : when convert CNode(kPrimReduceSum, x, axis), node of index 2 in CNode->inputs is axis which should not be |
|
|
// Example : when convert CNode(kPrimReduceSum, x, axis), node of index 2 in CNode->inputs is axis which should not be |
|
|
// converted to switch guarded. |
|
|
// converted to switch guarded. |
|
|
std::vector<std::pair<PrimitivePtr, std::vector<size_t>>> white_list( |
|
|
|
|
|
{{prim::kPrimApplyMomentum, {1, 2}}, {prim::kPrimMomentum, {2, 3}}, |
|
|
|
|
|
{prim::kPrimStateSetItem, {1}}, {prim::kPrimTupleGetItem, {2}}, |
|
|
|
|
|
{prim::kPrimEnvGetItem, {1}}, {prim::kPrimEnvSetItem, {1}}, |
|
|
|
|
|
{prim::kPrimReduceSum, {2}}, {prim::kPrimReduceMean, {2}}, |
|
|
|
|
|
{prim::kPrimReduceAll, {2}}, {prim::kPrimCast, {2}}, |
|
|
|
|
|
{prim::kPrimTranspose, {2}}, {prim::kPrimOneHot, {2}}, |
|
|
|
|
|
{prim::kPrimGatherV2, {3}}, {prim::kPrimReshape, {2}}, |
|
|
|
|
|
{prim::kPrimAssign, {1}}, {prim::kPrimAssignAdd, {1}}, |
|
|
|
|
|
{prim::kPrimAssignSub, {1}}, {prim::kPrimTensorSummary, {1}}, |
|
|
|
|
|
{prim::kPrimImageSummary, {1}}, {prim::kPrimScalarSummary, {1}}, |
|
|
|
|
|
{prim::kPrimHistogramSummary, {1}}}); |
|
|
|
|
|
|
|
|
std::vector<std::pair<PrimitivePtr, std::vector<size_t>>> white_list({{prim::kPrimApplyMomentum, {1, 2}}, |
|
|
|
|
|
{prim::kPrimMomentum, {2, 3}}, |
|
|
|
|
|
{prim::kPrimStateSetItem, {1}}, |
|
|
|
|
|
{prim::kPrimTupleGetItem, {2}}, |
|
|
|
|
|
{prim::kPrimEnvGetItem, {1}}, |
|
|
|
|
|
{prim::kPrimEnvSetItem, {1}}, |
|
|
|
|
|
{prim::kPrimReduceSum, {2}}, |
|
|
|
|
|
{prim::kPrimReduceMean, {2}}, |
|
|
|
|
|
{prim::kPrimReduceAll, {2}}, |
|
|
|
|
|
{prim::kPrimCast, {2}}, |
|
|
|
|
|
{prim::kPrimTranspose, {2}}, |
|
|
|
|
|
{prim::kPrimOneHot, {2}}, |
|
|
|
|
|
{prim::kPrimGatherV2, {3}}, |
|
|
|
|
|
{prim::kPrimReshape, {2}}, |
|
|
|
|
|
{prim::kPrimAssign, {1}}, |
|
|
|
|
|
{prim::kPrimAssignAdd, {1}}, |
|
|
|
|
|
{prim::kPrimAssignSub, {1}}, |
|
|
|
|
|
{prim::kPrimTensorSummary, {1}}, |
|
|
|
|
|
{prim::kPrimImageSummary, {1}}, |
|
|
|
|
|
{prim::kPrimScalarSummary, {1}}, |
|
|
|
|
|
{prim::kPrimApplyRMSProp, {6, 7, 8}}, |
|
|
|
|
|
{prim::kPrimCumSum, {2}}, |
|
|
|
|
|
{prim::kPrimTile, {2}}, |
|
|
|
|
|
{prim::kPrimHistogramSummary, {1}}}); |
|
|
for (auto &item : white_list) { |
|
|
for (auto &item : white_list) { |
|
|
auto matched = std::any_of(item.second.begin(), item.second.end(), [&item, &node, &index](size_t idx) { |
|
|
auto matched = std::any_of(item.second.begin(), item.second.end(), [&item, &node, &index](size_t idx) { |
|
|
return IsPrimitiveCNode(node, item.first) && idx == index; |
|
|
return IsPrimitiveCNode(node, item.first) && idx == index; |
|
|
|