| @@ -92,14 +92,17 @@ class IrExportBuilder { | |||
| void SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, onnx::TensorProto *const tensor_proto); | |||
| void SetAttributeProto(const AnfNodePtr &node, onnx::NodeProto *const node_proto); | |||
| void SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto); | |||
| void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, onnx::NodeProto *const node_proto, | |||
| std::string suffix = "0"); | |||
| void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, onnx::AttributeProto *const attr_proto, | |||
| std::string *const seq_string); | |||
| void SetValueToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); | |||
| void SetTypeToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); | |||
| void SetScalarToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); | |||
| void SetTensorToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); | |||
| void SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto); | |||
| void SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto); | |||
| void SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto, const std::string &value_name); | |||
| void SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto, | |||
| std::string *const seq_string); | |||
| void SetSeqElemToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto, | |||
| std::string *const seq_string); | |||
| onnx::TensorProto_DataType GetOnnxDataType(TypeId type_id); | |||
| onnx::TensorProto_DataType GetOnnxDataBitsIntType(int bits); | |||
| @@ -107,8 +110,10 @@ class IrExportBuilder { | |||
| std::string GetNodeName(const AnfNodePtr &node); | |||
| std::string GetUniqueNodeName(const AnfNodePtr &node); | |||
| std::string GetOpTypeName(const AnfNodePtr &node); | |||
| size_t AllocateIndex() { return ++node_index_; } | |||
| void ResetIndex() { node_index_ = 0; } | |||
| size_t GetNodeIndex() { return ++node_index_; } | |||
| void ResetNodeIndex() { node_index_ = 0; } | |||
| size_t GetTupleIndex() { return ++shape_index_; } | |||
| void ResetTupleIndex() { shape_index_ = 0; } | |||
| private: | |||
| onnx::ModelProto model_; | |||
| @@ -116,6 +121,7 @@ class IrExportBuilder { | |||
| std::list<FuncGraphPtr> todo_; | |||
| std::map<AnfNodePtr, size_t> node_index_map_; | |||
| size_t node_index_{0}; | |||
| size_t shape_index_{0}; | |||
| }; | |||
| using IrExporterPtr = std::shared_ptr<IrExporter>; | |||
| @@ -148,7 +154,7 @@ void IrExportBuilder::BuildModelInfo() { | |||
| void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph) { | |||
| onnx::GraphProto *graph_proto = model_.mutable_graph(); | |||
| graph_proto->set_name(func_graph->ToString()); | |||
| ResetIndex(); | |||
| ResetNodeIndex(); | |||
| todo_.clear(); | |||
| todo_.push_back(func_graph); | |||
| while (!todo_.empty()) { | |||
| @@ -179,7 +185,7 @@ void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, onnx::Grap | |||
| input_proto->set_name(param_name); | |||
| SetValueInfoProto(param, input_proto); | |||
| if (!param->has_default()) { | |||
| MS_LOG(DEBUG) << "Parameter: '" << item->ToString() << "' has no default"; | |||
| MS_LOG(DEBUG) << "Parameter: '" << item->ToString() << "' has no default."; | |||
| continue; | |||
| } | |||
| @@ -234,13 +240,20 @@ void IrExportBuilder::SetValueInfoProto(const TypePtr &type, const BaseShapePtr | |||
| auto elem_type = tensor->element(); | |||
| const auto &dims = shape->cast<abstract::ShapePtr>()->shape(); | |||
| type_proto->mutable_tensor_type()->set_elem_type(GetOnnxDataType(elem_type->type_id())); | |||
| for (const auto &dim : dims) { | |||
| MS_LOG(DEBUG) << "SetValueInfoProto dim: " << dim; | |||
| type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim); | |||
| if (dims.size() == 0) { | |||
| MS_LOG(DEBUG) << "SetValueInfoProto set default dim 1."; | |||
| type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); | |||
| } else { | |||
| for (const auto &dim : dims) { | |||
| MS_LOG(DEBUG) << "SetValueInfoProto dim: " << dim; | |||
| type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim); | |||
| } | |||
| } | |||
| } else if (type->isa<Tuple>()) { | |||
| auto tup_shape = shape->cast<abstract::TupleShapePtr>(); | |||
| type_proto->set_denotation(std::to_string(tup_shape->shape().size())); | |||
| type_proto->set_denotation(type->type_name() + ":" + std::to_string(tup_shape->shape().size())); | |||
| } else if (type->isa<Number>() || type->isa<String>()) { | |||
| type_proto->set_denotation(type->type_name()); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Value type: " << type->type_name() << " is not supported!"; | |||
| } | |||
| @@ -250,9 +263,10 @@ void IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, onnx::Att | |||
| if (value == nullptr || attr_proto == nullptr) { | |||
| MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; | |||
| } | |||
| attr_proto->set_ref_attr_name("tensor"); | |||
| attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); | |||
| onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); | |||
| attr_proto->set_ref_attr_name("tensor:value0"); | |||
| attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS); | |||
| onnx::TensorProto *tensor_proto = attr_proto->add_tensors(); | |||
| tensor_proto->set_name("value0"); | |||
| auto data = value->cast<tensor::TensorPtr>(); | |||
| tensor_proto->set_raw_data(data->data_c(), static_cast<size_t>(data->data().nbytes())); | |||
| auto dtype = data->data_type(); | |||
| @@ -286,6 +300,7 @@ void IrExportBuilder::SetParamToTensorProto(const ParameterPtr ¶m, onnx::Ten | |||
| void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { | |||
| std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude); | |||
| bool is_only_return = true; | |||
| for (const AnfNodePtr &node : nodes) { | |||
| if (!node->isa<CNode>()) { | |||
| MS_LOG(DEBUG) << "Node: '" << node->ToString() << "' is not cnode"; | |||
| @@ -293,9 +308,13 @@ void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProt | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (cnode == func_graph->get_return()) { | |||
| if (is_only_return) { | |||
| MS_LOG(EXCEPTION) << "Only has return node, can't convert to binary model!"; | |||
| } | |||
| BuildOutput(cnode, graph_proto); | |||
| } else { | |||
| BuildCNode(cnode, graph_proto); | |||
| is_only_return = false; | |||
| } | |||
| } | |||
| } | |||
| @@ -305,24 +324,11 @@ void IrExportBuilder::BuildOutput(const CNodePtr &node, onnx::GraphProto *const | |||
| MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2."; | |||
| } | |||
| AnfNodePtr arg = node->input(1); | |||
| // Using make_tuple to set multi-output | |||
| if (IsPrimitiveCNode(arg, prim::kPrimMakeTuple)) { | |||
| auto tuple_node = arg->cast<CNodePtr>(); | |||
| for (size_t i = 1; i < tuple_node->size(); i++) { | |||
| auto input_node = arg->cast<CNodePtr>()->input(i); | |||
| onnx::ValueInfoProto *output_proto = graph_proto->add_output(); | |||
| auto output_name = GetUniqueNodeName(tuple_node->input(i)); | |||
| output_proto->set_name(output_name); | |||
| last_node_->add_output(output_name); | |||
| SetValueInfoProto(tuple_node->input(i), output_proto); | |||
| } | |||
| } else { | |||
| onnx::ValueInfoProto *output_proto = graph_proto->add_output(); | |||
| std::string output_name = GetUniqueNodeName(node); | |||
| output_proto->set_name(output_name); | |||
| last_node_->add_output(output_name); | |||
| SetValueInfoProto(arg, output_proto); | |||
| } | |||
| onnx::ValueInfoProto *output_proto = graph_proto->add_output(); | |||
| std::string output_name = GetUniqueNodeName(node); | |||
| output_proto->set_name(output_name); | |||
| last_node_->set_output(0, output_name); | |||
| SetValueInfoProto(arg, output_proto); | |||
| } | |||
| std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) { | |||
| @@ -345,45 +351,43 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) { | |||
| } | |||
| void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, | |||
| onnx::NodeProto *const node_proto, std::string suffix) { | |||
| onnx::AttributeProto *attr_proto = node_proto->add_attribute(); | |||
| attr_proto->set_ref_attr_name("shape"); | |||
| if (suffix.compare("0") != 0) { | |||
| attr_proto->set_name("shape" + suffix); | |||
| } else { | |||
| attr_proto->set_name("shape"); | |||
| } | |||
| onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); | |||
| SetTensorProto(type, shape, tensor_proto); | |||
| } | |||
| void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto) { | |||
| // Get shape of cnode | |||
| // 1. prim ArgMaxWithValue need to get shape from tuple element | |||
| // 2. some cnode doesn't has shape, such as LayerNorm | |||
| // 3. other cnodes have shape | |||
| if (node->IsApply(prim::kPrimArgMaxWithValue) || node->IsApply(prim::kPrimLayerNorm)) { | |||
| auto type = node->Type(); | |||
| auto shape = node->Shape(); | |||
| if (!type->isa<Tuple>()) { | |||
| MS_LOG(EXCEPTION) << "Output data of ArgMaxWithValue cnode must be tuple: " << type->type_name(); | |||
| } | |||
| onnx::AttributeProto *const attr_proto, std::string *const seq_string) { | |||
| if (type->isa<Tuple>() && seq_string != nullptr) { | |||
| *seq_string += "Tuple["; | |||
| auto elements = type->cast<TuplePtr>()->elements(); | |||
| auto tuple_shape = shape->cast<abstract::TupleShapePtr>()->shape(); | |||
| for (size_t i = 0; i < elements.size(); i++) { | |||
| SetShapeToNodeProto(elements[i], tuple_shape[i], node_proto, std::to_string(i)); | |||
| SetShapeToNodeProto(elements[i], tuple_shape[i], attr_proto, seq_string); | |||
| } | |||
| *seq_string += "],"; | |||
| } else if (type->isa<TensorType>() && shape->isa<abstract::Shape>() && seq_string != nullptr) { | |||
| string shape_name = "shape" + std::to_string(GetTupleIndex()); | |||
| *seq_string += shape_name + ","; | |||
| onnx::TensorProto *tensor_proto = attr_proto->add_tensors(); | |||
| tensor_proto->set_name(shape_name); | |||
| SetTensorProto(type, shape, tensor_proto); | |||
| } else if ((type->isa<Number>() || type->isa<String>()) && seq_string != nullptr) { | |||
| *seq_string += type->type_name() + ","; | |||
| } else { | |||
| auto type = node->Type(); | |||
| auto shape = node->Shape(); | |||
| if (!type->isa<TensorType>() || !shape->isa<abstract::Shape>()) { | |||
| MS_LOG(DEBUG) << "Cnode has no shape: " << node->ToString(); | |||
| return; | |||
| } | |||
| SetShapeToNodeProto(type, shape, node_proto); | |||
| MS_LOG(EXCEPTION) << "Type of cnode need to be supported: " << type->type_name(); | |||
| } | |||
| } | |||
| void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto) { | |||
| // Get shape of cnode | |||
| // 1. need to get shape from tuple element | |||
| // 2. save shape in TensorProto | |||
| // 3. save tuple string in ref_attr_name | |||
| auto type = node->Type(); | |||
| auto shape = node->Shape(); | |||
| ResetTupleIndex(); | |||
| std::string seq_string = "shape:"; | |||
| onnx::AttributeProto *attr_proto = node_proto->add_attribute(); | |||
| SetShapeToNodeProto(type, shape, attr_proto, &seq_string); | |||
| attr_proto->set_ref_attr_name(seq_string); | |||
| MS_LOG(DEBUG) << "CNode shape: " << seq_string; | |||
| } | |||
| void IrExportBuilder::BuildCNode(const CNodePtr &node, onnx::GraphProto *const graph_proto) { | |||
| auto inputs_size = node->size(); | |||
| if (inputs_size < 1) { | |||
| @@ -445,15 +449,19 @@ std::string IrExportBuilder::GetUniqueNodeName(const AnfNodePtr &node) { | |||
| std::string node_name = ""; | |||
| if (node->isa<Parameter>()) { | |||
| node_name = GetNodeName(node); | |||
| } else if (node->isa<CNode>() || node->isa<ValueNode>()) { | |||
| } else if (node->isa<CNode>()) { | |||
| auto iter = node_index_map_.find(node); | |||
| if (iter != node_index_map_.end()) { | |||
| node_name = GetNodeName(node) + ":" + std::to_string(iter->second); | |||
| } else { | |||
| auto node_idx = AllocateIndex(); | |||
| auto node_idx = GetNodeIndex(); | |||
| node_index_map_[node] = node_idx; | |||
| node_name = GetNodeName(node) + ":" + std::to_string(node_idx); | |||
| } | |||
| } else if (node->isa<ValueNode>()) { | |||
| auto node_idx = GetNodeIndex(); | |||
| node_index_map_[node] = node_idx; | |||
| node_name = GetNodeName(node) + ":" + std::to_string(node_idx); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Can not support type of node:" << node->ToString(); | |||
| } | |||
| @@ -487,17 +495,21 @@ void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, onnx::Attri | |||
| if (value == nullptr || attr_proto == nullptr) { | |||
| MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; | |||
| } | |||
| attr_proto->set_ref_attr_name("type"); | |||
| attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); | |||
| onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); | |||
| attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS); | |||
| onnx::TensorProto *tensor_proto = attr_proto->add_tensors(); | |||
| if (value->isa<Int>()) { | |||
| attr_proto->set_ref_attr_name("type:value0"); | |||
| tensor_proto->set_name("value0"); | |||
| auto int_value = value->cast<IntPtr>(); | |||
| tensor_proto->set_data_type(GetOnnxDataBitsIntType(int_value->nbits())); | |||
| } else if (value->isa<Float>()) { | |||
| attr_proto->set_ref_attr_name("type:value0"); | |||
| tensor_proto->set_name("value0"); | |||
| auto float_value = value->cast<FloatPtr>(); | |||
| tensor_proto->set_data_type(GetOnnxDataBitsFloatType(float_value->nbits())); | |||
| } else if (value->isa<TensorType>()) { | |||
| tensor_proto->set_name("tensor"); | |||
| attr_proto->set_ref_attr_name("type:tensor0"); | |||
| tensor_proto->set_name("tensor0"); | |||
| auto elem_type = value->cast<TensorTypePtr>()->element(); | |||
| if (elem_type->isa<Int>()) { | |||
| auto int_value = elem_type->cast<IntPtr>(); | |||
| @@ -521,10 +533,18 @@ void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, onnx::Attr | |||
| SetScalarToAttributeProto(value, attr_proto); | |||
| } else if (value->isa<Number>() || value->isa<TensorType>()) { | |||
| SetTypeToAttributeProto(value, attr_proto); | |||
| } else if (value->isa<ValueSequeue>()) { | |||
| SetSequenceToAttributeProto(value->cast<ValueSequeuePtr>(), attr_proto); | |||
| } else if (value->isa<ValueSequeue>() || value->isa<ValueSequeue>()) { | |||
| ResetTupleIndex(); | |||
| std::string seq_string = "scalar:"; | |||
| attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS); | |||
| SetSequenceToAttributeProto(value->cast<ValueSequeuePtr>(), attr_proto, &seq_string); | |||
| attr_proto->set_ref_attr_name(seq_string); | |||
| MS_LOG(DEBUG) << "Attr string: " << seq_string; | |||
| } else if (value->isa<tensor::Tensor>()) { | |||
| SetTensorToAttributeProto(value, attr_proto); | |||
| } else if (value->isa<None>()) { | |||
| attr_proto->set_ref_attr_name("none"); | |||
| MS_LOG(DEBUG) << "Attr string: " << value->type_name(); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name(); | |||
| } | |||
| @@ -534,16 +554,18 @@ void IrExportBuilder::SetScalarToAttributeProto(const ValuePtr &value, onnx::Att | |||
| if (value == nullptr || attr_proto == nullptr) { | |||
| MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; | |||
| } | |||
| attr_proto->set_ref_attr_name("scalar"); | |||
| attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); | |||
| onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); | |||
| SetScalarToProto(value, tensor_proto); | |||
| attr_proto->set_ref_attr_name("scalar:value0"); | |||
| attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS); | |||
| onnx::TensorProto *tensor_proto = attr_proto->add_tensors(); | |||
| SetScalarToProto(value, tensor_proto, "value0"); | |||
| } | |||
| void IrExportBuilder::SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto) { | |||
| void IrExportBuilder::SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto, | |||
| const std::string &value_name) { | |||
| if (value == nullptr || tensor_proto == nullptr) { | |||
| MS_LOG(EXCEPTION) << "ValuePtr or TensorProto is null!"; | |||
| } | |||
| tensor_proto->set_name(value_name); | |||
| if (value->isa<StringImm>()) { | |||
| tensor_proto->set_data_type(onnx::TensorProto_DataType_STRING); | |||
| tensor_proto->add_string_data(GetValue<std::string>(value)); | |||
| @@ -562,44 +584,74 @@ void IrExportBuilder::SetScalarToProto(const ValuePtr &value, onnx::TensorProto | |||
| } else if (value->isa<Int64Imm>()) { | |||
| tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64); | |||
| tensor_proto->add_int64_data(value->cast<Int64ImmPtr>()->value()); | |||
| } else if (value->isa<FloatImm>()) { | |||
| } else if (value->isa<UInt8Imm>()) { | |||
| tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT8); | |||
| tensor_proto->add_int32_data(value->cast<UInt8ImmPtr>()->value()); | |||
| } else if (value->isa<UInt16Imm>()) { | |||
| tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT16); | |||
| tensor_proto->add_int32_data(value->cast<UInt16ImmPtr>()->value()); | |||
| } else if (value->isa<UInt32Imm>()) { | |||
| tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT32); | |||
| tensor_proto->add_uint64_data(value->cast<UInt32ImmPtr>()->value()); | |||
| } else if (value->isa<UInt64Imm>()) { | |||
| tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT64); | |||
| tensor_proto->add_uint64_data(value->cast<UInt64ImmPtr>()->value()); | |||
| } else if (value->isa<FP32Imm>()) { | |||
| tensor_proto->set_data_type(onnx::TensorProto_DataType_FLOAT); | |||
| tensor_proto->add_float_data(GetValue<float>(value)); | |||
| } else if (value->isa<FP64Imm>()) { | |||
| tensor_proto->set_data_type(onnx::TensorProto_DataType_DOUBLE); | |||
| tensor_proto->add_double_data(GetValue<double>(value)); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Unsupported scalar type: " << value->type_name(); | |||
| } | |||
| } | |||
| void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value, | |||
| onnx::AttributeProto *const attr_proto) { | |||
| void IrExportBuilder::SetSeqElemToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto, | |||
| std::string *const seq_string) { | |||
| string value_name = "value" + std::to_string(GetTupleIndex()); | |||
| if (seq_string != nullptr) { | |||
| *seq_string += value_name + ","; | |||
| } | |||
| onnx::TensorProto *tensor_proto = attr_proto->add_tensors(); | |||
| SetScalarToProto(value, tensor_proto, value_name); | |||
| } | |||
| void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto, | |||
| std::string *const seq_string) { | |||
| if (value == nullptr || attr_proto == nullptr) { | |||
| MS_LOG(EXCEPTION) << "ValueSequeuePtr or AttributeProto is null!"; | |||
| } | |||
| attr_proto->set_ref_attr_name("scalar"); | |||
| attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); | |||
| onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); | |||
| if (value->isa<ValueTuple>()) { | |||
| if (value->isa<ValueTuple>() && seq_string != nullptr) { | |||
| *seq_string += "Tuple["; | |||
| const ValueTuplePtr &tuple_value = value->cast<ValueTuplePtr>(); | |||
| if (tuple_value->value().size() == 0) { | |||
| MS_LOG(DEBUG) << "SetSequenceToAttributeProto tuple size is 0"; | |||
| return; | |||
| } | |||
| auto type_id = tuple_value->value()[0]->type()->type_id(); | |||
| tensor_proto->set_data_type(GetOnnxDataType(type_id)); | |||
| for (const auto &item : tuple_value->value()) { | |||
| SetScalarToProto(item, tensor_proto); | |||
| if (item->isa<ValueTuple>()) { | |||
| SetSequenceToAttributeProto(item->cast<ValueTuplePtr>(), attr_proto, seq_string); | |||
| } else { | |||
| SetSeqElemToAttributeProto(item, attr_proto, seq_string); | |||
| } | |||
| } | |||
| } else if (value->isa<ValueList>()) { | |||
| *seq_string += "],"; | |||
| } else if (value->isa<ValueList>() && seq_string != nullptr) { | |||
| *seq_string += "List["; | |||
| const ValueListPtr &list_value = value->cast<ValueListPtr>(); | |||
| if (list_value->value().size() == 0) { | |||
| MS_LOG(DEBUG) << "SetSequenceToAttributeProto list size is 0"; | |||
| MS_LOG(DEBUG) << "SetSequenceToAttributeProto list size is 0."; | |||
| return; | |||
| } | |||
| auto type_id = list_value->value()[0]->type()->type_id(); | |||
| tensor_proto->set_data_type(GetOnnxDataType(type_id)); | |||
| for (const auto &item : list_value->value()) { | |||
| SetScalarToProto(item, tensor_proto); | |||
| if (item->isa<ValueList>()) { | |||
| SetSequenceToAttributeProto(item->cast<ValueListPtr>(), attr_proto, seq_string); | |||
| } else { | |||
| SetSeqElemToAttributeProto(item, attr_proto, seq_string); | |||
| } | |||
| } | |||
| *seq_string += "],"; | |||
| } | |||
| } | |||