|
|
|
@@ -27,6 +27,7 @@ |
|
|
|
#include "base/core_ops.h" |
|
|
|
#include "proto/mind_ir.pb.h" |
|
|
|
#include "utils/check_convert_utils.h" |
|
|
|
#include "debug/dump_proto.h" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
using FloatPtr = std::shared_ptr<Float>; |
|
|
|
@@ -78,7 +79,7 @@ class IrExporter { |
|
|
|
explicit IrExporter(IrExportBuilderPtr builder) : builder_(builder) {} |
|
|
|
virtual ~IrExporter() = default; |
|
|
|
std::string GetDumpString(const FuncGraphPtr &func_graph); |
|
|
|
mind_ir::ModelProto GetDumpProto(const FuncGraphPtr &func_graph, bool save_tensor_data = false); |
|
|
|
ModelProtoPtr GetDumpProto(const FuncGraphPtr &func_graph, bool save_tensor_data = false); |
|
|
|
|
|
|
|
private: |
|
|
|
IrExportBuilderPtr builder_; |
|
|
|
@@ -86,39 +87,38 @@ class IrExporter { |
|
|
|
|
|
|
|
class IrExportBuilder { |
|
|
|
public: |
|
|
|
IrExportBuilder() = default; |
|
|
|
IrExportBuilder() : model_(std::make_shared<mind_ir::ModelProto>()) {} |
|
|
|
~IrExportBuilder() { google::protobuf::ShutdownProtobufLibrary(); } |
|
|
|
std::string GetProtoString(const FuncGraphPtr &func_graph); |
|
|
|
void BuildModelInfo(); |
|
|
|
void BuildModel(const FuncGraphPtr &func_graph, bool save_tensor_data = false); |
|
|
|
mind_ir::ModelProto Model() { return model_; } |
|
|
|
bool BuildModel(const FuncGraphPtr &func_graph, bool save_tensor_data = false); |
|
|
|
ModelProtoPtr Model() { return model_; } |
|
|
|
|
|
|
|
private: |
|
|
|
void BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto, |
|
|
|
bool BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto, |
|
|
|
bool save_tensor_data = false); |
|
|
|
void BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto, |
|
|
|
bool BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto, |
|
|
|
bool save_tensor_data = false); |
|
|
|
void BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto); |
|
|
|
void BuildOutput(const CNodePtr &node, mind_ir::GraphProto *const graph_proto); |
|
|
|
void BuildCNode(const CNodePtr &node, mind_ir::GraphProto *const graph_proto); |
|
|
|
bool BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto); |
|
|
|
bool BuildOutput(const CNodePtr &node, mind_ir::GraphProto *const graph_proto); |
|
|
|
bool BuildCNode(const CNodePtr &node, mind_ir::GraphProto *const graph_proto); |
|
|
|
std::string BuildInputNode(const AnfNodePtr &node, mind_ir::GraphProto *const graph_proto); |
|
|
|
|
|
|
|
void SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto); |
|
|
|
void SetValueInfoProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::ValueInfoProto *const value_proto); |
|
|
|
void SetParamToTensorProto(const ParameterPtr ¶m, mind_ir::TensorProto *const tensor_proto); |
|
|
|
void SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::TensorProto *const tensor_proto); |
|
|
|
void SetAttributeProto(const AnfNodePtr &node, mind_ir::NodeProto *const node_proto); |
|
|
|
void SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodeProto *const node_proto); |
|
|
|
void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::AttributeProto *const attr_proto, |
|
|
|
bool SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto); |
|
|
|
bool SetParamToTensorProto(const ParameterPtr ¶m, mind_ir::TensorProto *const tensor_proto); |
|
|
|
bool SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::TensorProto *const tensor_proto); |
|
|
|
bool SetAttributeProto(const AnfNodePtr &node, mind_ir::NodeProto *const node_proto); |
|
|
|
bool SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodeProto *const node_proto); |
|
|
|
bool SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::AttributeProto *const attr_proto, |
|
|
|
std::string *const seq_string); |
|
|
|
void SetValueToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); |
|
|
|
void SetTypeToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); |
|
|
|
void SetScalarToAttributeProto_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); |
|
|
|
void SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); |
|
|
|
void SetTensorToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); |
|
|
|
void SetSequenceToAttributeProto(const ValueSequeuePtr &value, mind_ir::AttributeProto *const attr_proto, |
|
|
|
bool SetValueToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); |
|
|
|
bool SetTypeToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); |
|
|
|
bool SetScalarToAttributeProto_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); |
|
|
|
bool SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); |
|
|
|
bool SetTensorToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); |
|
|
|
bool SetSequenceToAttributeProto(const ValueSequeuePtr &value, mind_ir::AttributeProto *const attr_proto, |
|
|
|
std::string *const seq_string); |
|
|
|
void SetSeqElemToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto, |
|
|
|
bool SetSeqElemToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto, |
|
|
|
std::string *const seq_string); |
|
|
|
|
|
|
|
mind_ir::TensorProto_DataType GetMindirDataType(TypeId type_id); |
|
|
|
@@ -134,7 +134,7 @@ class IrExportBuilder { |
|
|
|
void ResetTupleIndex() { shape_index_ = 0; } |
|
|
|
|
|
|
|
private: |
|
|
|
mind_ir::ModelProto model_; |
|
|
|
ModelProtoPtr model_; |
|
|
|
mind_ir::NodeProto *last_node_{nullptr}; |
|
|
|
std::list<FuncGraphPtr> todo_; |
|
|
|
std::map<AnfNodePtr, std::string> node_index_map_; |
|
|
|
@@ -147,11 +147,14 @@ class IrExportBuilder { |
|
|
|
using IrExporterPtr = std::shared_ptr<IrExporter>; |
|
|
|
|
|
|
|
std::string IrExporter::GetDumpString(const FuncGraphPtr &func_graph) { |
|
|
|
(void)GetDumpProto(func_graph); |
|
|
|
auto dump_proto = GetDumpProto(func_graph); |
|
|
|
if (dump_proto == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "Get dump proto for graph " << func_graph->ToString() << " failed."; |
|
|
|
} |
|
|
|
return builder_->GetProtoString(func_graph); |
|
|
|
} |
|
|
|
|
|
|
|
mind_ir::ModelProto IrExporter::GetDumpProto(const FuncGraphPtr &func_graph, bool save_tensor_data) { |
|
|
|
ModelProtoPtr IrExporter::GetDumpProto(const FuncGraphPtr &func_graph, bool save_tensor_data) { |
|
|
|
if ((builder_ == nullptr) || (func_graph == nullptr)) { |
|
|
|
MS_LOG(EXCEPTION) << "Input params is null."; |
|
|
|
} |
|
|
|
@@ -160,26 +163,28 @@ mind_ir::ModelProto IrExporter::GetDumpProto(const FuncGraphPtr &func_graph, boo |
|
|
|
builder_->BuildModelInfo(); |
|
|
|
|
|
|
|
// Export model and return string |
|
|
|
builder_->BuildModel(func_graph, save_tensor_data); |
|
|
|
if (!builder_->BuildModel(func_graph, save_tensor_data)) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
return builder_->Model(); |
|
|
|
} |
|
|
|
|
|
|
|
std::string IrExportBuilder::GetProtoString(const FuncGraphPtr &func_graph) { |
|
|
|
MS_LOG(DEBUG) << "BuildModel complete!"; |
|
|
|
return model_.SerializeAsString(); |
|
|
|
return model_->SerializeAsString(); |
|
|
|
} |
|
|
|
|
|
|
|
void IrExportBuilder::BuildModelInfo() { |
|
|
|
constexpr auto ir_version = "0.1.0"; |
|
|
|
constexpr auto mindspore_name = "MindSpore"; |
|
|
|
model_.set_ir_version(ir_version); |
|
|
|
model_.set_producer_name(mindspore_name); |
|
|
|
model_.set_model_version(VERSION); |
|
|
|
model_->set_ir_version(ir_version); |
|
|
|
model_->set_producer_name(mindspore_name); |
|
|
|
model_->set_model_version(VERSION); |
|
|
|
} |
|
|
|
|
|
|
|
void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tensor_data) { |
|
|
|
bool IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tensor_data) { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
mind_ir::GraphProto *graph_proto = model_.mutable_graph(); |
|
|
|
mind_ir::GraphProto *graph_proto = model_->mutable_graph(); |
|
|
|
graph_proto->set_name(func_graph->ToString()); |
|
|
|
graph_proto->set_bprop_hash(func_graph->bprop_hash()); |
|
|
|
ResetNodeIndex(); |
|
|
|
@@ -188,7 +193,11 @@ void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tenso |
|
|
|
// Build the main funcGraph |
|
|
|
nodeName_.insert(func_graph->ToString()); |
|
|
|
top_graph = true; |
|
|
|
BuildFuncGraph(func_graph, graph_proto, save_tensor_data); |
|
|
|
if (!BuildFuncGraph(func_graph, graph_proto, save_tensor_data)) { |
|
|
|
MS_LOG(ERROR) << "Build func_graph " << func_graph->ToString() << " failed."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
std::set<FuncGraphPtr> graphVisited; |
|
|
|
graphVisited.insert(func_graph); |
|
|
|
top_graph = false; |
|
|
|
@@ -199,32 +208,40 @@ void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tenso |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (nodeName_.count(fg->ToString()) > 0) { |
|
|
|
MS_LOG(EXCEPTION) << "There is a duplicate name: " << fg->ToString(); |
|
|
|
MS_LOG(ERROR) << "There is a duplicate name: " << fg->ToString(); |
|
|
|
return false; |
|
|
|
} |
|
|
|
nodeName_.insert(fg->ToString()); |
|
|
|
graphVisited.insert(fg); |
|
|
|
auto graph = model_.add_functions(); |
|
|
|
BuildFuncGraph(fg, graph, save_tensor_data); |
|
|
|
auto graph = model_->add_functions(); |
|
|
|
if (!BuildFuncGraph(fg, graph, save_tensor_data)) { |
|
|
|
MS_LOG(ERROR) << "Build func_graph " << fg->ToString() << " failed."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
// Release resource |
|
|
|
nodeName_.clear(); |
|
|
|
node_index_map_.clear(); |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto, |
|
|
|
bool IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto, |
|
|
|
bool save_tensor_data) { |
|
|
|
// Export funcGraph name. |
|
|
|
graph_proto->set_name(func_graph->ToString()); |
|
|
|
// Export parameters |
|
|
|
// 1. parameters should be mapped to ValueInfoProto |
|
|
|
// 2. parameters with default value should be mapped to Initializer |
|
|
|
BuildParameters(func_graph, graph_proto, save_tensor_data); |
|
|
|
if (!BuildParameters(func_graph, graph_proto, save_tensor_data)) { |
|
|
|
MS_LOG(ERROR) << "Build parameters failed."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
// Export operator nodes(include output) |
|
|
|
BuildNodes(func_graph, graph_proto); |
|
|
|
return BuildNodes(func_graph, graph_proto); |
|
|
|
} |
|
|
|
|
|
|
|
void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto, |
|
|
|
bool IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto, |
|
|
|
bool save_tensor_data) { |
|
|
|
MS_EXCEPTION_IF_NULL(func_graph); |
|
|
|
MS_EXCEPTION_IF_NULL(graph_proto); |
|
|
|
@@ -232,14 +249,18 @@ void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::G |
|
|
|
MS_EXCEPTION_IF_NULL(item); |
|
|
|
auto param = item->cast<ParameterPtr>(); |
|
|
|
if (param == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "Parameter: '" << item->ToString() << "' could not cast to parameter."; |
|
|
|
MS_LOG(ERROR) << "Parameter: '" << item->ToString() << "' could not cast to parameter."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
std::string param_name = GetUniqueNodeName(param); |
|
|
|
if (top_graph && param->has_default()) { |
|
|
|
MS_LOG(DEBUG) << "Parameter: '" << item->DebugString() << "' has default. address: " << (size_t)param.get(); |
|
|
|
mind_ir::TensorProto *parameter_proto = graph_proto->add_parameter(); |
|
|
|
parameter_proto->set_name(param_name); |
|
|
|
SetParamToTensorProto(param, parameter_proto); |
|
|
|
if (!SetParamToTensorProto(param, parameter_proto)) { |
|
|
|
MS_LOG(ERROR) << "Set parameter " << param->DebugString() << " to TensorProto failed."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param->default_param()); |
|
|
|
if (tensor && save_tensor_data) { |
|
|
|
parameter_proto->set_raw_data(tensor->data_c(), tensor->data().nbytes()); |
|
|
|
@@ -247,19 +268,25 @@ void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::G |
|
|
|
} else { |
|
|
|
mind_ir::ValueInfoProto *input_proto = graph_proto->add_input(); |
|
|
|
input_proto->set_name(param_name); |
|
|
|
SetValueInfoProto(param, input_proto); |
|
|
|
if (!SetValueInfoProto(param, input_proto)) { |
|
|
|
MS_LOG(ERROR) << "Set parameter " << param->DebugString() << " to TensorProto failed."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
if (nodeName_.count(param_name) > 0) { |
|
|
|
MS_LOG(EXCEPTION) << "parameter name is duplicate:" << param_name; |
|
|
|
MS_LOG(ERROR) << "parameter name is duplicate:" << param_name; |
|
|
|
return false; |
|
|
|
} |
|
|
|
nodeName_.insert(param_name); |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataType(TypeId type_id) { |
|
|
|
auto iter = g_data_type_map.find(type_id); |
|
|
|
if (iter == g_data_type_map.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "Convert type error, unsupported type! " << type_id; |
|
|
|
MS_LOG(ERROR) << "Convert type error, unsupported type! " << type_id; |
|
|
|
return mind_ir::TensorProto_DataType_UNDEFINED; |
|
|
|
} |
|
|
|
return iter->second; |
|
|
|
} |
|
|
|
@@ -267,7 +294,8 @@ mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataType(TypeId type_id) |
|
|
|
mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsIntType(int bits) { |
|
|
|
auto iter = g_data_bits_int_map.find(bits); |
|
|
|
if (iter == g_data_bits_int_map.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "Convert bits int error, unsupported bits! " << bits; |
|
|
|
MS_LOG(ERROR) << "Convert bits int error, unsupported bits! " << bits; |
|
|
|
return mind_ir::TensorProto_DataType_UNDEFINED; |
|
|
|
} |
|
|
|
return iter->second; |
|
|
|
} |
|
|
|
@@ -275,7 +303,8 @@ mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsIntType(int bits |
|
|
|
mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsUIntType(int bits) { |
|
|
|
auto iter = g_data_bits_uint_map.find(bits); |
|
|
|
if (iter == g_data_bits_uint_map.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "Convert bits uint error, unsupported bits! " << bits; |
|
|
|
MS_LOG(ERROR) << "Convert bits uint error, unsupported bits! " << bits; |
|
|
|
return mind_ir::TensorProto_DataType_UNDEFINED; |
|
|
|
} |
|
|
|
return iter->second; |
|
|
|
} |
|
|
|
@@ -283,20 +312,22 @@ mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsUIntType(int bit |
|
|
|
mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsFloatType(int bits) { |
|
|
|
auto iter = g_data_bits_float_map.find(bits); |
|
|
|
if (iter == g_data_bits_float_map.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "Convert bits float error, unsupported bits! " << bits; |
|
|
|
MS_LOG(ERROR) << "Convert bits float error, unsupported bits! " << bits; |
|
|
|
return mind_ir::TensorProto_DataType_UNDEFINED; |
|
|
|
} |
|
|
|
return iter->second; |
|
|
|
} |
|
|
|
|
|
|
|
void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto) { |
|
|
|
bool IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto) { |
|
|
|
if (node == nullptr || value_proto == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "AnfNode or ValueInfo is null!"; |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "SetValueInfoProto: " << node->DebugString(); |
|
|
|
const TypePtr &type = node->Type(); |
|
|
|
const BaseShapePtr &shape = node->Shape(); |
|
|
|
// For the bprop fg which has not been renormalized. |
|
|
|
if (type == nullptr || shape == nullptr) { |
|
|
|
return; |
|
|
|
return true; |
|
|
|
} |
|
|
|
if (type->isa<TensorType>() && shape->isa<abstract::Shape>()) { |
|
|
|
auto tensor = type->cast<TensorTypePtr>(); |
|
|
|
@@ -304,7 +335,11 @@ void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueIn |
|
|
|
auto elem_type = tensor->element(); |
|
|
|
const auto &dims = shape->cast<abstract::ShapePtr>()->shape(); |
|
|
|
mind_ir::TensorProto *tensor_proto = value_proto->add_tensor(); |
|
|
|
tensor_proto->set_data_type(GetMindirDataType(elem_type->type_id())); |
|
|
|
auto data_type = GetMindirDataType(elem_type->type_id()); |
|
|
|
if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
tensor_proto->set_data_type(data_type); |
|
|
|
if (dims.size() == 0) { |
|
|
|
MS_LOG(DEBUG) << "The dim of ValueInfoProto is 0."; |
|
|
|
} else { |
|
|
|
@@ -320,9 +355,10 @@ void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueIn |
|
|
|
value_proto->set_denotation(type->type_name()); |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "Value type: " << type->type_name(); |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
void IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) { |
|
|
|
bool IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) { |
|
|
|
if (value == nullptr || attr_proto == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; |
|
|
|
} |
|
|
|
@@ -335,34 +371,45 @@ void IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, mind_ir:: |
|
|
|
tensor_proto->set_raw_data(data->data_c(), static_cast<size_t>(data->data().nbytes())); |
|
|
|
auto dtype = data->data_type(); |
|
|
|
auto shape = data->shape_c(); |
|
|
|
tensor_proto->set_data_type(GetMindirDataType(dtype)); |
|
|
|
auto data_type = GetMindirDataType(dtype); |
|
|
|
if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
tensor_proto->set_data_type(data_type); |
|
|
|
for (const auto &dim : shape) { |
|
|
|
tensor_proto->add_dims(dim); |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
void IrExportBuilder::SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, |
|
|
|
bool IrExportBuilder::SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, |
|
|
|
mind_ir::TensorProto *const tensor_proto) { |
|
|
|
if (!type->isa<TensorType>() || !shape->isa<abstract::Shape>()) { |
|
|
|
MS_LOG(EXCEPTION) << "Type or shape is not supported! " << type->ToString(); |
|
|
|
MS_LOG(ERROR) << "Type or shape is not supported! " << type->ToString(); |
|
|
|
return false; |
|
|
|
} |
|
|
|
auto tensor = type->cast<TensorTypePtr>(); |
|
|
|
const auto &dims = shape->cast<abstract::ShapePtr>()->shape(); |
|
|
|
tensor_proto->set_data_type(GetMindirDataType(tensor->element()->type_id())); |
|
|
|
auto data_type = GetMindirDataType(tensor->element()->type_id()); |
|
|
|
if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
tensor_proto->set_data_type(data_type); |
|
|
|
for (const auto &dim : dims) { |
|
|
|
tensor_proto->add_dims(dim); |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
void IrExportBuilder::SetParamToTensorProto(const ParameterPtr ¶m, mind_ir::TensorProto *const tensor_proto) { |
|
|
|
bool IrExportBuilder::SetParamToTensorProto(const ParameterPtr ¶m, mind_ir::TensorProto *const tensor_proto) { |
|
|
|
if (param == nullptr || tensor_proto == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "Parameter or TensorProto is null!"; |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "SetParamToTensorProto: " << param->DebugString(); |
|
|
|
SetTensorProto(param->Type(), param->Shape(), tensor_proto); |
|
|
|
return SetTensorProto(param->Type(), param->Shape(), tensor_proto); |
|
|
|
} |
|
|
|
|
|
|
|
void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) { |
|
|
|
bool IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) { |
|
|
|
std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude); |
|
|
|
for (const AnfNodePtr &node : nodes) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
@@ -372,24 +419,36 @@ void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphP |
|
|
|
} |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
if (cnode == func_graph->get_return()) { |
|
|
|
BuildOutput(cnode, graph_proto); |
|
|
|
if (!BuildOutput(cnode, graph_proto)) { |
|
|
|
MS_LOG(ERROR) << "Build output for graph " << func_graph->ToString() << " failed."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
} else { |
|
|
|
BuildCNode(cnode, graph_proto); |
|
|
|
if (!BuildCNode(cnode, graph_proto)) { |
|
|
|
MS_LOG(ERROR) << "Build proto for cnode " << cnode->DebugString() << " failed."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
void IrExportBuilder::BuildOutput(const CNodePtr &node, mind_ir::GraphProto *const graph_proto) { |
|
|
|
bool IrExportBuilder::BuildOutput(const CNodePtr &node, mind_ir::GraphProto *const graph_proto) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
const int OutputSize = 2; |
|
|
|
if (node->size() != OutputSize) { |
|
|
|
MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2."; |
|
|
|
MS_LOG(ERROR) << "Number of inputs of return node is not equal to 2."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
AnfNodePtr arg = node->input(1); |
|
|
|
std::string node_name = BuildInputNode(arg, graph_proto); |
|
|
|
if (node_name.empty()) { |
|
|
|
MS_LOG(ERROR) << "Build input node failed for arg " << arg->DebugString(); |
|
|
|
return false; |
|
|
|
} |
|
|
|
mind_ir::ValueInfoProto *output_proto = graph_proto->add_output(); |
|
|
|
output_proto->set_name(node_name); |
|
|
|
SetValueInfoProto(arg, output_proto); |
|
|
|
return SetValueInfoProto(arg, output_proto); |
|
|
|
} |
|
|
|
|
|
|
|
std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) { |
|
|
|
@@ -408,16 +467,18 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) { |
|
|
|
auto nodeName = GetUniqueNodeName(node); |
|
|
|
type_name = "REF::" + nodeName; |
|
|
|
if (nodeName_.count(nodeName) == 0) { |
|
|
|
MS_LOG(EXCEPTION) << "There is not the name: " << nodeName; |
|
|
|
MS_LOG(ERROR) << "There is not the name: " << nodeName; |
|
|
|
return ""; |
|
|
|
} |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "Need to support op type: " << node->type_name(); |
|
|
|
MS_LOG(ERROR) << "Need to support op type: " << node->type_name(); |
|
|
|
return ""; |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "ExportType: " << type_name; |
|
|
|
return type_name; |
|
|
|
} |
|
|
|
|
|
|
|
void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, |
|
|
|
bool IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, |
|
|
|
mind_ir::AttributeProto *const attr_proto, std::string *const seq_string) { |
|
|
|
MS_EXCEPTION_IF_NULL(type); |
|
|
|
MS_EXCEPTION_IF_NULL(shape); |
|
|
|
@@ -427,7 +488,9 @@ void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePt |
|
|
|
auto elements = type->cast<TuplePtr>()->elements(); |
|
|
|
auto tuple_shape = shape->cast<abstract::TupleShapePtr>()->shape(); |
|
|
|
for (size_t i = 0; i < elements.size(); i++) { |
|
|
|
SetShapeToNodeProto(elements[i], tuple_shape[i], attr_proto, seq_string); |
|
|
|
if (!SetShapeToNodeProto(elements[i], tuple_shape[i], attr_proto, seq_string)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
*seq_string += "],"; |
|
|
|
} else if (type->isa<TensorType>() && shape->isa<abstract::Shape>()) { |
|
|
|
@@ -435,7 +498,7 @@ void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePt |
|
|
|
*seq_string += shape_name + ","; |
|
|
|
mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors(); |
|
|
|
tensor_proto->set_name(shape_name); |
|
|
|
SetTensorProto(type, shape, tensor_proto); |
|
|
|
return SetTensorProto(type, shape, tensor_proto); |
|
|
|
} else if (type->isa<Number>()) { |
|
|
|
if (type->isa<Bool>()) { |
|
|
|
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL); |
|
|
|
@@ -453,11 +516,13 @@ void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePt |
|
|
|
} else if (type->isa<String>() || type->isa<UMonadType>() || type->isa<IOMonadType>()) { |
|
|
|
*seq_string += type->type_name() + ","; |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "Type of cnode need to be supported: " << type->type_name(); |
|
|
|
MS_LOG(ERROR) << "Type of cnode need to be supported: " << type->type_name(); |
|
|
|
return false; |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodeProto *const node_proto) { |
|
|
|
bool IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodeProto *const node_proto) { |
|
|
|
// Get shape of cnode |
|
|
|
// 1. need to get shape from tuple element |
|
|
|
// 2. save shape in TensorProto |
|
|
|
@@ -465,21 +530,27 @@ void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodePro |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto type = node->Type(); |
|
|
|
auto shape = node->Shape(); |
|
|
|
// For the bprop fg which has not been renormalized. |
|
|
|
if (type == nullptr || shape == nullptr) { |
|
|
|
return; |
|
|
|
return true; |
|
|
|
} |
|
|
|
ResetTupleIndex(); |
|
|
|
std::string seq_string = "shape:"; |
|
|
|
mind_ir::AttributeProto *attr_proto = node_proto->add_attribute(); |
|
|
|
SetShapeToNodeProto(type, shape, attr_proto, &seq_string); |
|
|
|
if (!SetShapeToNodeProto(type, shape, attr_proto, &seq_string)) { |
|
|
|
MS_LOG(ERROR) << "Set shape to NodeProto for " << node->DebugString() << " failed."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
attr_proto->set_ref_attr_name(seq_string); |
|
|
|
MS_LOG(DEBUG) << "CNode shape: " << seq_string; |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *const graph_proto) { |
|
|
|
bool IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *const graph_proto) { |
|
|
|
auto inputs_size = node->size(); |
|
|
|
if (inputs_size < 1) { |
|
|
|
MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; |
|
|
|
MS_LOG(ERROR) << "Inputs of node " << node->DebugString() << " is empty"; |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
// Need to build input node before dealing with cnode |
|
|
|
@@ -488,7 +559,12 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *cons |
|
|
|
for (size_t i = 1; i < inputs_size; i++) { |
|
|
|
auto input = node->input(i); |
|
|
|
op_inputs.push_back(input); |
|
|
|
input_names.push_back(BuildInputNode(input, graph_proto)); |
|
|
|
std::string node_name = BuildInputNode(input, graph_proto); |
|
|
|
if (node_name.empty()) { |
|
|
|
MS_LOG(ERROR) << "Build input node for " << input->DebugString() << " failed."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
input_names.push_back(node_name); |
|
|
|
} |
|
|
|
|
|
|
|
// Build cnode |
|
|
|
@@ -503,10 +579,16 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *cons |
|
|
|
node_proto->set_domain(node->fullname_with_scope()); |
|
|
|
AnfNodePtr op = node->input(0); |
|
|
|
std::string type_name = GetOpTypeName(op); |
|
|
|
if (type_name.empty()) { |
|
|
|
MS_LOG(ERROR) << "Get op type name for " << op->DebugString() << " failed."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
node_proto->set_op_type(type_name); |
|
|
|
last_node_ = node_proto; |
|
|
|
// Maybe Tensor or Function or nullptr |
|
|
|
SetShapeToNodeProto(node, node_proto); |
|
|
|
if (!SetShapeToNodeProto(node, node_proto)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
(void)std::for_each(input_names.begin(), input_names.end(), |
|
|
|
[&node_proto](const string &name) { node_proto->add_input(name); }); |
|
|
|
@@ -520,9 +602,13 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *cons |
|
|
|
attr_proto->set_name(attr.first); |
|
|
|
auto attr_value = attr.second; |
|
|
|
CheckAndConvertUtils::ConvertAttrValueInExport(type_name, attr.first, &attr_value); |
|
|
|
SetValueToAttributeProto(attr_value, attr_proto); |
|
|
|
if (!SetValueToAttributeProto(attr_value, attr_proto)) { |
|
|
|
MS_LOG(ERROR) << "Set value to AttributeProto failed."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
std::string IrExportBuilder::BuildInputNode(const AnfNodePtr &node, mind_ir::GraphProto *const graph_proto) { |
|
|
|
@@ -538,7 +624,9 @@ std::string IrExportBuilder::BuildInputNode(const AnfNodePtr &node, mind_ir::Gra |
|
|
|
mind_ir::NodeProto *node_proto = graph_proto->add_node(); |
|
|
|
node_proto->set_name(node_name); |
|
|
|
node_proto->add_output(node_name); |
|
|
|
SetAttributeProto(node, node_proto); |
|
|
|
if (!SetAttributeProto(node, node_proto)) { |
|
|
|
return ""; |
|
|
|
} |
|
|
|
} |
|
|
|
return node_name; |
|
|
|
} |
|
|
|
@@ -581,7 +669,7 @@ std::string IrExportBuilder::GetNodeName(const AnfNodePtr &node) { |
|
|
|
return node_name; |
|
|
|
} |
|
|
|
|
|
|
|
void IrExportBuilder::SetAttributeProto(const AnfNodePtr &node, mind_ir::NodeProto *const node_proto) { |
|
|
|
bool IrExportBuilder::SetAttributeProto(const AnfNodePtr &node, mind_ir::NodeProto *const node_proto) { |
|
|
|
if (node == nullptr || node_proto == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "AnfNode or NodeProto is null!"; |
|
|
|
} |
|
|
|
@@ -592,10 +680,10 @@ void IrExportBuilder::SetAttributeProto(const AnfNodePtr &node, mind_ir::NodePro |
|
|
|
mind_ir::AttributeProto *attr_proto = node_proto->add_attribute(); |
|
|
|
attr_proto->set_name("value"); |
|
|
|
MS_LOG(DEBUG) << "Set Constant attribute: " << value->ToString(); |
|
|
|
SetValueToAttributeProto(value, attr_proto); |
|
|
|
return SetValueToAttributeProto(value, attr_proto); |
|
|
|
} |
|
|
|
|
|
|
|
void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) { |
|
|
|
bool IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) { |
|
|
|
if (value == nullptr || attr_proto == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; |
|
|
|
} |
|
|
|
@@ -605,17 +693,29 @@ void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, mind_ir::At |
|
|
|
attr_proto->set_ref_attr_name("type:value0"); |
|
|
|
tensor_proto->set_name("value0"); |
|
|
|
auto int_value = value->cast<IntPtr>(); |
|
|
|
tensor_proto->set_data_type(GetMindirDataBitsIntType(int_value->nbits())); |
|
|
|
auto data_type = GetMindirDataBitsIntType(int_value->nbits()); |
|
|
|
if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
tensor_proto->set_data_type(data_type); |
|
|
|
} else if (value->isa<UInt>()) { |
|
|
|
attr_proto->set_ref_attr_name("type:value0"); |
|
|
|
tensor_proto->set_name("value0"); |
|
|
|
auto float_value = value->cast<UIntPtr>(); |
|
|
|
tensor_proto->set_data_type(GetMindirDataBitsUIntType(float_value->nbits())); |
|
|
|
auto data_type = GetMindirDataBitsUIntType(float_value->nbits()); |
|
|
|
if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
tensor_proto->set_data_type(data_type); |
|
|
|
} else if (value->isa<Float>()) { |
|
|
|
attr_proto->set_ref_attr_name("type:value0"); |
|
|
|
tensor_proto->set_name("value0"); |
|
|
|
auto float_value = value->cast<FloatPtr>(); |
|
|
|
tensor_proto->set_data_type(GetMindirDataBitsFloatType(float_value->nbits())); |
|
|
|
auto data_type = GetMindirDataBitsFloatType(float_value->nbits()); |
|
|
|
if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
tensor_proto->set_data_type(data_type); |
|
|
|
} else if (value->isa<Bool>()) { |
|
|
|
attr_proto->set_ref_attr_name("type:value0"); |
|
|
|
tensor_proto->set_name("value0"); |
|
|
|
@@ -626,35 +726,49 @@ void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, mind_ir::At |
|
|
|
auto elem_type = value->cast<TensorTypePtr>()->element(); |
|
|
|
if (elem_type->isa<Int>()) { |
|
|
|
auto int_value = elem_type->cast<IntPtr>(); |
|
|
|
tensor_proto->set_data_type(GetMindirDataBitsIntType(int_value->nbits())); |
|
|
|
auto data_type = GetMindirDataBitsIntType(int_value->nbits()); |
|
|
|
if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
tensor_proto->set_data_type(data_type); |
|
|
|
} else if (elem_type->isa<Float>()) { |
|
|
|
auto float_value = elem_type->cast<FloatPtr>(); |
|
|
|
tensor_proto->set_data_type(GetMindirDataBitsFloatType(float_value->nbits())); |
|
|
|
auto data_type = GetMindirDataBitsFloatType(float_value->nbits()); |
|
|
|
if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
tensor_proto->set_data_type(data_type); |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "Unsupported type " << elem_type->type_name(); |
|
|
|
MS_LOG(ERROR) << "Unsupported type " << elem_type->type_name(); |
|
|
|
return false; |
|
|
|
} |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name(); |
|
|
|
return false; |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) { |
|
|
|
bool IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) { |
|
|
|
if (value == nullptr || attr_proto == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; |
|
|
|
} |
|
|
|
if (value->isa<StringImm>() || value->isa<Scalar>()) { |
|
|
|
SetScalarToAttributeProto_ir(value, attr_proto); |
|
|
|
return SetScalarToAttributeProto_ir(value, attr_proto); |
|
|
|
} else if (value->isa<Number>() || value->isa<TensorType>()) { |
|
|
|
SetTypeToAttributeProto(value, attr_proto); |
|
|
|
return SetTypeToAttributeProto(value, attr_proto); |
|
|
|
} else if (value->isa<ValueSequeue>()) { |
|
|
|
ResetTupleIndex(); |
|
|
|
std::string seq_string = "scalar:"; |
|
|
|
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS); |
|
|
|
SetSequenceToAttributeProto(value->cast<ValueSequeuePtr>(), attr_proto, &seq_string); |
|
|
|
if (!SetSequenceToAttributeProto(value->cast<ValueSequeuePtr>(), attr_proto, &seq_string)) { |
|
|
|
MS_LOG(ERROR) << "Set sequence to AttributeProto failed."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
attr_proto->set_ref_attr_name(seq_string); |
|
|
|
MS_LOG(DEBUG) << "Attr string: " << seq_string; |
|
|
|
} else if (value->isa<tensor::Tensor>()) { |
|
|
|
SetTensorToAttributeProto(value, attr_proto); |
|
|
|
return SetTensorToAttributeProto(value, attr_proto); |
|
|
|
} else if (value->isa<None>()) { |
|
|
|
attr_proto->set_ref_attr_name("none"); |
|
|
|
MS_LOG(DEBUG) << "Attr string: " << value->type_name(); |
|
|
|
@@ -664,14 +778,17 @@ void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, mind_ir::A |
|
|
|
} else if (value->isa<IOMonad>()) { |
|
|
|
attr_proto->set_ref_attr_name("Monad:IOMonad"); |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "Unsupported Monad type: " << value->type_name(); |
|
|
|
MS_LOG(ERROR) << "Unsupported Monad type: " << value->type_name(); |
|
|
|
return false; |
|
|
|
} |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name(); |
|
|
|
MS_LOG(ERROR) << "Unsupported type: " << value->type_name(); |
|
|
|
return false; |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
void IrExportBuilder::SetScalarToAttributeProto_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) { |
|
|
|
bool IrExportBuilder::SetScalarToAttributeProto_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) { |
|
|
|
if (value == nullptr || attr_proto == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; |
|
|
|
} |
|
|
|
@@ -714,13 +831,15 @@ void IrExportBuilder::SetScalarToAttributeProto_ir(const ValuePtr &value, mind_i |
|
|
|
attr_proto->set_d(GetValue<double>(value)); |
|
|
|
} else if (value->isa<tensor::Tensor>()) { |
|
|
|
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSOR); |
|
|
|
SetTensorToAttributeProto(value, attr_proto); |
|
|
|
return SetTensorToAttributeProto(value, attr_proto); |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "Unsupported scalar type: " << value->type_name(); |
|
|
|
MS_LOG(ERROR) << "Unsupported scalar type: " << value->type_name(); |
|
|
|
return false; |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
void IrExportBuilder::SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) { |
|
|
|
bool IrExportBuilder::SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) { |
|
|
|
if (value == nullptr || attr_proto == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; |
|
|
|
} |
|
|
|
@@ -728,12 +847,20 @@ void IrExportBuilder::SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ |
|
|
|
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS); |
|
|
|
mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors(); |
|
|
|
auto int_value = value->cast<IntPtr>(); |
|
|
|
tensor_proto->set_data_type(GetMindirDataBitsIntType(int_value->nbits())); |
|
|
|
auto data_type = GetMindirDataBitsIntType(int_value->nbits()); |
|
|
|
if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
tensor_proto->set_data_type(data_type); |
|
|
|
} else if (value->isa<Float>()) { |
|
|
|
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS); |
|
|
|
mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors(); |
|
|
|
auto float_value = value->cast<FloatPtr>(); |
|
|
|
tensor_proto->set_data_type(GetMindirDataBitsFloatType(float_value->nbits())); |
|
|
|
auto data_type = GetMindirDataBitsFloatType(float_value->nbits()); |
|
|
|
if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
tensor_proto->set_data_type(data_type); |
|
|
|
} else if (value->isa<StringImm>()) { |
|
|
|
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_STRING); |
|
|
|
attr_proto->add_strings(GetValue<std::string>(value)); |
|
|
|
@@ -772,22 +899,24 @@ void IrExportBuilder::SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ |
|
|
|
attr_proto->add_doubles(GetValue<double>(value)); |
|
|
|
} else if (value->isa<tensor::Tensor>()) { |
|
|
|
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSOR); |
|
|
|
SetTensorToAttributeProto(value, attr_proto); |
|
|
|
return SetTensorToAttributeProto(value, attr_proto); |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "Unsupported scalar type: " << value->type_name(); |
|
|
|
MS_LOG(ERROR) << "Unsupported scalar type: " << value->type_name(); |
|
|
|
return false; |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
void IrExportBuilder::SetSeqElemToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto, |
|
|
|
bool IrExportBuilder::SetSeqElemToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto, |
|
|
|
std::string *const seq_string) { |
|
|
|
string value_name = "value" + std::to_string(GetTupleIndex()); |
|
|
|
if (seq_string != nullptr) { |
|
|
|
*seq_string += value_name + ","; |
|
|
|
} |
|
|
|
SetScalarToAttributeProto_irs(value, attr_proto); |
|
|
|
return SetScalarToAttributeProto_irs(value, attr_proto); |
|
|
|
} |
|
|
|
|
|
|
|
void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value, |
|
|
|
bool IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value, |
|
|
|
mind_ir::AttributeProto *const attr_proto, |
|
|
|
std::string *const seq_string) { |
|
|
|
if (value == nullptr || attr_proto == nullptr) { |
|
|
|
@@ -799,13 +928,19 @@ void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value, |
|
|
|
if (tuple_value->value().size() == 0) { |
|
|
|
*seq_string += "],"; |
|
|
|
MS_LOG(DEBUG) << "SetSequenceToAttributeProto tuple size is 0"; |
|
|
|
return; |
|
|
|
return true; |
|
|
|
} |
|
|
|
for (const auto &item : tuple_value->value()) { |
|
|
|
if (item->isa<ValueTuple>()) { |
|
|
|
SetSequenceToAttributeProto(item->cast<ValueTuplePtr>(), attr_proto, seq_string); |
|
|
|
if (!SetSequenceToAttributeProto(item->cast<ValueTuplePtr>(), attr_proto, seq_string)) { |
|
|
|
MS_LOG(ERROR) << "Set sequence to AttributeProto failed."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
} else { |
|
|
|
SetSeqElemToAttributeProto(item, attr_proto, seq_string); |
|
|
|
if (!SetSeqElemToAttributeProto(item, attr_proto, seq_string)) { |
|
|
|
MS_LOG(ERROR) << "Set seq elem to AttributeProto failed."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
*seq_string += "],"; |
|
|
|
@@ -815,18 +950,25 @@ void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value, |
|
|
|
if (list_value->value().size() == 0) { |
|
|
|
*seq_string += "],"; |
|
|
|
MS_LOG(DEBUG) << "SetSequenceToAttributeProto list size is 0."; |
|
|
|
return; |
|
|
|
return true; |
|
|
|
} |
|
|
|
for (const auto &item : list_value->value()) { |
|
|
|
MS_EXCEPTION_IF_NULL(item); |
|
|
|
if (item->isa<ValueList>()) { |
|
|
|
SetSequenceToAttributeProto(item->cast<ValueListPtr>(), attr_proto, seq_string); |
|
|
|
if (!SetSequenceToAttributeProto(item->cast<ValueListPtr>(), attr_proto, seq_string)) { |
|
|
|
MS_LOG(ERROR) << "Set sequence to AttributeProto failed."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
} else { |
|
|
|
SetSeqElemToAttributeProto(item, attr_proto, seq_string); |
|
|
|
if (!SetSeqElemToAttributeProto(item, attr_proto, seq_string)) { |
|
|
|
MS_LOG(ERROR) << "Set seq elem to AttributeProto failed."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
*seq_string += "],"; |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
std::string GetBinaryProtoString(const FuncGraphPtr &func_graph) { |
|
|
|
@@ -842,7 +984,7 @@ std::string GetBinaryProtoString(const FuncGraphPtr &func_graph) { |
|
|
|
return exporter->GetDumpString(func_graph); |
|
|
|
} |
|
|
|
|
|
|
|
mind_ir::ModelProto GetBinaryProto(const FuncGraphPtr &func_graph, bool save_tensor_data) { |
|
|
|
ModelProtoPtr GetBinaryProto(const FuncGraphPtr &func_graph, bool save_tensor_data) { |
|
|
|
auto exporter = std::make_shared<IrExporter>(std::make_shared<IrExportBuilder>()); |
|
|
|
auto result = exporter->GetDumpProto(func_graph, save_tensor_data); |
|
|
|
return result; |
|
|
|
|