Browse Source

!26170 Fix cyclomatic complexity

Merge pull request !26170 from hewei/fix_ccn
tags/v1.6.0
i-robot Gitee 4 years ago
parent
commit
9916e5844d
10 changed files with 317 additions and 291 deletions
  1. +8
    -13
      mindspore/ccsrc/backend/optimizer/gpu/adjust_depend_for_parallel_optimizer_recompute_all_gather_fusion.cc
  2. +8
    -13
      mindspore/ccsrc/backend/optimizer/pass/adjust_depend_for_parallel_optimizer_recompute_all_gather.cc
  3. +16
    -0
      mindspore/ccsrc/backend/session/anf_runtime_algorithm.h
  4. +31
    -38
      mindspore/ccsrc/debug/anf_ir_utils.cc
  5. +4
    -1
      mindspore/ccsrc/frontend/optimizer/irpass/incorporate_call.h
  6. +41
    -32
      mindspore/ccsrc/frontend/optimizer/irpass/partial_eliminate.h
  7. +64
    -60
      mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc
  8. +45
    -23
      mindspore/ccsrc/transform/express_ir/mindir_exporter.cc
  9. +63
    -84
      mindspore/ccsrc/transform/express_ir/onnx_exporter.cc
  10. +37
    -27
      mindspore/core/ir/anf.cc

+ 8
- 13
mindspore/ccsrc/backend/optimizer/gpu/adjust_depend_for_parallel_optimizer_recompute_all_gather_fusion.cc View File

