|
|
@@ -37,6 +37,7 @@ |
|
|
#include "src/param_value_lite.h" |
|
|
#include "src/param_value_lite.h" |
|
|
#include "tools/converter/parser/onnx/onnx.pb.h" |
|
|
#include "tools/converter/parser/onnx/onnx.pb.h" |
|
|
#include "utils/log_adapter.h" |
|
|
#include "utils/log_adapter.h" |
|
|
|
|
|
#include "securec/include/securec.h" |
|
|
|
|
|
|
|
|
using string = std::string; |
|
|
using string = std::string; |
|
|
using int32 = int32_t; |
|
|
using int32 = int32_t; |
|
|
@@ -57,24 +58,16 @@ enum ParseForm : int { |
|
|
}; |
|
|
}; |
|
|
|
|
|
|
|
|
static std::map<std::string, ParseForm> kParseTypeSwitchMap{ |
|
|
static std::map<std::string, ParseForm> kParseTypeSwitchMap{ |
|
|
{"type", FORM_PARSE_TYPE}, |
|
|
|
|
|
{"scalar", FORM_PARSE_SCALAR}, |
|
|
|
|
|
{"tensor", FORM_PARSE_TENSOR}}; |
|
|
|
|
|
|
|
|
{"type", FORM_PARSE_TYPE}, {"scalar", FORM_PARSE_SCALAR}, {"tensor", FORM_PARSE_TENSOR}}; |
|
|
|
|
|
|
|
|
static std::unordered_map<int, TypeId> kDefaultValueSwitchMap{ |
|
|
static std::unordered_map<int, TypeId> kDefaultValueSwitchMap{ |
|
|
{onnx::TensorProto_DataType_BOOL, kNumberTypeBool}, |
|
|
|
|
|
{onnx::TensorProto_DataType_INT8, kNumberTypeInt8}, |
|
|
|
|
|
{onnx::TensorProto_DataType_INT16, kNumberTypeInt16}, |
|
|
|
|
|
{onnx::TensorProto_DataType_INT32, kNumberTypeInt32}, |
|
|
|
|
|
{onnx::TensorProto_DataType_INT64, kNumberTypeInt64}, |
|
|
|
|
|
{onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8}, |
|
|
|
|
|
{onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16}, |
|
|
|
|
|
{onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32}, |
|
|
|
|
|
{onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64}, |
|
|
|
|
|
{onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16}, |
|
|
|
|
|
{onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32}, |
|
|
|
|
|
{onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64}, |
|
|
|
|
|
{onnx::TensorProto_DataType_STRING, kObjectTypeString}, |
|
|
|
|
|
|
|
|
{onnx::TensorProto_DataType_BOOL, kNumberTypeBool}, {onnx::TensorProto_DataType_INT8, kNumberTypeInt8}, |
|
|
|
|
|
{onnx::TensorProto_DataType_INT16, kNumberTypeInt16}, {onnx::TensorProto_DataType_INT32, kNumberTypeInt32}, |
|
|
|
|
|
{onnx::TensorProto_DataType_INT64, kNumberTypeInt64}, {onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8}, |
|
|
|
|
|
{onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16}, {onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32}, |
|
|
|
|
|
{onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64}, {onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16}, |
|
|
|
|
|
{onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32}, {onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64}, |
|
|
|
|
|
{onnx::TensorProto_DataType_STRING, kObjectTypeString}, |
|
|
}; |
|
|
}; |
|
|
|
|
|
|
|
|
#if 0 |
|
|
#if 0 |
|
|
@@ -194,16 +187,15 @@ ParserAttrShape(const std::string &attr_name, const std::unordered_map<string, a |
|
|
return {}; |
|
|
return {}; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \ |
|
|
|
|
|
ValuePtr ParseAttrInScalar_##type##_##valuetype( \ |
|
|
|
|
|
const onnx::TensorProto &attr_tensor) { \ |
|
|
|
|
|
if (attr_tensor.type##_data_size() == 1) { \ |
|
|
|
|
|
auto value = static_cast<valuetype>(attr_tensor.type##_data(0)); \ |
|
|
|
|
|
return MakeValue<valuetype>(value); \ |
|
|
|
|
|
} else { \ |
|
|
|
|
|
MS_LOG(ERROR) << "size of scalar tensor doesn't equal 1!"; \ |
|
|
|
|
|
} \ |
|
|
|
|
|
return {}; \ |
|
|
|
|
|
|
|
|
#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \ |
|
|
|
|
|
ValuePtr ParseAttrInScalar_##type##_##valuetype(const onnx::TensorProto &attr_tensor) { \ |
|
|
|
|
|
if (attr_tensor.type##_data_size() == 1) { \ |
|
|
|
|
|
auto value = static_cast<valuetype>(attr_tensor.type##_data(0)); \ |
|
|
|
|
|
return MakeValue<valuetype>(value); \ |
|
|
|
|
|
} else { \ |
|
|
|
|
|
MS_LOG(ERROR) << "size of scalar tensor doesn't equal 1!"; \ |
|
|
|
|
|
} \ |
|
|
|
|
|
return {}; \ |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
PARSE_ONNXATTR_IN_SCALAR_FORM(double, double) |
|
|
PARSE_ONNXATTR_IN_SCALAR_FORM(double, double) |
|
|
@@ -255,7 +247,11 @@ bool AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &nod |
|
|
std::string initial_data = initialize_proto.raw_data(); |
|
|
std::string initial_data = initialize_proto.raw_data(); |
|
|
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->Data()); |
|
|
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->Data()); |
|
|
MS_EXCEPTION_IF_NULL(tensor_data_buf); |
|
|
MS_EXCEPTION_IF_NULL(tensor_data_buf); |
|
|
memcpy_s(tensor_data_buf, tensor_info->Size(), initial_data.data(), initial_data.size()); |
|
|
|
|
|
|
|
|
auto ret = memcpy_s(tensor_data_buf, tensor_info->Size(), initial_data.data(), initial_data.size()); |
|
|
|
|
|
if (EOK != ret) { |
|
|
|
|
|
MS_LOG(ERROR) << "memcpy_s error"; |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
ParamValueLitePtr param_value = std::make_shared<ParamValueLite>(); |
|
|
ParamValueLitePtr param_value = std::make_shared<ParamValueLite>(); |
|
|
MS_EXCEPTION_IF_NULL(param_value); |
|
|
MS_EXCEPTION_IF_NULL(param_value); |
|
|
@@ -402,7 +398,11 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &val |
|
|
tensor_info->MallocData(); |
|
|
tensor_info->MallocData(); |
|
|
const std::string &tensor_buf = attr_tensor.raw_data(); |
|
|
const std::string &tensor_buf = attr_tensor.raw_data(); |
|
|
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->Data()); |
|
|
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->Data()); |
|
|
memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_buf.data(), tensor_buf.size()); |
|
|
|
|
|
|
|
|
auto ret = memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_buf.data(), tensor_buf.size()); |
|
|
|
|
|
if (EOK != ret) { |
|
|
|
|
|
MS_LOG(ERROR) << "memcpy_s error"; |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
auto new_value_node = NewValueNode(MakeValue(tensor_info)); |
|
|
auto new_value_node = NewValueNode(MakeValue(tensor_info)); |
|
|
MS_EXCEPTION_IF_NULL(new_value_node); |
|
|
MS_EXCEPTION_IF_NULL(new_value_node); |
|
|
auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]); |
|
|
auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]); |
|
|
@@ -641,21 +641,20 @@ bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFunc |
|
|
} |
|
|
} |
|
|
#endif |
|
|
#endif |
|
|
|
|
|
|
|
|
#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \ |
|
|
|
|
|
void ParseAttrInScalar_##type##_##valuetype( \ |
|
|
|
|
|
const PrimitivePtr &prim, const std::string &attr_name, \ |
|
|
|
|
|
const onnx::TensorProto &attr_tensor) { \ |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim); \ |
|
|
|
|
|
std::vector<ValuePtr> attr_value_vec; \ |
|
|
|
|
|
for (int i = 0; i < attr_tensor.type##_data_size(); ++i) { \ |
|
|
|
|
|
auto value = static_cast<valuetype>(attr_tensor.type##_data(i)); \ |
|
|
|
|
|
attr_value_vec.push_back(MakeValue<valuetype>(value)); \ |
|
|
|
|
|
} \ |
|
|
|
|
|
if (attr_value_vec.size() == 1) { \ |
|
|
|
|
|
prim->AddAttr(attr_name, attr_value_vec[0]); \ |
|
|
|
|
|
} else { \ |
|
|
|
|
|
prim->AddAttr(attr_name, std::make_shared<ValueList>(attr_value_vec)); \ |
|
|
|
|
|
} \ |
|
|
|
|
|
|
|
|
#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \ |
|
|
|
|
|
void ParseAttrInScalar_##type##_##valuetype(const PrimitivePtr &prim, const std::string &attr_name, \ |
|
|
|
|
|
const onnx::TensorProto &attr_tensor) { \ |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim); \ |
|
|
|
|
|
std::vector<ValuePtr> attr_value_vec; \ |
|
|
|
|
|
for (int i = 0; i < attr_tensor.type##_data_size(); ++i) { \ |
|
|
|
|
|
auto value = static_cast<valuetype>(attr_tensor.type##_data(i)); \ |
|
|
|
|
|
attr_value_vec.push_back(MakeValue<valuetype>(value)); \ |
|
|
|
|
|
} \ |
|
|
|
|
|
if (attr_value_vec.size() == 1) { \ |
|
|
|
|
|
prim->AddAttr(attr_name, attr_value_vec[0]); \ |
|
|
|
|
|
} else { \ |
|
|
|
|
|
prim->AddAttr(attr_name, std::make_shared<ValueList>(attr_value_vec)); \ |
|
|
|
|
|
} \ |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
PARSE_ONNXATTR_IN_SCALAR_FORM(double, double) |
|
|
PARSE_ONNXATTR_IN_SCALAR_FORM(double, double) |
|
|
@@ -666,8 +665,8 @@ PARSE_ONNXATTR_IN_SCALAR_FORM(int32, bool) |
|
|
PARSE_ONNXATTR_IN_SCALAR_FORM(int64, int64) |
|
|
PARSE_ONNXATTR_IN_SCALAR_FORM(int64, int64) |
|
|
PARSE_ONNXATTR_IN_SCALAR_FORM(uint64, uint64) |
|
|
PARSE_ONNXATTR_IN_SCALAR_FORM(uint64, uint64) |
|
|
|
|
|
|
|
|
bool AnfImporterFromProtobuf::BuildParameterForFuncGraph( |
|
|
|
|
|
const ParameterPtr &node, const onnx::ValueInfoProto &value_proto) { |
|
|
|
|
|
|
|
|
bool AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node, |
|
|
|
|
|
const onnx::ValueInfoProto &value_proto) { |
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
if (!value_proto.has_type() || !value_proto.has_name()) { |
|
|
if (!value_proto.has_type() || !value_proto.has_name()) { |
|
|
MS_LOG(ERROR) << "onnx ValueInfoProto has no type or name! "; |
|
|
MS_LOG(ERROR) << "onnx ValueInfoProto has no type or name! "; |
|
|
@@ -690,30 +689,28 @@ bool AnfImporterFromProtobuf::BuildParameterForFuncGraph( |
|
|
shape.push_back(tensor_shape.dim(i).dim_value()); |
|
|
shape.push_back(tensor_shape.dim(i).dim_value()); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if (kDefaultValueSwitchMap.find(tensor_typeproto.elem_type()) == |
|
|
|
|
|
kDefaultValueSwitchMap.end()) { |
|
|
|
|
|
|
|
|
if (kDefaultValueSwitchMap.find(tensor_typeproto.elem_type()) == kDefaultValueSwitchMap.end()) { |
|
|
MS_LOG(ERROR) << "onnx TypeProto_Tensor elem_type is not support yet!"; |
|
|
MS_LOG(ERROR) << "onnx TypeProto_Tensor elem_type is not support yet!"; |
|
|
return false; |
|
|
return false; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
auto type_ptr = |
|
|
|
|
|
TypeIdToType(kDefaultValueSwitchMap[tensor_typeproto.elem_type()]); |
|
|
|
|
|
auto abstract_tensor = |
|
|
|
|
|
std::make_shared<abstract::AbstractTensor>(type_ptr, shape); |
|
|
|
|
|
|
|
|
auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[tensor_typeproto.elem_type()]); |
|
|
|
|
|
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape); |
|
|
node->set_abstract(abstract_tensor); |
|
|
node->set_abstract(abstract_tensor); |
|
|
|
|
|
|
|
|
if (default_para_map_.find(value_proto.name()) != default_para_map_.end()) { |
|
|
if (default_para_map_.find(value_proto.name()) != default_para_map_.end()) { |
|
|
tensor::Tensor *tensor_info = new tensor::Tensor( |
|
|
|
|
|
kDefaultValueSwitchMap[tensor_typeproto.elem_type()], shape); |
|
|
|
|
|
|
|
|
tensor::Tensor *tensor_info = new tensor::Tensor(kDefaultValueSwitchMap[tensor_typeproto.elem_type()], shape); |
|
|
MS_EXCEPTION_IF_NULL(tensor_info); |
|
|
MS_EXCEPTION_IF_NULL(tensor_info); |
|
|
tensor_info->MallocData(); |
|
|
tensor_info->MallocData(); |
|
|
const onnx::TensorProto initialize_proto = |
|
|
|
|
|
default_para_map_[value_proto.name()]; |
|
|
|
|
|
|
|
|
const onnx::TensorProto initialize_proto = default_para_map_[value_proto.name()]; |
|
|
std::string initial_data = initialize_proto.raw_data(); |
|
|
std::string initial_data = initialize_proto.raw_data(); |
|
|
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->Data()); |
|
|
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->Data()); |
|
|
MS_EXCEPTION_IF_NULL(tensor_data_buf); |
|
|
MS_EXCEPTION_IF_NULL(tensor_data_buf); |
|
|
memcpy_s(tensor_data_buf, tensor_info->Size(), initial_data.data(), |
|
|
|
|
|
initial_data.size()); |
|
|
|
|
|
|
|
|
auto ret = memcpy_s(tensor_data_buf, tensor_info->Size(), initial_data.data(), initial_data.size()); |
|
|
|
|
|
if (EOK != ret) { |
|
|
|
|
|
MS_LOG(ERROR) << "memcpy_s error"; |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
ParamValueLitePtr param_value = std::make_shared<ParamValueLite>(); |
|
|
ParamValueLitePtr param_value = std::make_shared<ParamValueLite>(); |
|
|
MS_EXCEPTION_IF_NULL(param_value); |
|
|
MS_EXCEPTION_IF_NULL(param_value); |
|
|
@@ -725,18 +722,15 @@ bool AnfImporterFromProtobuf::BuildParameterForFuncGraph( |
|
|
return true; |
|
|
return true; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
bool AnfImporterFromProtobuf::ImportParametersForGraph( |
|
|
|
|
|
const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) { |
|
|
|
|
|
|
|
|
bool AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, |
|
|
|
|
|
const onnx::GraphProto &importProto) { |
|
|
MS_EXCEPTION_IF_NULL(outputFuncGraph); |
|
|
MS_EXCEPTION_IF_NULL(outputFuncGraph); |
|
|
MS_LOG(INFO) << "Parameters had default paramerer size is: " |
|
|
|
|
|
<< importProto.initializer_size(); |
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "Parameters had default paramerer size is: " << importProto.initializer_size(); |
|
|
|
|
|
|
|
|
for (int i = 0; i < importProto.initializer_size(); ++i) { |
|
|
for (int i = 0; i < importProto.initializer_size(); ++i) { |
|
|
const onnx::TensorProto &initializer_proto = importProto.initializer(i); |
|
|
const onnx::TensorProto &initializer_proto = importProto.initializer(i); |
|
|
if (!initializer_proto.has_name()) { |
|
|
if (!initializer_proto.has_name()) { |
|
|
MS_LOG(ERROR) |
|
|
|
|
|
<< "initializer vector of onnx GraphProto has no name at index: " |
|
|
|
|
|
<< i; |
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "initializer vector of onnx GraphProto has no name at index: " << i; |
|
|
return false; |
|
|
return false; |
|
|
} |
|
|
} |
|
|
default_para_map_[initializer_proto.name()] = initializer_proto; |
|
|
default_para_map_[initializer_proto.name()] = initializer_proto; |
|
|
@@ -745,8 +739,7 @@ bool AnfImporterFromProtobuf::ImportParametersForGraph( |
|
|
MS_LOG(INFO) << "all parameters size: " << importProto.input_size(); |
|
|
MS_LOG(INFO) << "all parameters size: " << importProto.input_size(); |
|
|
for (int i = 0; i < importProto.input_size(); ++i) { |
|
|
for (int i = 0; i < importProto.input_size(); ++i) { |
|
|
const onnx::ValueInfoProto &input_proto = importProto.input(i); |
|
|
const onnx::ValueInfoProto &input_proto = importProto.input(i); |
|
|
if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), |
|
|
|
|
|
input_proto)) { |
|
|
|
|
|
|
|
|
if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), input_proto)) { |
|
|
MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i; |
|
|
MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i; |
|
|
return false; |
|
|
return false; |
|
|
} |
|
|
} |
|
|
@@ -754,25 +747,20 @@ bool AnfImporterFromProtobuf::ImportParametersForGraph( |
|
|
return true; |
|
|
return true; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm( |
|
|
|
|
|
const PrimitivePtr &prim, const std::string &attr_name, |
|
|
|
|
|
const onnx::TensorProto &attr_tensor) { |
|
|
|
|
|
|
|
|
bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name, |
|
|
|
|
|
const onnx::TensorProto &attr_tensor) { |
|
|
MS_EXCEPTION_IF_NULL(prim); |
|
|
MS_EXCEPTION_IF_NULL(prim); |
|
|
const int attr_tensor_type = attr_tensor.data_type(); |
|
|
const int attr_tensor_type = attr_tensor.data_type(); |
|
|
if (kDefaultValueSwitchMap.find(attr_tensor_type) == |
|
|
|
|
|
kDefaultValueSwitchMap.end()) { |
|
|
|
|
|
MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:" |
|
|
|
|
|
<< attr_tensor_type; |
|
|
|
|
|
|
|
|
if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) { |
|
|
|
|
|
MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:" << attr_tensor_type; |
|
|
return false; |
|
|
return false; |
|
|
} |
|
|
} |
|
|
prim->AddAttr(attr_name, |
|
|
|
|
|
TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); |
|
|
|
|
|
|
|
|
prim->AddAttr(attr_name, TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); |
|
|
return true; |
|
|
return true; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
bool AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm( |
|
|
|
|
|
const PrimitivePtr &prim, const std::string &attr_name, |
|
|
|
|
|
const onnx::TensorProto &attr_tensor) { |
|
|
|
|
|
|
|
|
bool AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim, const std::string &attr_name, |
|
|
|
|
|
const onnx::TensorProto &attr_tensor) { |
|
|
MS_EXCEPTION_IF_NULL(prim); |
|
|
MS_EXCEPTION_IF_NULL(prim); |
|
|
const int attr_tensor_type = attr_tensor.data_type(); |
|
|
const int attr_tensor_type = attr_tensor.data_type(); |
|
|
switch (attr_tensor_type) { |
|
|
switch (attr_tensor_type) { |
|
|
@@ -806,23 +794,20 @@ bool AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm( |
|
|
break; |
|
|
break; |
|
|
} |
|
|
} |
|
|
default: |
|
|
default: |
|
|
MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " |
|
|
|
|
|
<< attr_tensor_type; |
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type; |
|
|
return false; |
|
|
return false; |
|
|
} |
|
|
} |
|
|
return true; |
|
|
return true; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm( |
|
|
|
|
|
const PrimitivePtr &prim, const std::string &attr_name, |
|
|
|
|
|
const onnx::TensorProto &attr_tensor) { |
|
|
|
|
|
|
|
|
bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name, |
|
|
|
|
|
const onnx::TensorProto &attr_tensor) { |
|
|
MS_EXCEPTION_IF_NULL(prim); |
|
|
MS_EXCEPTION_IF_NULL(prim); |
|
|
MS_LOG(ERROR) << "parse attr type don't support attr type is tensor"; |
|
|
MS_LOG(ERROR) << "parse attr type don't support attr type is tensor"; |
|
|
return false; |
|
|
return false; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
bool AnfImporterFromProtobuf::GetAttrValueForCNode( |
|
|
|
|
|
const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) { |
|
|
|
|
|
|
|
|
bool AnfImporterFromProtobuf::GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) { |
|
|
MS_EXCEPTION_IF_NULL(prim); |
|
|
MS_EXCEPTION_IF_NULL(prim); |
|
|
const std::string &attr_name = attr_proto.name(); |
|
|
const std::string &attr_name = attr_proto.name(); |
|
|
if (!attr_proto.has_ref_attr_name()) { |
|
|
if (!attr_proto.has_ref_attr_name()) { |
|
|
@@ -846,32 +831,33 @@ bool AnfImporterFromProtobuf::GetAttrValueForCNode( |
|
|
return false; |
|
|
return false; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm( |
|
|
|
|
|
const std::string &value_node_name, const onnx::TensorProto &attr_tensor) { |
|
|
|
|
|
|
|
|
bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &value_node_name, |
|
|
|
|
|
const onnx::TensorProto &attr_tensor) { |
|
|
const int attr_tensor_type = attr_tensor.data_type(); |
|
|
const int attr_tensor_type = attr_tensor.data_type(); |
|
|
std::vector<int> shape; |
|
|
std::vector<int> shape; |
|
|
for (int i = 0; i < attr_tensor.dims_size(); ++i) { |
|
|
for (int i = 0; i < attr_tensor.dims_size(); ++i) { |
|
|
shape.push_back(attr_tensor.dims(i)); |
|
|
shape.push_back(attr_tensor.dims(i)); |
|
|
} |
|
|
} |
|
|
tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>( |
|
|
|
|
|
kDefaultValueSwitchMap[attr_tensor_type], shape); |
|
|
|
|
|
|
|
|
tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape); |
|
|
tensor_info->MallocData(); |
|
|
tensor_info->MallocData(); |
|
|
const std::string &tensor_buf = attr_tensor.raw_data(); |
|
|
const std::string &tensor_buf = attr_tensor.raw_data(); |
|
|
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->Data()); |
|
|
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->Data()); |
|
|
memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_buf.data(), |
|
|
|
|
|
tensor_buf.size()); |
|
|
|
|
|
|
|
|
auto ret = memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_buf.data(), tensor_buf.size()); |
|
|
|
|
|
if (EOK != ret) { |
|
|
|
|
|
MS_LOG(ERROR) << "memcpy_s error"; |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
auto new_value_node = NewValueNode(MakeValue(tensor_info)); |
|
|
auto new_value_node = NewValueNode(MakeValue(tensor_info)); |
|
|
MS_EXCEPTION_IF_NULL(new_value_node); |
|
|
MS_EXCEPTION_IF_NULL(new_value_node); |
|
|
auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]); |
|
|
auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]); |
|
|
auto abstract_tensor = |
|
|
|
|
|
std::make_shared<abstract::AbstractTensor>(type_ptr, shape); |
|
|
|
|
|
|
|
|
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape); |
|
|
new_value_node->set_abstract(abstract_tensor); |
|
|
new_value_node->set_abstract(abstract_tensor); |
|
|
anfnode_build_map_[value_node_name] = new_value_node; |
|
|
anfnode_build_map_[value_node_name] = new_value_node; |
|
|
return true; |
|
|
return true; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm( |
|
|
|
|
|
const std::string &value_node_name, const onnx::TensorProto &attr_tensor) { |
|
|
|
|
|
|
|
|
bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm(const std::string &value_node_name, |
|
|
|
|
|
const onnx::TensorProto &attr_tensor) { |
|
|
const int attr_tensor_type = attr_tensor.data_type(); |
|
|
const int attr_tensor_type = attr_tensor.data_type(); |
|
|
ValuePtr value_ptr = nullptr; |
|
|
ValuePtr value_ptr = nullptr; |
|
|
switch (attr_tensor_type) { |
|
|
switch (attr_tensor_type) { |
|
|
@@ -906,8 +892,7 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm( |
|
|
break; |
|
|
break; |
|
|
} |
|
|
} |
|
|
default: |
|
|
default: |
|
|
MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " |
|
|
|
|
|
<< attr_tensor_type; |
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type; |
|
|
return false; |
|
|
return false; |
|
|
} |
|
|
} |
|
|
auto new_value_node = NewValueNode(value_ptr); |
|
|
auto new_value_node = NewValueNode(value_ptr); |
|
|
@@ -918,28 +903,23 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm( |
|
|
return true; |
|
|
return true; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm( |
|
|
|
|
|
const std::string &value_node_name, const onnx::TensorProto &attr_tensor) { |
|
|
|
|
|
|
|
|
bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm(const std::string &value_node_name, |
|
|
|
|
|
const onnx::TensorProto &attr_tensor) { |
|
|
const int attr_tensor_type = attr_tensor.data_type(); |
|
|
const int attr_tensor_type = attr_tensor.data_type(); |
|
|
if (kDefaultValueSwitchMap.find(attr_tensor_type) == |
|
|
|
|
|
kDefaultValueSwitchMap.end()) { |
|
|
|
|
|
MS_LOG(ERROR) |
|
|
|
|
|
<< "Obtain ValueNode attr in type-form has not support input type: " |
|
|
|
|
|
<< attr_tensor_type; |
|
|
|
|
|
|
|
|
if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) { |
|
|
|
|
|
MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type; |
|
|
return false; |
|
|
return false; |
|
|
} |
|
|
} |
|
|
auto new_value_node = |
|
|
|
|
|
NewValueNode(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); |
|
|
|
|
|
abstract::AbstractTypePtr abs_type = |
|
|
|
|
|
std::make_shared<abstract::AbstractType>(std::make_shared<TypeType>()); |
|
|
|
|
|
|
|
|
auto new_value_node = NewValueNode(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); |
|
|
|
|
|
abstract::AbstractTypePtr abs_type = std::make_shared<abstract::AbstractType>(std::make_shared<TypeType>()); |
|
|
new_value_node->set_abstract(abs_type); |
|
|
new_value_node->set_abstract(abs_type); |
|
|
anfnode_build_map_[value_node_name] = new_value_node; |
|
|
anfnode_build_map_[value_node_name] = new_value_node; |
|
|
return true; |
|
|
return true; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
bool AnfImporterFromProtobuf::GetAttrValueForValueNode( |
|
|
|
|
|
const std::string &ref_attr_name, const std::string &value_node_name, |
|
|
|
|
|
const onnx::TensorProto &attr_tensor) { |
|
|
|
|
|
|
|
|
bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &ref_attr_name, |
|
|
|
|
|
const std::string &value_node_name, |
|
|
|
|
|
const onnx::TensorProto &attr_tensor) { |
|
|
switch (kParseTypeSwitchMap[ref_attr_name]) { |
|
|
switch (kParseTypeSwitchMap[ref_attr_name]) { |
|
|
case FORM_PARSE_SCALAR: { |
|
|
case FORM_PARSE_SCALAR: { |
|
|
return ObtainValueNodeInScalarForm(value_node_name, attr_tensor); |
|
|
return ObtainValueNodeInScalarForm(value_node_name, attr_tensor); |
|
|
@@ -951,14 +931,12 @@ bool AnfImporterFromProtobuf::GetAttrValueForValueNode( |
|
|
return ObtainValueNodeInTypeForm(value_node_name, attr_tensor); |
|
|
return ObtainValueNodeInTypeForm(value_node_name, attr_tensor); |
|
|
} |
|
|
} |
|
|
default: |
|
|
default: |
|
|
MS_LOG(ERROR) |
|
|
|
|
|
<< "parse ValueNode value don't support input of ref_attr_name"; |
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "parse ValueNode value don't support input of ref_attr_name"; |
|
|
return false; |
|
|
return false; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph( |
|
|
|
|
|
const onnx::NodeProto &node_proto) { |
|
|
|
|
|
|
|
|
bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto) { |
|
|
const std::string &value_node_name = node_proto.output(0); |
|
|
const std::string &value_node_name = node_proto.output(0); |
|
|
const onnx::AttributeProto &attr_proto = node_proto.attribute(0); |
|
|
const onnx::AttributeProto &attr_proto = node_proto.attribute(0); |
|
|
if (!attr_proto.has_ref_attr_name()) { |
|
|
if (!attr_proto.has_ref_attr_name()) { |
|
|
@@ -971,22 +949,20 @@ bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph( |
|
|
return GetAttrValueForValueNode(ref_attr_name, value_node_name, attr_tensor); |
|
|
return GetAttrValueForValueNode(ref_attr_name, value_node_name, attr_tensor); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
abstract::AbstractTensorPtr AnfImporterFromProtobuf::GetAbstractForCNode( |
|
|
|
|
|
const onnx::AttributeProto &attr_proto) { |
|
|
|
|
|
|
|
|
abstract::AbstractTensorPtr AnfImporterFromProtobuf::GetAbstractForCNode(const onnx::AttributeProto &attr_proto) { |
|
|
std::vector<int> shape_vec; |
|
|
std::vector<int> shape_vec; |
|
|
const onnx::TensorProto &attr_tensor = attr_proto.t(); |
|
|
const onnx::TensorProto &attr_tensor = attr_proto.t(); |
|
|
for (int i = 0; i < attr_tensor.dims_size(); ++i) { |
|
|
for (int i = 0; i < attr_tensor.dims_size(); ++i) { |
|
|
shape_vec.push_back(attr_tensor.dims(i)); |
|
|
shape_vec.push_back(attr_tensor.dims(i)); |
|
|
} |
|
|
} |
|
|
auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor.data_type()]); |
|
|
auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor.data_type()]); |
|
|
auto abstract_tensor = |
|
|
|
|
|
std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vec); |
|
|
|
|
|
|
|
|
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vec); |
|
|
MS_EXCEPTION_IF_NULL(abstract_tensor); |
|
|
MS_EXCEPTION_IF_NULL(abstract_tensor); |
|
|
return abstract_tensor; |
|
|
return abstract_tensor; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph( |
|
|
|
|
|
const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto) { |
|
|
|
|
|
|
|
|
CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, |
|
|
|
|
|
const onnx::NodeProto &node_proto) { |
|
|
MS_EXCEPTION_IF_NULL(outputFuncGraph); |
|
|
MS_EXCEPTION_IF_NULL(outputFuncGraph); |
|
|
if (!node_proto.has_op_type()) { |
|
|
if (!node_proto.has_op_type()) { |
|
|
MS_LOG(ERROR) << "Get CNode op_type failed!"; |
|
|
MS_LOG(ERROR) << "Get CNode op_type failed!"; |
|
|
@@ -1028,8 +1004,7 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph( |
|
|
for (int i = 0; i < node_proto.input_size(); ++i) { |
|
|
for (int i = 0; i < node_proto.input_size(); ++i) { |
|
|
const std::string &input_name = node_proto.input(i); |
|
|
const std::string &input_name = node_proto.input(i); |
|
|
if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) { |
|
|
if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) { |
|
|
MS_LOG(ERROR) << node_name << " input " << i << input_name |
|
|
|
|
|
<< "can't find in nodes have parsed"; |
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed"; |
|
|
return nullptr; |
|
|
return nullptr; |
|
|
} |
|
|
} |
|
|
inputs.push_back(anfnode_build_map_[input_name]); |
|
|
inputs.push_back(anfnode_build_map_[input_name]); |
|
|
@@ -1061,9 +1036,8 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph( |
|
|
return cnode_ptr; |
|
|
return cnode_ptr; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
bool AnfImporterFromProtobuf::BuildReturnForFuncGraph( |
|
|
|
|
|
const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, |
|
|
|
|
|
const CNodePtr &cnode_ptr) { |
|
|
|
|
|
|
|
|
bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, |
|
|
|
|
|
const onnx::GraphProto &importProto, const CNodePtr &cnode_ptr) { |
|
|
MS_EXCEPTION_IF_NULL(outputFuncGraph); |
|
|
MS_EXCEPTION_IF_NULL(outputFuncGraph); |
|
|
MS_EXCEPTION_IF_NULL(cnode_ptr); |
|
|
MS_EXCEPTION_IF_NULL(cnode_ptr); |
|
|
std::vector<AnfNodePtr> inputs; |
|
|
std::vector<AnfNodePtr> inputs; |
|
|
@@ -1078,8 +1052,7 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph( |
|
|
elem.push_back(anfnode_build_map_[out_tuple]->abstract()); |
|
|
elem.push_back(anfnode_build_map_[out_tuple]->abstract()); |
|
|
} |
|
|
} |
|
|
auto maketuple_ptr = outputFuncGraph->NewCNode(inputs); |
|
|
auto maketuple_ptr = outputFuncGraph->NewCNode(inputs); |
|
|
maketuple_ptr->set_abstract( |
|
|
|
|
|
std::make_shared<abstract::AbstractTuple>(elem)); |
|
|
|
|
|
|
|
|
maketuple_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem)); |
|
|
inputs.clear(); |
|
|
inputs.clear(); |
|
|
inputs.push_back(NewValueNode(prim::kPrimReturn)); |
|
|
inputs.push_back(NewValueNode(prim::kPrimReturn)); |
|
|
inputs.push_back(maketuple_ptr); |
|
|
inputs.push_back(maketuple_ptr); |
|
|
@@ -1092,14 +1065,11 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph( |
|
|
const onnx::TypeProto &output_typeproto = output_node.type(); |
|
|
const onnx::TypeProto &output_typeproto = output_node.type(); |
|
|
int output_type = output_typeproto.tensor_type().elem_type(); |
|
|
int output_type = output_typeproto.tensor_type().elem_type(); |
|
|
std::vector<int> output_shape; |
|
|
std::vector<int> output_shape; |
|
|
for (int i = 0; i < output_typeproto.tensor_type().shape().dim_size(); |
|
|
|
|
|
++i) { |
|
|
|
|
|
output_shape.push_back( |
|
|
|
|
|
output_typeproto.tensor_type().shape().dim(i).dim_value()); |
|
|
|
|
|
|
|
|
for (int i = 0; i < output_typeproto.tensor_type().shape().dim_size(); ++i) { |
|
|
|
|
|
output_shape.push_back(output_typeproto.tensor_type().shape().dim(i).dim_value()); |
|
|
} |
|
|
} |
|
|
auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[output_type]); |
|
|
auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[output_type]); |
|
|
auto abstract_tensor = |
|
|
|
|
|
std::make_shared<abstract::AbstractTensor>(type_ptr, output_shape); |
|
|
|
|
|
|
|
|
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, output_shape); |
|
|
|
|
|
|
|
|
inputs.clear(); |
|
|
inputs.clear(); |
|
|
inputs.push_back(NewValueNode(prim::kPrimReturn)); |
|
|
inputs.push_back(NewValueNode(prim::kPrimReturn)); |
|
|
@@ -1113,8 +1083,8 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph( |
|
|
return true; |
|
|
return true; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
bool AnfImporterFromProtobuf::ImportNodesForGraph( |
|
|
|
|
|
const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) { |
|
|
|
|
|
|
|
|
bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, |
|
|
|
|
|
const onnx::GraphProto &importProto) { |
|
|
MS_EXCEPTION_IF_NULL(outputFuncGraph); |
|
|
MS_EXCEPTION_IF_NULL(outputFuncGraph); |
|
|
MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size(); |
|
|
MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size(); |
|
|
CNodePtr cnode_ptr = nullptr; |
|
|
CNodePtr cnode_ptr = nullptr; |
|
|
@@ -1139,8 +1109,7 @@ bool AnfImporterFromProtobuf::ImportNodesForGraph( |
|
|
return true; |
|
|
return true; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
bool AnfImporterFromProtobuf::BuildFuncGraph( |
|
|
|
|
|
const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) { |
|
|
|
|
|
|
|
|
bool AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) { |
|
|
MS_EXCEPTION_IF_NULL(outputFuncGraph); |
|
|
MS_EXCEPTION_IF_NULL(outputFuncGraph); |
|
|
GraphDebugInfoPtr debug_info_ptr = outputFuncGraph->debug_info(); |
|
|
GraphDebugInfoPtr debug_info_ptr = outputFuncGraph->debug_info(); |
|
|
MS_EXCEPTION_IF_NULL(debug_info_ptr); |
|
|
MS_EXCEPTION_IF_NULL(debug_info_ptr); |
|
|
@@ -1156,8 +1125,7 @@ bool AnfImporterFromProtobuf::BuildFuncGraph( |
|
|
return ImportNodesForGraph(outputFuncGraph, importProto); |
|
|
return ImportNodesForGraph(outputFuncGraph, importProto); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
bool AnfImporterFromProtobuf::ParseModelConfigureInfo( |
|
|
|
|
|
const onnx::ModelProto &model_proto) { |
|
|
|
|
|
|
|
|
bool AnfImporterFromProtobuf::ParseModelConfigureInfo(const onnx::ModelProto &model_proto) { |
|
|
if (!model_proto.has_producer_name()) { |
|
|
if (!model_proto.has_producer_name()) { |
|
|
MS_LOG(ERROR) << "Parse model producer name from pb file failed!"; |
|
|
MS_LOG(ERROR) << "Parse model producer name from pb file failed!"; |
|
|
return false; |
|
|
return false; |
|
|
@@ -1194,8 +1162,7 @@ int AnfImporterFromProtobuf::Import() { |
|
|
return RET_OK; |
|
|
return RET_OK; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary( |
|
|
|
|
|
const std::string &model_path) { |
|
|
|
|
|
|
|
|
onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(const std::string &model_path) { |
|
|
std::unique_ptr<char> onnx_file(new (std::nothrow) char[PATH_MAX]{0}); |
|
|
std::unique_ptr<char> onnx_file(new (std::nothrow) char[PATH_MAX]{0}); |
|
|
if (realpath(model_path.c_str(), onnx_file.get()) == nullptr) { |
|
|
if (realpath(model_path.c_str(), onnx_file.get()) == nullptr) { |
|
|
MS_LOG(ERROR) << "open file failed."; |
|
|
MS_LOG(ERROR) << "open file failed."; |
|
|
|