|
|
|
@@ -76,45 +76,85 @@ STATUS OnnxModelParser::ReadOnnxModelFromBinary(const std::string &modelFile, go |
|
|
|
STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache) { |
|
|
|
MS_LOG(DEBUG) << "set onnx constant tensors"; |
|
|
|
for (const auto &onnx_const_value : onnx_graph.initializer()) { |
|
|
|
auto data_type = GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(onnx_const_value.data_type())); |
|
|
|
if (data_type == kTypeUnknown) { |
|
|
|
MS_LOG(ERROR) << "not support onnx data type " |
|
|
|
<< static_cast<onnx::TensorProto_DataType>(onnx_const_value.data_type()); |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
std::unique_ptr<schema::TensorT> tensor(new (std::nothrow) schema::TensorT); |
|
|
|
if (tensor == nullptr) { |
|
|
|
MS_LOG(ERROR) << "new tensor failed"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
tensor->dataType = data_type; |
|
|
|
tensor->format = schema::Format_NCHW; // onnx use NCHW |
|
|
|
std::copy(onnx_const_value.dims().begin(), onnx_const_value.dims().end(), std::back_inserter(tensor->dims)); |
|
|
|
tensor->nodeType = schema::NodeType_ValueNode; |
|
|
|
if (CopyOnnxTensorData(onnx_const_value, tensor.get())) { |
|
|
|
MS_LOG(ERROR) << "copy onnx data failed"; |
|
|
|
return RET_ERROR; |
|
|
|
int index; |
|
|
|
const auto status = AddTensorProto(onnx_const_value, onnx_const_value.name(), GRAPH_INPUT, tensor_cache, &index); |
|
|
|
if (status != RET_OK) { |
|
|
|
return status; |
|
|
|
} |
|
|
|
// TODO(wangzhe) why use GRAPH_INPUT other than CONST(GRAPH_INPUT will add index to graphInputs) |
|
|
|
const auto index = tensor_cache->AddTensor(onnx_const_value.name(), tensor.release(), GRAPH_INPUT); |
|
|
|
MS_LOG(DEBUG) << "add const tensor: " << onnx_const_value.name() << ", index " << index; |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "process onnx Constant ops"; |
|
|
|
for (int i = 0; i < onnx_graph.node_size(); i++) { |
|
|
|
const auto &node = onnx_graph.node(i); |
|
|
|
if (node.op_type().compare("Constant") == 0) { |
|
|
|
for (const auto &attr : node.attribute()) { |
|
|
|
if (attr.name() == "sparse_value") { |
|
|
|
MS_LOG(ERROR) << "sparse_value"; |
|
|
|
} |
|
|
|
if (attr.name() == "value") { |
|
|
|
const auto &t = attr.t(); |
|
|
|
int index; |
|
|
|
const auto status = AddTensorProto(t, node.output(0), GRAPH_INPUT, tensor_cache, &index); |
|
|
|
if (status != RET_OK) { |
|
|
|
return status; |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "add const tensor: " << t.name() << ", index " << index; |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "processing Constant op attr " << attr.name() << " not implemented"; |
|
|
|
return RET_INVALID_OP_ATTR; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
// TODO(wangzhe) seems AddTensorCache should be renamed to prepare tensor to add to tensor_cache |
|
|
|
STATUS OnnxModelParser::AddTensorCache(const onnx::ValueInfoProto &proto, schema::TensorT *tensor) { |
|
|
|
STATUS OnnxModelParser::AddValueInfo(const onnx::ValueInfoProto &proto, const std::string &name, const TensorType &type, |
|
|
|
TensorCache *tensor_cache, int *index) { |
|
|
|
auto data_type = GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(proto.type().tensor_type().elem_type())); |
|
|
|
if (data_type == kTypeUnknown) { |
|
|
|
MS_LOG(ERROR) << "not support onnx type " |
|
|
|
MS_LOG(ERROR) << "not support onnx data type " |
|
|
|
<< static_cast<onnx::TensorProto_DataType>(proto.type().tensor_type().elem_type()); |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
std::unique_ptr<schema::TensorT> tensor(new schema::TensorT); |
|
|
|
if (tensor == nullptr) { |
|
|
|
MS_LOG(ERROR) << "new tensor failed"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
tensor->dataType = data_type; |
|
|
|
tensor->dims = GetDimsFromOnnxValue(proto); |
|
|
|
tensor->format = schema::Format_NCHW; |
|
|
|
tensor->nodeType = schema::NodeType_ValueNode; |
|
|
|
// TODO(wangzhe) tensor->data and quantParams not set, should we need tensor_cache->AddTensor? |
|
|
|
*index = tensor_cache->AddTensor(name, tensor.release(), type); |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
STATUS OnnxModelParser::AddTensorProto(const onnx::TensorProto &proto, const std::string &name, const TensorType &type, |
|
|
|
TensorCache *tensor_cache, int *index) { |
|
|
|
auto data_type = GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(proto.data_type())); |
|
|
|
if (data_type == kTypeUnknown) { |
|
|
|
MS_LOG(ERROR) << "not support onnx data type " << static_cast<onnx::TensorProto_DataType>(proto.data_type()); |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
|
|
|
|
std::unique_ptr<schema::TensorT> tensor(new (std::nothrow) schema::TensorT); |
|
|
|
if (tensor == nullptr) { |
|
|
|
MS_LOG(ERROR) << "new tensor failed"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
tensor->dataType = data_type; |
|
|
|
std::copy(proto.dims().begin(), proto.dims().end(), std::back_inserter(tensor->dims)); |
|
|
|
tensor->format = schema::Format_NCHW; |
|
|
|
tensor->nodeType = schema::NodeType_ValueNode; |
|
|
|
if (CopyOnnxTensorData(proto, tensor.get())) { |
|
|
|
MS_LOG(ERROR) << "copy onnx data failed"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
if (data_type == kNumberTypeInt64) { |
|
|
|
MS_LOG(ERROR) << "INT64" << proto.name(); |
|
|
|
tensor->dataType = kNumberTypeInt32; // CopyOnnxTensorData will convert int64 to int32 |
|
|
|
} |
|
|
|
*index = tensor_cache->AddTensor(name, tensor.release(), type); |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -123,15 +163,13 @@ STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph, |
|
|
|
for (const auto &input_value : onnx_graph.input()) { |
|
|
|
auto ret = tensor_cache->FindTensor(input_value.name()); |
|
|
|
if (ret < 0) { |
|
|
|
std::unique_ptr<schema::TensorT> tensor(new schema::TensorT); |
|
|
|
// TODO(wangzhe) why there is an addtensorCache? |
|
|
|
if (AddTensorCache(input_value, tensor.get())) { |
|
|
|
return RET_ERROR; |
|
|
|
int index; |
|
|
|
const auto status = AddValueInfo(input_value, input_value.name(), GRAPH_INPUT, tensor_cache, &index); |
|
|
|
if (status != RET_OK) { |
|
|
|
return status; |
|
|
|
} |
|
|
|
// TODO(wangzhe) why inputTensor is value and should be added into tensor_cache? |
|
|
|
auto tensor_index = tensor_cache->AddTensor(input_value.name(), tensor.release(), GRAPH_INPUT); |
|
|
|
graph->inputIndex.emplace_back(static_cast<uint32_t>(tensor_index)); |
|
|
|
MS_LOG(DEBUG) << "input_value name: " << input_value.name() << ", graph input index: " << tensor_index; |
|
|
|
MS_LOG(ERROR) << "input_value name: " << input_value.name() << ", graph input index: " << index; |
|
|
|
graph->inputIndex.emplace_back(static_cast<uint32_t>(index)); |
|
|
|
} |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
@@ -140,14 +178,13 @@ STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph, |
|
|
|
STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph, |
|
|
|
TensorCache *tensor_cache) { |
|
|
|
for (const auto &output_value : onnx_graph.output()) { |
|
|
|
std::unique_ptr<schema::TensorT> tensor(new schema::TensorT); |
|
|
|
if (AddTensorCache(output_value, tensor.get())) { |
|
|
|
return RET_ERROR; |
|
|
|
int index; |
|
|
|
const auto status = AddValueInfo(output_value, output_value.name(), OP_OUTPUT, tensor_cache, &index); |
|
|
|
if (status != RET_OK) { |
|
|
|
return status; |
|
|
|
} |
|
|
|
// TODO(wangzhe) why we need AddTensor at OutputTensor |
|
|
|
auto tensor_index = tensor_cache->AddTensor(output_value.name(), tensor.release(), OP_OUTPUT); |
|
|
|
graph->outputIndex.emplace_back(tensor_index); |
|
|
|
MS_LOG(DEBUG) << "output_value name: " << output_value.name() << ", graph output index: " << tensor_index; |
|
|
|
graph->outputIndex.emplace_back(index); |
|
|
|
MS_LOG(ERROR) << "output_value name: " << output_value.name() << ", graph output index: " << index; |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
@@ -332,32 +369,11 @@ STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, co |
|
|
|
|
|
|
|
STATUS OnnxModelParser::SetOpInputIndex(const std::vector<string> &node_inputs, schema::CNodeT *dst_op, |
|
|
|
const onnx::NodeProto &onnx_node, TensorCache *tensor_cache) { |
|
|
|
schema::Format format = schema::Format_MAX; |
|
|
|
for (const auto &onnx_node_attr : onnx_node.attribute()) { |
|
|
|
if (onnx_node_attr.name() == "order") { // do we need this code? onnx doc don't have order attr |
|
|
|
MS_LOG(EXCEPTION) << "find order attr"; |
|
|
|
if (onnx_node_attr.s() == "NHWC") { |
|
|
|
format = schema::Format_NHWC; |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "Unsupported format: " << onnx_node_attr.s(); |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
for (const auto &onnx_node_input : node_inputs) { |
|
|
|
auto index = tensor_cache->FindTensor(onnx_node_input); |
|
|
|
// MS_LOG(ERROR) << onnx_node.name() << " input " << onnx_node_input << " index in tensor_cache " << index; |
|
|
|
if (index < 0) { // TODO(wangzhe) can this be ignored? because it's no use |
|
|
|
/* |
|
|
|
std::unique_ptr<schema::TensorT> tensor(new schema::TensorT); |
|
|
|
index = tensor_cache->AddTensor(onnx_node_input, tensor.release(), OP_OUTPUT); |
|
|
|
*/ |
|
|
|
MS_LOG(EXCEPTION) << "input " << onnx_node_input << " of node " << onnx_node.name() << " can't be found"; |
|
|
|
// MS_LOG(INFO) << "new index: " << index; |
|
|
|
} |
|
|
|
if (format != schema::Format_MAX) { // TODO(wangzhe) also this |
|
|
|
auto inTensor = tensor_cache->GetCachedTensor().at(index); |
|
|
|
inTensor->format = format; |
|
|
|
if (index < 0) { |
|
|
|
MS_LOG(ERROR) << "input " << onnx_node_input << " of node " << onnx_node.name() << " can't be found"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "node: " << onnx_node_input << ", input index: " << index; |
|
|
|
dst_op->inputIndex.emplace_back(index); |
|
|
|
@@ -369,19 +385,12 @@ STATUS OnnxModelParser::SetOpOutputIndex(const std::vector<string> &node_outputs |
|
|
|
TensorCache *tensor_cache) { |
|
|
|
for (const auto &onnx_node_output : node_outputs) { |
|
|
|
auto index = tensor_cache->FindTensor(onnx_node_output); |
|
|
|
if (index < 0) { |
|
|
|
MS_LOG(INFO) << "output of node " << dst_op->name << " not in tensor_cache, creating"; |
|
|
|
MS_LOG(INFO) << "total " << node_outputs.size() << " outputs"; |
|
|
|
if (index < 0) { // when index >= 0, it's graph's output |
|
|
|
std::unique_ptr<schema::TensorT> tensor(new schema::TensorT); |
|
|
|
|
|
|
|
// GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(proto.type().tensor_type().elem_type())); |
|
|
|
// tensor->dataType = ; |
|
|
|
// tensor->dims = tflite_tensor->shape; |
|
|
|
tensor->nodeType = schema::NodeType_Parameter; |
|
|
|
|
|
|
|
index = tensor_cache->AddTensor(onnx_node_output, tensor.release(), OP_OUTPUT); |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "node: " << onnx_node_output << ", input index: " << index; |
|
|
|
MS_LOG(DEBUG) << "node: " << onnx_node_output << ", output index: " << index; |
|
|
|
dst_op->outputIndex.emplace_back(index); |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
@@ -390,8 +399,10 @@ STATUS OnnxModelParser::SetOpOutputIndex(const std::vector<string> &node_outputs |
|
|
|
STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_value, schema::TensorT *tensor) { |
|
|
|
size_t data_count = 1; |
|
|
|
std::for_each(tensor->dims.begin(), tensor->dims.end(), [&data_count](int dim) { data_count *= dim; }); |
|
|
|
MS_LOG(ERROR) << "const tensor dims " << tensor->dims.size(); |
|
|
|
size_t data_size = 0; |
|
|
|
const void *tensor_data = nullptr; |
|
|
|
int32_t *buffer = nullptr; |
|
|
|
switch (tensor->dataType) { |
|
|
|
case kNumberTypeFloat32: |
|
|
|
data_size = data_count * sizeof(float); |
|
|
|
@@ -410,12 +421,23 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v |
|
|
|
} |
|
|
|
break; |
|
|
|
case kNumberTypeInt64: |
|
|
|
data_size = data_count * sizeof(int64_t); |
|
|
|
data_size = data_count * sizeof(int32_t); |
|
|
|
buffer = new int32_t[data_count]; |
|
|
|
const int64_t *in_data; |
|
|
|
if (onnx_const_value.int64_data_size() == 0) { |
|
|
|
tensor_data = onnx_const_value.raw_data().data(); |
|
|
|
in_data = reinterpret_cast<const int64_t *>(onnx_const_value.raw_data().data()); |
|
|
|
} else { |
|
|
|
tensor_data = onnx_const_value.int64_data().data(); |
|
|
|
in_data = onnx_const_value.int64_data().data(); |
|
|
|
} |
|
|
|
for (int i = 0; i < data_count; ++i) { |
|
|
|
if (in_data[i] > static_cast<int64_t>(INT32_MAX) || in_data[i] < static_cast<int64_t>(INT32_MIN)) { |
|
|
|
MS_LOG(ERROR) << "int64 data " << in_data[i] << "too big to fit into int32"; |
|
|
|
return RET_ERROR; |
|
|
|
} else { |
|
|
|
buffer[i] = static_cast<int>(in_data[i]); |
|
|
|
} |
|
|
|
} |
|
|
|
tensor_data = reinterpret_cast<void *>(buffer); |
|
|
|
break; |
|
|
|
case kNumberTypeUInt8: |
|
|
|
case kNumberTypeInt8: |
|
|
|
@@ -431,6 +453,9 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v |
|
|
|
MS_LOG(ERROR) << "memcpy_s failed"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
if (kNumberTypeInt64 == tensor->dataType) { |
|
|
|
free(buffer); |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -491,6 +516,9 @@ MetaGraphT *OnnxModelParser::Parse(const std::string &modelFile, const std::stri |
|
|
|
} |
|
|
|
// init op node input/output tensor, and dst_op attr |
|
|
|
for (const auto &onnx_node : onnx_graph.node()) { |
|
|
|
if (onnx_node.op_type() == "Constant") { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (onnx_node.op_type() == "Gemm") { |
|
|
|
ParseOnnxGemmNode(onnx_graph, onnx_node, dst_graph.get(), &tensor_cache); |
|
|
|
continue; |
|
|
|
|