@@ -39,13 +39,10 @@ bool AdjustDependForParallelOptimizerRecomputeAllGatherFusion::Run(const FuncGra
continue;
}
auto cnode = node->cast<CNodePtr>();
auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
auto instance_name = primitive->instance_name();
bool is_allgather = AnfAlgo::GetCNodeName(cnode) == kAllGatherOpName;
bool is_fusion = AnfAlgo::HasNodeAttr(kAttrFusion, cnode) && AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrFusion) > 0;
bool is_recompute = cnode->GetAttr(kAttrDuplicated) != nullptr && GetValue<bool>(cnode->GetAttr(kAttrDuplicated));
bool is_from_parallel_optimizer = instance_name.find("parallel_optimizer") != std::string::npos;
if (is_allgather && is_fusion && is_recompute && is_from_parallel_optimizer) {
if (!AnfAlgo::IsAllgather(cnode) || !AnfAlgo::IsFusion(cnode) || !AnfAlgo::IsFromParallelOptimizer(cnode)) {
continue;
}
if (AnfAlgo::IsRecompute(cnode)) {
int64_t fusion_id = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrFusion);
if (std::find(parallel_optimizer_recompute_allgather_fusion_ids.begin(),
parallel_optimizer_recompute_allgather_fusion_ids.end(),
@@ -58,16 +55,14 @@ bool AdjustDependForParallelOptimizerRecomputeAllGatherFusion::Run(const FuncGra
} else {
parallel_optimizer_recompute_allgathers.push_back(node);
}
}
if (!is_recompute && is_fusion && is_allgather && is_from_parallel_optimizer) {
} else {
int64_t unrecompute_fusion_id = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrFusion);
unrecompute_max_fusion_id = std::max(unrecompute_fusion_id, unrecompute_max_fusion_id);
bool would_be_recomputed =
AnfAlgo::HasNodeAttr(kAttrRecompute, cnode) && AnfAlgo::GetNodeAttr<bool>(cnode, kAttrRecompute);
if (forward_allgather_recompute_value_in_fusion_group.find(unrecompute_fusion_id) ==
forward_allgather_recompute_value_in_fusion_group.end()) {
forward_allgather_recompute_value_in_fusion_group[unrecompute_fusion_id] = would_be_recomputed;
} else if (forward_allgather_recompute_value_in_fusion_group[unrecompute_fusion_id] != would_be_recomputed) {
auto [iter, inserted] =
forward_allgather_recompute_value_in_fusion_group.emplace(unrecompute_fusion_id, would_be_recomputed);
if (!inserted && iter->second != would_be_recomputed) {
MS_LOG(EXCEPTION) << "In same fusion group, the allgather recompute attribute should be equal. "
"The normal node is:"
<< cnode->fullname_with_scope();


+ 8
- 13
mindspore/ccsrc/backend/optimizer/pass/adjust_depend_for_parallel_optimizer_recompute_all_gather.cc View File

@@ -35,13 +35,10 @@ bool AdjustDependForParallelOptimizerRecomputeAllGather::Run(const FuncGraphPtr
continue;
}
auto cnode = node->cast<CNodePtr>();
auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
auto instance_name = primitive->instance_name();
bool is_allgather = AnfAlgo::GetCNodeName(cnode) == kAllGatherOpName;
bool is_fusion = AnfAlgo::HasNodeAttr(kAttrFusion, cnode) && AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrFusion) > 0;
bool is_recompute = cnode->GetAttr(kAttrDuplicated) != nullptr && GetValue<bool>(cnode->GetAttr(kAttrDuplicated));
bool is_from_parallel_optimizer = instance_name.find("parallel_optimizer") != std::string::npos;
if (is_allgather && is_fusion && is_recompute && is_from_parallel_optimizer) {
if (!AnfAlgo::IsAllgather(cnode) || !AnfAlgo::IsFusion(cnode) || !AnfAlgo::IsFromParallelOptimizer(cnode)) {
continue;
}
if (AnfAlgo::IsRecompute(cnode)) {
int64_t fusion_id = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrFusion);
if (std::find(parallel_optimizer_recompute_allgather_fusion_ids.begin(),
parallel_optimizer_recompute_allgather_fusion_ids.end(),
@@ -54,16 +51,14 @@ bool AdjustDependForParallelOptimizerRecomputeAllGather::Run(const FuncGraphPtr
} else {
parallel_optimizer_recompute_allgathers.push_back(node);
}
}
if (!is_recompute && is_fusion && is_allgather && is_from_parallel_optimizer) {
} else {
int64_t unrecompute_fusion_id = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrFusion);
unrecompute_max_fusion_id = std::max(unrecompute_fusion_id, unrecompute_max_fusion_id);
bool would_be_recomputed =
AnfAlgo::HasNodeAttr(kAttrRecompute, cnode) && AnfAlgo::GetNodeAttr<bool>(cnode, kAttrRecompute);
if (forward_allgather_recompute_value_in_fusion_group.find(unrecompute_fusion_id) ==
forward_allgather_recompute_value_in_fusion_group.end()) {
forward_allgather_recompute_value_in_fusion_group[unrecompute_fusion_id] = would_be_recomputed;
} else if (forward_allgather_recompute_value_in_fusion_group[unrecompute_fusion_id] != would_be_recomputed) {
auto [iter, inserted] =
forward_allgather_recompute_value_in_fusion_group.emplace(unrecompute_fusion_id, would_be_recomputed);
if (!inserted && iter->second != would_be_recomputed) {
MS_LOG(EXCEPTION) << "In same fusion group, the allgather recompute attribute should be equal. "
"The normal node is:"
<< cnode->fullname_with_scope();


+ 16
- 0
mindspore/ccsrc/backend/session/anf_runtime_algorithm.h View File

@@ -347,6 +347,22 @@ class AnfRuntimeAlgorithm {
static size_t GetOutputNumByAbstract(const AbstractBasePtr &node_abstract);
// Fetch all outputs of call node.
static std::vector<KernelWithIndex> GetAllOutputByCallNode(const KernelWithIndex &output_with_index);

static inline bool IsAllgather(const CNodePtr &cnode) { return GetCNodeName(cnode) == kAllGatherOpName; }

static inline bool IsFusion(const CNodePtr &cnode) {
return HasNodeAttr(kAttrFusion, cnode) && GetNodeAttr<int64_t>(cnode, kAttrFusion) > 0;
}

static inline bool IsFromParallelOptimizer(const CNodePtr &cnode) {
auto primitive = GetCNodePrimitive(cnode);
return (primitive != nullptr) && primitive->instance_name().find("parallel_optimizer") != std::string::npos;
}

static inline bool IsRecompute(const CNodePtr &cnode) {
auto attr_dup = cnode->GetAttr(kAttrDuplicated);
return attr_dup != nullptr && GetValue<bool>(attr_dup);
}
};
} // namespace session
using AnfAlgo = session::AnfRuntimeAlgorithm;


+ 31
- 38
mindspore/ccsrc/debug/anf_ir_utils.cc View File

@@ -340,49 +340,42 @@ std::string AnfExporter::GetOtherValueText(const FuncGraphPtr &, const ValuePtr
return oss.str();
}

static bool CanUseDumpText(const ValuePtr &value) {
return (value->isa<RefKey>() || value->isa<Scalar>() || value->isa<StringImm>() || value->isa<tensor::Tensor>() ||
value->isa<parse::Symbol>() || value->isa<None>() || value->isa<Null>() || value->isa<ValueSlice>() ||
value->isa<Type>() || value->isa<KeywordArg>());
}

std::string AnfExporter::GetValueText(const FuncGraphPtr &func_graph, const ValuePtr &value) {
std::ostringstream oss;
bool is_null_ptr = (func_graph == nullptr || value == nullptr);
if (is_null_ptr) {
return oss.str();
if (func_graph == nullptr || value == nullptr) {
return "";
}

if (value->isa<Primitive>()) {
oss << GetPrimitiveText(value->cast<PrimitivePtr>());
} else if (value->isa<MetaFuncGraph>()) {
return GetPrimitiveText(value->cast<PrimitivePtr>());
}
if (value->isa<MetaFuncGraph>()) {
MetaFuncGraphPtr meta_func_graph = value->cast<MetaFuncGraphPtr>();
oss << GetMetaFuncGraphText(meta_func_graph);
} else if (value->isa<SymbolicKeyInstance>()) {
oss << GetSymbolicKeyInstanceText(func_graph, value->cast<SymbolicKeyInstancePtr>());
} else if (value->isa<RefKey>()) {
oss << value->DumpText();
} else if (value->isa<Scalar>() || value->isa<StringImm>()) {
oss << value->DumpText();
} else if (value->isa<tensor::Tensor>()) {
oss << value->DumpText();
} else if (value->isa<parse::Symbol>() || value->isa<None>() || value->isa<Null>()) {
oss << value->DumpText();
} else if (value->isa<ValueSequeue>()) {
oss << GetSequenceText(func_graph, value);
} else if (value->isa<ValueDictionary>()) {
oss << GetDictText(func_graph, value);
} else if (value->isa<ValueSlice>()) {
ValueSlicePtr slice = value->cast<ValueSlicePtr>();
oss << slice->DumpText();
} else if (value->isa<Type>()) {
oss << value->DumpText();
} else if (value->isa<parse::NameSpace>()) {
oss << GetNameSpaceText(value->cast<parse::NameSpacePtr>());
} else if (value->isa<parse::PyObjectWrapper>()) {
oss << value->type_name();
} else if (value->isa<KeywordArg>()) {
KeywordArgPtr keyword_arg = value->cast<KeywordArgPtr>();
oss << keyword_arg->DumpText();
} else {
return GetOtherValueText(func_graph, value);
return GetMetaFuncGraphText(meta_func_graph);
}

return oss.str();
if (value->isa<SymbolicKeyInstance>()) {
return GetSymbolicKeyInstanceText(func_graph, value->cast<SymbolicKeyInstancePtr>());
}
if (value->isa<ValueSequeue>()) {
return GetSequenceText(func_graph, value);
}
if (value->isa<ValueDictionary>()) {
return GetDictText(func_graph, value);
}
if (value->isa<parse::NameSpace>()) {
return GetNameSpaceText(value->cast<parse::NameSpacePtr>());
}
if (value->isa<parse::PyObjectWrapper>()) {
return value->type_name();
}
if (CanUseDumpText(value)) {
return value->DumpText();
}
return GetOtherValueText(func_graph, value);
}

// This function is used to output node in CNode's inputs


+ 4
- 1
mindspore/ccsrc/frontend/optimizer/irpass/incorporate_call.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -121,7 +121,10 @@ class IncorporateCall : public AnfVisitor {
(void)args.insert(args.end(), Xs_.begin(), Xs_.end());
}
}
return MakeNewNode(node, args);
}

AnfNodePtr MakeNewNode(const AnfNodePtr &node, const std::vector<AnfNodePtr> &args) {
auto new_node = node->func_graph()->NewCNode(args);
new_node->set_abstract(node->abstract());
// Check if the another only usage of {G, Xs} is UpdateState{s, {G, Xs}}, if yes, replace


+ 41
- 32
mindspore/ccsrc/frontend/optimizer/irpass/partial_eliminate.h View File

@@ -181,43 +181,12 @@ class ChoicePartialEliminater : public AnfVisitor {
AnfNodePtrList UnifyParameters(const size_t &anchor_index, const AnfNodePtrList &fg_list,
const std::vector<AnfNodePtrList> args_list) {
std::vector<size_t> inputs_index_list[args_list.size()];
size_t extra_input_counter = 0;
AnfNodePtrList extra_inputs;
const auto &anchor_args = args_list[anchor_index];
size_t anchor_args_size = anchor_args.size();
auto anchor_fg = GetValueNode<FuncGraphPtr>(fg_list[anchor_index]);
MS_EXCEPTION_IF_NULL(anchor_fg);
// Find the new location of the old_inputs except Zs;
for (size_t i = 0; i < args_list.size(); ++i) {
if (i == anchor_index) {
continue;
}
const auto &another_args = args_list[i];
auto &curr_inputs_index = inputs_index_list[i];
for (size_t j = 0; j < another_args.size(); ++j) {
size_t k;
for (k = 0; k < anchor_args_size; ++k) {
if (another_args[j] == anchor_args[k]) {
curr_inputs_index.push_back(k);
break;
}
}
if (k == anchor_args_size) {
// check if used by another func_graph;
for (k = 0; k < extra_input_counter; ++k) {
if (another_args[j] == extra_inputs[k]) {
curr_inputs_index.push_back(anchor_args_size + k);
break;
}
}
if (k == extra_input_counter) {
extra_inputs.push_back(another_args[j]);
curr_inputs_index.push_back(anchor_args_size + extra_input_counter);
extra_input_counter++;
}
}
}
}
size_t extra_input_counter = FindNewLocation(args_list, anchor_index, inputs_index_list, &extra_inputs);

auto manager = anchor_fg->manager();
MS_EXCEPTION_IF_NULL(manager);
@@ -284,6 +253,46 @@ class ChoicePartialEliminater : public AnfVisitor {

return extra_inputs;
}

// Find the new location of the old_inputs except Zs.
size_t FindNewLocation(const std::vector<AnfNodePtrList> &args_list, size_t anchor_index,
std::vector<size_t> *inputs_index_list, AnfNodePtrList *extra_inputs_ptr) {
const auto &anchor_args = args_list[anchor_index];
auto &extra_inputs = *extra_inputs_ptr;
size_t extra_input_counter = 0;
size_t anchor_args_size = anchor_args.size();
for (size_t i = 0; i < args_list.size(); ++i) {
if (i == anchor_index) {
continue;
}
const auto &another_args = args_list[i];
auto &curr_inputs_index = inputs_index_list[i];
for (size_t j = 0; j < another_args.size(); ++j) {
size_t k;
for (k = 0; k < anchor_args_size; ++k) {
if (another_args[j] == anchor_args[k]) {
curr_inputs_index.push_back(k);
break;
}
}
if (k == anchor_args_size) {
// check if used by another func_graph;
for (k = 0; k < extra_input_counter; ++k) {
if (another_args[j] == extra_inputs[k]) {
curr_inputs_index.push_back(anchor_args_size + k);
break;
}
}
if (k == extra_input_counter) {
extra_inputs.push_back(another_args[j]);
curr_inputs_index.push_back(anchor_args_size + extra_input_counter);
extra_input_counter++;
}
}
}
}
return extra_input_counter;
}
};

// {{prim::kPrimSwitch, cond, {prim::kPrimPartial, G1, Xs}, {prim::kPrimPartial, G2, Ys}}, Zs} ->


+ 64
- 60
mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -673,6 +673,68 @@ void CheckAndApplyApproximation() {
}
}

static void ConstructCNodeCostGraphEdges(const mindspore::CNodePtr &cnode) {
auto &inputs = cnode->inputs();
ValueNodePtr prim_anf_node = inputs[0]->cast<ValueNodePtr>();
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
size_t edge_count = 0;
auto node_op_info = cnode->user_data<OperatorInfo>();

for (size_t i = 1; i < inputs.size(); ++i) {
auto prev_cnode = inputs[i]->cast<CNodePtr>();
bool bool_result_prev_cnode = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0)));
if (bool_result_prev_cnode) {
continue;
}
ValueNodePtr prev_prim_anf_node = prev_cnode->input(0)->cast<ValueNodePtr>();
PrimitivePtr prev_prim = prev_prim_anf_node->value()->cast<PrimitivePtr>();
size_t output_index = 0;

while ((IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == prim::kTupleGetItem) ||
(prev_prim->name() == DEPEND)) {
if (IsAutoParallelCareNode(prev_cnode)) {
auto prev_op_info = prev_cnode->user_data<OperatorInfo>();
CreateEdgeBetweenTwoOps(prev_op_info, node_op_info, cnode, prev_cnode, prim, prev_prim, output_index, i,
&edge_count);
break;
} else if (prev_prim->name() == prim::kTupleGetItem) {
// In this case, 'prev_anf_node' is 'tuple_getitem', the actual precursor node is node before
// this 'tuple_getitem'
MS_LOG(INFO) << "Jumping the 'tuple_getitem' operator.";
output_index = LongToSize(GetValue<int64_t>(GetValueNode(prev_cnode->input(2))));
prev_cnode = prev_cnode->input(1)->cast<CNodePtr>();
bool bool_result_tuple = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0)));
if (bool_result_tuple) {
break;
}
prev_prim_anf_node = prev_cnode->input(0)->cast<ValueNodePtr>();
prev_prim = prev_prim_anf_node->value()->cast<PrimitivePtr>();
if (!IsAutoParallelCareNode(prev_cnode)) {
MS_LOG(EXCEPTION) << "Did not create OperatorInfo for : " << prev_prim->name();
}
MS_LOG(INFO) << "Jumped the 'tuple_getitem' operator, "
<< "and creating an edge between the Operator before "
<< "'tuple_getitem' and the Operator after 'tuple_getitem'.";
} else if (prev_prim->name() == DEPEND) {
// In this case, 'prev_anf_node' is 'depend', the actual precursor node is node before
// this 'depend'
MS_LOG(INFO) << "Jumping the 'depend' operator.";
prev_cnode = prev_cnode->input(1)->cast<CNodePtr>();
bool bool_result_depend = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0)));
if (bool_result_depend) {
break;
}
prev_prim_anf_node = prev_cnode->input(0)->cast<ValueNodePtr>();
prev_prim = prev_prim_anf_node->value()->cast<PrimitivePtr>();
MS_LOG(INFO) << "Jumped the 'depend' operator, "
<< "and creating an edge between the Operator before "
<< "'depend' and the Operator after 'depend'.";
}
}
}
MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << node_op_info->name();
}

