|
|
|
@@ -354,6 +354,16 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & |
|
|
|
return operator_info; |
|
|
|
} |
|
|
|
|
|
|
|
bool IsFindWrong(const OperatorInfoPtr current_op_ptr, const std::string &prim_name) { |
|
|
|
bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) && |
|
|
|
(current_op_ptr->name().find(BATCH_PARALLEL) == std::string::npos) && |
|
|
|
(current_op_ptr->name().find(prim_name + "Info") == std::string::npos); |
|
|
|
if (prim_name == GATHERV2) { |
|
|
|
is_find_wrong = is_find_wrong && (current_op_ptr->name().find(prim_name + "PInfo") == std::string::npos); |
|
|
|
} |
|
|
|
return is_find_wrong; |
|
|
|
} |
|
|
|
|
|
|
|
// Using CNode's UniqueIds to construct nodes |
|
|
|
Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) { |
|
|
|
MS_LOG(INFO) << "Constructing nodes for cost graph begins."; |
|
|
|
@@ -399,10 +409,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node |
|
|
|
const auto single_loop = CostModelContext::GetInstance()->dp_algo_single_loop(); |
|
|
|
if (single_loop && is_in_loop && (loop_to_ops[loop_index] < operators_in_forloop.size())) { |
|
|
|
const auto ¤t_op_ptr = operators_in_forloop[loop_to_ops[loop_index]]; |
|
|
|
bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) && |
|
|
|
(current_op_ptr->name().find(BATCH_PARALLEL) == std::string::npos) && |
|
|
|
(current_op_ptr->name().find(prim->name()) == std::string::npos); |
|
|
|
if (is_find_wrong) { |
|
|
|
if (IsFindWrong(current_op_ptr, prim->name())) { |
|
|
|
MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name() |
|
|
|
<< " does not match the Prim: " << prim->name() |
|
|
|
<< ". The fullname_with_scope: " << cnode->fullname_with_scope(); |
|
|
|
@@ -456,10 +463,7 @@ void SetOperatorToCNode(const OperatorInfoPtr ¤t_op_ptr, const PrimitivePt |
|
|
|
if (current_op_ptr == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "Find " << prim->name() << " from CostGraph failed."; |
|
|
|
} else { |
|
|
|
bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) && |
|
|
|
(current_op_ptr->name().find(BATCH_PARALLEL) == std::string::npos) && |
|
|
|
(current_op_ptr->name().find(prim->name()) == std::string::npos); |
|
|
|
if (is_find_wrong) { |
|
|
|
if (IsFindWrong(current_op_ptr, prim->name())) { |
|
|
|
MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name() |
|
|
|
<< " does not match the Prim: " << prim->name(); |
|
|
|
} |
|
|
|
@@ -518,10 +522,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no |
|
|
|
bool is_op_created = single_loop && is_in_loop && (loop_to_ops[loop_index] < operators_in_forloop.size()); |
|
|
|
if (is_op_created) { |
|
|
|
const auto ¤t_op_ptr = operators_in_forloop[loop_to_ops[loop_index]]; |
|
|
|
bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) && |
|
|
|
(current_op_ptr->name().find(BATCH_PARALLEL) == std::string::npos) && |
|
|
|
(current_op_ptr->name().find(prim->name()) == std::string::npos); |
|
|
|
if (is_find_wrong) { |
|
|
|
if (IsFindWrong(current_op_ptr, prim->name())) { |
|
|
|
MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name() |
|
|
|
<< " does not match the Prim: " << prim->name() |
|
|
|
<< ". The fullname_with_scope: " << cnode->fullname_with_scope(); |
|
|
|
|