| @@ -4,7 +4,6 @@ set(SRC_LIST | |||||
| "onnx_data_parser.cc" | "onnx_data_parser.cc" | ||||
| "onnx_util.cc" | "onnx_util.cc" | ||||
| "onnx_constant_parser.cc" | "onnx_constant_parser.cc" | ||||
| "onnx_file_constant_parser.cc" | |||||
| "subgraph_adapter/if_subgraph_adapter.cc" | "subgraph_adapter/if_subgraph_adapter.cc" | ||||
| "subgraph_adapter/subgraph_adapter_factory.cc" | "subgraph_adapter/subgraph_adapter_factory.cc" | ||||
| ) | ) | ||||
| @@ -17,7 +17,6 @@ PARSER_ONNX_SRC_FILES := \ | |||||
| onnx_data_parser.cc \ | onnx_data_parser.cc \ | ||||
| onnx_util.cc \ | onnx_util.cc \ | ||||
| onnx_constant_parser.cc \ | onnx_constant_parser.cc \ | ||||
| onnx_file_constant_parser.cc \ | |||||
| proto/onnx/ge_onnx.proto \ | proto/onnx/ge_onnx.proto \ | ||||
| proto/om.proto \ | proto/om.proto \ | ||||
| @@ -34,7 +34,11 @@ using namespace ge::parser; | |||||
| namespace ge { | namespace ge { | ||||
| namespace { | namespace { | ||||
| const char *kConstant = "Const"; | |||||
| const char *kConstant = "Const"; | |||||
| const char *const kLocation = "location"; | |||||
| const char *const kOffset = "offset"; | |||||
| const char *const kLength = "length"; | |||||
| const int64_t kOffsetCoefficient = 4096; | |||||
| } | } | ||||
| Status OnnxConstantParser::ParseConvertData(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor, int count) { | Status OnnxConstantParser::ParseConvertData(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor, int count) { | ||||
| int64_t data_type = tensor_proto.data_type(); | int64_t data_type = tensor_proto.data_type(); | ||||
| @@ -153,6 +157,79 @@ void OnnxConstantParser::ParseConvertDataElements(const ge::onnx::TensorProto &t | |||||
| } | } | ||||
| } | } | ||||
| Status OnnxConstantParser::LoadWeightFromFile(ge::Tensor &tensor, std::string &file_path, size_t offset, | |||||
| size_t length) { | |||||
| const GeTensor ge_tensor = TensorAdapter::AsGeTensor(tensor); | |||||
| const auto &tensor_desc = ge_tensor.GetTensorDesc(); | |||||
| int64_t weight_size = 0; | |||||
| GE_CHK_STATUS_RET(TensorUtils::GetTensorSizeInBytes(tensor_desc, weight_size), | |||||
| "Failed to get file constant weight size."); | |||||
| const size_t file_length = (length == 0U ? static_cast<size_t>(weight_size) : length); | |||||
| const std::string real_path = ge::parser::RealPath(file_path.c_str()); | |||||
| GE_CHECK_NOTNULL(real_path.c_str()); | |||||
| std::ifstream ifs(real_path, std::ifstream::binary); | |||||
| if (!ifs.is_open()) { | |||||
| REPORT_INNER_ERROR("E19999", "Read file %s failed.", file_path.c_str()); | |||||
| GELOGE(FAILED, "[Read][File]Failed, file %s.", file_path.c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| ifs.clear(); | |||||
| ifs.seekg(offset, ifs.beg); | |||||
| const std::unique_ptr<char_t[]> bin_buff = std::unique_ptr<char_t[]>(new (std::nothrow) char_t[file_length]); | |||||
| (void)ifs.read(static_cast<char_t *>(bin_buff.get()), static_cast<int64_t>(file_length)); | |||||
| ifs.close(); | |||||
| tensor.SetData(reinterpret_cast<uint8_t *>(bin_buff.get()), file_length); | |||||
| GELOGD("Load weight from %s success.", file_path.c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status OnnxConstantParser::ParseExternalWeight(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor) { | |||||
| std::string file_path; | |||||
| int64_t attr_offset = 0; | |||||
| int64_t attr_length = 0; | |||||
| for (int32_t i = 0; i < tensor_proto.external_data_size(); ++i) { | |||||
| const ge::onnx::StringStringEntryProto &string_proto = tensor_proto.external_data(i); | |||||
| if (string_proto.key() == kLocation) { | |||||
| file_path = string_proto.value(); | |||||
| continue; | |||||
| } | |||||
| if (string_proto.key() == kOffset) { | |||||
| try { | |||||
| attr_offset = stol(string_proto.value()); | |||||
| } catch (const std::exception &e) { | |||||
| REPORT_INNER_ERROR("E19999", "Convert %s to int64_t value failed:%s", string_proto.value().c_str(), e.what()); | |||||
| GELOGE(domi::PARAM_INVALID, "Convert %s to int64_t value failed:%s", string_proto.value().c_str(), e.what()); | |||||
| return FAILED; | |||||
| } | |||||
| if (attr_offset > (std::numeric_limits<int64_t>::max() / kOffsetCoefficient)) { | |||||
| REPORT_INNER_ERROR("E19999", "overflow, kOffsetCoefficient[%ld], value[%ld]", kOffsetCoefficient, attr_offset); | |||||
| GELOGE(domi::PARAM_INVALID, "overflow, kOffsetCoefficient[%ld], value[%ld]", kOffsetCoefficient, attr_offset); | |||||
| return FAILED; | |||||
| } | |||||
| attr_offset *= kOffsetCoefficient; | |||||
| continue; | |||||
| } | |||||
| if (string_proto.key() == kLength) { | |||||
| try { | |||||
| attr_length = stol(string_proto.value()); | |||||
| } catch (const std::exception &e) { | |||||
| REPORT_INNER_ERROR("E19999", "Convert %s to int64_t value failed:%s", string_proto.value().c_str(), e.what()); | |||||
| GELOGE(domi::PARAM_INVALID, "Convert %s to int64_t value failed:%s", string_proto.value().c_str(), e.what()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| if (file_path.empty()) { | |||||
| REPORT_INNER_ERROR("E19999", "External tensor proto[%s] must contain location.", tensor_proto.name().c_str()); | |||||
| GELOGE(domi::PARAM_INVALID, "External tensor proto[%s] must contain location.", tensor_proto.name().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| GE_CHK_STATUS_RET(LoadWeightFromFile(tensor, file_path, static_cast<size_t>(attr_offset), | |||||
| static_cast<size_t>(attr_length)), | |||||
| "Failed to load weight from file:%s", file_path.c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status OnnxConstantParser::ParseConvertTensor(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor) { | Status OnnxConstantParser::ParseConvertTensor(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor) { | ||||
| // convert shape and format | // convert shape and format | ||||
| std::vector<int64_t> tmp_shape; | std::vector<int64_t> tmp_shape; | ||||
| @@ -173,9 +250,12 @@ Status OnnxConstantParser::ParseConvertTensor(const ge::onnx::TensorProto &tenso | |||||
| tensor.SetTensorDesc(tensor_desc); | tensor.SetTensorDesc(tensor_desc); | ||||
| // set data | // set data | ||||
| if (ParseConvertData(tensor_proto, tensor, count) != SUCCESS) { | |||||
| GELOGE(FAILED, "[Invoke][ParseConvertData]Convert ge tensor data and format failed."); | |||||
| return FAILED; | |||||
| if (tensor_proto.data_location() == ge::onnx::TensorProto_DataLocation_EXTERNAL) { | |||||
| GE_CHK_STATUS_RET(ParseExternalWeight(tensor_proto, tensor), | |||||
| "[Invoke][ParseExternalWeight]Load external weight file failed."); | |||||
| } else { | |||||
| GE_CHK_STATUS_RET(ParseConvertData(tensor_proto, tensor, count), | |||||
| "[Invoke][ParseConvertData]Convert ge tensor data and format failed."); | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -32,6 +32,8 @@ class PARSER_FUNC_VISIBILITY OnnxConstantParser : public OnnxOpParser { | |||||
| static Status ParseConstFromInput(const ge::onnx::NodeProto *op_src, ge::Operator &op_def); | static Status ParseConstFromInput(const ge::onnx::NodeProto *op_src, ge::Operator &op_def); | ||||
| static Status ParseConvertTensor(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor); | static Status ParseConvertTensor(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor); | ||||
| static Status ParseConvertData(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor, int count); | static Status ParseConvertData(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor, int count); | ||||
| static Status LoadWeightFromFile(ge::Tensor &tensor, std::string &file_path, size_t offset, size_t file_length); | |||||
| static Status ParseExternalWeight(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor); | |||||
| static void ParseConvertDataElements(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor, int count, | static void ParseConvertDataElements(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor, int count, | ||||
| int64_t data_type); | int64_t data_type); | ||||
| static Status ParseConvertDataType(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor); | static Status ParseConvertDataType(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor); | ||||
| @@ -1,151 +0,0 @@ | |||||
| /** | |||||
| * Copyright (c) Huawei Technologies Co., Ltd. 2020. All rights reserved. | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "onnx_file_constant_parser.h" | |||||
| #include <vector> | |||||
| #include "graph/ge_tensor.h" | |||||
| #include "parser/common/op_parser_factory.h" | |||||
| #include "parser/onnx/onnx_util.h" | |||||
| #include "framework/common/util.h" | |||||
| #include "framework/common/types.h" | |||||
| using ge::onnx::NodeProto; | |||||
| using ge::onnx::TensorProto; | |||||
| using domi::ONNX; | |||||
| using GeShape = ge::GeShape; | |||||
| using GeTensorDesc = ge::GeTensorDesc; | |||||
| using namespace ge::parser; | |||||
| namespace { | |||||
| const char *const kAttrShape = "shape"; | |||||
| const char *const kAttrDataType = "dtype"; | |||||
| const char *const kFileConstantPath = "file_constant_path"; | |||||
| const char *const kLocation = "location"; | |||||
| const char *const kOffset = "offset"; | |||||
| const int64_t kOffsetCoefficient = 4096; | |||||
| const char *const kFileConstant = "FileConstant"; | |||||
| } | |||||
| namespace ge { | |||||
| Status OnnxFileConstantParser::ParseParams(const Message *op_src, ge::Operator &op_def) { | |||||
| GE_CHECK_NOTNULL(op_src); | |||||
| const ge::onnx::NodeProto *node = PtrToPtr<const Message, const ge::onnx::NodeProto>(op_src); | |||||
| GELOGD("Onnx op node name = %s, op type= %s, parse params", node->name().c_str(), node->op_type().c_str()); | |||||
| ge::onnx::TensorProto tensor_proto; | |||||
| if (GetTensorProto(*node, tensor_proto) != SUCCESS) { | |||||
| REPORT_INNER_ERROR("E19999", "node[%s] get tensor failed", node->name().c_str()); | |||||
| GELOGE(domi::PARAM_INVALID, "[Get][TensorProto] node[%s] get tensor failed", node->name().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| if (ParseDataType(tensor_proto, op_def) != SUCCESS) { | |||||
| REPORT_INNER_ERROR("E19999", "node[%s] parse data type failed", node->name().c_str()); | |||||
| GELOGE(domi::PARAM_INVALID, "[Parse][Shape] node[%s] parse data type failed", node->name().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| if (ParsePath(tensor_proto, op_def) != SUCCESS) { | |||||
| REPORT_INNER_ERROR("E19999", "node[%s] parse file path failed", node->name().c_str()); | |||||
| GELOGE(domi::PARAM_INVALID, "[Parse][Shape] node[%s] parse file path failed", node->name().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| ParseShape(tensor_proto, op_def); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status OnnxFileConstantParser::GetTensorProto(const ge::onnx::NodeProto &node_proto, | |||||
| ge::onnx::TensorProto &tensor_proto) const { | |||||
| for (const auto &it : node_proto.attribute()) { | |||||
| if (it.name() != ge::kAttrNameValue) { | |||||
| continue; | |||||
| } | |||||
| tensor_proto = it.t(); | |||||
| return SUCCESS; | |||||
| } | |||||
| REPORT_INNER_ERROR("E19999", "node_proto[%s] get value failed", node_proto.name().c_str()); | |||||
| GELOGE(ge::PARAM_INVALID, "[Get][TensorProto] node_proto[%s] get value failed", node_proto.name().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| void OnnxFileConstantParser::ParseShape(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const { | |||||
| std::vector<int64_t> tmp_shape; | |||||
| for (int i = 0; i < tensor_proto.dims_size(); i++) { | |||||
| tmp_shape.push_back(tensor_proto.dims(i)); | |||||
| } | |||||
| op_def.SetAttr(kAttrShape, tmp_shape); | |||||
| } | |||||
| Status OnnxFileConstantParser::ParseDataType(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const { | |||||
| int64_t data_type = tensor_proto.data_type(); | |||||
| ge::DataType type = ge::OnnxUtil::ConvertOnnxDataType(data_type); | |||||
| if (type >= ge::DataType::DT_UNDEFINED) { | |||||
| REPORT_INNER_ERROR("E19999", "tensor_proto date type %ld is undefined.", data_type); | |||||
| GELOGE(domi::PARAM_INVALID, "[Check][Param] tensor_proto date type %ld is undefined.", data_type); | |||||
| return FAILED; | |||||
| } | |||||
| op_def.SetAttr(kAttrDataType, type); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status OnnxFileConstantParser::ParsePath(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const { | |||||
| ge::NamedAttrs attrs; | |||||
| for (int32_t i = 0; i < tensor_proto.external_data_size(); ++i) { | |||||
| const ge::onnx::StringStringEntryProto &string_proto = tensor_proto.external_data(i); | |||||
| if (SetPathAttr(string_proto, attrs) != SUCCESS) { | |||||
| REPORT_INNER_ERROR("E19999", "external tensor proto[%s] parse attrs failed.", tensor_proto.name().c_str()); | |||||
| GELOGE(domi::PARAM_INVALID, "external tensor proto[%s] parse attrs failed.", tensor_proto.name().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| if (!attrs.HasAttr(kLocation)) { | |||||
| REPORT_INNER_ERROR("E19999", "external tensor proto[%s] must contain location.", tensor_proto.name().c_str()); | |||||
| GELOGE(domi::PARAM_INVALID, "external tensor proto[%s] must contain location.", tensor_proto.name().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| op_def.SetAttr(kFileConstantPath, attrs); | |||||
| GELOGD("The weight file of Op[%s] is: [%s].", tensor_proto.name().c_str(), attrs.GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status OnnxFileConstantParser::SetPathAttr(const ge::onnx::StringStringEntryProto &string_proto, | |||||
| ge::NamedAttrs &attrs) const { | |||||
| if (string_proto.key() == kLocation) { | |||||
| AttrUtils::SetStr(attrs, kLocation, string_proto.value()); | |||||
| } else { | |||||
| int64_t value; | |||||
| try { | |||||
| value = stol(string_proto.value()); | |||||
| } catch (const std::exception &e) { | |||||
| REPORT_INNER_ERROR("E19999", "Convert %s to int64_t value failed:%s", string_proto.value().c_str(), e.what()); | |||||
| GELOGE(domi::PARAM_INVALID, "Convert %s to int64_t value failed:%s", string_proto.value().c_str(), e.what()); | |||||
| return FAILED; | |||||
| } | |||||
| if (string_proto.key() == kOffset) { | |||||
| if (value > (std::numeric_limits<int64_t>::max() / kOffsetCoefficient)) { | |||||
| REPORT_INNER_ERROR("E19999", "overflow, kOffsetCoefficient[%ld], value[%ld]", kOffsetCoefficient, value); | |||||
| GELOGE(domi::PARAM_INVALID, "overflow, kOffsetCoefficient[%ld], value[%ld]", kOffsetCoefficient, value); | |||||
| return FAILED; | |||||
| } | |||||
| value *= kOffsetCoefficient; | |||||
| } | |||||
| AttrUtils::SetInt(attrs, string_proto.key(), value); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| REGISTER_OP_PARSER_CREATOR(ONNX, kFileConstant, OnnxFileConstantParser); | |||||
| } // namespace ge | |||||
| @@ -1,37 +0,0 @@ | |||||
| /** | |||||
| * Copyright (c) Huawei Technologies Co., Ltd. 2020. All rights reserved. | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_PARSER_ONNX_ONNX_FILE_CONSTANT_PARSER_H_ | |||||
| #define GE_PARSER_ONNX_ONNX_FILE_CONSTANT_PARSER_H_ | |||||
| #include "parser/onnx/onnx_op_parser.h" | |||||
| #include "proto/onnx/ge_onnx.pb.h" | |||||
| namespace ge { | |||||
| class PARSER_FUNC_VISIBILITY OnnxFileConstantParser : public OnnxOpParser { | |||||
| public: | |||||
| Status ParseParams(const Message *op_src, ge::Operator &op_def) override; | |||||
| private: | |||||
| Status ParsePath(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const; | |||||
| Status ParseDataType(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const; | |||||
| void ParseShape(const ge::onnx::TensorProto &tensor_proto, ge::Operator &op_def) const; | |||||
| Status GetTensorProto(const ge::onnx::NodeProto &node_proto, ge::onnx::TensorProto &tensor_proto) const; | |||||
| Status SetPathAttr(const ge::onnx::StringStringEntryProto &string_proto, ge::NamedAttrs &attrs) const; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // GE_PARSER_ONNX_ONNX_FILE_CONSTANT_PARSER_H_ | |||||
| @@ -166,8 +166,7 @@ namespace ge { | |||||
| namespace { | namespace { | ||||
| const std::map<std::string, std::string> kOnnxOpMap = { | const std::map<std::string, std::string> kOnnxOpMap = { | ||||
| {ge::kOpTypeInput, ge::parser::DATA}, | {ge::kOpTypeInput, ge::parser::DATA}, | ||||
| {ge::kOpTypeConstant, ge::parser::CONSTANT}, | |||||
| {ge::kFileConstant, ge::parser::FILECONSTANT} | |||||
| {ge::kOpTypeConstant, ge::parser::CONSTANT} | |||||
| }; | }; | ||||
| const int64_t kDimValue = 1; | const int64_t kDimValue = 1; | ||||
| @@ -183,7 +182,7 @@ Status GenSubgraphParseTasks(const ge::ComputeGraphPtr &parent_graph, std::deque | |||||
| for (auto &node : parent_graph->GetDirectNode()) { | for (auto &node : parent_graph->GetDirectNode()) { | ||||
| auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
| GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
| for (const auto subgraph_name_to_index : op_desc->GetSubgraphNameIndexes()) { | |||||
| for (const auto &subgraph_name_to_index : op_desc->GetSubgraphNameIndexes()) { | |||||
| auto i = subgraph_name_to_index.second; | auto i = subgraph_name_to_index.second; | ||||
| auto subgraph_iname = subgraph_name_to_index.first; | auto subgraph_iname = subgraph_name_to_index.first; | ||||
| if (subgraph_iname.empty()) { | if (subgraph_iname.empty()) { | ||||
| @@ -359,37 +358,18 @@ Status OnnxModelParser::ParseInitializer(ge::onnx::GraphProto &onnx_graph, | |||||
| ge::onnx::NodeProto *const_node = onnx_graph.add_node(); | ge::onnx::NodeProto *const_node = onnx_graph.add_node(); | ||||
| std::string output_name = it.first + "_" + to_string(index++); | std::string output_name = it.first + "_" + to_string(index++); | ||||
| const_node->set_name(output_name); | const_node->set_name(output_name); | ||||
| const_node->set_op_type(ge::kOpTypeConstant); | |||||
| const_node->add_output(it.first); | const_node->add_output(it.first); | ||||
| ge::onnx::AttributeProto *attribute = const_node->add_attribute(); | ge::onnx::AttributeProto *attribute = const_node->add_attribute(); | ||||
| attribute->set_name(ge::kAttrNameValue); | attribute->set_name(ge::kAttrNameValue); | ||||
| ge::onnx::TensorProto *attribute_t = attribute->mutable_t(); | ge::onnx::TensorProto *attribute_t = attribute->mutable_t(); | ||||
| *attribute_t = it.second; | *attribute_t = it.second; | ||||
| if (it.second.data_location() == ge::onnx::TensorProto_DataLocation_EXTERNAL) { | |||||
| const_node->set_op_type(kFileConstant); | |||||
| GELOGD("Initializer const node [%s], the weight was stored in the file.", const_node->name().c_str()); | |||||
| } else { | |||||
| const_node->set_op_type(ge::kOpTypeConstant); | |||||
| } | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| void OnnxModelParser::UpdateConstantOpType(ge::onnx::NodeProto *node) const { | |||||
| // If weight in file, Marker Constant(not Initializer) as file constant | |||||
| for (auto it : node->attribute()) { | |||||
| if (it.name() == ge::kAttrNameValue) { | |||||
| const ::ge::onnx::TensorProto tensor_proto = it.t(); | |||||
| if (tensor_proto.data_location() == ge::onnx::TensorProto_DataLocation_EXTERNAL) { | |||||
| node->set_op_type(kFileConstant); | |||||
| GELOGD("Const node [%s], the weight was stored in the file.", node->name().c_str()); | |||||
| } | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| void OnnxModelParser::UpdateNodeNameAndOpType(ge::onnx::GraphProto &onnx_graph) const { | |||||
| void OnnxModelParser::UpdateAllNodeName(ge::onnx::GraphProto &onnx_graph) const { | |||||
| int index = 0; | int index = 0; | ||||
| for (int i = 0; i < onnx_graph.node_size(); i++) { | for (int i = 0; i < onnx_graph.node_size(); i++) { | ||||
| ge::onnx::NodeProto *node = onnx_graph.mutable_node(i); | ge::onnx::NodeProto *node = onnx_graph.mutable_node(i); | ||||
| @@ -397,9 +377,6 @@ void OnnxModelParser::UpdateNodeNameAndOpType(ge::onnx::GraphProto &onnx_graph) | |||||
| std::string node_name = node->op_type() + "_" + to_string(index++); | std::string node_name = node->op_type() + "_" + to_string(index++); | ||||
| node->set_name(node_name); | node->set_name(node_name); | ||||
| } | } | ||||
| if (node->op_type() == kOpTypeConstant) { | |||||
| UpdateConstantOpType(node); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -984,7 +961,7 @@ Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphP | |||||
| } | } | ||||
| GELOGI("The size of initializer_name_tensor is %zu after ParseInput", initializer_name_tensor.size()); | GELOGI("The size of initializer_name_tensor is %zu after ParseInput", initializer_name_tensor.size()); | ||||
| // 3. Parse Constant(initializer) from graph. | |||||
| // 3. Parse Constant from graph. | |||||
| ret = ParseInitializer(onnx_graph, initializer_name_tensor); | ret = ParseInitializer(onnx_graph, initializer_name_tensor); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "[Parse][Initializer] for onnx failed."); | GELOGE(ret, "[Parse][Initializer] for onnx failed."); | ||||
| @@ -998,8 +975,8 @@ Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphP | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| // 5. Update node name for node do not has name, update const op type | |||||
| UpdateNodeNameAndOpType(onnx_graph); | |||||
| // 5. Update node name for node do not has name. | |||||
| UpdateAllNodeName(onnx_graph); | |||||
| // 6 Precheck. | // 6 Precheck. | ||||
| ret = Prechecker(onnx_graph); | ret = Prechecker(onnx_graph); | ||||
| @@ -105,9 +105,7 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser { | |||||
| Status ParseInitializer(ge::onnx::GraphProto &onnx_graph, | Status ParseInitializer(ge::onnx::GraphProto &onnx_graph, | ||||
| std::map<std::string, ge::onnx::TensorProto> &initializer_name_tensor) const; | std::map<std::string, ge::onnx::TensorProto> &initializer_name_tensor) const; | ||||
| void UpdateConstantOpType(ge::onnx::NodeProto *node) const; | |||||
| void UpdateNodeNameAndOpType(ge::onnx::GraphProto &onnx_graph) const; | |||||
| void UpdateAllNodeName(ge::onnx::GraphProto &onnx_graph) const; | |||||
| Status ConstructOriType(const ge::onnx::NodeProto *node_proto, std::string &ori_type); | Status ConstructOriType(const ge::onnx::NodeProto *node_proto, std::string &ori_type); | ||||
| @@ -50,7 +50,6 @@ const char *const kAttrNameIndex = "index"; | |||||
| const char *const kAttrNameIsSubgraphOp = "is_subgraph_op"; | const char *const kAttrNameIsSubgraphOp = "is_subgraph_op"; | ||||
| const char *const kOpTypeConstant = "Constant"; | const char *const kOpTypeConstant = "Constant"; | ||||
| const char *const kOpTypeInput = "Input"; | const char *const kOpTypeInput = "Input"; | ||||
| const char *const kFileConstant = "FileConstant"; | |||||
| class OnnxUtil { | class OnnxUtil { | ||||
| public: | public: | ||||
| @@ -148,7 +148,66 @@ void *memCpyS(void *dest, const void *src, UINT32 count) { | |||||
| return dest; | return dest; | ||||
| } | } | ||||
| INT32 mmRmdir(const CHAR *lp_path_name) { return rmdir(lp_path_name); } | |||||
| INT32 mmRmdir(const CHAR *lp_path_name) { | |||||
| INT32 ret; | |||||
| DIR *childDir = NULL; | |||||
| if (lp_path_name == NULL) { | |||||
| return EN_INVALID_PARAM; | |||||
| } | |||||
| DIR *dir = opendir(lp_path_name); | |||||
| if (dir == NULL) { | |||||
| return EN_INVALID_PARAM; | |||||
| } | |||||
| const struct dirent *entry = NULL; | |||||
| size_t bufSize = strlen(lp_path_name) + (size_t)(PATH_SIZE + 2); // make sure the length is large enough | |||||
| while ((entry = readdir(dir)) != NULL) { | |||||
| if ((strcmp(".", entry->d_name) == MMPA_ZERO) || (strcmp("..", entry->d_name) == MMPA_ZERO)) { | |||||
| continue; | |||||
| } | |||||
| CHAR *buf = (CHAR *)malloc(bufSize); | |||||
| if (buf == NULL) { | |||||
| break; | |||||
| } | |||||
| ret = memset_s(buf, bufSize, 0, bufSize); | |||||
| if (ret == EN_ERROR) { | |||||
| free(buf); | |||||
| buf = NULL; | |||||
| break; | |||||
| } | |||||
| ret = snprintf_s(buf, bufSize, bufSize - 1U, "%s/%s", lp_path_name, entry->d_name); | |||||
| if (ret == EN_ERROR) { | |||||
| free(buf); | |||||
| buf = NULL; | |||||
| break; | |||||
| } | |||||
| childDir = opendir(buf); | |||||
| if (childDir != NULL) { | |||||
| (VOID)closedir(childDir); | |||||
| (VOID)mmRmdir(buf); | |||||
| free(buf); | |||||
| buf = NULL; | |||||
| continue; | |||||
| } else { | |||||
| ret = unlink(buf); | |||||
| if (ret == EN_OK) { | |||||
| free(buf); | |||||
| continue; | |||||
| } | |||||
| } | |||||
| free(buf); | |||||
| buf = NULL; | |||||
| } | |||||
| (VOID)closedir(dir); | |||||
| ret = rmdir(lp_path_name); | |||||
| if (ret == EN_ERROR) { | |||||
| return EN_ERROR; | |||||
| } | |||||
| return EN_OK; | |||||
| } | |||||
| mmTimespec mmGetTickCount() { | mmTimespec mmGetTickCount() { | ||||
| mmTimespec rts; | mmTimespec rts; | ||||
| @@ -229,7 +288,14 @@ VOID mmScandirFree(mmDirent **entryList, INT32 count) | |||||
| INT32 mmAccess2(const CHAR *pathName, INT32 mode) | INT32 mmAccess2(const CHAR *pathName, INT32 mode) | ||||
| { | { | ||||
| return 0; | |||||
| if (pathName == NULL) { | |||||
| return EN_INVALID_PARAM; | |||||
| } | |||||
| INT32 ret = access(pathName, mode); | |||||
| if (ret != EN_OK) { | |||||
| return EN_ERROR; | |||||
| } | |||||
| return EN_OK; | |||||
| } | } | ||||
| INT32 mmGetTimeOfDay(mmTimeval *timeVal, mmTimezone *timeZone) | INT32 mmGetTimeOfDay(mmTimeval *timeVal, mmTimezone *timeZone) | ||||
| @@ -239,7 +305,15 @@ INT32 mmGetTimeOfDay(mmTimeval *timeVal, mmTimezone *timeZone) | |||||
| INT32 mmRealPath(const CHAR *path, CHAR *realPath, INT32 realPathLen) | INT32 mmRealPath(const CHAR *path, CHAR *realPath, INT32 realPathLen) | ||||
| { | { | ||||
| return 0; | |||||
| INT32 ret = EN_OK; | |||||
| if ((path == NULL) || (realPath == NULL) || (realPathLen < MMPA_MAX_PATH)) { | |||||
| return EN_INVALID_PARAM; | |||||
| } | |||||
| const CHAR *ptr = realpath(path, realPath); | |||||
| if (ptr == NULL) { | |||||
| ret = EN_ERROR; | |||||
| } | |||||
| return ret; | |||||
| } | } | ||||
| INT32 mmGetErrorCode() | INT32 mmGetErrorCode() | ||||
| @@ -275,7 +275,6 @@ set(PARSER_SRC_FILES | |||||
| "${PARSER_DIR}/parser/common/thread_pool.cc" | "${PARSER_DIR}/parser/common/thread_pool.cc" | ||||
| "${PARSER_DIR}/parser/common/auto_mapping_subgraph_io_index_func.cc" | "${PARSER_DIR}/parser/common/auto_mapping_subgraph_io_index_func.cc" | ||||
| "${PARSER_DIR}/parser/onnx/onnx_constant_parser.cc" | "${PARSER_DIR}/parser/onnx/onnx_constant_parser.cc" | ||||
| "${PARSER_DIR}/parser/onnx/onnx_file_constant_parser.cc" | |||||
| "${PARSER_DIR}/parser/onnx/onnx_custom_parser_adapter.cc" | "${PARSER_DIR}/parser/onnx/onnx_custom_parser_adapter.cc" | ||||
| "${PARSER_DIR}/parser/onnx/onnx_data_parser.cc" | "${PARSER_DIR}/parser/onnx/onnx_data_parser.cc" | ||||
| "${PARSER_DIR}/parser/onnx/onnx_parser.cc" | "${PARSER_DIR}/parser/onnx/onnx_parser.cc" | ||||
| @@ -106,7 +106,52 @@ void STestOnnxParser::RegisterCustomOp() { | |||||
| domi::OpRegistry::Instance()->registrationDatas.clear(); | domi::OpRegistry::Instance()->registrationDatas.clear(); | ||||
| } | } | ||||
| ge::onnx::GraphProto CreateOnnxGraph() { | |||||
| ge::onnx::GraphProto CreateOnnxGraph1() { | |||||
| std::string case_dir = __FILE__; | |||||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||||
| std::string external_weight_file = case_dir + "/origin_models/file_constant_weight.bin"; | |||||
| ge::onnx::GraphProto onnx_graph; | |||||
| (void)onnx_graph.add_input(); | |||||
| (void)onnx_graph.add_output(); | |||||
| ::ge::onnx::NodeProto* node_const1 = onnx_graph.add_node(); | |||||
| ::ge::onnx::NodeProto* node_const2 = onnx_graph.add_node(); | |||||
| ::ge::onnx::NodeProto* node_add = onnx_graph.add_node(); | |||||
| node_const1->set_op_type(kOpTypeConstant); | |||||
| node_const2->set_op_type(kOpTypeConstant); | |||||
| node_add->set_op_type("Add"); | |||||
| node_add->set_domain("ai.onnx"); | |||||
| ::ge::onnx::AttributeProto* attr = node_const1->add_attribute(); | |||||
| attr->set_name(ge::kAttrNameValue); | |||||
| ::ge::onnx::TensorProto* tensor_proto = attr->mutable_t(); | |||||
| tensor_proto->set_data_type(OnnxDataType::UINT8); | |||||
| tensor_proto->set_data_location(ge::onnx::TensorProto_DataLocation_EXTERNAL); | |||||
| ge::onnx::StringStringEntryProto *string_proto = tensor_proto->add_external_data(); | |||||
| string_proto->set_key("location"); | |||||
| string_proto->set_value(external_weight_file); | |||||
| ge::onnx::StringStringEntryProto *offset_proto = tensor_proto->add_external_data(); | |||||
| offset_proto->set_key("offset"); | |||||
| offset_proto->set_value("0"); | |||||
| ge::onnx::StringStringEntryProto *length_proto = tensor_proto->add_external_data(); | |||||
| length_proto->set_key("length"); | |||||
| length_proto->set_value("3"); | |||||
| tensor_proto->add_dims(3); | |||||
| attr = node_const2->add_attribute(); | |||||
| attr->set_name(ge::kAttrNameValue); | |||||
| tensor_proto = attr->mutable_t(); | |||||
| tensor_proto->set_data_type(OnnxDataType::UINT8); | |||||
| tensor_proto->set_data_location(ge::onnx::TensorProto_DataLocation_DEFAULT); | |||||
| tensor_proto->add_dims(3); | |||||
| size_t raw_data_size = 3; | |||||
| std::unique_ptr<uint8_t[]> raw_data(new (std::nothrow) uint8_t[raw_data_size / sizeof(uint8_t)]); | |||||
| tensor_proto->set_raw_data(reinterpret_cast<char *>(raw_data.get()), raw_data_size); | |||||
| return onnx_graph; | |||||
| } | |||||
| ge::onnx::GraphProto CreateOnnxGraph2() { | |||||
| ge::onnx::GraphProto onnx_graph; | ge::onnx::GraphProto onnx_graph; | ||||
| (void)onnx_graph.add_input(); | (void)onnx_graph.add_input(); | ||||
| (void)onnx_graph.add_output(); | (void)onnx_graph.add_output(); | ||||
| @@ -116,17 +161,66 @@ ge::onnx::GraphProto CreateOnnxGraph() { | |||||
| node_const1->set_op_type(kOpTypeConstant); | node_const1->set_op_type(kOpTypeConstant); | ||||
| node_const2->set_op_type(kOpTypeConstant); | node_const2->set_op_type(kOpTypeConstant); | ||||
| node_add->set_op_type("Add"); | node_add->set_op_type("Add"); | ||||
| node_add->set_domain("ai.onnx"); | |||||
| ::ge::onnx::AttributeProto* attr = node_const1->add_attribute(); | ::ge::onnx::AttributeProto* attr = node_const1->add_attribute(); | ||||
| attr->set_name(ge::kAttrNameValue); | attr->set_name(ge::kAttrNameValue); | ||||
| ::ge::onnx::TensorProto* tensor_proto = attr->mutable_t(); | ::ge::onnx::TensorProto* tensor_proto = attr->mutable_t(); | ||||
| tensor_proto->set_data_type(OnnxDataType::UINT8); | |||||
| tensor_proto->set_data_location(ge::onnx::TensorProto_DataLocation_EXTERNAL); | tensor_proto->set_data_location(ge::onnx::TensorProto_DataLocation_EXTERNAL); | ||||
| attr = node_const1->add_attribute(); | |||||
| tensor_proto->add_dims(3); | |||||
| attr = node_const2->add_attribute(); | attr = node_const2->add_attribute(); | ||||
| attr->set_name(ge::kAttrNameValue); | attr->set_name(ge::kAttrNameValue); | ||||
| tensor_proto = attr->mutable_t(); | tensor_proto = attr->mutable_t(); | ||||
| tensor_proto->set_data_type(OnnxDataType::UINT8); | |||||
| tensor_proto->set_data_location(ge::onnx::TensorProto_DataLocation_DEFAULT); | tensor_proto->set_data_location(ge::onnx::TensorProto_DataLocation_DEFAULT); | ||||
| tensor_proto->add_dims(3); | |||||
| size_t raw_data_size = 3; | |||||
| std::unique_ptr<uint8_t[]> raw_data(new (std::nothrow) uint8_t[raw_data_size / sizeof(uint8_t)]); | |||||
| tensor_proto->set_raw_data(reinterpret_cast<char *>(raw_data.get()), raw_data_size); | |||||
| return onnx_graph; | |||||
| } | |||||
| ge::onnx::GraphProto CreateOnnxGraph3() { | |||||
| std::string case_dir = __FILE__; | |||||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||||
| std::string external_weight_file = case_dir + "/origin_models/file_constant_weight.bin"; | |||||
| ge::onnx::GraphProto onnx_graph; | |||||
| (void)onnx_graph.add_input(); | |||||
| (void)onnx_graph.add_output(); | |||||
| ::ge::onnx::NodeProto* node_const1 = onnx_graph.add_node(); | |||||
| ::ge::onnx::NodeProto* node_const2 = onnx_graph.add_node(); | |||||
| ::ge::onnx::NodeProto* node_add = onnx_graph.add_node(); | |||||
| node_const1->set_op_type(kOpTypeConstant); | |||||
| node_const2->set_op_type(kOpTypeConstant); | |||||
| node_add->set_op_type("Add"); | |||||
| node_add->set_domain("ai.onnx"); | |||||
| ::ge::onnx::AttributeProto* attr = node_const1->add_attribute(); | |||||
| attr->set_name(ge::kAttrNameValue); | |||||
| ::ge::onnx::TensorProto* tensor_proto = attr->mutable_t(); | |||||
| tensor_proto->set_data_type(OnnxDataType::UINT8); | |||||
| tensor_proto->set_data_location(ge::onnx::TensorProto_DataLocation_EXTERNAL); | |||||
| ge::onnx::StringStringEntryProto *string_proto = tensor_proto->add_external_data(); | |||||
| string_proto->set_key("location"); | |||||
| string_proto->set_value(external_weight_file); | |||||
| ge::onnx::StringStringEntryProto *offset_proto = tensor_proto->add_external_data(); | |||||
| offset_proto->set_key("offset"); | |||||
| offset_proto->set_value("9999999999999999999999999999999999"); | |||||
| tensor_proto->add_dims(3); | |||||
| attr = node_const2->add_attribute(); | |||||
| attr->set_name(ge::kAttrNameValue); | |||||
| tensor_proto = attr->mutable_t(); | |||||
| tensor_proto->set_data_type(OnnxDataType::UINT8); | |||||
| tensor_proto->set_data_location(ge::onnx::TensorProto_DataLocation_DEFAULT); | |||||
| tensor_proto->add_dims(3); | |||||
| size_t raw_data_size = 3; | |||||
| std::unique_ptr<uint8_t[]> raw_data(new (std::nothrow) uint8_t[raw_data_size / sizeof(uint8_t)]); | |||||
| tensor_proto->set_raw_data(reinterpret_cast<char *>(raw_data.get()), raw_data_size); | |||||
| return onnx_graph; | return onnx_graph; | ||||
| } | } | ||||
| @@ -212,15 +306,42 @@ TEST_F(STestOnnxParser, onnx_parser_if_node_with_const_input) { | |||||
| EXPECT_EQ(ret, GRAPH_SUCCESS); | EXPECT_EQ(ret, GRAPH_SUCCESS); | ||||
| } | } | ||||
| TEST_F(STestOnnxParser, onnx_test_ModelParseToGraph) | |||||
| TEST_F(STestOnnxParser, onnx_test_ModelParseToGraph_1) | |||||
| { | { | ||||
| OnnxModelParser modelParser; | |||||
| OnnxModelParser model_parser; | |||||
| model_parser.domain_verseion_["ai.onnx"] = 11; | |||||
| ge::onnx::ModelProto model_proto; | ge::onnx::ModelProto model_proto; | ||||
| auto onnx_graph = model_proto.mutable_graph(); | auto onnx_graph = model_proto.mutable_graph(); | ||||
| *onnx_graph = CreateOnnxGraph(); | |||||
| ge::Graph graph; | |||||
| *onnx_graph = CreateOnnxGraph1(); | |||||
| ge::Graph graph("graph"); | |||||
| Status ret = model_parser.ModelParseToGraph(model_proto, graph); | |||||
| EXPECT_EQ(ret, FAILED); | |||||
| } | |||||
| TEST_F(STestOnnxParser, onnx_test_ModelParseToGraph_2) | |||||
| { | |||||
| OnnxModelParser model_parser; | |||||
| model_parser.domain_verseion_["ai.onnx"] = 11; | |||||
| ge::onnx::ModelProto model_proto; | |||||
| auto onnx_graph = model_proto.mutable_graph(); | |||||
| *onnx_graph = CreateOnnxGraph2(); | |||||
| ge::Graph graph("graph"); | |||||
| Status ret = model_parser.ModelParseToGraph(model_proto, graph); | |||||
| EXPECT_EQ(ret, FAILED); | |||||
| } | |||||
| TEST_F(STestOnnxParser, onnx_test_ModelParseToGraph_3) | |||||
| { | |||||
| OnnxModelParser model_parser; | |||||
| model_parser.domain_verseion_["ai.onnx"] = 11; | |||||
| ge::onnx::ModelProto model_proto; | |||||
| auto onnx_graph = model_proto.mutable_graph(); | |||||
| *onnx_graph = CreateOnnxGraph3(); | |||||
| ge::Graph graph("graph"); | |||||
| Status ret = modelParser.ModelParseToGraph(model_proto, graph); | |||||
| Status ret = model_parser.ModelParseToGraph(model_proto, graph); | |||||
| EXPECT_EQ(ret, FAILED); | EXPECT_EQ(ret, FAILED); | ||||
| } | } | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -274,7 +274,6 @@ set(PARSER_SRC_FILES | |||||
| "${PARSER_DIR}/parser/common/thread_pool.cc" | "${PARSER_DIR}/parser/common/thread_pool.cc" | ||||
| "${PARSER_DIR}/parser/common/auto_mapping_subgraph_io_index_func.cc" | "${PARSER_DIR}/parser/common/auto_mapping_subgraph_io_index_func.cc" | ||||
| "${PARSER_DIR}/parser/onnx/onnx_constant_parser.cc" | "${PARSER_DIR}/parser/onnx/onnx_constant_parser.cc" | ||||
| "${PARSER_DIR}/parser/onnx/onnx_file_constant_parser.cc" | |||||
| "${PARSER_DIR}/parser/onnx/onnx_custom_parser_adapter.cc" | "${PARSER_DIR}/parser/onnx/onnx_custom_parser_adapter.cc" | ||||
| "${PARSER_DIR}/parser/onnx/onnx_data_parser.cc" | "${PARSER_DIR}/parser/onnx/onnx_data_parser.cc" | ||||
| "${PARSER_DIR}/parser/onnx/onnx_parser.cc" | "${PARSER_DIR}/parser/onnx/onnx_parser.cc" | ||||
| @@ -17,10 +17,8 @@ | |||||
| #include <gtest/gtest.h> | #include <gtest/gtest.h> | ||||
| #include <iostream> | #include <iostream> | ||||
| #include "parser/common/op_parser_factory.h" | #include "parser/common/op_parser_factory.h" | ||||
| #include "graph/operator_reg.h" | |||||
| #include "external/graph/types.h" | #include "external/graph/types.h" | ||||
| #include "register/op_registry.h" | #include "register/op_registry.h" | ||||
| #include "parser/common/op_registration_tbe.h" | |||||
| #include "external/parser/onnx_parser.h" | #include "external/parser/onnx_parser.h" | ||||
| #include "ut/parser/parser_ut_utils.h" | #include "ut/parser/parser_ut_utils.h" | ||||
| #include "external/ge/ge_api_types.h" | #include "external/ge/ge_api_types.h" | ||||
| @@ -30,7 +28,6 @@ | |||||
| #define protected public | #define protected public | ||||
| #define private public | #define private public | ||||
| #include "parser/onnx/onnx_constant_parser.h" | #include "parser/onnx/onnx_constant_parser.h" | ||||
| #include "parser/onnx/onnx_file_constant_parser.h" | |||||
| #include "parser/onnx/onnx_util.h" | #include "parser/onnx/onnx_util.h" | ||||
| #include "parser/onnx/onnx_parser.h" | #include "parser/onnx/onnx_parser.h" | ||||
| #undef protected | #undef protected | ||||
| @@ -399,188 +396,109 @@ TEST_F(UtestOnnxParser, OnnxConstantParser_ParseConvertDataType_test) | |||||
| EXPECT_EQ(ret, FAILED); | EXPECT_EQ(ret, FAILED); | ||||
| } | } | ||||
| TEST_F(UtestOnnxParser, FileConstantGetTensorProto) | |||||
| TEST_F(UtestOnnxParser, OnnxConstantParser_ParseExternalWeight_test) | |||||
| { | { | ||||
| OnnxFileConstantParser parser; | |||||
| ge::onnx::NodeProto input_node; | |||||
| ge::onnx::TensorProto tensor_proto; | |||||
| Status ret = parser.GetTensorProto(input_node, tensor_proto); | |||||
| EXPECT_EQ(ret, FAILED); | |||||
| ge::onnx::AttributeProto *attribute = input_node.add_attribute(); | |||||
| attribute->set_name("attribute"); | |||||
| attribute = input_node.add_attribute(); | |||||
| attribute->set_name("value"); | |||||
| ge::onnx::TensorProto *attribute_tensor = attribute->mutable_t(); | |||||
| *attribute_tensor = tensor_proto; | |||||
| ret = parser.GetTensorProto(input_node, tensor_proto); | |||||
| EXPECT_EQ(ret, SUCCESS); | |||||
| } | |||||
| TEST_F(UtestOnnxParser, FileConstantParseShape) | |||||
| { | |||||
| OnnxFileConstantParser parser; | |||||
| ge::onnx::TensorProto tensor_proto; | |||||
| tensor_proto.add_dims(4); | |||||
| tensor_proto.add_dims(2); | |||||
| ge::OpDescPtr op_desc_src = std::make_shared<ge::OpDesc>("file_constant", "FileConstant"); | |||||
| ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc_src); | |||||
| parser.ParseShape(tensor_proto, op); | |||||
| std::vector<int64_t> attr_value; | |||||
| op.GetAttr("shape", attr_value); | |||||
| EXPECT_EQ(attr_value.size(), 2U); | |||||
| if (attr_value.size() == 2U) { | |||||
| EXPECT_EQ(attr_value[0], 4); | |||||
| EXPECT_EQ(attr_value[1], 2); | |||||
| } | |||||
| } | |||||
| std::string case_dir = __FILE__; | |||||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||||
| std::string external_weight_file = case_dir + "/onnx_model/file_constant_weight.bin"; | |||||
| TEST_F(UtestOnnxParser, FileConstantParseDataType) | |||||
| { | |||||
| OnnxFileConstantParser parser; | |||||
| OnnxConstantParser constant_parser; | |||||
| ge::onnx::TensorProto tensor_proto; | ge::onnx::TensorProto tensor_proto; | ||||
| tensor_proto.set_data_type(OnnxDataType::UNDEFINED); | |||||
| ge::OpDescPtr op_desc_src = std::make_shared<ge::OpDesc>("file_constant", "FileConstant"); | |||||
| ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc_src); | |||||
| Status ret = parser.ParseDataType(tensor_proto, op); | |||||
| // without location, error | |||||
| ge::Tensor tensor ; | |||||
| TensorDesc tensor_desc = tensor.GetTensorDesc(); | |||||
| tensor_desc.SetDataType(ge::DataType::DT_UINT8); | |||||
| tensor_desc.SetShape(ge::Shape({3})); | |||||
| tensor.SetTensorDesc(tensor_desc); | |||||
| auto ret = constant_parser.ParseExternalWeight(tensor_proto, tensor); | |||||
| EXPECT_EQ(ret, FAILED); | EXPECT_EQ(ret, FAILED); | ||||
| tensor_proto.set_data_type(OnnxDataType::UINT8); | |||||
| ret = parser.ParseDataType(tensor_proto, op); | |||||
| EXPECT_EQ(ret, SUCCESS); | |||||
| ge::DataType attr_value; | |||||
| op.GetAttr("dtype", attr_value); | |||||
| EXPECT_EQ(attr_value, ge::DataType::DT_UINT8); | |||||
| } | |||||
| TEST_F(UtestOnnxParser, FileConstantParseAttr) | |||||
| { | |||||
| OnnxFileConstantParser parser; | |||||
| ge::onnx::StringStringEntryProto string_proto; | |||||
| ge::NamedAttrs attrs; | |||||
| // test location | |||||
| string_proto.set_key("location"); | |||||
| string_proto.set_value("/usr/local"); | |||||
| Status ret = parser.SetPathAttr(string_proto, attrs); | |||||
| EXPECT_EQ(ret, SUCCESS); | |||||
| std::string attr_value; | |||||
| AttrUtils::GetStr(attrs, "location", attr_value); | |||||
| EXPECT_EQ(attr_value, "/usr/local"); | |||||
| // test offset | |||||
| string_proto.set_key("offset"); | |||||
| string_proto.set_value("123"); | |||||
| ret = parser.SetPathAttr(string_proto, attrs); | |||||
| // test location, success | |||||
| ge::onnx::StringStringEntryProto *string_proto = tensor_proto.add_external_data(); | |||||
| string_proto->set_key("location"); | |||||
| string_proto->set_value(external_weight_file); | |||||
| ret = constant_parser.ParseExternalWeight(tensor_proto, tensor); | |||||
| EXPECT_EQ(ret, SUCCESS); | EXPECT_EQ(ret, SUCCESS); | ||||
| int64_t offset_value; | |||||
| AttrUtils::GetInt(attrs, "offset", offset_value); | |||||
| EXPECT_EQ(offset_value, 123 * 4096); | |||||
| // offset overflow | |||||
| string_proto.set_key("offset"); | |||||
| string_proto.set_value("9223372036854775800"); | |||||
| ret = parser.SetPathAttr(string_proto, attrs); | |||||
| EXPECT_EQ(ret, FAILED); | |||||
| // itol exception | |||||
| string_proto.set_key("offset"); | |||||
| string_proto.set_value("999999999999999999999999999999999999"); | |||||
| ret = parser.SetPathAttr(string_proto, attrs); | |||||
| EXPECT_EQ(ret, FAILED); | |||||
| } | |||||
| TEST_F(UtestOnnxParser, FileConstantParsePath) | |||||
| { | |||||
| OnnxFileConstantParser parser; | |||||
| ge::onnx::TensorProto tensor_proto; | |||||
| ge::OpDescPtr op_desc_src = std::make_shared<ge::OpDesc>("file_constant", "FileConstant"); | |||||
| ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc_src); | |||||
| // without location, error | |||||
| auto ret = parser.ParsePath(tensor_proto, op); | |||||
| EXPECT_EQ(ret, FAILED); | |||||
| // SetPathAttr error | |||||
| // test offset, overflow | |||||
| ge::onnx::StringStringEntryProto *offset_proto = tensor_proto.add_external_data(); | ge::onnx::StringStringEntryProto *offset_proto = tensor_proto.add_external_data(); | ||||
| offset_proto->set_key("offset"); | offset_proto->set_key("offset"); | ||||
| offset_proto->set_value("999999999999999999999999999999"); | |||||
| ret = parser.ParsePath(tensor_proto, op); | |||||
| offset_proto->set_value("9999999999999999999999999999999999"); | |||||
| ret = constant_parser.ParseExternalWeight(tensor_proto, tensor); | |||||
| EXPECT_EQ(ret, FAILED); | EXPECT_EQ(ret, FAILED); | ||||
| // has location, success | |||||
| ge::onnx::StringStringEntryProto *string_proto = tensor_proto.add_external_data(); | |||||
| string_proto->set_key("location"); | |||||
| string_proto->set_value("/usr/local"); | |||||
| // test tensor data | |||||
| offset_proto->set_key("offset"); | offset_proto->set_key("offset"); | ||||
| offset_proto->set_value("0"); | offset_proto->set_value("0"); | ||||
| ret = parser.ParsePath(tensor_proto, op); | |||||
| ge::onnx::StringStringEntryProto *length_proto = tensor_proto.add_external_data(); | |||||
| length_proto->set_key("length"); | |||||
| length_proto->set_value("3"); | |||||
| ret = constant_parser.ParseExternalWeight(tensor_proto, tensor); | |||||
| EXPECT_EQ(ret, SUCCESS); | EXPECT_EQ(ret, SUCCESS); | ||||
| // check location | |||||
| std::string attr_value; | |||||
| ge::NamedAttrs attrs; | |||||
| AttrUtils::GetNamedAttrs(op_desc_src, "file_constant_path", attrs); | |||||
| AttrUtils::GetStr(attrs, "location", attr_value); | |||||
| EXPECT_EQ(attr_value, "/usr/local"); | |||||
| auto tensor_size = tensor.GetSize(); | |||||
| EXPECT_EQ(tensor_size, 3); | |||||
| auto tensor_data0 = tensor.GetData()[0]; | |||||
| EXPECT_EQ(tensor_data0, static_cast<uint8_t>(0)); | |||||
| auto tensor_data1 = tensor.GetData()[1]; | |||||
| EXPECT_EQ(tensor_data1, static_cast<uint8_t>(1)); | |||||
| auto tensor_data2 = tensor.GetData()[2]; | |||||
| EXPECT_EQ(tensor_data2, static_cast<uint8_t>(2)); | |||||
| } | } | ||||
| TEST_F(UtestOnnxParser, FileConstantParseParam) | |||||
| TEST_F(UtestOnnxParser, OnnxConstantParser_ParseParams_test) | |||||
| { | { | ||||
| OnnxFileConstantParser parser; | |||||
| std::string case_dir = __FILE__; | |||||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||||
| std::string external_weight_file = case_dir + "/onnx_model/file_constant_weight.bin"; | |||||
| OnnxConstantParser constant_parser; | |||||
| ge::onnx::NodeProto input_node; | ge::onnx::NodeProto input_node; | ||||
| ge::OpDescPtr op_desc_src = std::make_shared<ge::OpDesc>("file_constant", "FileConstant"); | |||||
| ge::OpDescPtr op_desc_src = std::make_shared<ge::OpDesc>("Constant", "const.onnx"); | |||||
| ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc_src); | ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc_src); | ||||
| // get tensor proto failed | |||||
| auto ret = parser.ParseParams(reinterpret_cast<Message *>(&input_node), op); | |||||
| EXPECT_EQ(ret, FAILED); | |||||
| ge::onnx::TensorProto tensor_proto; | |||||
| ge::onnx::AttributeProto *attribute = input_node.add_attribute(); | ge::onnx::AttributeProto *attribute = input_node.add_attribute(); | ||||
| attribute->set_name("value"); | attribute->set_name("value"); | ||||
| ge::onnx::TensorProto *attribute_tensor = attribute->mutable_t(); | ge::onnx::TensorProto *attribute_tensor = attribute->mutable_t(); | ||||
| *attribute_tensor = tensor_proto; | |||||
| // parse data type failed | // parse data type failed | ||||
| attribute_tensor->set_data_type(OnnxDataType::UNDEFINED); | attribute_tensor->set_data_type(OnnxDataType::UNDEFINED); | ||||
| ret = parser.ParseParams(reinterpret_cast<Message *>(&input_node), op); | |||||
| auto ret = constant_parser.ParseParams(reinterpret_cast<Message *>(&input_node), op); | |||||
| EXPECT_EQ(ret, FAILED); | EXPECT_EQ(ret, FAILED); | ||||
| // parse path failed | |||||
| attribute_tensor->set_data_type(OnnxDataType::UINT16); | |||||
| ret = parser.ParseParams(reinterpret_cast<Message *>(&input_node), op); | |||||
| // without location, error | |||||
| attribute_tensor->set_data_type(OnnxDataType::UINT8); | |||||
| attribute_tensor->set_data_location(ge::onnx::TensorProto_DataLocation_EXTERNAL); | |||||
| ret = constant_parser.ParseParams(reinterpret_cast<Message *>(&input_node), op); | |||||
| EXPECT_EQ(ret, FAILED); | EXPECT_EQ(ret, FAILED); | ||||
| // success | // success | ||||
| ge::onnx::StringStringEntryProto *string_proto = attribute_tensor->add_external_data(); | ge::onnx::StringStringEntryProto *string_proto = attribute_tensor->add_external_data(); | ||||
| string_proto->set_key("location"); | string_proto->set_key("location"); | ||||
| string_proto->set_value("/usr/local"); | |||||
| attribute_tensor->add_dims(4); | |||||
| ret = parser.ParseParams(reinterpret_cast<Message *>(&input_node), op); | |||||
| string_proto->set_value(external_weight_file); | |||||
| attribute_tensor->add_dims(3); | |||||
| ret = constant_parser.ParseParams(reinterpret_cast<Message *>(&input_node), op); | |||||
| EXPECT_EQ(ret, SUCCESS); | EXPECT_EQ(ret, SUCCESS); | ||||
| // check location, shape, dtype | |||||
| NamedAttrs attrs; | |||||
| AttrUtils::GetNamedAttrs(*op_desc_src, "file_constant_path", attrs); | |||||
| std::string file_path; | |||||
| AttrUtils::GetStr(attrs, "location", file_path); | |||||
| EXPECT_EQ(file_path, "/usr/local"); | |||||
| // check tensor value | |||||
| ge::Tensor tensor; | |||||
| op.GetAttr("value", tensor); | |||||
| auto tensor_size = tensor.GetSize(); | |||||
| EXPECT_EQ(tensor_size, 3); | |||||
| auto tensor_data0 = tensor.GetData()[0]; | |||||
| EXPECT_EQ(tensor_data0, static_cast<uint8_t>(0)); | |||||
| auto tensor_data1 = tensor.GetData()[1]; | |||||
| EXPECT_EQ(tensor_data1, static_cast<uint8_t>(1)); | |||||
| auto tensor_data2 = tensor.GetData()[2]; | |||||
| EXPECT_EQ(tensor_data2, static_cast<uint8_t>(2)); | |||||
| // check shape, dtype | |||||
| std::vector<int64_t> dims; | std::vector<int64_t> dims; | ||||
| op.GetAttr("shape", dims); | |||||
| EXPECT_EQ(dims.size(), 1); | |||||
| if (!dims.empty()) { | |||||
| EXPECT_EQ(dims[0], 4); | |||||
| } | |||||
| DataType dtype; | |||||
| op.GetAttr("dtype", dtype); | |||||
| EXPECT_EQ(dtype, ge::DataType::DT_UINT16); | |||||
| dims = tensor.GetTensorDesc().GetShape().GetDims(); | |||||
| ASSERT_EQ(dims.size(), 1); | |||||
| EXPECT_EQ(dims[0], 3); | |||||
| DataType dtype = tensor.GetTensorDesc().GetDataType(); | |||||
| EXPECT_EQ(dtype, ge::DataType::DT_UINT8); | |||||
| } | } | ||||
| TEST_F(UtestOnnxParser, OnnxModelParser_ParseInput_test) | TEST_F(UtestOnnxParser, OnnxModelParser_ParseInput_test) | ||||
| @@ -598,16 +516,6 @@ TEST_F(UtestOnnxParser, OnnxModelParser_ParseInput_test) | |||||
| EXPECT_EQ(ret, domi::FAILED); | EXPECT_EQ(ret, domi::FAILED); | ||||
| } | } | ||||
| TEST_F(UtestOnnxParser, OnnxModelParser_ParseConstant_test) | |||||
| { | |||||
| OnnxModelParser model_parser; | |||||
| ge::onnx::GraphProto onnx_graph = CreateOnnxGraph(); | |||||
| model_parser.UpdateNodeNameAndOpType(onnx_graph); | |||||
| std::string type = onnx_graph.mutable_node(0)->op_type(); | |||||
| EXPECT_EQ(type, kFileConstant); | |||||
| } | |||||
| TEST_F(UtestOnnxParser, onnx_test_ConstructOriType) | TEST_F(UtestOnnxParser, onnx_test_ConstructOriType) | ||||
| { | { | ||||
| ge::onnx::ModelProto model_proto; | ge::onnx::ModelProto model_proto; | ||||