void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
// Step 2
MS_LOG(INFO) << "Constructing edges for cost graph begins.";
@@ -681,68 +743,10 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
continue;
}
auto &inputs = cnode->inputs();
ValueNodePtr prim_anf_node = inputs[0]->cast<ValueNodePtr>();
if (!IsAutoParallelCareNode(cnode)) {
continue;
}
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
size_t edge_count = 0;
auto node_op_info = cnode->user_data<OperatorInfo>();

for (size_t i = 1; i < inputs.size(); ++i) {
auto prev_cnode = inputs[i]->cast<CNodePtr>();
bool bool_result_prev_cnode = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0)));
if (bool_result_prev_cnode) {
continue;
}
ValueNodePtr prev_prim_anf_node = prev_cnode->input(0)->cast<ValueNodePtr>();
PrimitivePtr prev_prim = prev_prim_anf_node->value()->cast<PrimitivePtr>();
size_t output_index = 0;

while ((IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == prim::kTupleGetItem) ||
(prev_prim->name() == DEPEND)) {
if (IsAutoParallelCareNode(prev_cnode)) {
auto prev_op_info = prev_cnode->user_data<OperatorInfo>();
CreateEdgeBetweenTwoOps(prev_op_info, node_op_info, cnode, prev_cnode, prim, prev_prim, output_index, i,
&edge_count);
break;
} else if (prev_prim->name() == prim::kTupleGetItem) {
// In this case, 'prev_anf_node' is 'tuple_getitem', the actual precursor node is node before
// this 'tuple_getitem'
MS_LOG(INFO) << "Jumping the 'tuple_getitem' operator.";
output_index = LongToSize(GetValue<int64_t>(GetValueNode(prev_cnode->input(2))));
prev_cnode = prev_cnode->input(1)->cast<CNodePtr>();
bool bool_result_tuple = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0)));
if (bool_result_tuple) {
break;
}
prev_prim_anf_node = prev_cnode->input(0)->cast<ValueNodePtr>();
prev_prim = prev_prim_anf_node->value()->cast<PrimitivePtr>();
if (!IsAutoParallelCareNode(prev_cnode)) {
MS_LOG(EXCEPTION) << "Did not create OperatorInfo for : " << prev_prim->name();
}
MS_LOG(INFO) << "Jumped the 'tuple_getitem' operator, "
<< "and creating an edge between the Operator before "
<< "'tuple_getitem' and the Operator after 'tuple_getitem'.";
} else if (prev_prim->name() == DEPEND) {
// In this case, 'prev_anf_node' is 'depend', the actual precursor node is node before
// this 'depend'
MS_LOG(INFO) << "Jumping the 'depend' operator.";
prev_cnode = prev_cnode->input(1)->cast<CNodePtr>();
bool bool_result_depend = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0)));
if (bool_result_depend) {
break;
}
prev_prim_anf_node = prev_cnode->input(0)->cast<ValueNodePtr>();
prev_prim = prev_prim_anf_node->value()->cast<PrimitivePtr>();
MS_LOG(INFO) << "Jumped the 'depend' operator, "
<< "and creating an edge between the Operator before "
<< "'depend' and the Operator after 'depend'.";
}
}
}
MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << node_op_info->name();
ConstructCNodeCostGraphEdges(cnode);
}
CheckAndApplyApproximation();



