| @@ -92,14 +92,17 @@ class IrExportBuilder { | |||||
| void SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, onnx::TensorProto *const tensor_proto); | void SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, onnx::TensorProto *const tensor_proto); | ||||
| void SetAttributeProto(const AnfNodePtr &node, onnx::NodeProto *const node_proto); | void SetAttributeProto(const AnfNodePtr &node, onnx::NodeProto *const node_proto); | ||||
| void SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto); | void SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto); | ||||
| void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, onnx::NodeProto *const node_proto, | |||||
| std::string suffix = "0"); | |||||
| void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, onnx::AttributeProto *const attr_proto, | |||||
| std::string *const seq_string); | |||||
| void SetValueToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); | void SetValueToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); | ||||
| void SetTypeToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); | void SetTypeToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); | ||||
| void SetScalarToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); | void SetScalarToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); | ||||
| void SetTensorToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); | void SetTensorToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); | ||||
| void SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto); | |||||
| void SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto); | |||||
| void SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto, const std::string &value_name); | |||||
| void SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto, | |||||
| std::string *const seq_string); | |||||
| void SetSeqElemToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto, | |||||
| std::string *const seq_string); | |||||
| onnx::TensorProto_DataType GetOnnxDataType(TypeId type_id); | onnx::TensorProto_DataType GetOnnxDataType(TypeId type_id); | ||||
| onnx::TensorProto_DataType GetOnnxDataBitsIntType(int bits); | onnx::TensorProto_DataType GetOnnxDataBitsIntType(int bits); | ||||
| @@ -107,8 +110,10 @@ class IrExportBuilder { | |||||
| std::string GetNodeName(const AnfNodePtr &node); | std::string GetNodeName(const AnfNodePtr &node); | ||||
| std::string GetUniqueNodeName(const AnfNodePtr &node); | std::string GetUniqueNodeName(const AnfNodePtr &node); | ||||
| std::string GetOpTypeName(const AnfNodePtr &node); | std::string GetOpTypeName(const AnfNodePtr &node); | ||||
| size_t AllocateIndex() { return ++node_index_; } | |||||
| void ResetIndex() { node_index_ = 0; } | |||||
| size_t GetNodeIndex() { return ++node_index_; } | |||||
| void ResetNodeIndex() { node_index_ = 0; } | |||||
| size_t GetTupleIndex() { return ++shape_index_; } | |||||
| void ResetTupleIndex() { shape_index_ = 0; } | |||||
| private: | private: | ||||
| onnx::ModelProto model_; | onnx::ModelProto model_; | ||||
| @@ -116,6 +121,7 @@ class IrExportBuilder { | |||||
| std::list<FuncGraphPtr> todo_; | std::list<FuncGraphPtr> todo_; | ||||
| std::map<AnfNodePtr, size_t> node_index_map_; | std::map<AnfNodePtr, size_t> node_index_map_; | ||||
| size_t node_index_{0}; | size_t node_index_{0}; | ||||
| size_t shape_index_{0}; | |||||
| }; | }; | ||||
| using IrExporterPtr = std::shared_ptr<IrExporter>; | using IrExporterPtr = std::shared_ptr<IrExporter>; | ||||
| @@ -148,7 +154,7 @@ void IrExportBuilder::BuildModelInfo() { | |||||
| void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph) { | void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph) { | ||||
| onnx::GraphProto *graph_proto = model_.mutable_graph(); | onnx::GraphProto *graph_proto = model_.mutable_graph(); | ||||
| graph_proto->set_name(func_graph->ToString()); | graph_proto->set_name(func_graph->ToString()); | ||||
| ResetIndex(); | |||||
| ResetNodeIndex(); | |||||
| todo_.clear(); | todo_.clear(); | ||||
| todo_.push_back(func_graph); | todo_.push_back(func_graph); | ||||
| while (!todo_.empty()) { | while (!todo_.empty()) { | ||||
| @@ -179,7 +185,7 @@ void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, onnx::Grap | |||||
| input_proto->set_name(param_name); | input_proto->set_name(param_name); | ||||
| SetValueInfoProto(param, input_proto); | SetValueInfoProto(param, input_proto); | ||||
| if (!param->has_default()) { | if (!param->has_default()) { | ||||
| MS_LOG(DEBUG) << "Parameter: '" << item->ToString() << "' has no default"; | |||||
| MS_LOG(DEBUG) << "Parameter: '" << item->ToString() << "' has no default."; | |||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -234,13 +240,20 @@ void IrExportBuilder::SetValueInfoProto(const TypePtr &type, const BaseShapePtr | |||||
| auto elem_type = tensor->element(); | auto elem_type = tensor->element(); | ||||
| const auto &dims = shape->cast<abstract::ShapePtr>()->shape(); | const auto &dims = shape->cast<abstract::ShapePtr>()->shape(); | ||||
| type_proto->mutable_tensor_type()->set_elem_type(GetOnnxDataType(elem_type->type_id())); | type_proto->mutable_tensor_type()->set_elem_type(GetOnnxDataType(elem_type->type_id())); | ||||
| for (const auto &dim : dims) { | |||||
| MS_LOG(DEBUG) << "SetValueInfoProto dim: " << dim; | |||||
| type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim); | |||||
| if (dims.size() == 0) { | |||||
| MS_LOG(DEBUG) << "SetValueInfoProto set default dim 1."; | |||||
| type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); | |||||
| } else { | |||||
| for (const auto &dim : dims) { | |||||
| MS_LOG(DEBUG) << "SetValueInfoProto dim: " << dim; | |||||
| type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim); | |||||
| } | |||||
| } | } | ||||
| } else if (type->isa<Tuple>()) { | } else if (type->isa<Tuple>()) { | ||||
| auto tup_shape = shape->cast<abstract::TupleShapePtr>(); | auto tup_shape = shape->cast<abstract::TupleShapePtr>(); | ||||
| type_proto->set_denotation(std::to_string(tup_shape->shape().size())); | |||||
| type_proto->set_denotation(type->type_name() + ":" + std::to_string(tup_shape->shape().size())); | |||||
| } else if (type->isa<Number>() || type->isa<String>()) { | |||||
| type_proto->set_denotation(type->type_name()); | |||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << "Value type: " << type->type_name() << " is not supported!"; | MS_LOG(EXCEPTION) << "Value type: " << type->type_name() << " is not supported!"; | ||||
| } | } | ||||
| @@ -250,9 +263,10 @@ void IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, onnx::Att | |||||
| if (value == nullptr || attr_proto == nullptr) { | if (value == nullptr || attr_proto == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; | MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; | ||||
| } | } | ||||
| attr_proto->set_ref_attr_name("tensor"); | |||||
| attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); | |||||
| onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); | |||||
| attr_proto->set_ref_attr_name("tensor:value0"); | |||||
| attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS); | |||||
| onnx::TensorProto *tensor_proto = attr_proto->add_tensors(); | |||||
| tensor_proto->set_name("value0"); | |||||
| auto data = value->cast<tensor::TensorPtr>(); | auto data = value->cast<tensor::TensorPtr>(); | ||||
| tensor_proto->set_raw_data(data->data_c(), static_cast<size_t>(data->data().nbytes())); | tensor_proto->set_raw_data(data->data_c(), static_cast<size_t>(data->data().nbytes())); | ||||
| auto dtype = data->data_type(); | auto dtype = data->data_type(); | ||||
| @@ -286,6 +300,7 @@ void IrExportBuilder::SetParamToTensorProto(const ParameterPtr ¶m, onnx::Ten | |||||
| void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { | void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { | ||||
| std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude); | std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude); | ||||
| bool is_only_return = true; | |||||
| for (const AnfNodePtr &node : nodes) { | for (const AnfNodePtr &node : nodes) { | ||||
| if (!node->isa<CNode>()) { | if (!node->isa<CNode>()) { | ||||
| MS_LOG(DEBUG) << "Node: '" << node->ToString() << "' is not cnode"; | MS_LOG(DEBUG) << "Node: '" << node->ToString() << "' is not cnode"; | ||||
| @@ -293,9 +308,13 @@ void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProt | |||||
| } | } | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| if (cnode == func_graph->get_return()) { | if (cnode == func_graph->get_return()) { | ||||
| if (is_only_return) { | |||||
| MS_LOG(EXCEPTION) << "Only has return node, can't convert to binary model!"; | |||||
| } | |||||
| BuildOutput(cnode, graph_proto); | BuildOutput(cnode, graph_proto); | ||||
| } else { | } else { | ||||
| BuildCNode(cnode, graph_proto); | BuildCNode(cnode, graph_proto); | ||||
| is_only_return = false; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -305,24 +324,11 @@ void IrExportBuilder::BuildOutput(const CNodePtr &node, onnx::GraphProto *const | |||||
| MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2."; | MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2."; | ||||
| } | } | ||||
| AnfNodePtr arg = node->input(1); | AnfNodePtr arg = node->input(1); | ||||
| // Using make_tuple to set multi-output | |||||
| if (IsPrimitiveCNode(arg, prim::kPrimMakeTuple)) { | |||||
| auto tuple_node = arg->cast<CNodePtr>(); | |||||
| for (size_t i = 1; i < tuple_node->size(); i++) { | |||||
| auto input_node = arg->cast<CNodePtr>()->input(i); | |||||
| onnx::ValueInfoProto *output_proto = graph_proto->add_output(); | |||||
| auto output_name = GetUniqueNodeName(tuple_node->input(i)); | |||||
| output_proto->set_name(output_name); | |||||
| last_node_->add_output(output_name); | |||||
| SetValueInfoProto(tuple_node->input(i), output_proto); | |||||
| } | |||||
| } else { | |||||
| onnx::ValueInfoProto *output_proto = graph_proto->add_output(); | |||||
| std::string output_name = GetUniqueNodeName(node); | |||||
| output_proto->set_name(output_name); | |||||
| last_node_->add_output(output_name); | |||||
| SetValueInfoProto(arg, output_proto); | |||||
| } | |||||
| onnx::ValueInfoProto *output_proto = graph_proto->add_output(); | |||||
| std::string output_name = GetUniqueNodeName(node); | |||||
| output_proto->set_name(output_name); | |||||
| last_node_->set_output(0, output_name); | |||||
| SetValueInfoProto(arg, output_proto); | |||||
| } | } | ||||
| std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) { | std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) { | ||||
| @@ -345,45 +351,43 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) { | |||||
| } | } | ||||
| void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, | void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, | ||||
| onnx::NodeProto *const node_proto, std::string suffix) { | |||||
| onnx::AttributeProto *attr_proto = node_proto->add_attribute(); | |||||
| attr_proto->set_ref_attr_name("shape"); | |||||
| if (suffix.compare("0") != 0) { | |||||
| attr_proto->set_name("shape" + suffix); | |||||
| } else { | |||||
| attr_proto->set_name("shape"); | |||||
| } | |||||
| onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); | |||||
| SetTensorProto(type, shape, tensor_proto); | |||||
| } | |||||
| void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto) { | |||||
| // Get shape of cnode | |||||
| // 1. prim ArgMaxWithValue need to get shape from tuple element | |||||
| // 2. some cnode doesn't has shape, such as LayerNorm | |||||
| // 3. other cnodes have shape | |||||
| if (node->IsApply(prim::kPrimArgMaxWithValue) || node->IsApply(prim::kPrimLayerNorm)) { | |||||
| auto type = node->Type(); | |||||
| auto shape = node->Shape(); | |||||
| if (!type->isa<Tuple>()) { | |||||
| MS_LOG(EXCEPTION) << "Output data of ArgMaxWithValue cnode must be tuple: " << type->type_name(); | |||||
| } | |||||
| onnx::AttributeProto *const attr_proto, std::string *const seq_string) { | |||||
| if (type->isa<Tuple>() && seq_string != nullptr) { | |||||
| *seq_string += "Tuple["; | |||||
| auto elements = type->cast<TuplePtr>()->elements(); | auto elements = type->cast<TuplePtr>()->elements(); | ||||
| auto tuple_shape = shape->cast<abstract::TupleShapePtr>()->shape(); | auto tuple_shape = shape->cast<abstract::TupleShapePtr>()->shape(); | ||||
| for (size_t i = 0; i < elements.size(); i++) { | for (size_t i = 0; i < elements.size(); i++) { | ||||
| SetShapeToNodeProto(elements[i], tuple_shape[i], node_proto, std::to_string(i)); | |||||
| SetShapeToNodeProto(elements[i], tuple_shape[i], attr_proto, seq_string); | |||||
| } | } | ||||
| *seq_string += "],"; | |||||
| } else if (type->isa<TensorType>() && shape->isa<abstract::Shape>() && seq_string != nullptr) { | |||||
| string shape_name = "shape" + std::to_string(GetTupleIndex()); | |||||
| *seq_string += shape_name + ","; | |||||
| onnx::TensorProto *tensor_proto = attr_proto->add_tensors(); | |||||
| tensor_proto->set_name(shape_name); | |||||
| SetTensorProto(type, shape, tensor_proto); | |||||
| } else if ((type->isa<Number>() || type->isa<String>()) && seq_string != nullptr) { | |||||
| *seq_string += type->type_name() + ","; | |||||
| } else { | } else { | ||||
| auto type = node->Type(); | |||||
| auto shape = node->Shape(); | |||||
| if (!type->isa<TensorType>() || !shape->isa<abstract::Shape>()) { | |||||
| MS_LOG(DEBUG) << "Cnode has no shape: " << node->ToString(); | |||||
| return; | |||||
| } | |||||
| SetShapeToNodeProto(type, shape, node_proto); | |||||
| MS_LOG(EXCEPTION) << "Type of cnode need to be supported: " << type->type_name(); | |||||
| } | } | ||||
| } | } | ||||
| void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto) { | |||||
| // Get shape of cnode | |||||
| // 1. need to get shape from tuple element | |||||
| // 2. save shape in TensorProto | |||||
| // 3. save tuple string in ref_attr_name | |||||
| auto type = node->Type(); | |||||
| auto shape = node->Shape(); | |||||
| ResetTupleIndex(); | |||||
| std::string seq_string = "shape:"; | |||||
| onnx::AttributeProto *attr_proto = node_proto->add_attribute(); | |||||
| SetShapeToNodeProto(type, shape, attr_proto, &seq_string); | |||||
| attr_proto->set_ref_attr_name(seq_string); | |||||
| MS_LOG(DEBUG) << "CNode shape: " << seq_string; | |||||
| } | |||||
| void IrExportBuilder::BuildCNode(const CNodePtr &node, onnx::GraphProto *const graph_proto) { | void IrExportBuilder::BuildCNode(const CNodePtr &node, onnx::GraphProto *const graph_proto) { | ||||
| auto inputs_size = node->size(); | auto inputs_size = node->size(); | ||||
| if (inputs_size < 1) { | if (inputs_size < 1) { | ||||
| @@ -445,15 +449,19 @@ std::string IrExportBuilder::GetUniqueNodeName(const AnfNodePtr &node) { | |||||
| std::string node_name = ""; | std::string node_name = ""; | ||||
| if (node->isa<Parameter>()) { | if (node->isa<Parameter>()) { | ||||
| node_name = GetNodeName(node); | node_name = GetNodeName(node); | ||||
| } else if (node->isa<CNode>() || node->isa<ValueNode>()) { | |||||
| } else if (node->isa<CNode>()) { | |||||
| auto iter = node_index_map_.find(node); | auto iter = node_index_map_.find(node); | ||||
| if (iter != node_index_map_.end()) { | if (iter != node_index_map_.end()) { | ||||
| node_name = GetNodeName(node) + ":" + std::to_string(iter->second); | node_name = GetNodeName(node) + ":" + std::to_string(iter->second); | ||||
| } else { | } else { | ||||
| auto node_idx = AllocateIndex(); | |||||
| auto node_idx = GetNodeIndex(); | |||||
| node_index_map_[node] = node_idx; | node_index_map_[node] = node_idx; | ||||
| node_name = GetNodeName(node) + ":" + std::to_string(node_idx); | node_name = GetNodeName(node) + ":" + std::to_string(node_idx); | ||||
| } | } | ||||
| } else if (node->isa<ValueNode>()) { | |||||
| auto node_idx = GetNodeIndex(); | |||||
| node_index_map_[node] = node_idx; | |||||
| node_name = GetNodeName(node) + ":" + std::to_string(node_idx); | |||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << "Can not support type of node:" << node->ToString(); | MS_LOG(EXCEPTION) << "Can not support type of node:" << node->ToString(); | ||||
| } | } | ||||
| @@ -487,17 +495,21 @@ void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, onnx::Attri | |||||
| if (value == nullptr || attr_proto == nullptr) { | if (value == nullptr || attr_proto == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; | MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; | ||||
| } | } | ||||
| attr_proto->set_ref_attr_name("type"); | |||||
| attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); | |||||
| onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); | |||||
| attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS); | |||||
| onnx::TensorProto *tensor_proto = attr_proto->add_tensors(); | |||||
| if (value->isa<Int>()) { | if (value->isa<Int>()) { | ||||
| attr_proto->set_ref_attr_name("type:value0"); | |||||
| tensor_proto->set_name("value0"); | |||||
| auto int_value = value->cast<IntPtr>(); | auto int_value = value->cast<IntPtr>(); | ||||
| tensor_proto->set_data_type(GetOnnxDataBitsIntType(int_value->nbits())); | tensor_proto->set_data_type(GetOnnxDataBitsIntType(int_value->nbits())); | ||||
| } else if (value->isa<Float>()) { | } else if (value->isa<Float>()) { | ||||
| attr_proto->set_ref_attr_name("type:value0"); | |||||
| tensor_proto->set_name("value0"); | |||||
| auto float_value = value->cast<FloatPtr>(); | auto float_value = value->cast<FloatPtr>(); | ||||
| tensor_proto->set_data_type(GetOnnxDataBitsFloatType(float_value->nbits())); | tensor_proto->set_data_type(GetOnnxDataBitsFloatType(float_value->nbits())); | ||||
| } else if (value->isa<TensorType>()) { | } else if (value->isa<TensorType>()) { | ||||
| tensor_proto->set_name("tensor"); | |||||
| attr_proto->set_ref_attr_name("type:tensor0"); | |||||
| tensor_proto->set_name("tensor0"); | |||||
| auto elem_type = value->cast<TensorTypePtr>()->element(); | auto elem_type = value->cast<TensorTypePtr>()->element(); | ||||
| if (elem_type->isa<Int>()) { | if (elem_type->isa<Int>()) { | ||||
| auto int_value = elem_type->cast<IntPtr>(); | auto int_value = elem_type->cast<IntPtr>(); | ||||
| @@ -521,10 +533,18 @@ void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, onnx::Attr | |||||
| SetScalarToAttributeProto(value, attr_proto); | SetScalarToAttributeProto(value, attr_proto); | ||||
| } else if (value->isa<Number>() || value->isa<TensorType>()) { | } else if (value->isa<Number>() || value->isa<TensorType>()) { | ||||
| SetTypeToAttributeProto(value, attr_proto); | SetTypeToAttributeProto(value, attr_proto); | ||||
| } else if (value->isa<ValueSequeue>()) { | |||||
| SetSequenceToAttributeProto(value->cast<ValueSequeuePtr>(), attr_proto); | |||||
| } else if (value->isa<ValueSequeue>() || value->isa<ValueSequeue>()) { | |||||
| ResetTupleIndex(); | |||||
| std::string seq_string = "scalar:"; | |||||
| attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS); | |||||
| SetSequenceToAttributeProto(value->cast<ValueSequeuePtr>(), attr_proto, &seq_string); | |||||
| attr_proto->set_ref_attr_name(seq_string); | |||||
| MS_LOG(DEBUG) << "Attr string: " << seq_string; | |||||
| } else if (value->isa<tensor::Tensor>()) { | } else if (value->isa<tensor::Tensor>()) { | ||||
| SetTensorToAttributeProto(value, attr_proto); | SetTensorToAttributeProto(value, attr_proto); | ||||
| } else if (value->isa<None>()) { | |||||
| attr_proto->set_ref_attr_name("none"); | |||||
| MS_LOG(DEBUG) << "Attr string: " << value->type_name(); | |||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name(); | MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name(); | ||||
| } | } | ||||
| @@ -534,16 +554,18 @@ void IrExportBuilder::SetScalarToAttributeProto(const ValuePtr &value, onnx::Att | |||||
| if (value == nullptr || attr_proto == nullptr) { | if (value == nullptr || attr_proto == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; | MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; | ||||
| } | } | ||||
| attr_proto->set_ref_attr_name("scalar"); | |||||
| attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); | |||||
| onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); | |||||
| SetScalarToProto(value, tensor_proto); | |||||
| attr_proto->set_ref_attr_name("scalar:value0"); | |||||
| attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS); | |||||
| onnx::TensorProto *tensor_proto = attr_proto->add_tensors(); | |||||
| SetScalarToProto(value, tensor_proto, "value0"); | |||||
| } | } | ||||
| void IrExportBuilder::SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto) { | |||||
| void IrExportBuilder::SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto, | |||||
| const std::string &value_name) { | |||||
| if (value == nullptr || tensor_proto == nullptr) { | if (value == nullptr || tensor_proto == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "ValuePtr or TensorProto is null!"; | MS_LOG(EXCEPTION) << "ValuePtr or TensorProto is null!"; | ||||
| } | } | ||||
| tensor_proto->set_name(value_name); | |||||
| if (value->isa<StringImm>()) { | if (value->isa<StringImm>()) { | ||||
| tensor_proto->set_data_type(onnx::TensorProto_DataType_STRING); | tensor_proto->set_data_type(onnx::TensorProto_DataType_STRING); | ||||
| tensor_proto->add_string_data(GetValue<std::string>(value)); | tensor_proto->add_string_data(GetValue<std::string>(value)); | ||||
| @@ -562,44 +584,74 @@ void IrExportBuilder::SetScalarToProto(const ValuePtr &value, onnx::TensorProto | |||||
| } else if (value->isa<Int64Imm>()) { | } else if (value->isa<Int64Imm>()) { | ||||
| tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64); | tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64); | ||||
| tensor_proto->add_int64_data(value->cast<Int64ImmPtr>()->value()); | tensor_proto->add_int64_data(value->cast<Int64ImmPtr>()->value()); | ||||
| } else if (value->isa<FloatImm>()) { | |||||
| } else if (value->isa<UInt8Imm>()) { | |||||
| tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT8); | |||||
| tensor_proto->add_int32_data(value->cast<UInt8ImmPtr>()->value()); | |||||
| } else if (value->isa<UInt16Imm>()) { | |||||
| tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT16); | |||||
| tensor_proto->add_int32_data(value->cast<UInt16ImmPtr>()->value()); | |||||
| } else if (value->isa<UInt32Imm>()) { | |||||
| tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT32); | |||||
| tensor_proto->add_uint64_data(value->cast<UInt32ImmPtr>()->value()); | |||||
| } else if (value->isa<UInt64Imm>()) { | |||||
| tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT64); | |||||
| tensor_proto->add_uint64_data(value->cast<UInt64ImmPtr>()->value()); | |||||
| } else if (value->isa<FP32Imm>()) { | |||||
| tensor_proto->set_data_type(onnx::TensorProto_DataType_FLOAT); | tensor_proto->set_data_type(onnx::TensorProto_DataType_FLOAT); | ||||
| tensor_proto->add_float_data(GetValue<float>(value)); | tensor_proto->add_float_data(GetValue<float>(value)); | ||||
| } else if (value->isa<FP64Imm>()) { | |||||
| tensor_proto->set_data_type(onnx::TensorProto_DataType_DOUBLE); | |||||
| tensor_proto->add_double_data(GetValue<double>(value)); | |||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << "Unsupported scalar type: " << value->type_name(); | MS_LOG(EXCEPTION) << "Unsupported scalar type: " << value->type_name(); | ||||
| } | } | ||||
| } | } | ||||
| void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value, | |||||
| onnx::AttributeProto *const attr_proto) { | |||||
| void IrExportBuilder::SetSeqElemToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto, | |||||
| std::string *const seq_string) { | |||||
| string value_name = "value" + std::to_string(GetTupleIndex()); | |||||
| if (seq_string != nullptr) { | |||||
| *seq_string += value_name + ","; | |||||
| } | |||||
| onnx::TensorProto *tensor_proto = attr_proto->add_tensors(); | |||||
| SetScalarToProto(value, tensor_proto, value_name); | |||||
| } | |||||
| void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto, | |||||
| std::string *const seq_string) { | |||||
| if (value == nullptr || attr_proto == nullptr) { | if (value == nullptr || attr_proto == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "ValueSequeuePtr or AttributeProto is null!"; | MS_LOG(EXCEPTION) << "ValueSequeuePtr or AttributeProto is null!"; | ||||
| } | } | ||||
| attr_proto->set_ref_attr_name("scalar"); | |||||
| attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); | |||||
| onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); | |||||
| if (value->isa<ValueTuple>()) { | |||||
| if (value->isa<ValueTuple>() && seq_string != nullptr) { | |||||
| *seq_string += "Tuple["; | |||||
| const ValueTuplePtr &tuple_value = value->cast<ValueTuplePtr>(); | const ValueTuplePtr &tuple_value = value->cast<ValueTuplePtr>(); | ||||
| if (tuple_value->value().size() == 0) { | if (tuple_value->value().size() == 0) { | ||||
| MS_LOG(DEBUG) << "SetSequenceToAttributeProto tuple size is 0"; | MS_LOG(DEBUG) << "SetSequenceToAttributeProto tuple size is 0"; | ||||
| return; | return; | ||||
| } | } | ||||
| auto type_id = tuple_value->value()[0]->type()->type_id(); | |||||
| tensor_proto->set_data_type(GetOnnxDataType(type_id)); | |||||
| for (const auto &item : tuple_value->value()) { | for (const auto &item : tuple_value->value()) { | ||||
| SetScalarToProto(item, tensor_proto); | |||||
| if (item->isa<ValueTuple>()) { | |||||
| SetSequenceToAttributeProto(item->cast<ValueTuplePtr>(), attr_proto, seq_string); | |||||
| } else { | |||||
| SetSeqElemToAttributeProto(item, attr_proto, seq_string); | |||||
| } | |||||
| } | } | ||||
| } else if (value->isa<ValueList>()) { | |||||
| *seq_string += "],"; | |||||
| } else if (value->isa<ValueList>() && seq_string != nullptr) { | |||||
| *seq_string += "List["; | |||||
| const ValueListPtr &list_value = value->cast<ValueListPtr>(); | const ValueListPtr &list_value = value->cast<ValueListPtr>(); | ||||
| if (list_value->value().size() == 0) { | if (list_value->value().size() == 0) { | ||||
| MS_LOG(DEBUG) << "SetSequenceToAttributeProto list size is 0"; | |||||
| MS_LOG(DEBUG) << "SetSequenceToAttributeProto list size is 0."; | |||||
| return; | return; | ||||
| } | } | ||||
| auto type_id = list_value->value()[0]->type()->type_id(); | |||||
| tensor_proto->set_data_type(GetOnnxDataType(type_id)); | |||||
| for (const auto &item : list_value->value()) { | for (const auto &item : list_value->value()) { | ||||
| SetScalarToProto(item, tensor_proto); | |||||
| if (item->isa<ValueList>()) { | |||||
| SetSequenceToAttributeProto(item->cast<ValueListPtr>(), attr_proto, seq_string); | |||||
| } else { | |||||
| SetSeqElemToAttributeProto(item, attr_proto, seq_string); | |||||
| } | |||||
| } | } | ||||
| *seq_string += "],"; | |||||
| } | } | ||||
| } | } | ||||