|
|
|
@@ -90,14 +90,17 @@ class IrExportBuilder { |
|
|
|
void SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, onnx::TensorProto *const tensor_proto); |
|
|
|
void SetAttributeProto(const AnfNodePtr &node, onnx::NodeProto *const node_proto); |
|
|
|
void SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto); |
|
|
|
void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, onnx::NodeProto *const node_proto, |
|
|
|
std::string suffix = "0"); |
|
|
|
void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, onnx::AttributeProto *const attr_proto, |
|
|
|
std::string *const seq_string); |
|
|
|
void SetValueToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); |
|
|
|
void SetTypeToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); |
|
|
|
void SetScalarToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); |
|
|
|
void SetTensorToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); |
|
|
|
void SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto); |
|
|
|
void SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto); |
|
|
|
void SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto, const std::string &value_name); |
|
|
|
void SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto, |
|
|
|
std::string *const seq_string); |
|
|
|
void SetSeqElemToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto, |
|
|
|
std::string *const seq_string); |
|
|
|
|
|
|
|
onnx::TensorProto_DataType GetOnnxDataType(TypeId type_id); |
|
|
|
onnx::TensorProto_DataType GetOnnxDataBitsIntType(int bits); |
|
|
|
@@ -105,8 +108,10 @@ class IrExportBuilder { |
|
|
|
std::string GetNodeName(const AnfNodePtr &node); |
|
|
|
std::string GetUniqueNodeName(const AnfNodePtr &node); |
|
|
|
std::string GetOpTypeName(const AnfNodePtr &node); |
|
|
|
size_t AllocateIndex() { return ++node_index_; } |
|
|
|
void ResetIndex() { node_index_ = 0; } |
|
|
|
size_t GetNodeIndex() { return ++node_index_; } |
|
|
|
void ResetNodeIndex() { node_index_ = 0; } |
|
|
|
size_t GetTupleIndex() { return ++shape_index_; } |
|
|
|
void ResetTupleIndex() { shape_index_ = 0; } |
|
|
|
|
|
|
|
private: |
|
|
|
onnx::ModelProto model_; |
|
|
|
@@ -114,6 +119,7 @@ class IrExportBuilder { |
|
|
|
std::list<FuncGraphPtr> todo_; |
|
|
|
std::map<AnfNodePtr, size_t> node_index_map_; |
|
|
|
size_t node_index_{0}; |
|
|
|
size_t shape_index_{0}; |
|
|
|
}; |
|
|
|
|
|
|
|
using IrExporterPtr = std::shared_ptr<IrExporter>; |
|
|
|
@@ -146,7 +152,7 @@ void IrExportBuilder::BuildModelInfo() { |
|
|
|
void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph) { |
|
|
|
onnx::GraphProto *graph_proto = model_.mutable_graph(); |
|
|
|
graph_proto->set_name(func_graph->ToString()); |
|
|
|
ResetIndex(); |
|
|
|
ResetNodeIndex(); |
|
|
|
todo_.clear(); |
|
|
|
todo_.push_back(func_graph); |
|
|
|
while (!todo_.empty()) { |
|
|
|
@@ -177,7 +183,7 @@ void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, onnx::Grap |
|
|
|
input_proto->set_name(param_name); |
|
|
|
SetValueInfoProto(param, input_proto); |
|
|
|
if (!param->has_default()) { |
|
|
|
MS_LOG(DEBUG) << "Parameter: '" << item->ToString() << "' has no default"; |
|
|
|
MS_LOG(DEBUG) << "Parameter: '" << item->ToString() << "' has no default."; |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -232,13 +238,20 @@ void IrExportBuilder::SetValueInfoProto(const TypePtr &type, const BaseShapePtr |
|
|
|
auto elem_type = tensor->element(); |
|
|
|
const auto &dims = shape->cast<abstract::ShapePtr>()->shape(); |
|
|
|
type_proto->mutable_tensor_type()->set_elem_type(GetOnnxDataType(elem_type->type_id())); |
|
|
|
for (const auto &dim : dims) { |
|
|
|
MS_LOG(DEBUG) << "SetValueInfoProto dim: " << dim; |
|
|
|
type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim); |
|
|
|
if (dims.size() == 0) { |
|
|
|
MS_LOG(DEBUG) << "SetValueInfoProto set default dim 1."; |
|
|
|
type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); |
|
|
|
} else { |
|
|
|
for (const auto &dim : dims) { |
|
|
|
MS_LOG(DEBUG) << "SetValueInfoProto dim: " << dim; |
|
|
|
type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim); |
|
|
|
} |
|
|
|
} |
|
|
|
} else if (type->isa<Tuple>()) { |
|
|
|
auto tup_shape = shape->cast<abstract::TupleShapePtr>(); |
|
|
|
type_proto->set_denotation(std::to_string(tup_shape->shape().size())); |
|
|
|
type_proto->set_denotation(type->type_name() + ":" + std::to_string(tup_shape->shape().size())); |
|
|
|
} else if (type->isa<Number>() || type->isa<String>()) { |
|
|
|
type_proto->set_denotation(type->type_name()); |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "Value type: " << type->type_name() << " is not supported!"; |
|
|
|
} |
|
|
|
@@ -248,9 +261,10 @@ void IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, onnx::Att |
|
|
|
if (value == nullptr || attr_proto == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; |
|
|
|
} |
|
|
|
attr_proto->set_ref_attr_name("tensor"); |
|
|
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); |
|
|
|
onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); |
|
|
|
attr_proto->set_ref_attr_name("tensor:value0"); |
|
|
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS); |
|
|
|
onnx::TensorProto *tensor_proto = attr_proto->add_tensors(); |
|
|
|
tensor_proto->set_name("value0"); |
|
|
|
auto data = value->cast<tensor::TensorPtr>(); |
|
|
|
tensor_proto->set_raw_data(data->data_c(), static_cast<size_t>(data->data().nbytes())); |
|
|
|
auto dtype = data->data_type(); |
|
|
|
@@ -284,6 +298,7 @@ void IrExportBuilder::SetParamToTensorProto(const ParameterPtr ¶m, onnx::Ten |
|
|
|
|
|
|
|
void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { |
|
|
|
std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude); |
|
|
|
bool is_only_return = true; |
|
|
|
for (const AnfNodePtr &node : nodes) { |
|
|
|
if (!node->isa<CNode>()) { |
|
|
|
MS_LOG(DEBUG) << "Node: '" << node->ToString() << "' is not cnode"; |
|
|
|
@@ -291,9 +306,13 @@ void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProt |
|
|
|
} |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
if (cnode == func_graph->get_return()) { |
|
|
|
if (is_only_return) { |
|
|
|
MS_LOG(EXCEPTION) << "Only has return node, can't convert to binary model!"; |
|
|
|
} |
|
|
|
BuildOutput(cnode, graph_proto); |
|
|
|
} else { |
|
|
|
BuildCNode(cnode, graph_proto); |
|
|
|
is_only_return = false; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -303,24 +322,11 @@ void IrExportBuilder::BuildOutput(const CNodePtr &node, onnx::GraphProto *const |
|
|
|
MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2."; |
|
|
|
} |
|
|
|
AnfNodePtr arg = node->input(1); |
|
|
|
// Using make_tuple to set multi-output |
|
|
|
if (IsPrimitiveCNode(arg, prim::kPrimMakeTuple)) { |
|
|
|
auto tuple_node = arg->cast<CNodePtr>(); |
|
|
|
for (size_t i = 1; i < tuple_node->size(); i++) { |
|
|
|
auto input_node = arg->cast<CNodePtr>()->input(i); |
|
|
|
onnx::ValueInfoProto *output_proto = graph_proto->add_output(); |
|
|
|
auto output_name = GetUniqueNodeName(tuple_node->input(i)); |
|
|
|
output_proto->set_name(output_name); |
|
|
|
last_node_->add_output(output_name); |
|
|
|
SetValueInfoProto(tuple_node->input(i), output_proto); |
|
|
|
} |
|
|
|
} else { |
|
|
|
onnx::ValueInfoProto *output_proto = graph_proto->add_output(); |
|
|
|
std::string output_name = GetUniqueNodeName(node); |
|
|
|
output_proto->set_name(output_name); |
|
|
|
last_node_->add_output(output_name); |
|
|
|
SetValueInfoProto(arg, output_proto); |
|
|
|
} |
|
|
|
onnx::ValueInfoProto *output_proto = graph_proto->add_output(); |
|
|
|
std::string output_name = GetUniqueNodeName(node); |
|
|
|
output_proto->set_name(output_name); |
|
|
|
last_node_->set_output(0, output_name); |
|
|
|
SetValueInfoProto(arg, output_proto); |
|
|
|
} |
|
|
|
|
|
|
|
std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) { |
|
|
|
@@ -343,45 +349,44 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) { |
|
|
|
} |
|
|
|
|
|
|
|
void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, |
|
|
|
onnx::NodeProto *const node_proto, std::string suffix) { |
|
|
|
onnx::AttributeProto *attr_proto = node_proto->add_attribute(); |
|
|
|
attr_proto->set_ref_attr_name("shape"); |
|
|
|
if (suffix.compare("0") != 0) { |
|
|
|
attr_proto->set_name("shape" + suffix); |
|
|
|
} else { |
|
|
|
attr_proto->set_name("shape"); |
|
|
|
} |
|
|
|
onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); |
|
|
|
SetTensorProto(type, shape, tensor_proto); |
|
|
|
} |
|
|
|
|
|
|
|
void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto) { |
|
|
|
// Get shape of cnode |
|
|
|
// 1. prim ArgMaxWithValue need to get shape from tuple element |
|
|
|
// 2. some cnode doesn't has shape, such as LayerNorm |
|
|
|
// 3. other cnodes have shape |
|
|
|
if (node->IsApply(prim::kPrimArgMaxWithValue) || node->IsApply(prim::kPrimLayerNorm)) { |
|
|
|
auto type = node->Type(); |
|
|
|
auto shape = node->Shape(); |
|
|
|
if (!type->isa<Tuple>()) { |
|
|
|
MS_LOG(EXCEPTION) << "Output data of ArgMaxWithValue cnode must be tuple: " << type->type_name(); |
|
|
|
} |
|
|
|
onnx::AttributeProto *const attr_proto, std::string *const seq_string) { |
|
|
|
if (type->isa<Tuple>() && seq_string != nullptr) { |
|
|
|
*seq_string += "Tuple["; |
|
|
|
auto elements = type->cast<TuplePtr>()->elements(); |
|
|
|
auto tuple_shape = shape->cast<abstract::TupleShapePtr>()->shape(); |
|
|
|
for (size_t i = 0; i < elements.size(); i++) { |
|
|
|
SetShapeToNodeProto(elements[i], tuple_shape[i], node_proto, std::to_string(i)); |
|
|
|
SetShapeToNodeProto(elements[i], tuple_shape[i], attr_proto, seq_string); |
|
|
|
} |
|
|
|
*seq_string += "],"; |
|
|
|
} else if (type->isa<TensorType>() && shape->isa<abstract::Shape>() && seq_string != nullptr) { |
|
|
|
string shape_name = "shape" + std::to_string(GetTupleIndex()); |
|
|
|
*seq_string += shape_name + ","; |
|
|
|
onnx::TensorProto *tensor_proto = attr_proto->add_tensors(); |
|
|
|
tensor_proto->set_name(shape_name); |
|
|
|
SetTensorProto(type, shape, tensor_proto); |
|
|
|
} else if ((type->isa<Number>() || type->isa<String>()) && seq_string != nullptr) { |
|
|
|
*seq_string += type->type_name() + ","; |
|
|
|
} else { |
|
|
|
auto type = node->Type(); |
|
|
|
auto shape = node->Shape(); |
|
|
|
if (!type->isa<TensorType>() || !shape->isa<abstract::Shape>()) { |
|
|
|
MS_LOG(DEBUG) << "Cnode has no shape: " << node->ToString(); |
|
|
|
return; |
|
|
|
} |
|
|
|
SetShapeToNodeProto(type, shape, node_proto); |
|
|
|
MS_LOG(EXCEPTION) << "Type of cnode need to be supported: " << type->type_name(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto) { |
|
|
|
// Get shape of cnode |
|
|
|
// 1. need to get shape from tuple element |
|
|
|
// 2. save shape in TensorProto |
|
|
|
// 3. save tuple string in ref_attr_name |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto type = node->Type(); |
|
|
|
auto shape = node->Shape(); |
|
|
|
ResetTupleIndex(); |
|
|
|
std::string seq_string = "shape:"; |
|
|
|
onnx::AttributeProto *attr_proto = node_proto->add_attribute(); |
|
|
|
SetShapeToNodeProto(type, shape, attr_proto, &seq_string); |
|
|
|
attr_proto->set_ref_attr_name(seq_string); |
|
|
|
MS_LOG(DEBUG) << "CNode shape: " << seq_string; |
|
|
|
} |
|
|
|
|
|
|
|
void IrExportBuilder::BuildCNode(const CNodePtr &node, onnx::GraphProto *const graph_proto) { |
|
|
|
auto inputs_size = node->size(); |
|
|
|
if (inputs_size < 1) { |
|
|
|
@@ -443,15 +448,19 @@ std::string IrExportBuilder::GetUniqueNodeName(const AnfNodePtr &node) { |
|
|
|
std::string node_name = ""; |
|
|
|
if (node->isa<Parameter>()) { |
|
|
|
node_name = GetNodeName(node); |
|
|
|
} else if (node->isa<CNode>() || node->isa<ValueNode>()) { |
|
|
|
} else if (node->isa<CNode>()) { |
|
|
|
auto iter = node_index_map_.find(node); |
|
|
|
if (iter != node_index_map_.end()) { |
|
|
|
node_name = GetNodeName(node) + ":" + std::to_string(iter->second); |
|
|
|
} else { |
|
|
|
auto node_idx = AllocateIndex(); |
|
|
|
auto node_idx = GetNodeIndex(); |
|
|
|
node_index_map_[node] = node_idx; |
|
|
|
node_name = GetNodeName(node) + ":" + std::to_string(node_idx); |
|
|
|
} |
|
|
|
} else if (node->isa<ValueNode>()) { |
|
|
|
auto node_idx = GetNodeIndex(); |
|
|
|
node_index_map_[node] = node_idx; |
|
|
|
node_name = GetNodeName(node) + ":" + std::to_string(node_idx); |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "Can not support type of node:" << node->ToString(); |
|
|
|
} |
|
|
|
@@ -485,17 +494,21 @@ void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, onnx::Attri |
|
|
|
if (value == nullptr || attr_proto == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; |
|
|
|
} |
|
|
|
attr_proto->set_ref_attr_name("type"); |
|
|
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); |
|
|
|
onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); |
|
|
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS); |
|
|
|
onnx::TensorProto *tensor_proto = attr_proto->add_tensors(); |
|
|
|
if (value->isa<Int>()) { |
|
|
|
attr_proto->set_ref_attr_name("type:value0"); |
|
|
|
tensor_proto->set_name("value0"); |
|
|
|
auto int_value = value->cast<IntPtr>(); |
|
|
|
tensor_proto->set_data_type(GetOnnxDataBitsIntType(int_value->nbits())); |
|
|
|
} else if (value->isa<Float>()) { |
|
|
|
attr_proto->set_ref_attr_name("type:value0"); |
|
|
|
tensor_proto->set_name("value0"); |
|
|
|
auto float_value = value->cast<FloatPtr>(); |
|
|
|
tensor_proto->set_data_type(GetOnnxDataBitsFloatType(float_value->nbits())); |
|
|
|
} else if (value->isa<TensorType>()) { |
|
|
|
tensor_proto->set_name("tensor"); |
|
|
|
attr_proto->set_ref_attr_name("type:tensor0"); |
|
|
|
tensor_proto->set_name("tensor0"); |
|
|
|
auto elem_type = value->cast<TensorTypePtr>()->element(); |
|
|
|
if (elem_type->isa<Int>()) { |
|
|
|
auto int_value = elem_type->cast<IntPtr>(); |
|
|
|
@@ -519,10 +532,18 @@ void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, onnx::Attr |
|
|
|
SetScalarToAttributeProto(value, attr_proto); |
|
|
|
} else if (value->isa<Number>() || value->isa<TensorType>()) { |
|
|
|
SetTypeToAttributeProto(value, attr_proto); |
|
|
|
} else if (value->isa<ValueSequeue>()) { |
|
|
|
SetSequenceToAttributeProto(value->cast<ValueSequeuePtr>(), attr_proto); |
|
|
|
} else if (value->isa<ValueSequeue>() || value->isa<ValueSequeue>()) { |
|
|
|
ResetTupleIndex(); |
|
|
|
std::string seq_string = "scalar:"; |
|
|
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS); |
|
|
|
SetSequenceToAttributeProto(value->cast<ValueSequeuePtr>(), attr_proto, &seq_string); |
|
|
|
attr_proto->set_ref_attr_name(seq_string); |
|
|
|
MS_LOG(DEBUG) << "Attr string: " << seq_string; |
|
|
|
} else if (value->isa<tensor::Tensor>()) { |
|
|
|
SetTensorToAttributeProto(value, attr_proto); |
|
|
|
} else if (value->isa<None>()) { |
|
|
|
attr_proto->set_ref_attr_name("none"); |
|
|
|
MS_LOG(DEBUG) << "Attr string: " << value->type_name(); |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name(); |
|
|
|
} |
|
|
|
@@ -532,16 +553,18 @@ void IrExportBuilder::SetScalarToAttributeProto(const ValuePtr &value, onnx::Att |
|
|
|
if (value == nullptr || attr_proto == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; |
|
|
|
} |
|
|
|
attr_proto->set_ref_attr_name("scalar"); |
|
|
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); |
|
|
|
onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); |
|
|
|
SetScalarToProto(value, tensor_proto); |
|
|
|
attr_proto->set_ref_attr_name("scalar:value0"); |
|
|
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSORS); |
|
|
|
onnx::TensorProto *tensor_proto = attr_proto->add_tensors(); |
|
|
|
SetScalarToProto(value, tensor_proto, "value0"); |
|
|
|
} |
|
|
|
|
|
|
|
void IrExportBuilder::SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto) { |
|
|
|
void IrExportBuilder::SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto, |
|
|
|
const std::string &value_name) { |
|
|
|
if (value == nullptr || tensor_proto == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "ValuePtr or TensorProto is null!"; |
|
|
|
} |
|
|
|
tensor_proto->set_name(value_name); |
|
|
|
if (value->isa<StringImm>()) { |
|
|
|
tensor_proto->set_data_type(onnx::TensorProto_DataType_STRING); |
|
|
|
tensor_proto->add_string_data(GetValue<std::string>(value)); |
|
|
|
@@ -560,44 +583,74 @@ void IrExportBuilder::SetScalarToProto(const ValuePtr &value, onnx::TensorProto |
|
|
|
} else if (value->isa<Int64Imm>()) { |
|
|
|
tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64); |
|
|
|
tensor_proto->add_int64_data(value->cast<Int64ImmPtr>()->value()); |
|
|
|
} else if (value->isa<FloatImm>()) { |
|
|
|
} else if (value->isa<UInt8Imm>()) { |
|
|
|
tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT8); |
|
|
|
tensor_proto->add_int32_data(value->cast<UInt8ImmPtr>()->value()); |
|
|
|
} else if (value->isa<UInt16Imm>()) { |
|
|
|
tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT16); |
|
|
|
tensor_proto->add_int32_data(value->cast<UInt16ImmPtr>()->value()); |
|
|
|
} else if (value->isa<UInt32Imm>()) { |
|
|
|
tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT32); |
|
|
|
tensor_proto->add_uint64_data(value->cast<UInt32ImmPtr>()->value()); |
|
|
|
} else if (value->isa<UInt64Imm>()) { |
|
|
|
tensor_proto->set_data_type(onnx::TensorProto_DataType_UINT64); |
|
|
|
tensor_proto->add_uint64_data(value->cast<UInt64ImmPtr>()->value()); |
|
|
|
} else if (value->isa<FP32Imm>()) { |
|
|
|
tensor_proto->set_data_type(onnx::TensorProto_DataType_FLOAT); |
|
|
|
tensor_proto->add_float_data(GetValue<float>(value)); |
|
|
|
} else if (value->isa<FP64Imm>()) { |
|
|
|
tensor_proto->set_data_type(onnx::TensorProto_DataType_DOUBLE); |
|
|
|
tensor_proto->add_double_data(GetValue<double>(value)); |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "Unsupported scalar type: " << value->type_name(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value, |
|
|
|
onnx::AttributeProto *const attr_proto) { |
|
|
|
void IrExportBuilder::SetSeqElemToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto, |
|
|
|
std::string *const seq_string) { |
|
|
|
string value_name = "value" + std::to_string(GetTupleIndex()); |
|
|
|
if (seq_string != nullptr) { |
|
|
|
*seq_string += value_name + ","; |
|
|
|
} |
|
|
|
onnx::TensorProto *tensor_proto = attr_proto->add_tensors(); |
|
|
|
SetScalarToProto(value, tensor_proto, value_name); |
|
|
|
} |
|
|
|
|
|
|
|
void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto, |
|
|
|
std::string *const seq_string) { |
|
|
|
if (value == nullptr || attr_proto == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "ValueSequeuePtr or AttributeProto is null!"; |
|
|
|
} |
|
|
|
attr_proto->set_ref_attr_name("scalar"); |
|
|
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); |
|
|
|
onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); |
|
|
|
if (value->isa<ValueTuple>()) { |
|
|
|
if (value->isa<ValueTuple>() && seq_string != nullptr) { |
|
|
|
*seq_string += "Tuple["; |
|
|
|
const ValueTuplePtr &tuple_value = value->cast<ValueTuplePtr>(); |
|
|
|
if (tuple_value->value().size() == 0) { |
|
|
|
MS_LOG(DEBUG) << "SetSequenceToAttributeProto tuple size is 0"; |
|
|
|
return; |
|
|
|
} |
|
|
|
auto type_id = tuple_value->value()[0]->type()->type_id(); |
|
|
|
tensor_proto->set_data_type(GetOnnxDataType(type_id)); |
|
|
|
for (const auto &item : tuple_value->value()) { |
|
|
|
SetScalarToProto(item, tensor_proto); |
|
|
|
if (item->isa<ValueTuple>()) { |
|
|
|
SetSequenceToAttributeProto(item->cast<ValueTuplePtr>(), attr_proto, seq_string); |
|
|
|
} else { |
|
|
|
SetSeqElemToAttributeProto(item, attr_proto, seq_string); |
|
|
|
} |
|
|
|
} |
|
|
|
} else if (value->isa<ValueList>()) { |
|
|
|
*seq_string += "],"; |
|
|
|
} else if (value->isa<ValueList>() && seq_string != nullptr) { |
|
|
|
*seq_string += "List["; |
|
|
|
const ValueListPtr &list_value = value->cast<ValueListPtr>(); |
|
|
|
if (list_value->value().size() == 0) { |
|
|
|
MS_LOG(DEBUG) << "SetSequenceToAttributeProto list size is 0"; |
|
|
|
MS_LOG(DEBUG) << "SetSequenceToAttributeProto list size is 0."; |
|
|
|
return; |
|
|
|
} |
|
|
|
auto type_id = list_value->value()[0]->type()->type_id(); |
|
|
|
tensor_proto->set_data_type(GetOnnxDataType(type_id)); |
|
|
|
for (const auto &item : list_value->value()) { |
|
|
|
SetScalarToProto(item, tensor_proto); |
|
|
|
if (item->isa<ValueList>()) { |
|
|
|
SetSequenceToAttributeProto(item->cast<ValueListPtr>(), attr_proto, seq_string); |
|
|
|
} else { |
|
|
|
SetSeqElemToAttributeProto(item, attr_proto, seq_string); |
|
|
|
} |
|
|
|
} |
|
|
|
*seq_string += "],"; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|