+ 45
- 23
mindspore/ccsrc/transform/express_ir/mindir_exporter.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -115,7 +115,9 @@ class IrExportBuilder {
bool SetValueToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
bool SetTypeToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
bool SetScalarToAttributeProto_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
bool SetScalarToAttributeProtoForInt_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
bool SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
bool SetScalarToAttributeProtoForInt_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
bool SetTensorToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
bool SetSequenceToAttributeProto(const ValueSequeuePtr &value, mind_ir::AttributeProto *const attr_proto,
std::string *const seq_string);
@@ -831,7 +833,27 @@ bool IrExportBuilder::SetScalarToAttributeProto_ir(const ValuePtr &value, mind_i
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL);
int64_t attr_value = GetValue<bool>(value) ? 1 : 0;
attr_proto->set_i(attr_value);
} else if (value->isa<Int8Imm>()) {
} else if (SetScalarToAttributeProtoForInt_ir(value, attr_proto)) {
return true;
} else if (value->isa<FP32Imm>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_FLOAT);
attr_proto->set_f(GetValue<float>(value));
} else if (value->isa<FP64Imm>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_DOUBLE);
attr_proto->set_d(GetValue<double>(value));
} else if (value->isa<tensor::Tensor>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSOR);
return SetTensorToAttributeProto(value, attr_proto);
} else {
MS_LOG(ERROR) << "Unsupported scalar type: " << value->type_name();
return false;
}
return true;
}

