diff --git a/mindspore/ccsrc/transform/onnx/ir_exporter.cc b/mindspore/ccsrc/transform/onnx/ir_exporter.cc index 6feb4db1be..78858eea8a 100644 --- a/mindspore/ccsrc/transform/onnx/ir_exporter.cc +++ b/mindspore/ccsrc/transform/onnx/ir_exporter.cc @@ -92,17 +92,14 @@ class IrExportBuilder { void SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, onnx::TensorProto *const tensor_proto); void SetAttributeProto(const AnfNodePtr &node, onnx::NodeProto *const node_proto); void SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto); - void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, onnx::AttributeProto *const attr_proto, - std::string *const seq_string); + void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, onnx::NodeProto *const node_proto, + std::string suffix = "0"); void SetValueToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); void SetTypeToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); void SetScalarToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); void SetTensorToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); - void SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto, const std::string &value_name); - void SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto, - std::string *const seq_string); - void SetSeqElemToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto, - std::string *const seq_string); + void SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto); + void SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto); onnx::TensorProto_DataType GetOnnxDataType(TypeId type_id); onnx::TensorProto_DataType GetOnnxDataBitsIntType(int bits); @@ -110,10 +107,8 @@ class IrExportBuilder { std::string GetNodeName(const AnfNodePtr &node); std::string GetUniqueNodeName(const AnfNodePtr &node); std::string GetOpTypeName(const AnfNodePtr &node); - size_t GetNodeIndex() { return ++node_index_; } - void ResetNodeIndex() { node_index_ = 0; } - size_t GetTupleIndex() { return ++shape_index_; } - void ResetTupleIndex() { shape_index_ = 0; } + size_t AllocateIndex() { return ++node_index_; } + void ResetIndex() { node_index_ = 0; } private: onnx::ModelProto model_; @@ -121,7 +116,6 @@ class IrExportBuilder { std::list todo_; std::map node_index_map_; size_t node_index_{0}; - size_t shape_index_{0}; }; using IrExporterPtr = std::shared_ptr; @@ -154,7 +148,7 @@ void IrExportBuilder::BuildModelInfo() { void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph) { onnx::GraphProto *graph_proto = model_.mutable_graph(); graph_proto->set_name(func_graph->ToString()); - ResetNodeIndex(); + ResetIndex(); todo_.clear(); todo_.push_back(func_graph); while (!todo_.empty()) { @@ -185,7 +179,7 @@ void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, onnx::Grap input_proto->set_name(param_name); SetValueInfoProto(param, input_proto); if (!param->has_default()) { - MS_LOG(DEBUG) << "Parameter: '" << item->ToString() << "' has no default."; + MS_LOG(DEBUG) << "Parameter: '" << item->ToString() << "' has no default"; continue; } @@ -240,20 +234,13 @@ void IrExportBuilder::SetValueInfoProto(const TypePtr &type, const BaseShapePtr auto elem_type = tensor->element(); const auto &dims = shape->cast()->shape(); type_proto->mutable_tensor_type()->set_elem_type(GetOnnxDataType(elem_type->type_id())); - if (dims.size() == 0) { - MS_LOG(DEBUG) << "SetValueInfoProto set default dim 1."; - type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); - } else { - for (const auto &dim : dims) { - MS_LOG(DEBUG) << "SetValueInfoProto dim: " << dim; - type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim); - } + for (const auto &dim : dims) { + MS_LOG(DEBUG) << "SetValueInfoProto dim: " << dim; + type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim); } } else if (type->isa()) { auto tup_shape = shape->cast(); - type_proto->set_denotation(type->type_name() + ":" + std::to_string(tup_shape->shape().size())); - } else if (type->isa() || type->isa()) { - type_proto->set_denotation(type->type_name()); + type_proto->set_denotation(std::to_string(tup_shape->shape().size())); } else { MS_LOG(EXCEPTION) << "Value type: " << type->type_name() << " is not supported!"; } @@ -263,10 +250,9 @@ void IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, onnx::Att if (value == nullptr || attr_proto == nullptr) { MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; } - attr_proto->set_ref_attr_name("tensor:value0"); - attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS); - onnx::TensorProto *tensor_proto = attr_proto->add_tensors(); - tensor_proto->set_name("value0"); + attr_proto->set_ref_attr_name("tensor"); + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); + onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); auto data = value->cast(); tensor_proto->set_raw_data(data->data_c(), static_cast(data->data().nbytes())); auto dtype = data->data_type(); @@ -300,7 +286,6 @@ void IrExportBuilder::SetParamToTensorProto(const ParameterPtr ¶m, onnx::Ten void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { std::vector nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude); - bool is_only_return = true; for (const AnfNodePtr &node : nodes) { if (!node->isa()) { MS_LOG(DEBUG) << "Node: '" << node->ToString() << "' is not cnode"; @@ -308,13 +293,9 @@ void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProt } auto cnode = node->cast(); if (cnode == func_graph->get_return()) { - if (is_only_return) { - MS_LOG(EXCEPTION) << "Only has return node, can't convert to binary model!"; - } BuildOutput(cnode, graph_proto); } else { BuildCNode(cnode, graph_proto); - is_only_return = false; } } } @@ -324,11 +305,24 @@ void IrExportBuilder::BuildOutput(const CNodePtr &node, onnx::GraphProto *const MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2."; } AnfNodePtr arg = node->input(1); - onnx::ValueInfoProto *output_proto = graph_proto->add_output(); - std::string output_name = GetUniqueNodeName(node); - output_proto->set_name(output_name); - last_node_->set_output(0, output_name); - SetValueInfoProto(arg, output_proto); + // Using make_tuple to set multi-output + if (IsPrimitiveCNode(arg, prim::kPrimMakeTuple)) { + auto tuple_node = arg->cast(); + for (size_t i = 1; i < tuple_node->size(); i++) { + auto input_node = arg->cast()->input(i); + onnx::ValueInfoProto *output_proto = graph_proto->add_output(); + auto output_name = GetUniqueNodeName(tuple_node->input(i)); + output_proto->set_name(output_name); + last_node_->add_output(output_name); + SetValueInfoProto(tuple_node->input(i), output_proto); + } + } else { + onnx::ValueInfoProto *output_proto = graph_proto->add_output(); + std::string output_name = GetUniqueNodeName(node); + output_proto->set_name(output_name); + last_node_->add_output(output_name); + SetValueInfoProto(arg, output_proto); + } } std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) { @@ -351,41 +345,43 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) { } void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, - onnx::AttributeProto *const attr_proto, std::string *const seq_string) { - if (type->isa() && seq_string != nullptr) { - *seq_string += "Tuple["; - auto elements = type->cast()->elements(); - auto tuple_shape = shape->cast()->shape(); - for (size_t i = 0; i < elements.size(); i++) { - SetShapeToNodeProto(elements[i], tuple_shape[i], attr_proto, seq_string); - } - *seq_string += "],"; - } else if (type->isa() && shape->isa() && seq_string != nullptr) { - string shape_name = "shape" + std::to_string(GetTupleIndex()); - *seq_string += shape_name + ","; - onnx::TensorProto *tensor_proto = attr_proto->add_tensors(); - tensor_proto->set_name(shape_name); - SetTensorProto(type, shape, tensor_proto); - } else if ((type->isa() || type->isa()) && seq_string != nullptr) { - *seq_string += type->type_name() + ","; + onnx::NodeProto *const node_proto, std::string suffix) { + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_ref_attr_name("shape"); + if (suffix.compare("0") != 0) { + attr_proto->set_name("shape" + suffix); } else { - MS_LOG(EXCEPTION) << "Type of cnode need to be supported: " << type->type_name(); + attr_proto->set_name("shape"); } + onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); + SetTensorProto(type, shape, tensor_proto); } void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto) { // Get shape of cnode - // 1. need to get shape from tuple element - // 2. save shape in TensorProto - // 3. save tuple string in ref_attr_name - auto type = node->Type(); - auto shape = node->Shape(); - ResetTupleIndex(); - std::string seq_string = "shape:"; - onnx::AttributeProto *attr_proto = node_proto->add_attribute(); - SetShapeToNodeProto(type, shape, attr_proto, &seq_string); - attr_proto->set_ref_attr_name(seq_string); - MS_LOG(DEBUG) << "CNode shape: " << seq_string; + // 1. prim ArgMaxWithValue need to get shape from tuple element + // 2. some cnode doesn't has shape, such as LayerNorm + // 3. other cnodes have shape + if (node->IsApply(prim::kPrimArgMaxWithValue) || node->IsApply(prim::kPrimLayerNorm)) { + auto type = node->Type(); + auto shape = node->Shape(); + if (!type->isa()) { + MS_LOG(EXCEPTION) << "Output data of ArgMaxWithValue cnode must be tuple: " << type->type_name(); + } + auto elements = type->cast()->elements(); + auto tuple_shape = shape->cast()->shape(); + for (size_t i = 0; i < elements.size(); i++) { + SetShapeToNodeProto(elements[i], tuple_shape[i], node_proto, std::to_string(i)); + } + } else { + auto type = node->Type(); + auto shape = node->Shape(); + if (!type->isa() || !shape->isa()) { + MS_LOG(DEBUG) << "Cnode has no shape: " << node->ToString(); + return; + } + SetShapeToNodeProto(type, shape, node_proto); + } } void IrExportBuilder::BuildCNode(const CNodePtr &node, onnx::GraphProto *const graph_proto) { @@ -449,19 +445,15 @@ std::string IrExportBuilder::GetUniqueNodeName(const AnfNodePtr &node) { std::string node_name = ""; if (node->isa()) { node_name = GetNodeName(node); - } else if (node->isa()) { + } else if (node->isa() || node->isa()) { auto iter = node_index_map_.find(node); if (iter != node_index_map_.end()) { node_name = GetNodeName(node) + ":" + std::to_string(iter->second); } else { - auto node_idx = GetNodeIndex(); + auto node_idx = AllocateIndex(); node_index_map_[node] = node_idx; node_name = GetNodeName(node) + ":" + std::to_string(node_idx); } - } else if (node->isa()) { - auto node_idx = GetNodeIndex(); - node_index_map_[node] = node_idx; - node_name = GetNodeName(node) + ":" + std::to_string(node_idx); } else { MS_LOG(EXCEPTION) << "Can not support type of node:" << node->ToString(); } @@ -495,21 +487,17 @@ void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, onnx::Attri if (value == nullptr || attr_proto == nullptr) { MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; } - attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS); - onnx::TensorProto *tensor_proto = attr_proto->add_tensors(); + attr_proto->set_ref_attr_name("type"); + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); + onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); if (value->isa()) { - attr_proto->set_ref_attr_name("type:value0"); - tensor_proto->set_name("value0"); auto int_value = value->cast(); tensor_proto->set_data_type(GetOnnxDataBitsIntType(int_value->nbits())); } else if (value->isa()) { - attr_proto->set_ref_attr_name("type:value0"); - tensor_proto->set_name("value0"); auto float_value = value->cast(); tensor_proto->set_data_type(GetOnnxDataBitsFloatType(float_value->nbits())); } else if (value->isa()) { - attr_proto->set_ref_attr_name("type:tensor0"); - tensor_proto->set_name("tensor0"); + tensor_proto->set_name("tensor"); auto elem_type = value->cast()->element(); if (elem_type->isa()) { auto int_value = elem_type->cast(); @@ -533,18 +521,10 @@ void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, onnx::Attr SetScalarToAttributeProto(value, attr_proto); } else if (value->isa() || value->isa()) { SetTypeToAttributeProto(value, attr_proto); - } else if (value->isa() || value->isa()) { - ResetTupleIndex(); - std::string seq_string = "scalar:"; - attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS); - SetSequenceToAttributeProto(value->cast(), attr_proto, &seq_string); - attr_proto->set_ref_attr_name(seq_string); - MS_LOG(DEBUG) << "Attr string: " << seq_string; + } else if (value->isa()) { + SetSequenceToAttributeProto(value->cast(), attr_proto); } else if (value->isa()) { SetTensorToAttributeProto(value, attr_proto); - } else if (value->isa()) { - attr_proto->set_ref_attr_name("none"); - MS_LOG(DEBUG) << "Attr string: " << value->type_name(); } else { MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name(); } @@ -554,18 +534,16 @@ void IrExportBuilder::SetScalarToAttributeProto(const ValuePtr &value, onnx::Att if (value == nullptr || attr_proto == nullptr) { MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; } - attr_proto->set_ref_attr_name("scalar:value0"); - attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS); - onnx::TensorProto *tensor_proto = attr_proto->add_tensors(); - SetScalarToProto(value, tensor_proto, "value0"); + attr_proto->set_ref_attr_name("scalar"); + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); + onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); + SetScalarToProto(value, tensor_proto); } -void IrExportBuilder::SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto, - const std::string &value_name) { +void IrExportBuilder::SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto) { if (value == nullptr || tensor_proto == nullptr) { MS_LOG(EXCEPTION) << "ValuePtr or TensorProto is null!"; } - tensor_proto->set_name(value_name); if (value->isa()) { tensor_proto->set_data_type(onnx::TensorProto_DataType_STRING); tensor_proto->add_string_data(GetValue(value)); @@ -584,74 +562,44 @@ void IrExportBuilder::SetScalarToProto(const ValuePtr &value, onnx::TensorProto } else if (value->isa()) { tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64); tensor_proto->add_int64_data(value->cast()->value()); - } else if (value->isa()) { - tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT8); - tensor_proto->add_int32_data(value->cast()->value()); - } else if (value->isa()) { - tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT16); - tensor_proto->add_int32_data(value->cast()->value()); - } else if (value->isa()) { - tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT32); - tensor_proto->add_uint64_data(value->cast()->value()); - } else if (value->isa()) { - tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT64); - tensor_proto->add_uint64_data(value->cast()->value()); - } else if (value->isa()) { + } else if (value->isa()) { tensor_proto->set_data_type(onnx::TensorProto_DataType_FLOAT); tensor_proto->add_float_data(GetValue(value)); - } else if (value->isa()) { - tensor_proto->set_data_type(onnx::TensorProto_DataType_DOUBLE); - tensor_proto->add_double_data(GetValue(value)); } else { MS_LOG(EXCEPTION) << "Unsupported scalar type: " << value->type_name(); } } -void IrExportBuilder::SetSeqElemToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto, - std::string *const seq_string) { - string value_name = "value" + std::to_string(GetTupleIndex()); - if (seq_string != nullptr) { - *seq_string += value_name + ","; - } - onnx::TensorProto *tensor_proto = attr_proto->add_tensors(); - SetScalarToProto(value, tensor_proto, value_name); -} - -void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto, - std::string *const seq_string) { +void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value, + onnx::AttributeProto *const attr_proto) { if (value == nullptr || attr_proto == nullptr) { MS_LOG(EXCEPTION) << "ValueSequeuePtr or AttributeProto is null!"; } - if (value->isa() && seq_string != nullptr) { - *seq_string += "Tuple["; + attr_proto->set_ref_attr_name("scalar"); + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); + onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); + if (value->isa()) { const ValueTuplePtr &tuple_value = value->cast(); if (tuple_value->value().size() == 0) { MS_LOG(DEBUG) << "SetSequenceToAttributeProto tuple size is 0"; return; } + auto type_id = tuple_value->value()[0]->type()->type_id(); + tensor_proto->set_data_type(GetOnnxDataType(type_id)); for (const auto &item : tuple_value->value()) { - if (item->isa()) { - SetSequenceToAttributeProto(item->cast(), attr_proto, seq_string); - } else { - SetSeqElemToAttributeProto(item, attr_proto, seq_string); - } + SetScalarToProto(item, tensor_proto); } - *seq_string += "],"; - } else if (value->isa() && seq_string != nullptr) { - *seq_string += "List["; + } else if (value->isa()) { const ValueListPtr &list_value = value->cast(); if (list_value->value().size() == 0) { - MS_LOG(DEBUG) << "SetSequenceToAttributeProto list size is 0."; + MS_LOG(DEBUG) << "SetSequenceToAttributeProto list size is 0"; return; } + auto type_id = list_value->value()[0]->type()->type_id(); + tensor_proto->set_data_type(GetOnnxDataType(type_id)); for (const auto &item : list_value->value()) { - if (item->isa()) { - SetSequenceToAttributeProto(item->cast(), attr_proto, seq_string); - } else { - SetSeqElemToAttributeProto(item, attr_proto, seq_string); - } + SetScalarToProto(item, tensor_proto); } - *seq_string += "],"; } }