bool IrExportBuilder::SetScalarToAttributeProtoForInt_ir(const ValuePtr &value,
mind_ir::AttributeProto *const attr_proto) {
if (value->isa<Int8Imm>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT8);
attr_proto->set_i(value->cast<Int8ImmPtr>()->value());
} else if (value->isa<Int16Imm>()) {
@@ -855,17 +877,7 @@ bool IrExportBuilder::SetScalarToAttributeProto_ir(const ValuePtr &value, mind_i
} else if (value->isa<UInt64Imm>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT64);
attr_proto->set_i(UlongToLong(value->cast<UInt64ImmPtr>()->value()));
} else if (value->isa<FP32Imm>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_FLOAT);
attr_proto->set_f(GetValue<float>(value));
} else if (value->isa<FP64Imm>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_DOUBLE);
attr_proto->set_d(GetValue<double>(value));
} else if (value->isa<tensor::Tensor>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSOR);
return SetTensorToAttributeProto(value, attr_proto);
} else {
MS_LOG(ERROR) << "Unsupported scalar type: " << value->type_name();
return false;
}
return true;
@@ -899,7 +911,27 @@ bool IrExportBuilder::SetScalarToAttributeProto_irs(const ValuePtr &value, mind_
} else if (value->isa<BoolImm>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL);
attr_proto->add_ints(GetValue<bool>(value));
} else if (value->isa<Int8Imm>()) {
} else if (SetScalarToAttributeProtoForInt_irs(value, attr_proto)) {
return true;
} else if (value->isa<FP32Imm>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_FLOAT);
attr_proto->add_floats(GetValue<float>(value));
} else if (value->isa<FP64Imm>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_DOUBLE);
attr_proto->add_doubles(GetValue<double>(value));
} else if (value->isa<tensor::Tensor>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSOR);
return SetTensorToAttributeProto(value, attr_proto);
} else {
MS_LOG(ERROR) << "Unsupported scalar type: " << value->type_name();
return false;
}
return true;
}

bool IrExportBuilder::SetScalarToAttributeProtoForInt_irs(const ValuePtr &value,
mind_ir::AttributeProto *const attr_proto) {
if (value->isa<Int8Imm>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT8);
attr_proto->add_ints(value->cast<Int8ImmPtr>()->value());
} else if (value->isa<Int16Imm>()) {
@@ -923,17 +955,7 @@ bool IrExportBuilder::SetScalarToAttributeProto_irs(const ValuePtr &value, mind_
} else if (value->isa<UInt64Imm>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT64);
attr_proto->add_ints(SizeToInt(value->cast<UInt64ImmPtr>()->value()));
} else if (value->isa<FP32Imm>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_FLOAT);
attr_proto->add_floats(GetValue<float>(value));
} else if (value->isa<FP64Imm>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_DOUBLE);
attr_proto->add_doubles(GetValue<double>(value));
} else if (value->isa<tensor::Tensor>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSOR);
return SetTensorToAttributeProto(value, attr_proto);
} else {
MS_LOG(ERROR) << "Unsupported scalar type: " << value->type_name();
return false;
}
return true;


+ 63
- 84
mindspore/ccsrc/transform/express_ir/onnx_exporter.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -16,9 +16,11 @@

#include <map>
#include <memory>
#include <vector>
#include <unordered_map>
#include <utility>
#include <functional>
#include <algorithm>

#include "ir/tensor.h"
#include "ir/param_info.h"
@@ -384,6 +386,7 @@ class OnnxExporter {

void MatchAndMark(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &nodes,
std::unordered_map<AnfNodePtr, OpMergedInfo> *op_merged_infos_ptr);
void MatchAndMarkCNode(const CNodePtr &cnode, std::unordered_map<AnfNodePtr, OpMergedInfo> *op_merged_infos_ptr);
void ExportNodes(const FuncGraphPtr &func_graph, std::map<AnfNodePtr, size_t> *node_map_ptr,
onnx::GraphProto *graph_proto);

@@ -600,7 +603,7 @@ void OnnxExporter::SetTensorProtoInfo(const ParameterPtr &param, onnx::TensorPro

void OnnxExporter::MatchAndMark(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &nodes,
std::unordered_map<AnfNodePtr, OpMergedInfo> *op_merged_infos_ptr) {
std::unordered_map<AnfNodePtr, OpMergedInfo> &op_merged_infos = *op_merged_infos_ptr;
auto &op_merged_infos = *op_merged_infos_ptr;

for (auto &node : nodes) {
if (!node->isa<CNode>()) {
@@ -623,36 +626,41 @@ void OnnxExporter::MatchAndMark(const FuncGraphPtr &func_graph, const std::vecto
// if the key `input` does not exist, just create a new one
op_merged_infos[input].referred_count += 1;
}
// MindSpore Conv + BiasAdd --> ONNX Conv
if (cnode->IsApply(std::make_shared<Primitive>("BiasAdd")) &&
IsPrimitiveCNode(cnode->input(1), prim::kPrimConv2D)) {
op_merged_infos[cnode].mode = OP_MERGE_CONV;
op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE;
op_merged_infos[cnode->input(1)].referred_count -= 1;
} else if (cnode->IsApply(std::make_shared<Primitive>("BiasAdd")) &&
IsPrimitiveCNode(cnode->input(1), prim::kPrimMatMul)) {
op_merged_infos[cnode].mode = OP_MERGE_GEMM;
op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE;
op_merged_infos[cnode->input(1)].referred_count -= 1;
} else if (cnode->IsApply(prim::kPrimTupleGetItem) &&
IsPrimitiveCNode(cnode->input(1), std::make_shared<Primitive>("BatchNorm")) &&
GetInt64Value(cnode->input(2)) == 0) {
op_merged_infos[cnode].mode = OP_MERGE_BATCH_NORM;
op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE;
op_merged_infos[cnode->input(1)].referred_count -= 1;
} else if (cnode->IsApply(prim::kPrimTupleGetItem) &&
IsPrimitiveCNode(cnode->input(1), std::make_shared<Primitive>("MaxPoolWithArgmax")) &&
GetInt64Value(cnode->input(2)) == 0) {
op_merged_infos[cnode].mode = OP_MERGE_MAXPOOL_WITH_ARGMAX;
op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE;
op_merged_infos[cnode->input(1)].referred_count -= 1;
} else if (cnode->IsApply(prim::kPrimTupleGetItem) &&
IsPrimitiveCNode(cnode->input(1), std::make_shared<Primitive>("LayerNorm")) &&
GetInt64Value(cnode->input(2)) == 0) {
op_merged_infos[cnode].mode = OP_MERGE_LAYER_NORM;
op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE;
op_merged_infos[cnode->input(1)].referred_count -= 1;
}
MatchAndMarkCNode(cnode, op_merged_infos_ptr);
}
}

void OnnxExporter::MatchAndMarkCNode(const CNodePtr &cnode,
std::unordered_map<AnfNodePtr, OpMergedInfo> *op_merged_infos_ptr) {
auto &op_merged_infos = *op_merged_infos_ptr;
// MindSpore Conv + BiasAdd --> ONNX Conv
if (cnode->IsApply(std::make_shared<Primitive>("BiasAdd")) && IsPrimitiveCNode(cnode->input(1), prim::kPrimConv2D)) {
op_merged_infos[cnode].mode = OP_MERGE_CONV;
op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE;
op_merged_infos[cnode->input(1)].referred_count -= 1;
} else if (cnode->IsApply(std::make_shared<Primitive>("BiasAdd")) &&
IsPrimitiveCNode(cnode->input(1), prim::kPrimMatMul)) {
op_merged_infos[cnode].mode = OP_MERGE_GEMM;
op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE;
op_merged_infos[cnode->input(1)].referred_count -= 1;
} else if (cnode->IsApply(prim::kPrimTupleGetItem) &&
IsPrimitiveCNode(cnode->input(1), std::make_shared<Primitive>("BatchNorm")) &&
GetInt64Value(cnode->input(2)) == 0) {
op_merged_infos[cnode].mode = OP_MERGE_BATCH_NORM;
op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE;
op_merged_infos[cnode->input(1)].referred_count -= 1;
} else if (cnode->IsApply(prim::kPrimTupleGetItem) &&
IsPrimitiveCNode(cnode->input(1), std::make_shared<Primitive>("MaxPoolWithArgmax")) &&
GetInt64Value(cnode->input(2)) == 0) {
op_merged_infos[cnode].mode = OP_MERGE_MAXPOOL_WITH_ARGMAX;
op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE;
op_merged_infos[cnode->input(1)].referred_count -= 1;
} else if (cnode->IsApply(prim::kPrimTupleGetItem) &&
IsPrimitiveCNode(cnode->input(1), std::make_shared<Primitive>("LayerNorm")) &&
GetInt64Value(cnode->input(2)) == 0) {
op_merged_infos[cnode].mode = OP_MERGE_LAYER_NORM;
op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE;
op_merged_infos[cnode->input(1)].referred_count -= 1;
}
}

@@ -1571,59 +1579,30 @@ void OnnxExporter::ExportPrimGatherV2(const FuncGraphPtr &, const CNodePtr &node

void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node,
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) {
// Type of the 2nd input of 'Reshape' of MindSpore is tuple, but ONNX's is tensor, need to do some convert
if (node->IsApply(prim::kPrimReshape)) {
return ExportPrimReshape(func_graph, node, node_map_ptr, graph_proto);
}
if (node->IsApply(prim::kPrimReduceMean) || node->IsApply(prim::kPrimReduceSum)) {
return ExportPrimReduce(func_graph, node, node_map_ptr, graph_proto);
}
if (node->IsApply(prim::kPrimTranspose)) {
return ExportPrimTranspose(func_graph, node, node_map_ptr, graph_proto);
}
if (node->IsApply(prim::kPrimStridedSlice)) {
return ExportPrimStridedSlice(func_graph, node, node_map_ptr, graph_proto);
}
if (node->IsApply(prim::kPrimResizeNearestNeighbor)) {
return ExportPrimResizeNearestNeighbor(func_graph, node, node_map_ptr, graph_proto);
}
if (node->IsApply(prim::kPrimConcat)) {
return ExportPrimConcat(func_graph, node, node_map_ptr, graph_proto);
}

// MindSpore Cast(x, T) --> ONNX Cast[to=T](x)
if (node->IsApply(prim::kPrimCast)) {
return ExportPrimCast(func_graph, node, node_map_ptr, graph_proto);
}

// ONNX PRelu requires unidirectional broadcasting, here need some process
if (node->IsApply(std::make_shared<Primitive>("PReLU"))) {
return ExportPrimPReLU(func_graph, node, node_map_ptr, graph_proto);
}

// MindSpore ReLU6(x) --> ONNX Clip[min=0.f, max=6.f](x)
if (node->IsApply(std::make_shared<Primitive>("ReLU6"))) {
return ExportPrimReLU6(func_graph, node, node_map_ptr, graph_proto);
}

// MindSpore DepthwiseConv2dNative --> ONNX Conv(x, reshape(w))
if (node->IsApply(std::make_shared<Primitive>("DepthwiseConv2dNative"))) {
return ExportPrimDepthwiseConv2d(func_graph, node, node_map_ptr, graph_proto);
}

// MindSpore Tile(x) --> ONNX Tile(x, repeat)
if (node->IsApply(prim::kPrimTile)) {
return ExportPrimTile(func_graph, node, node_map_ptr, graph_proto);
}

// MindSpore Square(x) --> ONNX Pow(x, 2)
if (node->IsApply(prim::kPrimSquare)) {
return ExportPrimSquare(func_graph, node, node_map_ptr, graph_proto);
}
using ExportFunc = std::function<void(OnnxExporter *, const FuncGraphPtr &, const CNodePtr &,
std::map<AnfNodePtr, size_t> *, onnx::GraphProto *const)>;
static std::vector<std::pair<PrimitivePtr, ExportFunc>> export_table = {
{prim::kPrimReshape, &OnnxExporter::ExportPrimReshape},
{prim::kPrimReduceMean, &OnnxExporter::ExportPrimReduce},
{prim::kPrimReduceSum, &OnnxExporter::ExportPrimReduce},
{prim::kPrimTranspose, &OnnxExporter::ExportPrimTranspose},
{prim::kPrimStridedSlice, &OnnxExporter::ExportPrimStridedSlice},
{prim::kPrimResizeNearestNeighbor, &OnnxExporter::ExportPrimResizeNearestNeighbor},
{prim::kPrimConcat, &OnnxExporter::ExportPrimConcat},
{prim::kPrimCast, &OnnxExporter::ExportPrimCast},
{prim::kPrimPRelu, &OnnxExporter::ExportPrimPReLU},
{prim::kPrimRelu6, &OnnxExporter::ExportPrimReLU6},
{prim::kPrimDepthwiseConv2dNative, &OnnxExporter::ExportPrimDepthwiseConv2d},
{prim::kPrimTile, &OnnxExporter::ExportPrimTile},
{prim::kPrimSquare, &OnnxExporter::ExportPrimSquare},
{prim::kPrimGather, &OnnxExporter::ExportPrimGatherV2},
};

// MindSpore GatherV2(x, indices, axis) --> ONNX Gather(x, indices)
if (node->IsApply(prim::kPrimGather)) {
return ExportPrimGatherV2(func_graph, node, node_map_ptr, graph_proto);
auto iter = std::find_if(export_table.begin(), export_table.end(),
[&node](const auto &item) { return node->IsApply(item.first); });
if (iter != export_table.end()) {
iter->second(this, func_graph, node, node_map_ptr, graph_proto);
return;
}

auto inputs = node->inputs();


+ 37
- 27
mindspore/core/ir/anf.cc View File

@@ -404,14 +404,48 @@ PrimitivePtr GetPrimitiveFromValueNode(const AnfNodePtr &node) {
return value->cast<PrimitivePtr>();
}

static std::string GetNodeTargetForVarInputNode(const CNodePtr &cnode) {
auto &inputs = cnode->inputs();
std::vector<AnfNodePtr> real_inputs;
const size_t update_state_valid_input_index = 2;
const size_t make_tuple_valid_input_index = 1;
if (cnode->IsApply(prim::kPrimUpdateState) && inputs.size() > update_state_valid_input_index) {
(void)std::copy(inputs.begin() + SizeToLong(update_state_valid_input_index), inputs.end(),
std::back_inserter(real_inputs));
} else if (cnode->IsApply(prim::kPrimMakeTuple) && inputs.size() > make_tuple_valid_input_index) {
(void)std::copy(inputs.begin() + SizeToLong(make_tuple_valid_input_index), inputs.end(),
std::back_inserter(real_inputs));
}
std::string first_input_target = kTargetUnDefined;
bool has_diff_target =
std::any_of(std::rbegin(real_inputs), std::rend(real_inputs), [&first_input_target](const AnfNodePtr &n) {
auto target = GetOriginNodeTarget(n);
if (target == kTargetUnDefined) {
return false;
}
if (first_input_target == kTargetUnDefined) {
first_input_target = target;
}
return target != first_input_target;
});
if (!has_diff_target) {
return first_input_target;
}
return kTargetUnDefined;
}

static inline bool IsSummaryPrimitiveCNode(const AnfNodePtr &node) {
return IsPrimitiveCNode(node, prim::kPrimImageSummary) || IsPrimitiveCNode(node, prim::kPrimScalarSummary) ||
IsPrimitiveCNode(node, prim::kPrimTensorSummary) || IsPrimitiveCNode(node, prim::kPrimHistogramSummary);
}

std::string GetVirtualNodeTargetFromInputs(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto &inputs = cnode->inputs();
#ifndef ENABLE_SECURITY
if (IsPrimitiveCNode(node, prim::kPrimImageSummary) || IsPrimitiveCNode(node, prim::kPrimScalarSummary) ||
IsPrimitiveCNode(node, prim::kPrimTensorSummary) || IsPrimitiveCNode(node, prim::kPrimHistogramSummary)) {
if (IsSummaryPrimitiveCNode(node)) {
if (inputs.size() > 1) {
return GetOriginNodeTarget(inputs[1]);
}
@@ -428,31 +462,7 @@ std::string GetVirtualNodeTargetFromInputs(const AnfNodePtr &node) {
return GetOriginNodeTarget(inputs[use_index]);
}
} else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple) || IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
std::vector<AnfNodePtr> real_inputs;
const size_t update_state_valid_input_index = 2;
const size_t make_tuple_valid_input_index = 1;
if (IsPrimitiveCNode(node, prim::kPrimUpdateState) && inputs.size() > update_state_valid_input_index) {
(void)std::copy(inputs.begin() + SizeToLong(update_state_valid_input_index), inputs.end(),
std::back_inserter(real_inputs));
} else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple) && inputs.size() > make_tuple_valid_input_index) {
(void)std::copy(inputs.begin() + SizeToLong(make_tuple_valid_input_index), inputs.end(),
std::back_inserter(real_inputs));
}
std::string first_input_target = kTargetUnDefined;
bool has_diff_target =
std::any_of(std::rbegin(real_inputs), std::rend(real_inputs), [&first_input_target](const AnfNodePtr &n) {
auto target = GetOriginNodeTarget(n);
if (target == kTargetUnDefined) {
return false;
}
if (first_input_target == kTargetUnDefined) {
first_input_target = target;
}
return target != first_input_target;
});
if (!has_diff_target) {
return first_input_target;
}
return GetNodeTargetForVarInputNode(node->cast<CNodePtr>());
} else if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
return GetOriginNodeTarget(cnode->input(1));
}


Loading…
Cancel
Save