| @@ -0,0 +1,32 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_EXTERNAL_ACL_GRAPH_CAFFE_H_ | |||||
| #define INC_EXTERNAL_ACL_GRAPH_CAFFE_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "graph/ge_error_codes.h" | |||||
| #include "graph/types.h" | |||||
| #include "graph/graph.h" | |||||
| namespace ge { | |||||
| graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, ge::Graph &graph); | |||||
| } // namespace ge | |||||
| #endif // INC_EXTERNAL_ACL_GRAPH_CAFFE_H_ | |||||
| @@ -0,0 +1,33 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_EXTERNAL_ACL_PARSER_TENSORFLOW_H_ | |||||
| #define INC_EXTERNAL_ACL_PARSER_TENSORFLOW_H_ | |||||
| #include <atomic> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "graph/ge_error_codes.h" | |||||
| #include "graph/types.h" | |||||
| #include "graph/graph.h" | |||||
| namespace ge { | |||||
| graphStatus aclgrphParseTensorFlow(const char *model_file, ge::Graph &graph); | |||||
| } // namespace ge | |||||
| #endif // INC_EXTERNAL_ACL_PARSER_TENSORFLOW_H_ | |||||
| @@ -0,0 +1,150 @@ | |||||
| set(PROTO_LIST | |||||
| "${TOP_DIR}/inc/register/proto/tensorflow/graph_library.proto" | |||||
| ) | |||||
| set(SRC_LIST | |||||
| "tensorflow/tensorflow_arg_parser.cc" | |||||
| "tensorflow/tensorflow_auto_mapping_parser_adapter.cc" | |||||
| "tensorflow/tensorflow_constant_parser.cc" | |||||
| "tensorflow/tensorflow_data_parser.cc" | |||||
| "tensorflow/tensorflow_enter_parser.cc" | |||||
| "tensorflow/tensorflow_fill_parser.cc" | |||||
| "tensorflow/tensorflow_frameworkop_parser.cc" | |||||
| "tensorflow/tensorflow_fusionop_util.cc" | |||||
| "tensorflow/tensorflow_identity_parser.cc" | |||||
| "tensorflow/tensorflow_merge_parser.cc" | |||||
| "tensorflow/tensorflow_no_op_parser.cc" | |||||
| "tensorflow/tensorflow_parser.cc" | |||||
| "tensorflow/tensorflow_ref_switch_parser.cc" | |||||
| "tensorflow/tensorflow_reshape_parser.cc" | |||||
| "tensorflow/tensorflow_shape_n_parser.cc" | |||||
| "tensorflow/tensorflow_squeeze_parser.cc" | |||||
| "tensorflow/tensorflow_var_is_initialized_op_parser.cc" | |||||
| "tensorflow/tensorflow_variable_v2_parser.cc" | |||||
| "caffe/caffe_parser.cc" | |||||
| "caffe/caffe_data_parser.cc" | |||||
| "caffe/caffe_reshape_parser.cc" | |||||
| "caffe/caffe_custom_parser_adapter.cc" | |||||
| "caffe/caffe_op_parser.cc" | |||||
| "tensorflow/scope/scope_pass_manager.cc" | |||||
| "tensorflow/graph_functiondef.cc" | |||||
| "tensorflow/graph_optimizer.cc" | |||||
| "tensorflow/iterator_fusion_pass.cc" | |||||
| "common/op_def/arg_op.cc" | |||||
| "common/op_def/constant_op.cc" | |||||
| "common/op_def/fill_op.cc" | |||||
| "common/op_def/frameworkop_op.cc" | |||||
| "common/op_def/no_op_op.cc" | |||||
| "common/op_def/ref_switch_op.cc" | |||||
| "common/op_def/shape_n_op.cc" | |||||
| "common/op_def/var_is_initialized_op_op.cc" | |||||
| "common/op_def/variable_op.cc" | |||||
| ) | |||||
| protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | |||||
| ############ libfmk_parser.so ############ | |||||
| add_library(fmk_parser SHARED ${SRC_LIST} ${PROTO_SRCS}) | |||||
| target_compile_options(fmk_parser PRIVATE | |||||
| -Werror | |||||
| ) | |||||
| target_compile_definitions(fmk_parser PRIVATE | |||||
| PROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||||
| ) | |||||
| target_include_directories(fmk_parser PRIVATE | |||||
| ${CMAKE_CURRENT_LIST_DIR} | |||||
| ${TOP_DIR}/framework/domi | |||||
| ${TOP_DIR}/framework/domi/common | |||||
| ${TOP_DIR}/framework/domi/parser | |||||
| ${TOP_DIR}/inc | |||||
| ${TOP_DIR}/inc/external | |||||
| ${TOP_DIR}/inc/external/parser | |||||
| ${TOP_DIR}/inc/external/graph | |||||
| ${TOP_DIR}/inc/framework | |||||
| ${CMAKE_BINARY_DIR} | |||||
| ${CMAKE_BINARY_DIR}/proto/ge | |||||
| ) | |||||
| target_link_libraries(fmk_parser | |||||
| $<BUILD_INTERFACE:intf_pub> | |||||
| -Wl,--no-as-needed | |||||
| protobuf | |||||
| error_manager | |||||
| parser_common | |||||
| graph | |||||
| register | |||||
| _caffe_parser | |||||
| c_sec | |||||
| slog | |||||
| mmpa | |||||
| -Wl,--as-needed | |||||
| json | |||||
| -lrt | |||||
| ) | |||||
| ################################################################## | |||||
| add_custom_command( | |||||
| OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/stub_tensorflow_parser.cc | |||||
| ${CMAKE_CURRENT_BINARY_DIR}/stub_caffe_parser.cc | |||||
| COMMAND echo "Generating stub files." | |||||
| && ${HI_PYTHON} ${CMAKE_CURRENT_LIST_DIR}/../stub/gen_stubapi.py ${TOP_DIR}/inc/external ${CMAKE_CURRENT_BINARY_DIR} | |||||
| && mv tensorflow_parser.cc stub_tensorflow_parser.cc | |||||
| && mv caffe_parser.cc stub_caffe_parser.cc | |||||
| && echo "Generating stub files end." | |||||
| WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} | |||||
| DEPENDS ../stub/gen_stubapi.py ${TOP_DIR}/inc/external ${CMAKE_CURRENT_BINARY_DIR} | |||||
| ) | |||||
| ################################################################## | |||||
| ############ stub/libfmk_parser.so ############ | |||||
| add_library(fmk_parser_stub SHARED | |||||
| ${CMAKE_CURRENT_BINARY_DIR}/stub_tensorflow_parser.cc | |||||
| ${CMAKE_CURRENT_BINARY_DIR}/stub_caffe_parser.cc | |||||
| ) | |||||
| target_compile_options(fmk_parser_stub PRIVATE | |||||
| -O2 | |||||
| ) | |||||
| target_compile_definitions(fmk_parser_stub PRIVATE | |||||
| $<$<STREQUAL:${PRODUCT_SIDE},host>:FMK_SUPPORT_DUMP> | |||||
| PROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||||
| REUSE_MEMORY=1 | |||||
| FMK_HOST_INFER | |||||
| ) | |||||
| target_include_directories(fmk_parser_stub PRIVATE | |||||
| ${CMAKE_CURRENT_LIST_DIR} | |||||
| ${TOP_DIR}/inc | |||||
| ${TOP_DIR}/inc/external | |||||
| ${TOP_DIR}/inc/external/parser | |||||
| ${TOP_DIR}/inc/external/graph | |||||
| ${TOP_DIR}/inc/framework | |||||
| ${CMAKE_BINARY_DIR} | |||||
| ${CMAKE_CURRENT_BINARY_DIR} | |||||
| ) | |||||
| target_link_libraries(fmk_parser_stub PRIVATE | |||||
| $<BUILD_INTERFACE:intf_pub> | |||||
| ) | |||||
| set_target_properties(fmk_parser_stub PROPERTIES | |||||
| OUTPUT_NAME fmk_parser | |||||
| LIBRARY_OUTPUT_DIRECTORY stub | |||||
| ) | |||||
| ############ install ############ | |||||
| set(INSTALL_BASE_DIR "") | |||||
| set(INSTALL_LIBRARY_DIR lib) | |||||
| install(TARGETS fmk_parser OPTIONAL | |||||
| LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} | |||||
| ) | |||||
| install(TARGETS fmk_parser_stub OPTIONAL | |||||
| LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR}/stub | |||||
| ) | |||||
| @@ -0,0 +1,144 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "parser/caffe/caffe_custom_parser_adapter.h" | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "common/debug/log.h" | |||||
| #include "common/ge/ge_util.h" | |||||
| #include "common/util.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/omg/omg_inner_types.h" | |||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| #include "parser/common/op_parser_factory.h" | |||||
| #include "register/op_registry.h" | |||||
| using domi::ParseParamByOpFunc; | |||||
| using domi::ParseParamFunc; | |||||
| using std::vector; | |||||
| namespace ge { | |||||
| namespace { | |||||
| const char *const kConvolution = "Convolution"; | |||||
| const char *const kInnerProduct = "InnerProduct"; | |||||
| const int64_t kDimDedaultValue = 1; | |||||
| const int kBlobIndexOne = 1; | |||||
| } // namespace | |||||
| Status CaffeCustomParserAdapter::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) { | |||||
| GE_CHECK_NOTNULL(op_src); | |||||
| const LayerParameter *layer = reinterpret_cast<const LayerParameter *>(op_src); | |||||
| GELOGD("Caffe layer name = %s, layer type= %s, parse params", layer->name().c_str(), layer->type().c_str()); | |||||
| GE_CHECK_NOTNULL(op_dest); | |||||
| ParseParamFunc customOpParser = domi::OpRegistry::Instance()->GetParseParamFunc(op_dest->GetType(), layer->type()); | |||||
| GE_CHECK_NOTNULL(customOpParser); | |||||
| op_dest->SetName(layer->name()); | |||||
| ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_dest); | |||||
| GE_CHK_BOOL_RET_STATUS(customOpParser(op_src, op) == SUCCESS, FAILED, "Custom parser params failed"); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status CaffeCustomParserAdapter::ParseParams(const Operator &op_src, ge::OpDescPtr &op_dest) { | |||||
| GELOGI("Caffe custom op begin to params: layer name = %s, layer type= %s ", op_src.GetName().c_str(), | |||||
| op_src.GetOpType().c_str()); | |||||
| GE_CHECK_NOTNULL(op_dest); | |||||
| ParseParamByOpFunc custom_op_parser = domi::OpRegistry::Instance()->GetParseParamByOperatorFunc(op_src.GetOpType()); | |||||
| GE_CHECK_NOTNULL(custom_op_parser); | |||||
| op_dest->SetName(op_src.GetName()); | |||||
| ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_dest); | |||||
| GE_CHK_BOOL_RET_STATUS(custom_op_parser(op_src, op) == SUCCESS, FAILED, "Custom parser params failed"); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status CaffeCustomParserAdapter::ParseWeights(const Message *op_src, ge::NodePtr &node) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| auto op = node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_src); | |||||
| GE_CHECK_NOTNULL(op); | |||||
| const LayerParameter *layer = reinterpret_cast<const LayerParameter *>(op_src); | |||||
| GE_CHK_BOOL_RET_STATUS(nullptr != layer, FAILED, "Dynamic cast op_src to LayerParameter failed"); | |||||
| GELOGI("layer: %s blobs_size: %d bottom_size: %d", layer->name().c_str(), layer->blobs_size(), layer->bottom_size()); | |||||
| if (layer->blobs_size() == 0) { | |||||
| return SUCCESS; | |||||
| } | |||||
| bool bias_en = false; | |||||
| int start_pos = layer->bottom_size(); | |||||
| for (int i = 0; i < layer->blobs_size(); ++i) { | |||||
| ge::GeTensorPtr weight = ge::MakeShared<ge::GeTensor>(); | |||||
| GE_CHECK_NOTNULL(weight); | |||||
| GE_CHK_STATUS_RET(ConvertWeight(layer->blobs(i), layer->name(), weight), "Convert blobs(%d) for layer %s failed", i, | |||||
| layer->name().c_str()); | |||||
| GE_IF_BOOL_EXEC(layer->type() == kConvolution && i == kBlobIndexOne, | |||||
| const ConvolutionParameter &conv_params_src = layer->convolution_param(); | |||||
| bias_en = conv_params_src.bias_term();); | |||||
| GE_IF_BOOL_EXEC(layer->type() == kInnerProduct && i == kBlobIndexOne, | |||||
| const InnerProductParameter &fc_params_src = layer->inner_product_param(); | |||||
| bias_en = fc_params_src.bias_term();); | |||||
| auto bias_shape = weight->MutableTensorDesc().GetShape(); | |||||
| // The num 0, 1, 2, 3 represet the dim index. | |||||
| bool matched = bias_en && bias_shape.GetDimNum() == static_cast<size_t>(ge::parser::DIM_DEFAULT_SIZE) && | |||||
| bias_shape.GetDim(0) == 1 && bias_shape.GetDim(1) == 1 && bias_shape.GetDim(2) == 1; | |||||
| if (matched) { | |||||
| weight->MutableTensorDesc().SetShape(ge::GeShape({bias_shape.GetDim(3)})); | |||||
| } | |||||
| matched = layer->type() == kInnerProduct && i == 0 && | |||||
| bias_shape.GetDimNum() == static_cast<size_t>(ge::parser::DIM_DEFAULT_SIZE) && | |||||
| bias_shape.GetDim(0) == 1 && bias_shape.GetDim(1) == 1; | |||||
| if (matched) { | |||||
| weight->MutableTensorDesc().SetShape(ge::GeShape({bias_shape.GetDim(2), bias_shape.GetDim(3)})); | |||||
| } | |||||
| // construct const node | |||||
| auto const_opdesc = ge::OpDescUtils::CreateConstOp(weight); // use org weight before SetWeights Overwrite | |||||
| GE_CHECK_NOTNULL(const_opdesc); | |||||
| auto owner_graph = node->GetOwnerComputeGraph(); | |||||
| GE_CHECK_NOTNULL(owner_graph); | |||||
| // add edge from const to current node | |||||
| auto const_node = owner_graph->AddNodeFront(const_opdesc); | |||||
| GE_CHECK_NOTNULL(const_node); | |||||
| auto index = start_pos + i; | |||||
| auto valid_input_name = op->GetValidInputNameByIndex(static_cast<uint32_t>(index)); | |||||
| if (valid_input_name.empty()) { | |||||
| if (node->AddLinkFrom(static_cast<const uint32_t &>(index), const_node) != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "AddEdge failed of from Node %s output to Node %s input %d", const_node->GetName().c_str(), | |||||
| node->GetName().c_str(), index); | |||||
| } | |||||
| } else { | |||||
| if (node->AddLinkFrom(valid_input_name, const_node) != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "AddEdge failed of from Node %s output to Node %s input %s", const_node->GetName().c_str(), | |||||
| node->GetName().c_str(), valid_input_name.c_str()); | |||||
| } | |||||
| } | |||||
| std::vector<ge::NodePtr> original_nodes; | |||||
| ge::GraphUtils::RecordOriginalNames(original_nodes, const_node); | |||||
| } | |||||
| GE_IF_BOOL_EXEC(!(ge::AttrUtils::SetInt(op, "tvm_origin_input_num", layer->bottom_size())), | |||||
| GELOGW("SetInt failed for op %s.", op->GetName().c_str());); // no need to return | |||||
| return SUCCESS; | |||||
| } | |||||
| REGISTER_CUSTOM_PARSER_ADAPTER_CREATOR(CAFFE, CaffeCustomParserAdapter); | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,60 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PARSER_CAFFE_CAFFE_CUSTOM_PARSER_ADAPTER_H_ | |||||
| #define PARSER_CAFFE_CAFFE_CUSTOM_PARSER_ADAPTER_H_ | |||||
| #include "parser/caffe/caffe_op_parser.h" | |||||
| namespace ge { | |||||
| class CaffeCustomParserAdapter : public CaffeOpParser { | |||||
| public: | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief parse params of the operation | |||||
| * @param [in] op_src params to be parsed | |||||
| * @param [out] op_dest params after parsing | |||||
| * @return SUCCESS parse successfully | |||||
| * @return FAILED parse failed | |||||
| * @author | |||||
| */ | |||||
| Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override; | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief parse params of the operation | |||||
| * @param [in] op_src params to be parsed | |||||
| * @param [out] op_dest params after parsing | |||||
| * @return SUCCESS parse successfully | |||||
| * @return FAILED parse failed | |||||
| * @author | |||||
| */ | |||||
| Status ParseParams(const Operator &op_src, ge::OpDescPtr &op_dest); | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief parse weight of the operation | |||||
| * @param [in] op_src params to be parsed | |||||
| * @param [out] node params after parsing | |||||
| * @return SUCCESS parse successfullyparse failed | |||||
| * @return FAILED | |||||
| * @author | |||||
| */ | |||||
| Status ParseWeights(const Message *op_src, ge::NodePtr &node) override; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // PARSER_CAFFE_CAFFE_CUSTOM_PARSER_ADAPTER_H_ | |||||
| @@ -0,0 +1,160 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "parser/caffe/caffe_data_parser.h" | |||||
| #include <unordered_map> | |||||
| #include <utility> | |||||
| #include "common/debug/log.h" | |||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "common/util.h" | |||||
| #include "common/util/error_manager/error_manager.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/omg/parser/parser_inner_ctx.h" | |||||
| #include "parser/common/op_parser_factory.h" | |||||
| using namespace ge::parser; | |||||
| namespace ge { | |||||
| Status CaffeDataParser::GetOutputDesc(const string &name, int dim_size, const std::vector<int64_t> &input_dims, | |||||
| ge::OpDescPtr &op) { | |||||
| GE_CHECK_NOTNULL(op); | |||||
| GELOGI("The input dim size is %zu in layer %s.", input_dims.size(), name.c_str()); | |||||
| // Caffe default data type is float32 | |||||
| GE_IF_BOOL_EXEC(!(ge::AttrUtils::SetInt(op, DATA_ATTR_NAME_DATA_TYPE, ge::DT_FLOAT)), | |||||
| GELOGW("SetInt failed for op %s.", op->GetName().c_str());); // no need to return | |||||
| // Initialize input and output description of OP according to input_dims information | |||||
| GE_RETURN_WITH_LOG_IF_ERROR(ParseShape(input_dims, op), "data layer %s ParseShape failed", name.c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status CaffeDataParser::ParseParams(const Message *op_src, ge::OpDescPtr &op) { | |||||
| GE_CHECK_NOTNULL(op_src); | |||||
| GE_CHECK_NOTNULL(op); | |||||
| const domi::caffe::LayerParameter *layer = DOMI_DYNAMIC_CAST<const domi::caffe::LayerParameter *>(op_src); | |||||
| GE_CHECK_NOTNULL(layer); | |||||
| GELOGD("Caffe layer name = %s, layer type= %s, parse params", layer->name().c_str(), layer->type().c_str()); | |||||
| if (layer->type() == ge::parser::INPUT_TYPE) { | |||||
| GE_CHK_STATUS_RET(ParseParamsForInput(layer, op), "Caffe layer name = %s, layer type= %s, parse params failed", | |||||
| layer->name().c_str(), layer->type().c_str()); | |||||
| } else if(layer->type() == ge::parser::DUMMY_DATA) { | |||||
| GE_CHK_STATUS_RET(ParseParamsForDummyData(layer, op), "Caffe layer name = %s, layer type= %s, parse params failed", | |||||
| layer->name().c_str(), layer->type().c_str()); | |||||
| } else { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E11030"); | |||||
| GELOGE(PARAM_INVALID, "Caffe prototxt has no optype [Input]"); | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status CaffeDataParser::ParseParamsForInput(const domi::caffe::LayerParameter *layer, ge::OpDescPtr &op) { | |||||
| if (layer->has_input_param()) { | |||||
| const domi::caffe::InputParameter &input_param = layer->input_param(); | |||||
| if (input_param.shape_size() == 0) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||||
| "E11027", {"layername", "layertype"}, {layer->name(), layer->type()}); | |||||
| GELOGE(PARAM_INVALID, | |||||
| "input_param shape size is zero, caffe layer name [%s], layer type [%s].", | |||||
| layer->name().c_str(), layer->type().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| for (int i = 0; i < input_param.shape_size(); i++) { | |||||
| const domi::caffe::BlobShape &blob_shape = input_param.shape(i); | |||||
| vector<int64_t> shape; | |||||
| unordered_map<string, vector<int64_t>> &shape_map = GetParserContext().input_dims; | |||||
| std::vector<int64_t> model_dims; | |||||
| for (auto &blob_shape_dim_temp : blob_shape.dim()) { | |||||
| model_dims.push_back(blob_shape_dim_temp); | |||||
| } | |||||
| string name = layer->name(); | |||||
| GE_IF_BOOL_EXEC(shape_map.count(name) != 0, model_dims = shape_map.at(name)); | |||||
| GE_CHK_STATUS_RET(GetOutputDesc(name, model_dims.size(), model_dims, op), "Get output desc failed in layer %s", | |||||
| name.c_str()); | |||||
| } | |||||
| } else { | |||||
| // Get from external input | |||||
| const ge::ParserContext &ctx = GetParserContext(); | |||||
| std::unordered_map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims; | |||||
| string name = layer->name(); | |||||
| auto search = input_dims.find(name); | |||||
| if (search == input_dims.end()) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||||
| "E11028", {"layername", "layertype"}, {layer->name(), layer->type()}); | |||||
| GELOGE(PARAM_INVALID, | |||||
| "Caffe prototxt has no input_param or user should set --input_shape in atc parameter, " | |||||
| "caffe layer name [%s], layer type [%s].", layer->name().c_str(), layer->type().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| std::vector<int64_t> dims = search->second; | |||||
| GE_CHK_STATUS_RET(GetOutputDesc(name, dims.size(), dims, op), "Get output desc failed in layer %s.", | |||||
| name.c_str()); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status CaffeDataParser::ParseParamsForDummyData(const domi::caffe::LayerParameter *layer, ge::OpDescPtr &op) { | |||||
| if (layer->has_dummy_data_param()) { | |||||
| const domi::caffe::DummyDataParameter &dummy_data_param = layer->dummy_data_param(); | |||||
| if (dummy_data_param.shape_size() == 0) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||||
| "E11027", {"layername", "layertype"}, {layer->name(), layer->type()}); | |||||
| GELOGE(PARAM_INVALID, | |||||
| "input_param shape size is zero, caffe layer name [%s], layer type [%s].", | |||||
| layer->name().c_str(), layer->type().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| for (int i = 0; i < dummy_data_param.shape_size(); i++) { | |||||
| const domi::caffe::BlobShape &blob_shape = dummy_data_param.shape(i); | |||||
| vector<int64_t> shape; | |||||
| unordered_map<string, vector<int64_t>> &shape_map = GetParserContext().input_dims; | |||||
| std::vector<int64_t> model_dims; | |||||
| for (auto &blob_shape_dim_temp : blob_shape.dim()) { | |||||
| model_dims.push_back(blob_shape_dim_temp); | |||||
| } | |||||
| string name = layer->name(); | |||||
| GE_IF_BOOL_EXEC(shape_map.count(name) != 0, model_dims = shape_map.at(name)); | |||||
| GE_CHK_STATUS_RET(GetOutputDesc(name, model_dims.size(), model_dims, op), "Get output desc failed in layer %s", | |||||
| name.c_str()); | |||||
| } | |||||
| } else { | |||||
| // Get from external input | |||||
| const ge::ParserContext &ctx = GetParserContext(); | |||||
| std::unordered_map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims; | |||||
| string name = layer->name(); | |||||
| auto search = input_dims.find(name); | |||||
| if (search == input_dims.end()) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||||
| "E11028", {"layername", "layertype"}, {layer->name(), layer->type()}); | |||||
| GELOGE(PARAM_INVALID, | |||||
| "Caffe prototxt has no input_param or user should set --input_shape in atc parameter, " | |||||
| "caffe layer name [%s], layer type [%s].", layer->name().c_str(), layer->type().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| std::vector<int64_t> dims = search->second; | |||||
| GE_CHK_STATUS_RET(GetOutputDesc(name, dims.size(), dims, op), "Get output desc failed in layer %s.", | |||||
| name.c_str()); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| REGISTER_OP_PARSER_CREATOR(CAFFE, DATA, CaffeDataParser); | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,57 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PARSER_CAFFE_CAFFE_DATA_PARSER_H_ | |||||
| #define PARSER_CAFFE_CAFFE_DATA_PARSER_H_ | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "parser/caffe/caffe_op_parser.h" | |||||
| #include "parser/common/data_op_parser.h" | |||||
| namespace ge { | |||||
| class CaffeDataParser : public CaffeOpParser, public DataOpParser { | |||||
| public: | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief parse params of the operation | |||||
| * @param [in] op_src params to be parsed | |||||
| * @param [out] graph params after parsing | |||||
| * @return SUCCESS parse successfully | |||||
| * @return FAILED parse failed | |||||
| */ | |||||
| Status ParseParams(const Message *op_src, ge::OpDescPtr &op) override; | |||||
| private: | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Get the output dimension according to the input dimension | |||||
| * @param [in] name the name of the input layer | |||||
| * @param [in] input_dims the dimension of the input layer | |||||
| * @param [out] op_def op after parsing | |||||
| * @return SUCCESS parse successfully | |||||
| * @return FAILED parse failed | |||||
| */ | |||||
| Status GetOutputDesc(const std::string &name, int dim_size, | |||||
| const std::vector<int64_t> &input_dims, ge::OpDescPtr &op); | |||||
| // caffe data layer type could be type of `Input` or `DummyData` | |||||
| Status ParseParamsForInput(const domi::caffe::LayerParameter *layer, ge::OpDescPtr &op); | |||||
| Status ParseParamsForDummyData(const domi::caffe::LayerParameter *layer, ge::OpDescPtr &op); | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // PARSER_CAFFE_CAFFE_DATA_PARSER_H_ | |||||
| @@ -0,0 +1,187 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "parser/caffe/caffe_op_parser.h" | |||||
| #include <memory> | |||||
| #include "parser/common/op_parser_factory.h" | |||||
| #include "common/util/error_manager/error_manager.h" | |||||
| #include "framework/omg/parser/parser_types.h" | |||||
| using namespace ge::parser; | |||||
| using domi::CAFFE; | |||||
| namespace ge { | |||||
| Status CaffeOpParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) { return SUCCESS; } | |||||
| Status CaffeOpParser::ParseWeights(const Message *op_src, ge::NodePtr &node) { return SUCCESS; } | |||||
| Status CaffeOpParser::AddConstInput(ge::NodePtr &node) { return SUCCESS; } | |||||
| void CaffeOpParser::ConvertShape(const BlobProto &proto, std::vector<int64_t> &shape) { | |||||
| shape.clear(); | |||||
| if (proto.has_num() || proto.has_channels() || proto.has_height() || proto.has_width()) { | |||||
| // Compatible with old formats, shape description: (num, channels, height, width) | |||||
| shape.push_back(proto.num()); | |||||
| shape.push_back(proto.channels()); | |||||
| shape.push_back(proto.height()); | |||||
| shape.push_back(proto.width()); | |||||
| } else { | |||||
| // The shape of the new format is described with "repeated Int64 dim" | |||||
| for (int i = 0; i < proto.shape().dim_size(); ++i) { | |||||
| shape.push_back(proto.shape().dim(i)); | |||||
| } | |||||
| } | |||||
| } | |||||
| Status CaffeOpParser::ConvertWeight(const BlobProto &proto, const string &lay_name, ge::GeTensorPtr &weight) { | |||||
| GE_CHECK_NOTNULL(weight); | |||||
| std::vector<int64_t> shape_vec; | |||||
| ConvertShape(proto, shape_vec); | |||||
| ge::GeShape shape(shape_vec); | |||||
| // Calculate the number of data in weight | |||||
| int count = 1; | |||||
| for (size_t i = 0; i < shape.GetDimNum(); ++i) { | |||||
| int dim = shape.GetDim(i); | |||||
| if (dim <= 0) { | |||||
| GELOGE(FAILED, "Convert weight fail, Blob size invalid"); | |||||
| return FAILED; | |||||
| } | |||||
| if (dim >= INT64_MAX / count) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||||
| "E11033", {"opname", "blobsize", "reason"}, | |||||
| {lay_name, std::to_string(dim) + "*" + std::to_string(count), | |||||
| "it exceeds INT64_MAX[" + std::to_string(INT64_MAX) + "]"}); | |||||
| GELOGE(FAILED, "Convert weight fail, Blob size exceeds INT64_MAX, dim:%d, count:%d", dim, count); | |||||
| return FAILED; | |||||
| } | |||||
| count *= dim; | |||||
| } | |||||
| return ParseWeightType(proto, shape, count, lay_name, weight); | |||||
| } | |||||
| Status CaffeOpParser::ParseWeightType(const BlobProto &proto, const ge::GeShape &shape, int size, | |||||
| const string &lay_name, ge::GeTensorPtr &weight) { | |||||
| // Extract weight data and store it in weightdef by float type | |||||
| GE_CHECK_NOTNULL(weight); | |||||
| ge::DataType dtype = ge::DT_FLOAT; | |||||
| if (proto.double_data_size() > 0) { | |||||
| // Convert by double type | |||||
| if (size != proto.double_data_size()) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||||
| "E11033", {"opname", "blobsize", "reason"}, | |||||
| {lay_name, std::to_string(proto.double_data_size()), | |||||
| "it does not match shape size[" + std::to_string(size) + "]"}); | |||||
| GELOGE(FAILED, "Convert weight fail, Blob size does not match shape size, shape size:%d, blob size:%d", size, | |||||
| proto.double_data_size()); | |||||
| return FAILED; | |||||
| } | |||||
| std::unique_ptr<float[]> buf(new (std::nothrow) float[size]()); | |||||
| GE_CHECK_NOTNULL(buf); | |||||
| for (int i = 0; i < size; ++i) { | |||||
| buf[i] = proto.double_data(i); | |||||
| } | |||||
| GE_IF_BOOL_EXEC(weight->SetData(reinterpret_cast<uint8_t *>(buf.get()), size * sizeof(float)) != ge::GRAPH_SUCCESS, | |||||
| GELOGW("SetData failed for GeTensor.");); // no need to return | |||||
| } else if (proto.int8_data().length() > 0) { | |||||
| if (size != static_cast<int>(proto.int8_data().length())) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||||
| "E11033", {"opname", "blobsize", "reason"}, | |||||
| {lay_name, std::to_string(proto.int8_data().length()), | |||||
| "it does not match shape size[" + std::to_string(size) + "]"}); | |||||
| GELOGE(FAILED, "Convert weight failed, Blob size does not match shape size, shape size:%d, blob size:%ld", size, | |||||
| proto.int8_data().length()); | |||||
| return FAILED; | |||||
| } | |||||
| const char *data_ptr = proto.int8_data().data(); | |||||
| GE_CHECK_NOTNULL(data_ptr); | |||||
| GE_IF_BOOL_EXEC( | |||||
| weight->SetData(reinterpret_cast<const uint8_t *>(data_ptr), size * sizeof(int8_t)) != ge::GRAPH_SUCCESS, | |||||
| GELOGW("SetData failed for GeTensor.");); // no need to return | |||||
| dtype = ge::DT_INT8; | |||||
| } else if (proto.int32_data_size() > 0) { | |||||
| if (size != proto.int32_data_size()) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||||
| "E11033", {"opname", "blobsize", "reason"}, | |||||
| {lay_name, std::to_string(proto.int32_data_size()), | |||||
| "it does not match shape size[" + std::to_string(size) + "]"}); | |||||
| GELOGE(FAILED, "Convert weight failed, Blob size does not match shape size, shape size:%d, blob size:%d", size, | |||||
| proto.int32_data_size()); | |||||
| return FAILED; | |||||
| } | |||||
| std::unique_ptr<int32_t[]> int32_weight_buf(new (std::nothrow) int32_t[size]()); | |||||
| GE_CHECK_NOTNULL(int32_weight_buf); | |||||
| for (int i = 0; i < size; ++i) { | |||||
| int32_weight_buf[i] = proto.int32_data(i); | |||||
| } | |||||
| GE_IF_BOOL_EXEC( | |||||
| weight->SetData(reinterpret_cast<uint8_t *>(int32_weight_buf.get()), size * sizeof(int32_t)) != ge::GRAPH_SUCCESS, | |||||
| GELOGW("SetData failed for GeTensor.");); // no need to return | |||||
| dtype = ge::DT_INT32; | |||||
| } else if (proto.uint64_data_size() > 0) { | |||||
| if (size != proto.uint64_data_size()) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||||
| "E11033", {"opname", "blobsize", "reason"}, | |||||
| {lay_name, std::to_string(proto.uint64_data_size()), | |||||
| "it does not match shape size[" + std::to_string(size) + "]"}); | |||||
| GELOGE(FAILED, "Convert weight failed, Blob size does not match shape size, shape size:%d, blob size:%d", size, | |||||
| proto.uint64_data_size()); | |||||
| return FAILED; | |||||
| } | |||||
| std::unique_ptr<uint64_t[]> uint64_weight_buf(new (std::nothrow) uint64_t[size]()); | |||||
| GE_CHECK_NOTNULL(uint64_weight_buf); | |||||
| for (int i = 0; i < size; ++i) { | |||||
| uint64_weight_buf[i] = proto.uint64_data(i); | |||||
| } | |||||
| GE_IF_BOOL_EXEC(weight->SetData(reinterpret_cast<uint8_t *>(uint64_weight_buf.get()), size * sizeof(uint64_t)) != | |||||
| ge::GRAPH_SUCCESS, | |||||
| GELOGW("SetData failed for GeTensor.");); // no need to return | |||||
| dtype = ge::DT_UINT64; | |||||
| } else { | |||||
| // Convert by float type | |||||
| if (size != proto.data_size()) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||||
| "E11033", {"opname", "blobsize", "reason"}, | |||||
| {lay_name, std::to_string(proto.data_size()), | |||||
| "it does not match shape size[" + std::to_string(size) + "]"}); | |||||
| GELOGE(FAILED, "Convert weight fail, Blob size does not match shape size, shape size:%d, blob.data_size:%d", size, | |||||
| proto.data_size()); | |||||
| return FAILED; | |||||
| } | |||||
| const float *data_ptr = proto.data().data(); | |||||
| GE_CHECK_NOTNULL(data_ptr); | |||||
| GE_IF_BOOL_EXEC( | |||||
| weight->SetData(reinterpret_cast<const uint8_t *>(data_ptr), size * sizeof(float)) != ge::GRAPH_SUCCESS, | |||||
| GELOGW("SetData failed for GeTensor.");); // no need to return | |||||
| } | |||||
| ge::GeTensorDesc weight_desc = ge::GeTensorDesc(); | |||||
| weight_desc.Update(shape, ge::FORMAT_NCHW, dtype); | |||||
| weight->SetTensorDesc(weight_desc); | |||||
| return SUCCESS; | |||||
| } | |||||
| // Dropout's corresponding op_parser is registered as caffeopparser, optimized in optimization stage. | |||||
| REGISTER_OP_PARSER_CREATOR(CAFFE, DROPOUT, CaffeOpParser); | |||||
| // A new operator added by framework in OM model is used to | |||||
| // collect and arrange all outputs in the order of the original model's output | |||||
| // Net output operator does not need special processing in the parse stage, | |||||
| // and directly registers in the op_parser file | |||||
| REGISTER_OP_PARSER_CREATOR(CAFFE, NETOUTPUT, CaffeOpParser); | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,120 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PARSER_CAFFE_CAFFE_OP_PARSER_H_ | |||||
| #define PARSER_CAFFE_CAFFE_OP_PARSER_H_ | |||||
| #include <vector> | |||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "common/util.h" | |||||
| #include "graph/compute_graph.h" | |||||
| #include "graph/ge_attr_value.h" | |||||
| #include "graph/ge_tensor.h" | |||||
| #include "graph/op_desc.h" | |||||
| #include "graph/operator.h" | |||||
| #include "graph/types.h" | |||||
| #include "graph/utils/attr_utils.h" | |||||
| #include "graph/utils/tensor_utils.h" | |||||
| #include "omg/parser/op_parser.h" | |||||
| #include "proto/caffe/caffe.pb.h" | |||||
| using domi::caffe::ArgMaxParameter; | |||||
| using domi::caffe::BatchNormParameter; | |||||
| using domi::caffe::BlobProto; | |||||
| using domi::caffe::BlobShape; | |||||
| using domi::caffe::ConcatParameter; | |||||
| using domi::caffe::ConvolutionParameter; | |||||
| using domi::caffe::DetectionOutputParameter; | |||||
| using domi::caffe::EltwiseParameter; | |||||
| using domi::caffe::FillerParameter; | |||||
| using domi::caffe::InnerProductParameter; | |||||
| using domi::caffe::LayerParameter; | |||||
| using domi::caffe::PoolingParameter; | |||||
| using domi::caffe::PReLUParameter; | |||||
| using domi::caffe::ReshapeParameter; | |||||
| using domi::caffe::ROIAlignParameter; | |||||
| using domi::caffe::TanHParameter; | |||||
| using domi::caffe::UpsampleParameter; | |||||
| namespace ge { | |||||
| /** | |||||
| * @ingroup ge_omg | |||||
| * @brief Used to parse Caffe operator information | |||||
| */ | |||||
| class CaffeOpParser : public OpParser { | |||||
| public: | |||||
| Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override; | |||||
| Status ParseParams(const Message *op_src, ge::Operator &op_dest) override { | |||||
| return domi::SUCCESS; | |||||
| } | |||||
| /** | |||||
| * @ingroup ge_omg | |||||
| * @brief parse weight information of the operation | |||||
| * @param [in] op_src Weight data to be parsed | |||||
| * @param [out] graph Weight data after parsing | |||||
| * @return SUCCESS parse successfully | |||||
| * @return FAILED parse failed | |||||
| * @author | |||||
| */ | |||||
| Status ParseWeights(const Message *op_src, ge::NodePtr &node) override; | |||||
| /** | |||||
| * @ingroup ge_omg | |||||
| * @brief add const input node | |||||
| * @param [in] node to add const input | |||||
| * @param [out] node after add const input | |||||
| * @return SUCCESS add const input successfully | |||||
| * @return FAILED add const input failed | |||||
| * @author | |||||
| */ | |||||
| virtual Status AddConstInput(ge::NodePtr &node); | |||||
| protected: | |||||
| /** | |||||
| * @ingroup ge_omg | |||||
| * @brief Convert blob proto to weight definition | |||||
| * @param [in] proto Weight data to be parsed | |||||
| * @param [out] weight Weight data after parsing | |||||
| * @return SUCCESS parse successfully | |||||
| * @return FAILED parse failed | |||||
| */ | |||||
| static Status ConvertWeight(const BlobProto &proto, const string &lay_name, ge::GeTensorPtr &weight); | |||||
| /** | |||||
| * @ingroup ge_omg | |||||
| * @brief Convert blob proto to shape definition | |||||
| * @param [in] proto Shape information before conversion | |||||
| * @param [out] shape Save converted shape information | |||||
| */ | |||||
| static void ConvertShape(const BlobProto &proto, std::vector<int64_t> &shape); | |||||
| private: | |||||
| /** | |||||
| * @ingroup ge_omg | |||||
| * @brief Convert blob proto to weight definition | |||||
| * @param [in] proto Weight data to be parsed | |||||
| * @param [out] weight Weight data after parsing | |||||
| * @return SUCCESS parse weight type successfully | |||||
| * @return FAILED parse failed | |||||
| */ | |||||
| static Status ParseWeightType(const BlobProto &proto, const ge::GeShape &shape, | |||||
| int size, const string &lay_name, ge::GeTensorPtr &weight); | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // PARSER_CAFFE_CAFFE_OP_PARSER_H_ | |||||
| @@ -0,0 +1,433 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PARSER_CAFFE_CAFFE_PARSER_H_ | |||||
| #define PARSER_CAFFE_CAFFE_PARSER_H_ | |||||
| #include <map> | |||||
| #include <set> | |||||
| #include <string> | |||||
| #include <unordered_map> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "external/graph/operator.h" | |||||
| #include "omg/parser/op_parser.h" | |||||
| #include "omg/parser/model_parser.h" | |||||
| #include "omg/parser/weights_parser.h" | |||||
| #include "proto/caffe/caffe.pb.h" | |||||
| #include "proto/om.pb.h" | |||||
| namespace ge { | |||||
| using domi::caffe::NetParameter; | |||||
| using std::map; | |||||
| using std::set; | |||||
| using std::string; | |||||
| using std::unordered_map; | |||||
| using std::vector; | |||||
| static std::map<std::vector<std::string>, std::vector<std::string>> params_share_map; | |||||
| class CaffeModelParser : public domi::ModelParser { | |||||
| public: | |||||
| CaffeModelParser() {} | |||||
| virtual ~CaffeModelParser() {} | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Parse the relevant data from the model file and save it to graph | |||||
| * @param [in] file Path of model file | |||||
| * @param [in|out] graph graph for saving model information | |||||
| * @return SUCCESS parse successfully | |||||
| * @return FAILED parse failed | |||||
| */ | |||||
| Status Parse(const char *file, ge::Graph &graph) override; | |||||
| Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override; | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Convert model files to JSON format | |||||
| * @param [in] model_file Path of model file | |||||
| * @param [out] json_file Converted JSON file path | |||||
| * @return SUCCESS parse successfully | |||||
| * @return others parse failed | |||||
| */ | |||||
| Status ToJson(const char *model_file, const char *json_file) override; | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Parse the relevant data from the model file and save it to graph | |||||
| * @param [in] graph_def input tensorflow model | |||||
| * @param [in|out] graph graph for saving model information | |||||
| * @return SUCCESS parse successfully | |||||
| * @return FAILED parse failed | |||||
| */ | |||||
| Status ParseProto(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) override; | |||||
| Status ParseProtoWithSubgraph(const google::protobuf::Message *root_proto, domi::GetGraphCallback callback, | |||||
| ge::ComputeGraphPtr &graph) override; | |||||
| /* | |||||
| * @ingroup domi_omg | |||||
| * @brief Mapping CAFFE's datatype to GE's datatype | |||||
| * @param [in] type, datatype types of operators in CAFFE networks | |||||
| * @return ge::DataType | |||||
| */ | |||||
| ge::DataType ConvertToGeDataType(const uint32_t type) override { return ge::DT_FLOAT; } | |||||
| Status ParseAllGraph(const google::protobuf::Message *root_proto, ge::ComputeGraphPtr &root_graph) override { | |||||
| return domi::SUCCESS; | |||||
| } | |||||
| private: | |||||
| Status Parse(const char *file, ge::ComputeGraphPtr &graph); | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Add the Layer in the model to the PreChecker | |||||
| * @param [in] net caffe net information | |||||
| * @return SUCCESS build successfully | |||||
| * @return FAILED build failed | |||||
| */ | |||||
| Status PreCheck(const domi::caffe::NetParameter &net); | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Parsing input related information from model files | |||||
| * @param [in] proto_message caffe net information | |||||
| * @param [in|out] net_input_name Used to store the acquired input name information | |||||
| * @param [in|out] net_input_data Used to store the acquired input data information | |||||
| * @return SUCCESS build successfully | |||||
| * @return FAILED build failed | |||||
| */ | |||||
| Status ParseInput(domi::caffe::NetParameter &proto_message, bool &input_data_flag); | |||||
| /* | |||||
| * @ingroup domi_omg | |||||
| * @brief Parse model by custom proto and save info to operators | |||||
| * @param [in] model_path, file path of model(prototxt file) | |||||
| * @param [in] custom_proto, file path of custom proto | |||||
| * @param [in] caffe_proto, file path of caffe proto | |||||
| * @param [out] operators, operators saving custom info | |||||
| * @return SUCCESS parse successfully | |||||
| * @return FAILED parse failed | |||||
| */ | |||||
| Status CustomProtoParse(const char *model_path, const string &custom_proto, const string &caffe_proto, | |||||
| std::vector<ge::Operator> &operators); | |||||
| /* | |||||
| * @ingroup domi_omg | |||||
| * @brief Parse model by custom proto and save info to operators | |||||
| * @param [in] model_path, file path of model(prototxt file) | |||||
| * @param [in] custom_proto_path, file path of custom proto | |||||
| * @param [in] custom_proto_name, custom proto name | |||||
| * @param [out] operators, operators saving custom info | |||||
| * @return SUCCESS parse successfully | |||||
| * @return FAILED parse failed | |||||
| */ | |||||
| Status ParseNetModelByCustomProto(const char *model_path, const string &custom_proto_path, | |||||
| const string &custom_proto_name, std::vector<ge::Operator> &operators); | |||||
| /* | |||||
| * @ingroup domi_omg | |||||
| * @brief Parse caffe proto file | |||||
| * @param [in] proto_file, file path of caffe proto | |||||
| * @param [out] identifier_op_map, identifer and op map | |||||
| * @return SUCCESS parse successfully | |||||
| * @return FAILED parse failed | |||||
| */ | |||||
| Status ParseProtoFile(const string &proto_file, std::map<int32_t, string> &identifier_op_map); | |||||
| /* | |||||
| * @ingroup domi_omg | |||||
| * @brief Save identifier op map info | |||||
| * @param [in] line, line of proto | |||||
| * @param [out] identifier_op_map, identifer and op map | |||||
| * @return SUCCESS parse successfully | |||||
| * @return FAILED parse failed | |||||
| */ | |||||
| Status SaveIdentifierOpMapInfo(const string &line, std::map<int32_t, string> &identifier_op_map); | |||||
| /* | |||||
| * @ingroup domi_omg | |||||
| * @brief Get op identifier | |||||
| * @param [in] line, line of proto | |||||
| * @param [out] identifier, identifer of op | |||||
| * @return SUCCESS parse successfully | |||||
| * @return FAILED parse failed | |||||
| */ | |||||
| Status GetIdentifier(const std::string &line, int32_t &identifier); | |||||
| /* | |||||
| * @ingroup domi_omg | |||||
| * @brief Read caffe model and shield google warning | |||||
| * @param [in] model_path, file path of model(prototxt file) | |||||
| * @param [out] message, message saving custom info | |||||
| * @return SUCCESS read file successfully | |||||
| * @return FAILED read file failed | |||||
| */ | |||||
| Status ReadModelWithoutWarning(const char *model_path, google::protobuf::Message *message); | |||||
| /* | |||||
| * @ingroup domi_omg | |||||
| * @brief Read caffe model and save it to message | |||||
| * @param [in] model_path, file path of model(prototxt file) | |||||
| * @param [out] message, message saving custom info | |||||
| * @return SUCCESS read file successfully | |||||
| * @return FAILED read file failed | |||||
| */ | |||||
| Status ReadCaffeModelFromText(const char *model_path, google::protobuf::Message *message); | |||||
| /* | |||||
| * @ingroup domi_omg | |||||
| * @brief Parse layer message and save custom info to operators | |||||
| * @param [in] layer_descriptor, layer description of message | |||||
| * @param [in] message, message of model | |||||
| * @param [out] operators, operators saving custom info | |||||
| * @return SUCCESS parse layer successfully | |||||
| * @return FAILED parse layer failed | |||||
| */ | |||||
| Status ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor, | |||||
| const google::protobuf::Message *message, std::vector<ge::Operator> &operators); | |||||
| /* | |||||
| * @ingroup domi_omg | |||||
| * @brief Create custom operator by op_name and op_type | |||||
| * @param [in] op_name, name of operator | |||||
| * @param [in] op_type, type of operator | |||||
| * @param [in] message, message of model | |||||
| * @param [in] index, index of field | |||||
| * @param [out] operators, operators saving custom info | |||||
| * @return SUCCESS create operator successfully | |||||
| * @return FAILED create operator failed | |||||
| */ | |||||
| Status CreateCustomOperator(std::string op_name, std::string op_type, const google::protobuf::Message *message, | |||||
| int index, std::vector<ge::Operator> &operators); | |||||
| /* | |||||
| * @ingroup domi_omg | |||||
| * @brief Parse message and set operator attrs | |||||
| * @param [in] message, message of model | |||||
| * @param [in/out] depth, depth of recursion | |||||
| * @param [out] ops, operator saving custom info | |||||
| * @return SUCCESS parse message successfully | |||||
| * @return FAILED parse message failed | |||||
| */ | |||||
| Status ParseOperatorAttrs(const google::protobuf::Message *message, int depth, ge::Operator &ops); | |||||
| /* | |||||
| * @ingroup domi_omg | |||||
| * @brief Parse field and set operator attrs | |||||
| * @param [in] reflection, reflection of message | |||||
| * @param [in] message, message of model | |||||
| * @param [in] field, field of message | |||||
| * @param [in/out] depth, depth of recursion | |||||
| * @param [out] ops, operator saving custom info | |||||
| * @return SUCCESS parse field successfully | |||||
| * @return FAILED parse field failed | |||||
| */ | |||||
| Status ParseField(const google::protobuf::Reflection *reflection, const google::protobuf::Message *message, | |||||
| const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops); | |||||
| /* | |||||
| * @ingroup domi_omg | |||||
| * @brief Parse repeated field and set operator attrs | |||||
| * @param [in] reflection, reflection of message | |||||
| * @param [in] message, message of model | |||||
| * @param [in] field, field of message | |||||
| * @param [in/out] depth, depth of recursion | |||||
| * @param [out] ops, operator saving custom info by vector | |||||
| * @return SUCCESS parse field successfully | |||||
| * @return FAILED parse field failed | |||||
| */ | |||||
| Status ParseRepeatedField(const google::protobuf::Reflection *reflection, const google::protobuf::Message *message, | |||||
| const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops); | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Add blob information to the bottom_blobs_map and top_blobs_map_ | |||||
| * @param [in] layer layer information | |||||
| * @param [in|out] inplace_blob_name_remapping save blob information | |||||
| * @return Status | |||||
| */ | |||||
| Status AddBlobsToMap(const domi::caffe::LayerParameter &layer, | |||||
| std::map<std::string, std::string> &inplace_blob_name_remapping); | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Add node information to graph | |||||
| * @param [in] layer layer infromation | |||||
| * @param [in|out] graph graph for saving model information | |||||
| * @return SUCCESS add successfully | |||||
| * @return FAILED add failed | |||||
| */ | |||||
| Status AddNode(const domi::caffe::LayerParameter &layer, ge::ComputeGraphPtr &graph); | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Add edge information to graph | |||||
| * @param [in|out] graph graph for saving model information | |||||
| * @return SUCCESS add successfully | |||||
| * @return FAILED add failed | |||||
| */ | |||||
| Status AddEdges(ge::ComputeGraphPtr &graph); | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Add edge information to graph | |||||
| * @param [in|out] graph graph for saving model information | |||||
| * @return SUCCESS add successfully | |||||
| * @return FAILED add failed | |||||
| */ | |||||
| Status AddEdge4Output(const domi::caffe::NetParameter &proto_message, ge::ComputeGraphPtr &graph); | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Check if the current layer is valid | |||||
| * @return true valid | |||||
| * @return false invalid | |||||
| */ | |||||
| bool CheckValidLayer(const domi::caffe::LayerParameter &layer); | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Check whether the top of the current layer is 'Inplace' | |||||
| * @return true is 'Inplace' | |||||
| * @return false not is 'Inplace' | |||||
| */ | |||||
| bool IsInplaceTopBlob(const domi::caffe::LayerParameter &layer, const std::string &top_name); | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Check whether the top of the current layer is user's specified output top | |||||
| * @return true yes | |||||
| * @return false no | |||||
| */ | |||||
| bool IsOutputTop(const string &op_name, int32_t index); | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Find a layer set with the same param | |||||
| * @param [in] Param name set of each layer | |||||
| * @param [in|out] Layer set of the same param | |||||
| * @return Status | |||||
| */ | |||||
| Status FindShareParamLayers(const std::map<std::string, std::vector<std::string>> &); | |||||
| Status AddTensorDescToOpDesc(ge::OpDescPtr &op_desc, const domi::caffe::LayerParameter &layer); | |||||
| Status AddTensorDescToOpDescByIr(ge::OpDescPtr &op_desc, const domi::caffe::LayerParameter &layer, | |||||
| const string &op_type); | |||||
| Status AddEdgeForUserOutNodes(ge::ComputeGraphPtr &graph); | |||||
| std::string RemapTopNameByLayer(const domi::caffe::LayerParameter &layer, const std::string &top_name, int index); | |||||
| Status GetCustomOp(const domi::caffe::LayerParameter &layer, vector<ge::Operator> &operators); | |||||
| bool IsOpAttrEmpty(const ge::Operator &op, const std::string &type); | |||||
| Status ParseOpParam(const domi::caffe::LayerParameter &layer, ge::OpDescPtr &op, | |||||
| std::shared_ptr<ge::OpParser> &op_parser); | |||||
| Status GetLeafNodeTops(ge::ComputeGraphPtr &graph); | |||||
| void SaveOrigionLayerTops(domi::caffe::LayerParameter &layer); | |||||
| Status ReorderInput(domi::caffe::NetParameter &net); | |||||
| void AddOutputInfoToContext(string layer_name, int32_t top_index); | |||||
| Status ParseOutputNodeTopInfo(const domi::caffe::NetParameter &proto_message); | |||||
| std::map<std::string, ge::NodePtr> node_map; | |||||
| // key: blob name, value: layer name and index | |||||
| std::unordered_map<std::string, std::vector<std::pair<std::string, int32_t>>> bottom_blobs_map_; | |||||
| // key: blob name, value: layer name and index | |||||
| std::unordered_map<std::string, std::vector<std::pair<std::string, int32_t>>> top_blobs_map_; | |||||
| std::vector<ge::Operator> custom_operator_; | |||||
| std::map<std::string, std::vector<std::string>> layer_tops_map_; | |||||
| }; | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Caffe weight parser | |||||
| */ | |||||
| class CaffeWeightsParser : public domi::WeightsParser { | |||||
| public: | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Parse weight data from file and save to graph | |||||
| * @param [in] file Path of weight file after training | |||||
| * @param [in|out] graph Save weight information after parsing | |||||
| * @return SUCCESS parse successfully | |||||
| * @return PARAM_INVALID param invalid | |||||
| * @return PARSE_WEIGHTS_FAILED parse failed | |||||
| */ | |||||
| Status Parse(const char *file, ge::Graph &graph) override; | |||||
| Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override; | |||||
| private: | |||||
| Status CheckNodes(ge::ComputeGraphPtr &graph); | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Convert netparameter to modedef and save in graph | |||||
| * @param [in] param Caffe network parameters to be converted | |||||
| * @param [in|out] graph Save weight information after parsing | |||||
| * @return SUCCESS parse successfully | |||||
| * @return FAILED parse failed | |||||
| */ | |||||
| static Status ConvertNetParameter(const NetParameter ¶m, ge::ComputeGraphPtr &graph); | |||||
| Status Parse(const char *file, ge::ComputeGraphPtr &graph); | |||||
| Status ParseWeightByFusionProto(const char *model_path, const string &custom_proto_path, | |||||
| const string &custom_proto_name, ge::ComputeGraphPtr &graph); | |||||
| Status ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor, | |||||
| const google::protobuf::Message *message, | |||||
| ge::ComputeGraphPtr &graph); | |||||
| Status ConvertLayerParameter(const google::protobuf::Message *layer_message, | |||||
| ge::ComputeGraphPtr &graph); | |||||
| Status CheckLayersSize(const google::protobuf::Message *message); | |||||
| Status ConvertLayerProto(const google::protobuf::Message *message, | |||||
| google::protobuf::Message *layer); | |||||
| Status ParseLayerField(const google::protobuf::Reflection *reflection, | |||||
| const google::protobuf::Message *message, | |||||
| const google::protobuf::FieldDescriptor *field, | |||||
| google::protobuf::Message *layer); | |||||
| Status ConvertBlobsProto(const google::protobuf::Message *message, | |||||
| google::protobuf::Message *blobs); | |||||
| Status ConvertBlobShapeProto(const google::protobuf::Message *message, | |||||
| google::protobuf::Message *dest_message); | |||||
| Status ConvertInnerProdcutProto(const google::protobuf::Message *message, | |||||
| google::protobuf::Message *dest_message); | |||||
| Status ConvertConvParamProto(const google::protobuf::Message *message, | |||||
| google::protobuf::Message *dest_message); | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Layer types to be ignored in weight resolution | |||||
| */ | |||||
| static const set<string> skiped_layer_type_; | |||||
| std::map<std::string, int32_t> layer_name_record_map_; | |||||
| }; | |||||
| } // namespace domi | |||||
| #endif // PARSER_CAFFE_CAFFE_PARSER_H_ | |||||
| @@ -0,0 +1,143 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "parser/caffe/caffe_reshape_parser.h" | |||||
| #include <vector> | |||||
| #include "common/debug/log.h" | |||||
| #include "common/ge/ge_util.h" | |||||
| #include "common/op/op_parser_util.h" | |||||
| #include "common/util.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| #include "parser/common/op_parser_factory.h" | |||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "proto/om.pb.h" | |||||
| using namespace ge::parser; | |||||
| using domi::CAFFE; | |||||
| namespace ge { | |||||
| namespace { | |||||
| const int kAnchorIndexZero = 0; | |||||
| const int kAnchorIndexOne = 1; | |||||
| } // namespace | |||||
| Status CaffeReshapeParser::ParseParams(const Message *op_src, ge::OpDescPtr &op) { | |||||
| GE_CHECK_NOTNULL(op_src); | |||||
| GE_CHECK_NOTNULL(op); | |||||
| const LayerParameter *layer = DOMI_DYNAMIC_CAST<const LayerParameter *>(op_src); | |||||
| if (layer == nullptr) { | |||||
| GELOGE(FAILED, "Reshape Dynamic cast op_src to LayerParameter failed"); | |||||
| return FAILED; | |||||
| } | |||||
| GELOGD("Caffe layer name = %s, layer type= %s, parse params", layer->name().c_str(), layer->type().c_str()); | |||||
| const ReshapeParameter &reshape_parameter = layer->reshape_param(); | |||||
| GE_IF_BOOL_EXEC(!(ge::AttrUtils::SetInt(op, RESHAPE_ATTR_AXIS, RESHAPE_AXIS_DEFAULT_VALUE)), | |||||
| GELOGW("SetInt failed for op %s.", op->GetName().c_str());); // no need to return | |||||
| GE_IF_BOOL_EXEC(!(ge::AttrUtils::SetInt(op, RESHAPE_ATTR_NUM_AXES, RESHAPE_NUM_AXES_DEFAULT_VALUE)), | |||||
| GELOGW("SetInt failed for op %s.", op->GetName().c_str());); // no need to return | |||||
| if (!reshape_parameter.has_shape()) { | |||||
| GELOGE(FAILED, "Reshape has no shape info, ret fail"); | |||||
| return FAILED; | |||||
| } | |||||
| const BlobShape &blob_shape = reshape_parameter.shape(); | |||||
| std::vector<int64_t> dims; | |||||
| for (int i = 0; i < blob_shape.dim_size(); i++) { | |||||
| dims.push_back(blob_shape.dim(i)); | |||||
| } | |||||
| if (reshape_parameter.has_axis()) { | |||||
| GE_LOGW_IF(reshape_parameter.axis() == -1, | |||||
| "axis with -1 may lead to calculation errors when input less than 4 dims."); | |||||
| GE_IF_BOOL_EXEC(!(ge::AttrUtils::SetInt(op, RESHAPE_ATTR_AXIS, reshape_parameter.axis())), | |||||
| GELOGW("SetInt failed for op %s.", op->GetName().c_str());); // no need to return | |||||
| } | |||||
| if (reshape_parameter.has_num_axes()) { | |||||
| GE_IF_BOOL_EXEC(!(ge::AttrUtils::SetInt(op, RESHAPE_ATTR_NUM_AXES, reshape_parameter.num_axes())), | |||||
| GELOGW("SetInt failed for op %s.", op->GetName().c_str());); // no need to return | |||||
| } | |||||
| GE_IF_BOOL_EXEC(!(ge::AttrUtils::SetListInt(op, RESHAPE_ATTR_SHAPE, dims)), | |||||
| GELOGW("SetListInt failed for op %s.", op->GetName().c_str());); // no need to return | |||||
| return SUCCESS; | |||||
| } | |||||
| Status CaffeReshapeParser::ParseWeights(const Message *op_src, ge::OpDescPtr &op) { | |||||
| (void)op_src; | |||||
| (void)op; | |||||
| return SUCCESS; | |||||
| } | |||||
| Status CaffeReshapeParser::AddConstInput(ge::NodePtr &node) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| auto owner_graph = node->GetOwnerComputeGraph(); | |||||
| if (owner_graph == nullptr) { | |||||
| GELOGE(FAILED, "node's graph is empty, name: %s", node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| ge::OpDescPtr op = node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op); | |||||
| vector<int64_t> attr_shape; | |||||
| GE_IF_BOOL_EXEC(!(ge::AttrUtils::GetListInt(op, RESHAPE_ATTR_SHAPE, attr_shape)), | |||||
| GELOGW("GetListInt failed for op %s.", op->GetName().c_str());); // no need to return | |||||
| size_t dims_size = attr_shape.size(); | |||||
| // construct GeTensorDesc | |||||
| ge::GeTensorDesc const_desc = ge::GeTensorDesc(); | |||||
| std::vector<int64_t> shape_vec = {static_cast<int64_t>(dims_size)}; | |||||
| ge::GeShape shape(shape_vec); | |||||
| const_desc.Update(shape, ge::FORMAT_NCHW, ge::DT_INT64); | |||||
| ge::graphStatus state = op->UpdateInputDesc(RESHAPE_ATTR_SHAPE, const_desc); | |||||
| if (state != ge::GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "Updata input_shape desc failed."); | |||||
| return FAILED; | |||||
| } | |||||
| // construct GeTensorPtr | |||||
| ge::GeTensorPtr constTensor = ge::MakeShared<ge::GeTensor>(); | |||||
| GE_CHECK_NOTNULL(constTensor); | |||||
| constTensor->SetTensorDesc(const_desc); | |||||
| std::unique_ptr<int64_t[]> data(new (std::nothrow) int64_t[dims_size]()); | |||||
| GE_CHECK_NOTNULL(data); | |||||
| for (size_t i = 0; i < dims_size; ++i) { | |||||
| data[i] = attr_shape[i]; | |||||
| } | |||||
| GE_IF_BOOL_EXEC( | |||||
| constTensor->SetData(reinterpret_cast<uint8_t *>(data.get()), dims_size * sizeof(int64_t)) != ge::GRAPH_SUCCESS, | |||||
| GELOGW("SetData failed for GeTensor.");); // no need to return | |||||
| // construct const node and add edge | |||||
| auto const_opdesc = ge::OpDescUtils::CreateConstOp(constTensor); | |||||
| GE_CHECK_NOTNULL(const_opdesc); | |||||
| auto const_node = owner_graph->AddNodeFront(const_opdesc); | |||||
| GE_CHECK_NOTNULL(const_node); | |||||
| ge::OutDataAnchorPtr out_archor_ptr = const_node->GetOutDataAnchor(kAnchorIndexZero); | |||||
| GE_CHECK_NOTNULL(out_archor_ptr); | |||||
| ge::InDataAnchorPtr in_archor_ptr = node->GetInDataAnchor(kAnchorIndexOne); | |||||
| GE_CHECK_NOTNULL(in_archor_ptr); | |||||
| state = ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr); | |||||
| if (state != ge::GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "AddEdge failed of from Node %s to Node %s", const_node->GetName().c_str(), node->GetName().c_str()); | |||||
| return domi::FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| REGISTER_OP_PARSER_CREATOR(CAFFE, RESHAPE, CaffeReshapeParser); | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,59 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PARSER_CAFFE_CAFFE_RESHAPE_PARSER_H_ | |||||
| #define PARSER_CAFFE_CAFFE_RESHAPE_PARSER_H_ | |||||
| #include "parser/caffe/caffe_op_parser.h" | |||||
| namespace ge { | |||||
| class CaffeReshapeParser : public CaffeOpParser { | |||||
| public: | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief parse params of the operation | |||||
| * @param [in] op_src params to be parsed | |||||
| * @param [out] op_dest params after parsing | |||||
| * @return SUCCESS parse successfully | |||||
| * @return FAILED parse failed | |||||
| */ | |||||
| Status ParseParams(const Message *op_src, ge::OpDescPtr &op) override; | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief parse weight of the operation | |||||
| * @param [in] op_src params to be parsed | |||||
| * @param [out] op_dest params after parsing | |||||
| * @return SUCCESS parse successfully | |||||
| * @return FAILED parse failed | |||||
| * @author | |||||
| */ | |||||
| Status ParseWeights(const Message *op_src, ge::OpDescPtr &op); | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief add const input node | |||||
| * @param [in] node to add const input | |||||
| * @param [out] node after add const input | |||||
| * @return SUCCESS add const input successfully | |||||
| * @return FAILED add const input failed | |||||
| * @author | |||||
| */ | |||||
| Status AddConstInput(ge::NodePtr &node) override; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // PARSER_CAFFE_CAFFE_RESHAPE_PARSER_H_ | |||||
| @@ -0,0 +1,190 @@ | |||||
| syntax = "proto3"; | |||||
| package ge.proto; | |||||
| enum DataType | |||||
| { | |||||
| DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set. | |||||
| DT_FLOAT = 1; // float type | |||||
| DT_FLOAT16 = 2; // fp16 type | |||||
| DT_INT8 = 3; // int8 type | |||||
| DT_UINT8 = 4; // uint8 type | |||||
| DT_INT16 = 5; // int16 type | |||||
| DT_UINT16 = 6; // uint16 type | |||||
| DT_INT32 = 7; // | |||||
| DT_INT64 = 8; // int64 type | |||||
| DT_UINT32 = 9; // unsigned int32 | |||||
| DT_UINT64 = 10; // unsigned int64 | |||||
| DT_BOOL = 11; // bool type | |||||
| DT_DOUBLE = 12; // double type | |||||
| DT_STRING = 13; // string type | |||||
| DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */ | |||||
| DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */ | |||||
| DT_COMPLEX64 = 16; // complex64 type | |||||
| DT_COMPLEX128 = 17; // complex128 type | |||||
| DT_QINT8 = 18; // qint8 type | |||||
| DT_QINT16 = 19; // qint16 type | |||||
| DT_QINT32 = 20; // qint32 type | |||||
| DT_QUINT8 = 21; // quint8 type | |||||
| DT_QUINT16 = 22; // quint16 type | |||||
| DT_RESOURCE = 23; // resource type | |||||
| DT_STRING_REF = 24; // string_ref type | |||||
| DT_DUAL = 25; /**< dual output type */ | |||||
| } | |||||
| message AttrDef | |||||
| { | |||||
| message ListValue | |||||
| { | |||||
| enum ListValueType{ | |||||
| VT_LIST_NONE = 0; | |||||
| VT_LIST_STRING = 1; | |||||
| VT_LIST_INT = 2; | |||||
| VT_LIST_FLOAT = 3; | |||||
| VT_LIST_BOOL = 4; | |||||
| VT_LIST_BYTES = 5; | |||||
| VT_LIST_TENSOR_DESC = 6; | |||||
| VT_LIST_TENSOR = 7; | |||||
| VT_LIST_GRAPH = 8; | |||||
| VT_LIST_NAMED_ATTRS = 9; | |||||
| VT_LIST_DATA_TYPE = 10; | |||||
| } | |||||
| repeated bytes s = 2; // "list(string)" | |||||
| repeated int64 i = 3; // "list(int)" | |||||
| repeated float f = 4; // "list(float)" | |||||
| repeated bool b = 5; // "list(bool)" | |||||
| repeated bytes bt = 7; | |||||
| repeated TensorDescriptor td = 8; | |||||
| repeated TensorDef t = 9; | |||||
| repeated GraphDef g = 10; | |||||
| repeated NamedAttrs na = 11; | |||||
| repeated int64 dt = 12; // list ge::DataType | |||||
| ListValueType val_type = 20; | |||||
| } | |||||
| message ListListInt{ | |||||
| message ListInt{ | |||||
| repeated int64 list_i = 1; // list int | |||||
| } | |||||
| repeated ListInt list_list_i = 1; // list list int | |||||
| } | |||||
| oneof value | |||||
| { | |||||
| bytes s = 2; // "string" | |||||
| int64 i = 3; // "int" | |||||
| float f = 4; // "float" | |||||
| bool b = 5; // "bool" | |||||
| bytes bt = 7; | |||||
| ListValue list = 1; // any "list(...)" | |||||
| NamedAttrs func = 10; // Used to support attr nesting | |||||
| TensorDescriptor td = 11; // GeTensorDesc type | |||||
| TensorDef t = 12; // GeTensor type | |||||
| GraphDef g = 13; // Graph type | |||||
| ListListInt list_list_int = 14; // List List Int type | |||||
| int64 dt = 15; // ge::DataType | |||||
| } | |||||
| } | |||||
| // A list of attr names and their values. The whole list is attached | |||||
| // with a string name. E.g., MatMul[T=float]. | |||||
| message NamedAttrs | |||||
| { | |||||
| string name = 1; | |||||
| map<string, AttrDef> attr = 2; | |||||
| } | |||||
| // Shape / dimension description, using row-major order | |||||
| message ShapeDef | |||||
| { | |||||
| repeated int64 dim = 1; // Size of each dimension | |||||
| } | |||||
| // Multidimensional data description | |||||
| message TensorDescriptor | |||||
| { | |||||
| string name = 1; // Optional parameter, tensor name | |||||
| DataType dtype = 2; // tensor datatype | |||||
| ShapeDef shape = 3; // Shape / dimension | |||||
| string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND" | |||||
| bool has_out_attr = 9; | |||||
| int64 size = 10; | |||||
| int64 weight_size = 11; | |||||
| bool reuse_input = 12; | |||||
| bool output_tensor = 13; | |||||
| string device_type = 14; | |||||
| bool input_tensor =15; | |||||
| int64 real_dim_cnt = 16; | |||||
| int64 reuse_input_index = 17; | |||||
| int64 data_offset = 18; | |||||
| int64 cmps_size = 19; | |||||
| string cmps_tab = 20; | |||||
| int64 cmps_tab_offset = 21; | |||||
| map<string, AttrDef> attr = 5; // Set of extra parameter fields | |||||
| } | |||||
| // GeTensor definition | |||||
| message TensorDef | |||||
| { | |||||
| TensorDescriptor desc = 1; // Tensor description | |||||
| bytes data = 2; // Tensor data | |||||
| } | |||||
| // Operator description | |||||
| message OpDef | |||||
| { | |||||
| string name = 1; // name | |||||
| string type = 2; // type | |||||
| repeated string input = 5; // input original op name + outgoing index. op_name:index | |||||
| map<string, AttrDef> attr = 10; // Set of operator parameter fields | |||||
| bool has_out_attr = 20; | |||||
| int64 id = 21; | |||||
| int64 stream_id =22; | |||||
| repeated string input_name = 23; | |||||
| repeated string src_name = 24; | |||||
| repeated int64 src_index = 25; | |||||
| repeated string dst_name = 26; | |||||
| repeated int64 dst_index = 27; | |||||
| repeated int64 input_i = 28; | |||||
| repeated int64 output_i = 29; | |||||
| repeated int64 workspace = 30; | |||||
| repeated int64 workspace_bytes = 31; | |||||
| repeated bool is_input_const = 32; | |||||
| repeated TensorDescriptor input_desc = 33; | |||||
| repeated TensorDescriptor output_desc = 34; | |||||
| repeated string subgraph_name = 35; | |||||
| } | |||||
| // Graph definition | |||||
| message GraphDef | |||||
| { | |||||
| string name = 1; // name | |||||
| repeated string input = 4; // Graph input | |||||
| repeated string output = 5; // Graph output | |||||
| repeated OpDef op = 6; // List of operators | |||||
| map<string, AttrDef> attr = 11; // Extended field | |||||
| } | |||||
| // model definition | |||||
| message ModelDef | |||||
| { | |||||
| string name = 1; // name | |||||
| uint32 version = 2; // IR Proto verion | |||||
| string custom_version = 3; // User model version number, passed in by user | |||||
| repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef | |||||
| map<string, AttrDef> attr = 11; // Extended field | |||||
| } | |||||
| @@ -0,0 +1,396 @@ | |||||
| /* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. | |||||
| * | |||||
| * This program is free software; you can redistribute it and/or modify | |||||
| * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. | |||||
| * | |||||
| * This program is distributed in the hope that it will be useful, | |||||
| * but WITHOUT ANY WARRANTY; without even the implied warranty of | |||||
| * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |||||
| * Apache License for more details at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| */ | |||||
| syntax = "proto3"; | |||||
| package domi; | |||||
| enum TargetType | |||||
| { | |||||
| MINI = 0; | |||||
| TINY = 1; | |||||
| LITE = 2; | |||||
| } | |||||
| // offline model | |||||
| message ModelDef { | |||||
| string name = 1; | |||||
| uint32 version = 2; | |||||
| uint64 memory_size = 10; | |||||
| uint32 stream_num = 11; | |||||
| uint32 event_num = 12; | |||||
| uint64 weight_size = 13; | |||||
| uint32 label_num = 15; | |||||
| repeated OpDef op = 20; | |||||
| TargetType target_type = 23; | |||||
| map<string, AttrDef> attr = 30; | |||||
| }; | |||||
| // operator define | |||||
| message OpDef { | |||||
| string name = 1; | |||||
| string type = 2; | |||||
| uint32 id = 3; | |||||
| uint32 stream_id = 4; | |||||
| repeated string input_name = 5; | |||||
| repeated string src_name = 8; | |||||
| repeated int32 src_index = 9; | |||||
| repeated int64 input = 10; | |||||
| repeated int64 output = 11; | |||||
| repeated TensorDescriptor input_desc = 12; | |||||
| repeated TensorDescriptor output_desc = 13; | |||||
| repeated WeightDef weights = 14; | |||||
| repeated string dst_name = 15; | |||||
| repeated int32 dst_index = 16; | |||||
| repeated int64 workspace = 20; | |||||
| repeated uint32 workspace_bytes = 21; | |||||
| repeated string weight_name = 22; | |||||
| repeated bool is_input_const = 23; | |||||
| map<string, AttrDef> attr = 30; | |||||
| QuantizeFactorParams quantize_factor = 31; | |||||
| oneof op_params { | |||||
| // start at 100 here | |||||
| SendOpParams sender_param = 100; | |||||
| RecvOpParams receiver_param = 200; | |||||
| ConvolutionOpParams convolution_param = 300; | |||||
| PoolingOpParams pooling_param = 400; | |||||
| EltwiseOpParams eltwise_param = 500; | |||||
| BatchNormOpParams batchnorm_param = 600; | |||||
| ScaleOpParams scale_param = 700; | |||||
| FullConnectionOpParams full_connection_param = 800; | |||||
| SoftmaxOpParams softmax_param = 900; | |||||
| ActivationOpParams activation_param = 1000; | |||||
| ReshapeOpParams reshape_param = 1100; | |||||
| } | |||||
| }; | |||||
| message SendOpParams { | |||||
| uint32 event_id = 1; | |||||
| }; | |||||
| message RecvOpParams { | |||||
| uint32 event_id = 1; | |||||
| }; | |||||
| enum QuantizeScaleType | |||||
| { | |||||
| VECTOR_SCALE = 0; | |||||
| SCALAR_SCALE = 1; | |||||
| } | |||||
| enum QuantizeScaleMode | |||||
| { | |||||
| NORMAL_MODE = 0; | |||||
| SQRT_MODE = 1; | |||||
| } | |||||
| enum QuantizeAlgorithm | |||||
| { | |||||
| NON_OFFSET_ALGO = 0; | |||||
| HALF_OFFSET_ALGO = 1; | |||||
| ALL_OFFSET_ALGO = 2; | |||||
| } | |||||
| message QuantizeFactor | |||||
| { | |||||
| QuantizeScaleMode scale_mode = 1; | |||||
| bytes scale_value = 2; | |||||
| int64 scale_offset = 3; | |||||
| bytes offset_data_value = 4; | |||||
| int64 offset_data_offset = 5; | |||||
| bytes offset_weight_value = 6; | |||||
| int64 offset_weight_offset = 7; | |||||
| bytes offset_pad_value = 8; | |||||
| int64 offset_pad_offset = 9; | |||||
| }; | |||||
| message QuantizeCalcFactor | |||||
| { | |||||
| bytes offsetw = 1; | |||||
| int64 offsetw_offset = 2; | |||||
| bytes offsetd = 3; | |||||
| int64 offsetd_offset = 4; | |||||
| bytes scalereq = 5; | |||||
| int64 scaledreq_offset = 6; | |||||
| bytes offsetdnext = 7; | |||||
| int64 offsetdnext_offset = 8; | |||||
| } | |||||
| message QuantizeFactorParams | |||||
| { | |||||
| QuantizeAlgorithm quantize_algo = 1; | |||||
| QuantizeScaleType scale_type = 2; | |||||
| QuantizeFactor quantize_param = 3; | |||||
| QuantizeFactor dequantize_param = 4; | |||||
| QuantizeFactor requantize_param = 5; | |||||
| QuantizeCalcFactor quantizecalc_param = 6; | |||||
| }; | |||||
| message ConvolutionOpParams { | |||||
| int32 mode = 1; | |||||
| int32 algo = 2; | |||||
| int32 pad_mode = 3; | |||||
| uint32 group = 4; | |||||
| uint32 num_output = 5; | |||||
| repeated uint32 pad = 10; | |||||
| repeated uint32 stride = 11; | |||||
| repeated uint32 dilation = 12; | |||||
| repeated uint32 kernel = 13; | |||||
| float alpha = 20; | |||||
| float beta = 21; | |||||
| WeightDef filter = 40; | |||||
| WeightDef bias = 41; | |||||
| bool relu_flag = 62; | |||||
| repeated uint32 adj = 70; | |||||
| repeated uint32 target_shape = 71; | |||||
| repeated uint32 before_pad = 72; | |||||
| }; | |||||
| message PoolingOpParams { | |||||
| int32 mode = 1; | |||||
| int32 nan_opt = 2; | |||||
| int32 pad_mode = 3; | |||||
| bool global_pooling = 4; | |||||
| repeated uint32 window = 10; | |||||
| repeated uint32 pad = 11; | |||||
| repeated uint32 stride = 12; | |||||
| bool ceil_mode = 13; | |||||
| int32 data_mode = 14; | |||||
| float alpha = 20; | |||||
| float beta = 21; | |||||
| repeated uint32 before_pad = 22; | |||||
| }; | |||||
| message EltwiseOpParams { | |||||
| int32 mode = 1; | |||||
| repeated float coeff = 2; | |||||
| float alpha = 3; | |||||
| float beta = 4; | |||||
| repeated WeightDef weight = 5; | |||||
| bool relu_flag = 6; | |||||
| }; | |||||
| message ActivationOpParams { | |||||
| int32 mode = 1; | |||||
| float coef = 2; | |||||
| float alpha = 3; | |||||
| float beta = 4; | |||||
| }; | |||||
| message BatchNormOpParams { | |||||
| int32 mode = 1; | |||||
| float alpha = 2; | |||||
| float beta = 3; | |||||
| double epsilon = 4;//optinal,[default = 1e-5] | |||||
| bool use_global_stats = 5; //optinal,by default true,testing mode | |||||
| float moving_average_fraction = 6; //optinal,[default = .999]; | |||||
| WeightDef estimated_mean = 7; | |||||
| WeightDef estimated_variance = 8; | |||||
| WeightDef scale = 9; | |||||
| WeightDef bias = 10; | |||||
| }; | |||||
| message ScaleOpParams { | |||||
| WeightDef scale = 1; | |||||
| WeightDef bias = 2; | |||||
| }; | |||||
| message ReshapeOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| ShapeDef shape = 3; | |||||
| int32 axis = 4; | |||||
| int32 num_axes = 5; | |||||
| int32 format = 6; | |||||
| }; | |||||
| message SoftmaxOpParams { | |||||
| int32 algo = 1; | |||||
| int32 mode = 2; | |||||
| float alpha = 3; | |||||
| float beta = 4; | |||||
| }; | |||||
| message FullConnectionOpParams { | |||||
| WeightDef filter = 1; | |||||
| WeightDef bias = 2; | |||||
| uint32 num_output = 3; | |||||
| bool relu_flag = 12; | |||||
| }; | |||||
| message FlattenOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| int32 start_axis = 3; | |||||
| int32 end_axis = 4; | |||||
| } | |||||
| message AddLimitedOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| int32 axis = 3; | |||||
| bool broadcast = 4; | |||||
| repeated WeightDef weight = 10; | |||||
| }; | |||||
| message MulLimitedOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| int32 axis = 3; | |||||
| bool broadcast = 4; | |||||
| repeated WeightDef weight = 10; | |||||
| }; | |||||
| message AddOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| repeated WeightDef weight = 10; | |||||
| }; | |||||
| message MulOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| repeated WeightDef weight = 10; | |||||
| }; | |||||
| message SubOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| repeated WeightDef weight = 10; | |||||
| }; | |||||
| message BiasAddOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| WeightDef bias = 10; | |||||
| }; | |||||
| message MatMulOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| bool transposeX = 3; | |||||
| bool transposeW = 4; | |||||
| WeightDef filter = 10; | |||||
| WeightDef bias = 12; | |||||
| }; | |||||
| message RsqrtOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| }; | |||||
| message WeightDef { | |||||
| int32 format = 1; | |||||
| int32 data_type = 2; | |||||
| ShapeDef shape = 3; | |||||
| bytes data = 4; | |||||
| int64 data_offset = 5; | |||||
| uint32 cmps_size = 6; | |||||
| bytes cmps_tab = 7; | |||||
| int64 cmps_tab_offset = 10; | |||||
| CompressInfo cmps_info = 8; | |||||
| AllOffsetQuantizeInfo alloffset_quantize_info = 11; | |||||
| } | |||||
| message ShapeDef { | |||||
| repeated int64 dim = 1; | |||||
| } | |||||
| enum DeviceType { | |||||
| NPU = 0; // In default, we will use NPU. | |||||
| CPU = 1; // CPU | |||||
| } | |||||
| message AllOffsetQuantizeInfo { | |||||
| float scale = 1; | |||||
| int32 offset = 2; | |||||
| } | |||||
| message TensorDescriptor { | |||||
| int32 format = 1; | |||||
| int32 data_type = 2; | |||||
| repeated int64 dim = 3; | |||||
| uint32 size = 4; | |||||
| bool reuse_input = 5; | |||||
| bool output_tensor = 7; | |||||
| DeviceType device_type = 8; | |||||
| bool input_tensor = 9; | |||||
| uint32 real_dim_cnt = 10; | |||||
| uint32 reuse_input_index = 11; | |||||
| AllOffsetQuantizeInfo alloffset_quantize_info = 12; | |||||
| } | |||||
| message CompressInfo { | |||||
| int32 blockRow = 1; // block row | |||||
| int32 blockCol = 2; // block col | |||||
| int32 fractalK = 3; // fractal K | |||||
| int32 fractalN = 4; // fractal N | |||||
| int32 lastFractalK = 5; // K of last fractal | |||||
| int32 lastFractalN = 6; // N of last fractal | |||||
| int32 cubeSize = 7; // cube's length | |||||
| int32 loadDir = 8; // data load directtiono 0:col load 1:row load | |||||
| } | |||||
| message AttrDef { | |||||
| message ListValue { | |||||
| repeated string s = 2; // "list(string)" | |||||
| repeated int64 i = 3 [packed = true]; // "list(int)" | |||||
| repeated float f = 4 [packed = true]; // "list(float)" | |||||
| repeated bool b = 5 [packed = true]; // "list(bool)" | |||||
| repeated uint32 u = 6 [packed = true]; // "list(uint)" | |||||
| repeated bytes bt = 7; | |||||
| } | |||||
| oneof value { | |||||
| string s = 2; // "string" | |||||
| int64 i = 3; // "int" | |||||
| float f = 4; // "float" | |||||
| bool b = 5; // "bool" | |||||
| uint32 u = 6; // "uint32" | |||||
| bytes bt = 7; | |||||
| ListValue list = 1; // any "list(...)" | |||||
| NamedAttrs func = 10; | |||||
| } | |||||
| } | |||||
| // A list of attr names and their values. The whole list is attached | |||||
| // with a string name. E.g., MatMul[T=float]. | |||||
| message NamedAttrs { | |||||
| string name = 1; | |||||
| map<string, AttrDef> attr = 2; | |||||
| } | |||||
| @@ -0,0 +1,165 @@ | |||||
| /* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. | |||||
| * | |||||
| * This program is free software; you can redistribute it and/or modify | |||||
| * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. | |||||
| * | |||||
| * This program is distributed in the hope that it will be useful, | |||||
| * but WITHOUT ANY WARRANTY; without even the implied warranty of | |||||
| * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |||||
| * Apache License for more details at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| */ | |||||
| syntax = "proto3"; | |||||
| package domi; | |||||
| message ModelTaskDef { | |||||
| string version = 1; | |||||
| map<string, string> attr = 9; // Extended field | |||||
| repeated TaskDef task = 10; | |||||
| uint64 memory_size = 11; | |||||
| uint32 stream_num = 12; | |||||
| uint32 event_num = 13; | |||||
| uint64 weight_size = 14; | |||||
| repeated bytes op = 15; // input/output opdef in bytes | |||||
| uint64 base_addr = 16; // base addr | |||||
| uint64 weight_addr = 17; // weight addr | |||||
| uint32 batch_num = 18; | |||||
| } | |||||
| message TaskDef { | |||||
| uint32 id = 1; | |||||
| uint32 type = 2; | |||||
| uint32 stream_id = 10; | |||||
| uint32 event_id = 11; | |||||
| KernelDef kernel = 20; | |||||
| KernelExDef kernel_ex = 21; | |||||
| KernelHcclDef kernel_hccl = 25; | |||||
| EventExDef event_ex = 26; | |||||
| LogTimeStampDef log_timestamp = 28; | |||||
| uint32 label_id = 30; | |||||
| MemcpyAsyncDef memcpy_async = 31; | |||||
| StreamSwitchDef stream_switch = 32; | |||||
| StreamActiveDef stream_active = 33; | |||||
| bytes private_def = 34; | |||||
| uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future | |||||
| StreamSwitchNDef stream_switch_n = 36; | |||||
| LabelSetDef label_set = 37; | |||||
| LabelGotoExDef label_goto_ex = 38; | |||||
| LabelSwitchByIndexDef label_switch_by_index = 39; | |||||
| } | |||||
| message KernelDef { | |||||
| KernelContext context = 1; | |||||
| string stub_func = 10; | |||||
| uint32 block_dim = 11; | |||||
| uint32 args_size = 12; | |||||
| bytes args = 13; | |||||
| bytes sm_desc = 14; | |||||
| bytes flowtable = 15; | |||||
| string so_name = 16; | |||||
| string kernel_name = 17; | |||||
| bytes kernel_ext_info = 18; | |||||
| uint32 kernel_ext_info_size = 19; | |||||
| } | |||||
| message KernelContext { | |||||
| uint32 kernel_type = 1; | |||||
| uint32 op_id = 2; // OP type in CCE | |||||
| uint32 kernel_func_id = 3; | |||||
| uint32 op_index = 4; // TE/Custom operator | |||||
| bool is_flowtable = 5; // Identify whether args is a flowtable structure | |||||
| bytes args_offset = 6; // args offset information | |||||
| uint32 args_count = 7; // args count | |||||
| repeated uint32 origin_op_index = 8; | |||||
| } | |||||
| message KernelExDef { | |||||
| uint32 flags = 1; | |||||
| uint32 op_index = 4; | |||||
| uint32 args_size = 12; | |||||
| bytes args = 13; | |||||
| bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput | |||||
| uint32 task_info_size = 15; | |||||
| bytes kernel_ext_info = 16; | |||||
| uint32 kernel_ext_info_size = 17; | |||||
| } | |||||
| message KernelHcclDef { | |||||
| uint32 op_index = 8; | |||||
| string hccl_type = 9; | |||||
| } | |||||
| message EventExDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 event_type = 2; | |||||
| } | |||||
| message LogTimeStampDef { | |||||
| uint64 logid = 1; | |||||
| bool notify = 2; | |||||
| uint32 flat = 3; | |||||
| } | |||||
| message MemcpyAsyncDef { | |||||
| uint64 dst = 1; | |||||
| uint64 dst_max = 2; | |||||
| uint64 src = 3; | |||||
| uint64 count = 4; | |||||
| uint32 kind = 5; | |||||
| uint32 op_index = 6; | |||||
| } | |||||
| message StreamSwitchDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 true_stream_id = 2; | |||||
| int64 value = 3; | |||||
| uint64 value_ptr = 4; | |||||
| uint32 data_type = 5; | |||||
| } | |||||
| message StreamActiveDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 active_stream_id = 2; | |||||
| } | |||||
| message StreamSwitchNDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 size = 2; | |||||
| repeated int64 target_value = 3; | |||||
| repeated uint32 true_stream_id = 4; | |||||
| uint32 element_size = 5; | |||||
| uint32 data_type = 6; | |||||
| } | |||||
| message LabelSetDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 label_id = 2; | |||||
| uint32 model_id = 3; | |||||
| } | |||||
| message LabelGotoExDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 label_id = 2; | |||||
| uint32 model_id = 3; | |||||
| } | |||||
| message LabelSwitchByIndexDef { | |||||
| uint32 op_index = 1; | |||||
| uint32 label_max = 2; | |||||
| } | |||||
| @@ -0,0 +1,76 @@ | |||||
| set(SRC_LIST | |||||
| "parser_factory.cc" | |||||
| "data_op_parser.cc" | |||||
| "op_parser_factory.cc" | |||||
| "pre_checker.cc" | |||||
| "register_tbe.cc" | |||||
| "parser_api.cc" | |||||
| "parser_inner_ctx.cc" | |||||
| "proto_file_parser.cc" | |||||
| "acl_graph_parser_util.cc" | |||||
| "tbe_plugin_loader.cc" | |||||
| "model_saver.cc" | |||||
| "../tensorflow/tensorflow_custom_parser_adapter.cc" | |||||
| "../tensorflow/tensorflow_fusion_custom_parser_adapter.cc" | |||||
| "../tensorflow/tensorflow_fusion_op_parser.cc" | |||||
| "../tensorflow/tensorflow_util.cc" | |||||
| "convert/pb2json.cc" | |||||
| "op_def/ir_pb_converter.cc" | |||||
| "op_def/defs.cc" | |||||
| "op_def/op_schema.cc" | |||||
| "op_def/operator.cc" | |||||
| "op_map.cc" | |||||
| "parser_types.cc" | |||||
| "pass_manager.cc" | |||||
| "parser_fp16_t.cc" | |||||
| "thread_pool.cc" | |||||
| ) | |||||
| ############ libparser_common.so ############ | |||||
| add_library(parser_common SHARED ${SRC_LIST}) | |||||
| target_compile_options(parser_common PRIVATE | |||||
| -Werror | |||||
| ) | |||||
| target_compile_definitions(parser_common PRIVATE | |||||
| PROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||||
| ) | |||||
| target_include_directories(parser_common PRIVATE | |||||
| ${CMAKE_CURRENT_LIST_DIR} | |||||
| ${TOP_DIR}/framework/domi | |||||
| ${TOP_DIR}/framework/domi/common | |||||
| ${TOP_DIR}/framework/domi/parser | |||||
| ${TOP_DIR}/inc | |||||
| ${TOP_DIR}/inc/common/util | |||||
| ${TOP_DIR}/inc/external | |||||
| ${TOP_DIR}/inc/external/graph | |||||
| ${TOP_DIR}/inc/framework | |||||
| ${CMAKE_BINARY_DIR} | |||||
| ${CMAKE_BINARY_DIR}/proto/ge | |||||
| ) | |||||
| target_link_libraries(parser_common PRIVATE | |||||
| $<BUILD_INTERFACE:intf_pub> | |||||
| -Wl,--no-as-needed | |||||
| graph | |||||
| protobuf | |||||
| register | |||||
| c_sec | |||||
| slog | |||||
| mmpa | |||||
| error_manager | |||||
| -Wl,--as-needed | |||||
| json | |||||
| -lrt | |||||
| -ldl | |||||
| ) | |||||
| ############ install ############ | |||||
| set(INSTALL_BASE_DIR "") | |||||
| set(INSTALL_LIBRARY_DIR lib) | |||||
| install(TARGETS parser_common OPTIONAL | |||||
| LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} | |||||
| ) | |||||
| @@ -0,0 +1,492 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "parser/common/acl_graph_parser_util.h" | |||||
| #include <dlfcn.h> | |||||
| #include <cstdlib> | |||||
| #include <fstream> | |||||
| #include <regex.h> | |||||
| #include <ctime> | |||||
| #include "common/string_util.h" | |||||
| #include "common/debug/log.h" | |||||
| #include "common/op/ge_op_utils.h" | |||||
| #include "ge/ge_api_types.h" | |||||
| #include "graph/opsproto_manager.h" | |||||
| #include "omg/parser/parser_inner_ctx.h" | |||||
| #include "tbe_plugin_loader.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "parser/common/register_tbe.h" | |||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "common/util/error_manager/error_manager.h" | |||||
| #include "google/protobuf/io/coded_stream.h" | |||||
| #include "google/protobuf/io/zero_copy_stream_impl.h" | |||||
| using google::protobuf::io::CodedInputStream; | |||||
| using google::protobuf::io::FileInputStream; | |||||
| using google::protobuf::io::ZeroCopyInputStream; | |||||
| using namespace ge::parser; | |||||
| namespace { | |||||
| /// The maximum length of the file. | |||||
| /// Based on the security coding specification and the current actual (protobuf) model size, it is determined as 2G-1 | |||||
| const int kMaxFileSizeLimit = INT_MAX; | |||||
| const int kMaxBuffSize = 256; | |||||
| const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte. | |||||
| const int kWarningThreshold = 536870912 * 2; // 536870912 represent 512M | |||||
| static string GetSoPath() { | |||||
| Dl_info dl_info; | |||||
| if (dladdr(reinterpret_cast<void *>(&GetSoPath), &dl_info) == 0) { | |||||
| GELOGW("Failed to read so_path!"); | |||||
| return string(); | |||||
| } else { | |||||
| std::string so_path = dl_info.dli_fname; | |||||
| char path[PATH_MAX] = {0}; | |||||
| if (so_path.length() >= PATH_MAX) { | |||||
| GELOGW("File path is too long!"); | |||||
| return string(); | |||||
| } | |||||
| if (realpath(so_path.c_str(), path) == nullptr) { | |||||
| GELOGW("Failed to get realpath of %s", so_path.c_str()); | |||||
| return string(); | |||||
| } | |||||
| so_path = path; | |||||
| so_path = so_path.substr(0, so_path.rfind('/') + 1); | |||||
| return so_path; | |||||
| } | |||||
| } | |||||
| static void GetOpsProtoPath(string &opsproto_path) { | |||||
| GELOGD("Start to get ops proto path schedule."); | |||||
| const char *path_env = std::getenv("ASCEND_OPP_PATH"); | |||||
| if (path_env != nullptr) { | |||||
| string path = path_env; | |||||
| string file_path = ge::parser::RealPath(path.c_str()); | |||||
| if (file_path.empty()) { | |||||
| GELOGE(ge::FAILED, "File path %s is invalid.", path.c_str()); | |||||
| return; | |||||
| } | |||||
| opsproto_path = (path + "/op_proto/custom/" + ":") + (path + "/op_proto/built-in/"); | |||||
| GELOGI("Get opsproto so path from env : %s", path.c_str()); | |||||
| return; | |||||
| } | |||||
| string path_base = GetSoPath(); | |||||
| GELOGI("path_base is %s", path_base.c_str()); | |||||
| path_base = path_base.substr(0, path_base.rfind('/')); | |||||
| path_base = path_base.substr(0, path_base.rfind('/') + 1); | |||||
| opsproto_path = (path_base + "ops/op_proto/custom/" + ":") + (path_base + "ops/op_proto/built-in/"); | |||||
| } | |||||
| } // namespace | |||||
| namespace ge { | |||||
| domi::Status AclGrphParseUtil::GetOutputLeaf(NodePtr node, | |||||
| std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info) { | |||||
| ge::OpDescPtr tmpDescPtr = node->GetOpDesc(); | |||||
| if (tmpDescPtr == nullptr) { | |||||
| GELOGE(domi::FAILED, "Get outnode op desc fail."); | |||||
| return domi::FAILED; | |||||
| } | |||||
| size_t size = tmpDescPtr->GetOutputsSize(); | |||||
| if (node->GetType() != NETOUTPUT) { | |||||
| for (size_t index = 0; index < size; ++index) { | |||||
| output_nodes_info.push_back(std::make_pair(node, index)); | |||||
| } | |||||
| } else { | |||||
| const auto in_anchors = node->GetAllInDataAnchors(); | |||||
| for (auto in_anchor : in_anchors) { | |||||
| auto out_anchor = in_anchor->GetPeerOutAnchor(); | |||||
| if (out_anchor == nullptr) { | |||||
| GELOGE(domi::FAILED, "Get leaf node op desc fail."); | |||||
| return domi::FAILED; | |||||
| } | |||||
| auto out_node = out_anchor->GetOwnerNode(); | |||||
| output_nodes_info.push_back(std::make_pair(out_node, out_anchor->GetIdx())); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| void AclGrphParseUtil::GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info, | |||||
| std::vector<std::string> &output_nodes_name) { | |||||
| output_nodes_name.clear(); | |||||
| if (ge::GetParserContext().out_top_names.empty()) { | |||||
| // tf process, no top name. | |||||
| for (const auto output_node_info : output_nodes_info) { | |||||
| std::string node_name = output_node_info.first->GetName(); | |||||
| int32_t index = output_node_info.second; | |||||
| output_nodes_name.push_back(node_name + ":" + std::to_string(index)); | |||||
| } | |||||
| return; | |||||
| } | |||||
| // caffe process reserved place; | |||||
| } | |||||
| domi::Status AclGrphParseUtil::SetDefaultOutputNode(ge::Graph &graph) { | |||||
| ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); | |||||
| if (compute_graph == nullptr) { | |||||
| GELOGE(FAILED, "compute_graph is nullptr."); | |||||
| return FAILED; | |||||
| } | |||||
| std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes_info; | |||||
| std::vector<std::string> output_nodes_name; | |||||
| for (ge::NodePtr node : compute_graph->GetDirectNode()) { | |||||
| if (!node->GetInAllNodes().empty() && node->GetOutAllNodes().empty()) { | |||||
| Status ret = AclGrphParseUtil::GetOutputLeaf(node, output_nodes_info); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(FAILED, "find leaf fail."); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| AclGrphParseUtil::GetOutputNodesNameAndIndex(output_nodes_info, output_nodes_name); | |||||
| compute_graph->SetGraphOutNodesInfo(output_nodes_info); | |||||
| ge::GetParserContext().net_out_nodes = output_nodes_name; | |||||
| GELOGI("Set graph %s default output node success.", graph.GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| domi::Status AclGrphParseUtil::LoadOpsProtoLib() { | |||||
| string opsproto_path; | |||||
| GetOpsProtoPath(opsproto_path); | |||||
| GELOGI("Get opsproto path is %s", opsproto_path.c_str()); | |||||
| OpsProtoManager *manager = OpsProtoManager::Instance(); | |||||
| map<string, string> option_tmp; | |||||
| option_tmp.emplace(std::pair<string, string>(string("ge.opsProtoLibPath"), opsproto_path)); | |||||
| bool is_proto_init = manager->Initialize(option_tmp); | |||||
| if (!is_proto_init) { | |||||
| GELOGE(FAILED, "Load ops_proto lib failed, ops proto path is invalid."); | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| void AclGrphParseUtil::SaveCustomCaffeProtoPath() { | |||||
| GELOGD("Enter save custom caffe proto path."); | |||||
| std::string path_base = GetSoPath(); | |||||
| path_base = path_base.substr(0, path_base.rfind('/')); | |||||
| path_base = path_base.substr(0, path_base.rfind('/') + 1); | |||||
| ge::GetParserContext().caffe_proto_path = path_base + "include/proto/"; | |||||
| string custom_op_path; | |||||
| const char *path_env = std::getenv("ASCEND_OPP_PATH"); | |||||
| if (path_env != nullptr) { | |||||
| std::string path = path_env; | |||||
| custom_op_path = path + "/framework/custom/caffe/"; | |||||
| GELOGI("Get custom proto path from env : %s", path_env); | |||||
| GetParserContext().custom_proto_path = custom_op_path; | |||||
| return; | |||||
| } | |||||
| custom_op_path = path_base + "ops/framework/custom/caffe/"; | |||||
| ge::GetParserContext().custom_proto_path = custom_op_path; | |||||
| return; | |||||
| } | |||||
| // Initialize PARSER, load custom op plugin | |||||
| // options will be used later for parser decoupling | |||||
| domi::Status AclGrphParseUtil::AclParserInitialize(const std::map<std::string, std::string> &options) { | |||||
| GELOGT(TRACE_INIT, "AclParserInitialize start"); | |||||
| // check init status | |||||
| if (parser_initialized) { | |||||
| GELOGW("AclParserInitialize is called more than once"); | |||||
| return SUCCESS; | |||||
| } | |||||
| // load custom op plugin | |||||
| TBEPluginLoader::Instance().LoadPluginSo(options); | |||||
| // load and save custom op proto for prediction | |||||
| (void)LoadOpsProtoLib(); | |||||
| SaveCustomCaffeProtoPath(); | |||||
| auto op_registry = domi::OpRegistry::Instance(); | |||||
| if (op_registry == nullptr) { | |||||
| GELOGE(FAILED, "Get OpRegistry instance failed"); | |||||
| return FAILED; | |||||
| } | |||||
| std::vector<OpRegistrationData> registrationDatas = op_registry->registrationDatas; | |||||
| GELOGI("The size of registrationDatas in parser is: %zu", registrationDatas.size()); | |||||
| for (OpRegistrationData ®_data : registrationDatas) { | |||||
| (void)OpRegistrationTbe::Instance()->Finalize(reg_data, false); | |||||
| domi::OpRegistry::Instance()->Register(reg_data); | |||||
| } | |||||
| // set init status | |||||
| if (!parser_initialized) { | |||||
| // Initialize success, first time calling initialize | |||||
| parser_initialized = true; | |||||
| } | |||||
| GELOGT(TRACE_STOP, "AclParserInitialize finished"); | |||||
| return SUCCESS; | |||||
| } | |||||
| namespace parser { | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string RealPath(const char *path) { | |||||
| if (path == nullptr) { | |||||
| GELOGE(ge::FAILED, "path pointer is NULL."); | |||||
| return ""; | |||||
| } | |||||
| if (strlen(path) >= PATH_MAX) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19002", {"filepath", "size"}, {path, std::to_string(PATH_MAX)}); | |||||
| GELOGE(ge::FAILED, "Path[%s] len is too long, it must be less than %d", path, PATH_MAX); | |||||
| return ""; | |||||
| } | |||||
| // Nullptr is returned when the path does not exist or there is no permission | |||||
| // Return absolute path when path is accessible | |||||
| std::string res; | |||||
| char resolved_path[PATH_MAX] = {0}; | |||||
| if (realpath(path, resolved_path) != nullptr) { | |||||
| res = resolved_path; | |||||
| } | |||||
| return res; | |||||
| } | |||||
| // Get file length | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY long GetFileLength(const std::string &input_file) { | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(input_file.empty(), return -1, "input_file path is null."); | |||||
| std::string real_path = RealPath(input_file.c_str()); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return -1, "input_file path '%s' not valid", input_file.c_str()); | |||||
| unsigned long long file_length = 0; | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(mmGetFileSize(input_file.c_str(), &file_length) != EN_OK, | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, | |||||
| {input_file, strerror(errno)}); | |||||
| return -1, "Open file[%s] failed. %s", input_file.c_str(), strerror(errno)); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_length == 0), | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19015", {"filepath"}, {input_file}); | |||||
| return -1, "File[%s] size is 0, not valid.", input_file.c_str()); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(file_length > kMaxFileSizeLimit, | |||||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||||
| "E19016", {"filepath", "filesize", "maxlen"}, | |||||
| {input_file, std::to_string(file_length), std::to_string(kMaxFileSizeLimit)}); | |||||
| return -1, "File[%s] size %lld is out of limit: %d.", | |||||
| input_file.c_str(), file_length, kMaxFileSizeLimit); | |||||
| return static_cast<long>(file_length); | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t GetCurrentTimestamp() { | |||||
| struct timeval tv{}; | |||||
| int ret = gettimeofday(&tv, nullptr); | |||||
| GE_LOGE_IF(ret != 0, "Func gettimeofday may failed: ret=%d", ret); | |||||
| auto total_use_time = tv.tv_usec + tv.tv_sec * 1000000; // 1000000: seconds to microseconds | |||||
| return static_cast<uint64_t>(total_use_time); | |||||
| } | |||||
| static bool ReadProtoFromCodedInputStream(CodedInputStream &coded_stream, Message *proto) { | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(proto == nullptr, | |||||
| return false, "incorrect parameter. nullptr == proto"); | |||||
| coded_stream.SetTotalBytesLimit(kProtoReadBytesLimit, kWarningThreshold); | |||||
| return proto->ParseFromCodedStream(&coded_stream); | |||||
| } | |||||
| /** @ingroup domi_common | |||||
| * @brief Read all data from binary file | |||||
| * @param [in] file_name File path | |||||
| * @param [out] buffer The address of the output memory, which needs to be released by the caller | |||||
| * @param [out] length Output memory size | |||||
| * @return false fail | |||||
| * @return true success | |||||
| */ | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadBytesFromBinaryFile(const char *file_name, char **buffer, | |||||
| int &length) { | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_name == nullptr), return false, "incorrect parameter. file is nullptr"); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((buffer == nullptr), return false, "incorrect parameter. buffer is nullptr"); | |||||
| std::string real_path = RealPath(file_name); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return false, "file path '%s' not valid", file_name); | |||||
| std::ifstream file(real_path.c_str(), std::ios::binary | std::ios::ate); | |||||
| if (!file.is_open()) { | |||||
| GELOGE(ge::FAILED, "Read file %s failed.", file_name); | |||||
| return false; | |||||
| } | |||||
| length = static_cast<int>(file.tellg()); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((length <= 0), file.close(); return false, "file length <= 0"); | |||||
| file.seekg(0, std::ios::beg); | |||||
| *buffer = new(std::nothrow) char[length](); | |||||
| GE_CHK_BOOL_TRUE_EXEC_RET_STATUS(*buffer == nullptr, false, file.close(), "new an object failed."); | |||||
| file.read(*buffer, length); | |||||
| file.close(); | |||||
| return true; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromBinaryFile(const char *file, Message *proto) { | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file == nullptr || proto == nullptr), | |||||
| return false, | |||||
| "Input parameter file or proto is nullptr!"); | |||||
| std::string real_path = RealPath(file); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), | |||||
| return false, "pb file path '%s' not valid", file); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == -1, return false, "file size not valid."); | |||||
| std::ifstream fs(real_path, std::ifstream::in | std::ifstream::binary); | |||||
| if (!fs.is_open()) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, {file, "ifstream is_open failed"}); | |||||
| GELOGE(ge::FAILED, "Open real path[%s] failed.", file); | |||||
| return false; | |||||
| } | |||||
| google::protobuf::io::IstreamInputStream istream(&fs); | |||||
| google::protobuf::io::CodedInputStream coded_stream(&istream); | |||||
| bool ret = ReadProtoFromCodedInputStream(coded_stream, proto); | |||||
| fs.close(); | |||||
| if (!ret) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19005", {"file"}, {file}); | |||||
| GELOGE(ge::FAILED, "Parse file[%s] failed.", file); | |||||
| return ret; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromArray(const void *data, int size, Message *proto) { | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((proto == nullptr || data == nullptr || size == 0), return false, | |||||
| "incorrect parameter. proto is nullptr || data is nullptr || size is 0"); | |||||
| google::protobuf::io::CodedInputStream coded_stream(reinterpret_cast<uint8_t *>(const_cast<void *>(data)), size); | |||||
| return ReadProtoFromCodedInputStream(coded_stream, proto); | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromText(const char *file, | |||||
| google::protobuf::Message *message) { | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file == nullptr || message == nullptr), return false, | |||||
| "incorrect parameter. nullptr == file || nullptr == message"); | |||||
| std::string real_path = RealPath(file); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19000", {"path", "errmsg"}, | |||||
| {file, strerror(errno)}); | |||||
| return false, "Path[%s]'s realpath is empty, errmsg[%s]", file, | |||||
| strerror(errno)); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == -1, return false, "file size not valid."); | |||||
| std::ifstream fs(real_path.c_str(), std::ifstream::in); | |||||
| if (!fs.is_open()) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19017", {"realpth", "protofile"}, {real_path, file}); | |||||
| GELOGE(ge::FAILED, | |||||
| "Fail to open proto file real path is '%s' when orginal file path is '%s'.", real_path.c_str(), file); | |||||
| return false; | |||||
| } | |||||
| google::protobuf::io::IstreamInputStream input(&fs); | |||||
| bool ret = google::protobuf::TextFormat::Parse(&input, message); | |||||
| GE_IF_BOOL_EXEC(!ret, | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19018", {"protofile"}, {file}); | |||||
| GELOGE(ret, "Parse file[%s] through [google::protobuf::TextFormat::Parse] failed, " | |||||
| "please check whether the file is a valid protobuf format file.", file)); | |||||
| fs.close(); | |||||
| return ret; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromMem(const char *data, int size, | |||||
| google::protobuf::Message *message) { | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((data == nullptr || message == nullptr), return false, | |||||
| "incorrect parameter. data is nullptr || message is nullptr"); | |||||
| std::string str(data, static_cast<size_t>(size)); | |||||
| std::istringstream fs(str); | |||||
| google::protobuf::io::IstreamInputStream input(&fs); | |||||
| bool ret = google::protobuf::TextFormat::Parse(&input, message); | |||||
| GE_IF_BOOL_EXEC( | |||||
| !ret, GELOGE(ret, "Call [google::protobuf::TextFormat::Parse] func ret fail, please check your text file.")); | |||||
| return ret; | |||||
| } | |||||
| /// | |||||
| /// @brief get the Original Type of FrameworkOp | |||||
| /// @param [in] node | |||||
| /// @param [out] type | |||||
| /// @return Status | |||||
| /// | |||||
| Status GetOriginalType(const ge::NodePtr &node, string &type) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| type = node->GetType(); | |||||
| GE_IF_BOOL_EXEC(type != FRAMEWORKOP, return SUCCESS); | |||||
| GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
| bool ret = ge::AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type); | |||||
| if (!ret) { | |||||
| GELOGE(INTERNAL_ERROR, "Get FrameWorkOp original type [%s]", type.c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| GELOGD("Get FrameWorkOp original type [%s]", type.c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY bool ValidateStr(const std::string &str, const std::string &mode) { | |||||
| char ebuff[kMaxBuffSize]; | |||||
| regex_t reg; | |||||
| int cflags = REG_EXTENDED | REG_NOSUB; | |||||
| int ret = regcomp(®, mode.c_str(), cflags); | |||||
| if (ret) { | |||||
| regerror(ret, ®, ebuff, kMaxBuffSize); | |||||
| GELOGW("regcomp failed, reason: %s", ebuff); | |||||
| regfree(®); | |||||
| return true; | |||||
| } | |||||
| ret = regexec(®, str.c_str(), 0, nullptr, 0); | |||||
| if (ret) { | |||||
| regerror(ret, ®, ebuff, kMaxBuffSize); | |||||
| GELOGE(ge::PARAM_INVALID, "regexec failed, reason: %s", ebuff); | |||||
| regfree(®); | |||||
| return false; | |||||
| } | |||||
| regfree(®); | |||||
| return true; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string CurrentTimeInStr() { | |||||
| std::time_t now = std::time(nullptr); | |||||
| std::tm *ptm = std::localtime(&now); | |||||
| if (ptm == nullptr) { | |||||
| GELOGE(ge::FAILED, "Localtime failed."); | |||||
| return ""; | |||||
| } | |||||
| const int kTimeBufferLen = 32; | |||||
| char buffer[kTimeBufferLen + 1] = {0}; | |||||
| // format: 20171122042550 | |||||
| std::strftime(buffer, kTimeBufferLen, "%Y%m%d%H%M%S", ptm); | |||||
| return std::string(buffer); | |||||
| } | |||||
| } // namespace parser | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,161 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef ACL_GRAPH_PARSE_UTIL_ | |||||
| #define ACL_GRAPH_PARSE_UTIL_ | |||||
| #include <map> | |||||
| #include <string> | |||||
| #include <google/protobuf/text_format.h> | |||||
| #include <sstream> | |||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "register/register_error_codes.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| namespace ge { | |||||
| using google::protobuf::Message; | |||||
| class AclGrphParseUtil { | |||||
| public: | |||||
| AclGrphParseUtil() {} | |||||
| virtual ~AclGrphParseUtil() {} | |||||
| domi::Status LoadOpsProtoLib(); | |||||
| void SaveCustomCaffeProtoPath(); | |||||
| domi::Status AclParserInitialize(const std::map<std::string, std::string> &options); | |||||
| domi::Status SetDefaultOutputNode(ge::Graph &graph); | |||||
| private: | |||||
| bool parser_initialized = false; | |||||
| domi::Status GetOutputLeaf(NodePtr node, std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info); | |||||
| void GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info, | |||||
| std::vector<std::string> &output_nodes_name); | |||||
| }; | |||||
| namespace parser { | |||||
| /// | |||||
| /// @ingroup: domi_common | |||||
| /// @brief: get length of file | |||||
| /// @param [in] input_file: path of file | |||||
| /// @return long: File length. If the file length fails to be obtained, the value -1 is returned. | |||||
| /// | |||||
| extern long GetFileLength(const std::string &input_file); | |||||
| /// | |||||
| /// @ingroup domi_common | |||||
| /// @brief Absolute path for obtaining files. | |||||
| /// @param [in] path of input file | |||||
| /// @param [out] Absolute path of a file. If the absolute path cannot be obtained, an empty string is returned | |||||
| /// | |||||
| std::string RealPath(const char *path); | |||||
| /// | |||||
| /// @ingroup domi_common | |||||
| /// @brief Obtains the absolute time (timestamp) of the current system. | |||||
| /// @return Timestamp, in microseconds (US) | |||||
| /// | |||||
| /// | |||||
| uint64_t GetCurrentTimestamp(); | |||||
| /// | |||||
| /// @ingroup domi_common | |||||
| /// @brief Reads all data from a binary file. | |||||
| /// @param [in] file_name path of file | |||||
| /// @param [out] buffer Output memory address, which needs to be released by the caller. | |||||
| /// @param [out] length Output memory size | |||||
| /// @return false fail | |||||
| /// @return true success | |||||
| /// | |||||
| bool ReadBytesFromBinaryFile(const char *file_name, char **buffer, int &length); | |||||
| /// | |||||
| /// @ingroup domi_common | |||||
| /// @brief proto file in bianary format | |||||
| /// @param [in] file path of proto file | |||||
| /// @param [out] proto memory for storing the proto file | |||||
| /// @return true success | |||||
| /// @return false fail | |||||
| /// | |||||
| bool ReadProtoFromBinaryFile(const char *file, Message *proto); | |||||
| /// | |||||
| /// @ingroup domi_common | |||||
| /// @brief Reads the proto structure from an array. | |||||
| /// @param [in] data proto data to be read | |||||
| /// @param [in] size proto data size | |||||
| /// @param [out] proto Memory for storing the proto file | |||||
| /// @return true success | |||||
| /// @return false fail | |||||
| /// | |||||
| bool ReadProtoFromArray(const void *data, int size, Message *proto); | |||||
| /// | |||||
| /// @ingroup domi_proto | |||||
| /// @brief Reads the proto file in the text format. | |||||
| /// @param [in] file path of proto file | |||||
| /// @param [out] message Memory for storing the proto file | |||||
| /// @return true success | |||||
| /// @return false fail | |||||
| /// | |||||
| bool ReadProtoFromText(const char *file, google::protobuf::Message *message); | |||||
| bool ReadProtoFromMem(const char *data, int size, google::protobuf::Message *message); | |||||
| /// | |||||
| /// @brief get the Original Type of FrameworkOp | |||||
| /// @param [in] node | |||||
| /// @param [out] type | |||||
| /// @return Status | |||||
| /// | |||||
| domi::Status GetOriginalType(const ge::NodePtr &node, string &type); | |||||
| /// | |||||
| /// @ingroup domi_common | |||||
| /// @brief Check whether the file path meets the whitelist verification requirements. | |||||
| /// @param [in] filePath file path | |||||
| /// @param [out] result | |||||
| /// | |||||
| bool ValidateStr(const std::string &filePath, const std::string &mode); | |||||
| /// | |||||
| /// @ingroup domi_common | |||||
| /// @brief Obtains the current time string. | |||||
| /// @return Time character string in the format: %Y%m%d%H%M%S, eg: 20171011083555 | |||||
| /// | |||||
| std::string CurrentTimeInStr(); | |||||
| } // namespace parser | |||||
| } // namespace ge | |||||
| /*lint --emacro((773),GE_TIMESTAMP_START)*/ | |||||
| /*lint -esym(773,GE_TIMESTAMP_START)*/ | |||||
| #define PARSER_TIMESTAMP_START(stage) uint64_t startUsec_##stage = ge::parser::GetCurrentTimestamp() | |||||
| #define PARSER_TIMESTAMP_END(stage, stage_name) \ | |||||
| do { \ | |||||
| uint64_t endUsec_##stage = ge::parser::GetCurrentTimestamp(); \ | |||||
| GELOGI("[GEPERFTRACE] The time cost of %s is [%lu] micro second.", (stage_name), \ | |||||
| (endUsec_##stage - startUsec_##stage)); \ | |||||
| } while (0); | |||||
| #define PARSER_TIMESTAMP_EVENT_END(stage, stage_name) \ | |||||
| do { \ | |||||
| uint64_t endUsec_##stage = ge::parser::GetCurrentTimestamp(); \ | |||||
| GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second.", (stage_name), \ | |||||
| (endUsec_##stage - startUsec_##stage)); \ | |||||
| } while (0); | |||||
| #endif // ACL_GRAPH_PARSE_UTIL_ | |||||
| @@ -0,0 +1,248 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| // File: pb2json.h | |||||
| // Description: This imply file for protobuf message and json interconversion | |||||
| #include "common/convert/pb2json.h" | |||||
| #include <set> | |||||
| #include <string> | |||||
| #include "securec.h" | |||||
| #include "framework/common/fmk_types.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| using std::set; | |||||
| using std::string; | |||||
| namespace ge { | |||||
| namespace { | |||||
| const int kSignificantDigits = 10; | |||||
| } | |||||
| // JSON parses non utf8 character throwing exceptions, so some fields need to be shielded through black fields | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void Pb2Json::Message2Json(const ProtobufMsg &message, | |||||
| const set<string> &black_fields, Json &json, | |||||
| bool enum2str) { | |||||
| auto descriptor = message.GetDescriptor(); | |||||
| auto reflection = message.GetReflection(); | |||||
| if (descriptor == nullptr || reflection == nullptr) { | |||||
| return; | |||||
| } | |||||
| auto count = descriptor->field_count(); | |||||
| for (auto i = 0; i < count; ++i) { | |||||
| const auto field = descriptor->field(i); | |||||
| if (field == nullptr) { | |||||
| return; | |||||
| } | |||||
| // Do not display weight data | |||||
| if (black_fields.find(field->name()) != black_fields.end()) { | |||||
| continue; | |||||
| } | |||||
| if (field->is_repeated()) { | |||||
| if (reflection->FieldSize(message, field) > 0) { | |||||
| RepeatedMessage2Json(message, field, reflection, black_fields, json[field->name()], enum2str); | |||||
| } | |||||
| continue; | |||||
| } | |||||
| if (!reflection->HasField(message, field)) { | |||||
| continue; | |||||
| } | |||||
| OneField2Json(message, field, reflection, black_fields, json, enum2str); | |||||
| } | |||||
| } | |||||
| void Pb2Json::OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, | |||||
| const ProtobufReflection *reflection, const set<string> &black_fields, Json &json, | |||||
| bool enum2str) { | |||||
| switch (field->type()) { | |||||
| case ProtobufFieldDescriptor::TYPE_MESSAGE: { | |||||
| const ProtobufMsg &tmp_message = reflection->GetMessage(message, field); | |||||
| if (0 != tmp_message.ByteSize()) { | |||||
| Message2Json(tmp_message, black_fields, json[field->name()], enum2str); | |||||
| } | |||||
| break; | |||||
| } | |||||
| case ProtobufFieldDescriptor::TYPE_BOOL: | |||||
| json[field->name()] = reflection->GetBool(message, field); | |||||
| break; | |||||
| case ProtobufFieldDescriptor::TYPE_ENUM: { | |||||
| auto *enum_value_desc = reflection->GetEnum(message, field); | |||||
| Enum2Json(enum_value_desc, field, enum2str, json); | |||||
| break; | |||||
| } | |||||
| case ProtobufFieldDescriptor::TYPE_INT32: | |||||
| case ProtobufFieldDescriptor::TYPE_SINT32: | |||||
| case ProtobufFieldDescriptor::TYPE_SFIXED32: | |||||
| json[field->name()] = reflection->GetInt32(message, field); | |||||
| break; | |||||
| case ProtobufFieldDescriptor::TYPE_UINT32: | |||||
| case ProtobufFieldDescriptor::TYPE_FIXED32: | |||||
| json[field->name()] = reflection->GetUInt32(message, field); | |||||
| break; | |||||
| case ProtobufFieldDescriptor::TYPE_INT64: | |||||
| case ProtobufFieldDescriptor::TYPE_SINT64: | |||||
| case ProtobufFieldDescriptor::TYPE_SFIXED64: | |||||
| json[field->name()] = reflection->GetInt64(message, field); | |||||
| break; | |||||
| case ProtobufFieldDescriptor::TYPE_UINT64: | |||||
| case ProtobufFieldDescriptor::TYPE_FIXED64: | |||||
| json[field->name()] = reflection->GetUInt64(message, field); | |||||
| break; | |||||
| case ProtobufFieldDescriptor::TYPE_FLOAT: | |||||
| char str[kSignificantDigits]; | |||||
| if (sprintf_s(str, kSignificantDigits, "%g", reflection->GetFloat(message, field)) != -1){ | |||||
| json[field->name()] = str; | |||||
| } else { | |||||
| json[field->name()] = reflection->GetFloat(message, field); | |||||
| } | |||||
| break; | |||||
| case ProtobufFieldDescriptor::TYPE_STRING: | |||||
| json[field->name()] = reflection->GetString(message, field); | |||||
| break; | |||||
| case ProtobufFieldDescriptor::TYPE_BYTES: { | |||||
| string field_name = field->name(); | |||||
| string type_bytes = reflection->GetString(message, field); | |||||
| json[field_name] = TypeBytes2String(field_name, type_bytes); | |||||
| break; | |||||
| } | |||||
| default: | |||||
| break; | |||||
| } | |||||
| } | |||||
| string Pb2Json::TypeBytes2String(string &field_name, string &type_bytes) { | |||||
| if (field_name != "offset") { | |||||
| return type_bytes; | |||||
| } | |||||
| string result = ""; | |||||
| for (char temp_value : type_bytes) { | |||||
| uint8_t *value = 0; | |||||
| value = reinterpret_cast<uint8_t *>(&temp_value); | |||||
| char str[kSignificantDigits]; | |||||
| if (sprintf_s(str, kSignificantDigits, "%d", *value) == -1){ | |||||
| GELOGW("Convert bytes to string fail, filed name:%s", field_name.c_str()); | |||||
| continue; | |||||
| } | |||||
| result += str; | |||||
| } | |||||
| return result; | |||||
| } | |||||
| void Pb2Json::RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, | |||||
| const ProtobufReflection *reflection, const set<string> &black_fields, Json &json, | |||||
| bool enum2str) { | |||||
| if ((field == nullptr) || (reflection == nullptr)) { | |||||
| Message2Json(message, black_fields, json, enum2str); | |||||
| return; | |||||
| } | |||||
| for (auto i = 0; i < reflection->FieldSize(message, field); ++i) { | |||||
| Json tmp_json; | |||||
| switch (field->type()) { | |||||
| case ProtobufFieldDescriptor::TYPE_MESSAGE: { | |||||
| const ProtobufMsg &tmp_message = reflection->GetRepeatedMessage(message, field, i); | |||||
| if (0 != tmp_message.ByteSize()) { | |||||
| Message2Json(tmp_message, black_fields, tmp_json, enum2str); | |||||
| } | |||||
| } break; | |||||
| case ProtobufFieldDescriptor::TYPE_BOOL: | |||||
| tmp_json = reflection->GetRepeatedBool(message, field, i); | |||||
| break; | |||||
| case ProtobufFieldDescriptor::TYPE_ENUM: { | |||||
| auto *enum_value_desc = reflection->GetRepeatedEnum(message, field, i); | |||||
| RepeatedEnum2Json(enum_value_desc, enum2str, tmp_json); | |||||
| } break; | |||||
| case ProtobufFieldDescriptor::TYPE_INT32: | |||||
| case ProtobufFieldDescriptor::TYPE_SINT32: | |||||
| case ProtobufFieldDescriptor::TYPE_SFIXED32: | |||||
| tmp_json = reflection->GetRepeatedInt32(message, field, i); | |||||
| break; | |||||
| case ProtobufFieldDescriptor::TYPE_UINT32: | |||||
| case ProtobufFieldDescriptor::TYPE_FIXED32: | |||||
| tmp_json = reflection->GetRepeatedUInt32(message, field, i); | |||||
| break; | |||||
| case ProtobufFieldDescriptor::TYPE_INT64: | |||||
| case ProtobufFieldDescriptor::TYPE_SINT64: | |||||
| case ProtobufFieldDescriptor::TYPE_SFIXED64: | |||||
| tmp_json = reflection->GetRepeatedInt64(message, field, i); | |||||
| break; | |||||
| case ProtobufFieldDescriptor::TYPE_UINT64: | |||||
| case ProtobufFieldDescriptor::TYPE_FIXED64: | |||||
| tmp_json = reflection->GetRepeatedUInt64(message, field, i); | |||||
| break; | |||||
| case ProtobufFieldDescriptor::TYPE_FLOAT: | |||||
| tmp_json = reflection->GetRepeatedFloat(message, field, i); | |||||
| break; | |||||
| case ProtobufFieldDescriptor::TYPE_STRING: | |||||
| case ProtobufFieldDescriptor::TYPE_BYTES: | |||||
| tmp_json = reflection->GetRepeatedString(message, field, i); | |||||
| break; | |||||
| default: | |||||
| break; | |||||
| } | |||||
| json += tmp_json; | |||||
| } | |||||
| } | |||||
| void Pb2Json::Enum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, const ProtobufFieldDescriptor *field, | |||||
| bool enum2str, Json &json) { | |||||
| if (enum_value_desc != nullptr) { | |||||
| if (field == nullptr) { | |||||
| return; | |||||
| } | |||||
| if (enum2str) { | |||||
| json[field->name()] = enum_value_desc->name(); | |||||
| } else { | |||||
| json[field->name()] = enum_value_desc->number(); | |||||
| } | |||||
| } | |||||
| } | |||||
| void Pb2Json::RepeatedEnum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, bool enum2str, Json &json) { | |||||
| if (enum_value_desc != nullptr) { | |||||
| if (enum2str) { | |||||
| json = enum_value_desc->name(); | |||||
| } else { | |||||
| json = enum_value_desc->number(); | |||||
| } | |||||
| } | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,68 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| // File: pb2json.h | |||||
| // Description: This header file for protobuf message and json interconversion | |||||
| #ifndef PARSER_COMMON_CONVERT_PB2JSON_H_ | |||||
| #define PARSER_COMMON_CONVERT_PB2JSON_H_ | |||||
| #include <functional> | |||||
| #include <memory> | |||||
| #include <set> | |||||
| #include <string> | |||||
| #include "google/protobuf/descriptor.h" | |||||
| #include "google/protobuf/message.h" | |||||
| #include "nlohmann/json.hpp" | |||||
| namespace ge { | |||||
| using Json = nlohmann::json; | |||||
| using ProtobufMsg = ::google::protobuf::Message; | |||||
| using ProtobufReflection = ::google::protobuf::Reflection; | |||||
| using ProtobufFieldDescriptor = ::google::protobuf::FieldDescriptor; | |||||
| using ProtobufDescriptor = ::google::protobuf::Descriptor; | |||||
| using ProtobufEnumValueDescriptor = ::google::protobuf::EnumValueDescriptor; | |||||
| class Pb2Json { | |||||
| public: | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Transfer protobuf object to JSON object | |||||
| * @param [out] json Converted JSON object | |||||
| * @return void success | |||||
| * @author | |||||
| */ | |||||
| static void Message2Json(const ProtobufMsg &message, const std::set<std::string> &black_fields, Json &json, | |||||
| bool enum2str = false); | |||||
| protected: | |||||
| static void RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, | |||||
| const ProtobufReflection *reflection, const std::set<std::string> &black_fields, | |||||
| Json &json, bool enum2str); | |||||
| static void Enum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, const ProtobufFieldDescriptor *field, | |||||
| bool enum2str, Json &json); | |||||
| static void RepeatedEnum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, bool enum2str, Json &json); | |||||
| static void OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, | |||||
| const ProtobufReflection *reflection, const std::set<std::string> &black_fields, Json &json, | |||||
| bool enum2str); | |||||
| static std::string TypeBytes2String(std::string &field_name, std::string &type_bytes); | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // PARSER_COMMON_CONVERT_PB2JSON_H_ | |||||
| @@ -0,0 +1,212 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "parser/common/data_op_parser.h" | |||||
| #include <cstdlib> | |||||
| #include "common/debug/log.h" | |||||
| #include "common/op/ge_op_utils.h" | |||||
| #include "common/math/math_util.h" | |||||
| #include "common/util.h" | |||||
| #include "graph/utils/type_utils.h" | |||||
| #include "omg/omg.h" | |||||
| using namespace cce; | |||||
| namespace { | |||||
| const int kDataMemAlignSize = 32; | |||||
| const int kTwoTimesAlign = 2; | |||||
| const int kDynamicBatchInputSize = -1; | |||||
| const uint32_t kScalarLength = 1; | |||||
| } // namespace | |||||
| namespace ge { | |||||
| FMK_FUNC_HOST_VISIBILITY Status DataOpParser::ParseShape(const vector<int64_t> &shape, ge::OpDescPtr op) { | |||||
| GE_RETURN_WITH_LOG_IF_FALSE(op != nullptr, "ParseShape failed for data_op, op is null"); | |||||
| const string &data_op_name = op->GetName(); | |||||
| GetParserContext().input_dims.emplace(data_op_name, shape); | |||||
| int64_t attr_type = 0; | |||||
| ge::DataType data_type; | |||||
| if (ge::AttrUtils::GetInt(op, ge::DATA_ATTR_NAME_DATA_TYPE, attr_type)) { | |||||
| data_type = static_cast<ge::DataType>(attr_type); | |||||
| } else { | |||||
| data_type = ge::DT_FLOAT; | |||||
| } | |||||
| // convert input | |||||
| vector<int64_t> def_format_shape(shape); | |||||
| ge::GeTensorDesc i_tensor_desc; | |||||
| ge::GeTensorDesc o_tensor_desc; | |||||
| const unordered_map<string, domiTensorFormat_t> &input_nodes_format_map = GetParserContext().input_nodes_format_map; | |||||
| auto map_iter = input_nodes_format_map.find(data_op_name); | |||||
| if (map_iter != input_nodes_format_map.end() && map_iter->second == domi::DOMI_TENSOR_NC1HWC0) { | |||||
| // Input 5D NC1HWC0 | |||||
| GE_RETURN_WITH_LOG_IF_ERROR(Init5DInputTensor(def_format_shape, i_tensor_desc), "InitInputTensor failed"); | |||||
| // Output | |||||
| GE_RETURN_WITH_LOG_IF_ERROR(Init5DOutputTensor(def_format_shape, o_tensor_desc), "InitOutputTensor failed"); | |||||
| } else { | |||||
| // No need to consider AIPP here, | |||||
| // The adjustdatanodedesc function of model_builder will process the | |||||
| // input_desc and output_desc of AIPP's data node. | |||||
| // Without AIPP, the data of input float is kept in cctranstensor implementation. | |||||
| // The cast operator can not run in the pvmodel simulation environment, | |||||
| // so the input data conversion processing maintains the original state. | |||||
| // To be modified after AICPU operators support pvmodel. | |||||
| if (data_type == ge::DT_FLOAT) { | |||||
| // Input | |||||
| GE_RETURN_WITH_LOG_IF_ERROR(InitInputTensor(def_format_shape, i_tensor_desc), "InitInputTensor failed"); | |||||
| // Output | |||||
| GE_RETURN_WITH_LOG_IF_ERROR(InitOutputTensor(def_format_shape, o_tensor_desc), "InitOutputTensor failed"); | |||||
| } else { | |||||
| // Input | |||||
| GE_RETURN_WITH_LOG_IF_ERROR(InitNDTensor(def_format_shape, data_type, i_tensor_desc), | |||||
| "Init ND InputTensor failed"); | |||||
| // Output | |||||
| GE_RETURN_WITH_LOG_IF_ERROR(InitNDTensor(def_format_shape, data_type, o_tensor_desc), | |||||
| "Init ND Output Tensor failed"); | |||||
| } | |||||
| } | |||||
| i_tensor_desc.SetFormat(ge::TypeUtils::DomiFormatToFormat(GetParserContext().format)); | |||||
| i_tensor_desc.SetOriginFormat(ge::TypeUtils::DomiFormatToFormat(GetParserContext().format)); | |||||
| o_tensor_desc.SetFormat(ge::TypeUtils::DomiFormatToFormat(GetParserContext().format)); | |||||
| if (op->AddInputDesc(i_tensor_desc) != ge::GRAPH_SUCCESS) { | |||||
| GELOGE(domi::INTERNAL_ERROR, "AddInputDesc failed for op %s.", op->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| if (op->AddOutputDesc(o_tensor_desc) != ge::GRAPH_SUCCESS) { | |||||
| GELOGE(domi::INTERNAL_ERROR, "AddOutputDesc failed for op %s.", op->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status DataOpParser::Init5DInputTensor(const vector<int64_t> &shape, ge::GeTensorDesc &tensor_desc) { | |||||
| tensor_desc.SetDataType(ge::DT_FLOAT16); | |||||
| tensor_desc.SetFormat(static_cast<ge::Format>(domi::DOMI_TENSOR_NC1HWC0)); | |||||
| ge::TensorUtils::SetReuseInput(tensor_desc, false); | |||||
| ge::TensorUtils::SetRealDimCnt(tensor_desc, shape.size()); | |||||
| tensor_desc.SetShape(ge::GeShape(shape)); | |||||
| int64_t tensor_size = 0; | |||||
| ge::graphStatus graph_status = ge::TensorUtils::GetTensorSizeInBytes(tensor_desc, tensor_size); | |||||
| if (graph_status != ge::GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "GetTensorSizeInBytes failed!"); | |||||
| return domi::FAILED; | |||||
| } | |||||
| // Set the actual occupied space size | |||||
| ge::TensorUtils::SetSize(tensor_desc, tensor_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status DataOpParser::InitNDTensor(const vector<int64_t> &shape, ge::DataType data_type, ge::GeTensorDesc &tensor_desc) { | |||||
| // Fixed input ND | |||||
| tensor_desc.SetFormat(static_cast<ge::Format>(DOMI_TENSOR_ND)); | |||||
| tensor_desc.SetDataType(data_type); | |||||
| tensor_desc.SetOriginDataType(data_type); | |||||
| ge::TensorUtils::SetReuseInput(tensor_desc, false); | |||||
| ge::TensorUtils::SetRealDimCnt(tensor_desc, shape.size()); | |||||
| tensor_desc.SetShape(ge::GeShape(shape)); | |||||
| tensor_desc.SetOriginShape(ge::GeShape(shape)); | |||||
| int64_t size = kScalarLength; | |||||
| if (!tensor_desc.GetShape().GetDims().empty()) { | |||||
| size = tensor_desc.GetShape().GetShapeSize(); | |||||
| } | |||||
| uint32_t type_size = 0; | |||||
| if (ge::TypeUtils::GetDataTypeLength(data_type, type_size)) { | |||||
| FMK_INT64_UINT32_MULCHECK(size, type_size); | |||||
| size *= type_size; | |||||
| } else { | |||||
| FMK_INT64_UINT32_MULCHECK(size, static_cast<uint32_t>(sizeof(float))); | |||||
| size *= sizeof(float); | |||||
| } | |||||
| ge::TensorUtils::SetSize(tensor_desc, size); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status DataOpParser::Init5DOutputTensor(const vector<int64_t> &shape, ge::GeTensorDesc &output) { | |||||
| output.SetDataType(ge::DT_FLOAT16); | |||||
| output.SetFormat(static_cast<ge::Format>(domi::DOMI_TENSOR_NC1HWC0)); | |||||
| ge::TensorUtils::SetReuseInput(output, false); | |||||
| ge::TensorUtils::SetRealDimCnt(output, shape.size()); | |||||
| output.SetShape(ge::GeShape(shape)); | |||||
| int64_t output_size = 0; | |||||
| ge::graphStatus graph_status = ge::TensorUtils::GetTensorMemorySizeInBytes(output, output_size); | |||||
| if (graph_status != ge::GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "GetTensorMemorySizeInBytes failed!"); | |||||
| return domi::FAILED; | |||||
| } | |||||
| // Set the actual occupied space size | |||||
| ge::TensorUtils::SetSize(output, output_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status DataOpParser::InitInputTensor(const vector<int64_t> &shape, ge::GeTensorDesc &input) { | |||||
| input.SetFormat(static_cast<ge::Format>(domiTensorFormat_t(DOMI_TENSOR_ND))); | |||||
| input.SetDataType(ge::DT_FLOAT); | |||||
| input.SetOriginDataType(ge::DT_FLOAT); | |||||
| ge::TensorUtils::SetReuseInput(input, false); | |||||
| input.SetShape(ge::GeShape(shape)); | |||||
| input.SetOriginShape(ge::GeShape(shape)); | |||||
| int64_t size = 0; | |||||
| // No need to check dynamic_batch_size since its first dim is -1. | |||||
| if (input.GetShape().GetDim(0) != -1) { | |||||
| size = input.GetShape().GetShapeSize(); | |||||
| } | |||||
| FMK_INT64_UINT32_MULCHECK(size, static_cast<uint32_t>(sizeof(float))); | |||||
| ge::TensorUtils::SetSize(input, size * sizeof(float)); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status DataOpParser::InitOutputTensor(const vector<int64_t> &shape, ge::GeTensorDesc &output) { | |||||
| int64_t output_size = 0; | |||||
| ge::GeShape output_shape = ge::GeShape(shape); | |||||
| ge::Format format = ge::FORMAT_ND; | |||||
| ge::DataType data_type = ge::DT_FLOAT; | |||||
| output.SetFormat(format); | |||||
| output.SetDataType(data_type); | |||||
| ge::TensorUtils::SetReuseInput(output, false); | |||||
| ge::TensorUtils::SetRealDimCnt(output, shape.size()); | |||||
| output.SetShape(output_shape); | |||||
| ge::graphStatus graph_status = ge::TensorUtils::CalcTensorMemSize(output_shape, format, data_type, output_size); | |||||
| if (graph_status != ge::GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "CalcTensorMemSize failed!"); | |||||
| return FAILED; | |||||
| } | |||||
| if (output_size == kDynamicBatchInputSize) { | |||||
| GELOGI("After calc tensor memory size, output_mem_size = %ld", output_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| int64_t size = output_size; | |||||
| auto valid_max_size = INT64_MAX - kTwoTimesAlign * kDataMemAlignSize; | |||||
| if (size > valid_max_size || size < 0) { | |||||
| GELOGE(FAILED, "The updated mem size is out of data range [0, %ld]", valid_max_size); | |||||
| return FAILED; | |||||
| } else { | |||||
| size = ((size + kTwoTimesAlign * kDataMemAlignSize - 1) / kDataMemAlignSize) * kDataMemAlignSize; | |||||
| } | |||||
| // Set the actual occupied space size | |||||
| ge::TensorUtils::SetSize(output, size); | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,109 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PARSER_COMMON_DATA_OP_PARSER_H_ | |||||
| #define PARSER_COMMON_DATA_OP_PARSER_H_ | |||||
| #include <google/protobuf/text_format.h> | |||||
| #include <vector> | |||||
| #include "common/debug/log.h" | |||||
| #include "common/op/attr_value_util.h" | |||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "omg/omg_inner_types.h" | |||||
| #include "proto/om.pb.h" | |||||
| #include "graph/attr_value.h" | |||||
| #include "graph/compute_graph.h" | |||||
| #include "graph/ge_tensor.h" | |||||
| #include "graph/op_desc.h" | |||||
| #include "graph/operator.h" | |||||
| #include "graph/utils/attr_utils.h" | |||||
| #include "graph/utils/tensor_utils.h" | |||||
| using google::protobuf::Message; | |||||
| using std::vector; | |||||
| namespace ge { | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Provide a public interface for DataOp | |||||
| * | |||||
| */ | |||||
| class DataOpParser { | |||||
| public: | |||||
| virtual ~DataOpParser() {} | |||||
| protected: | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief parser the Shape information of DataOp | |||||
| * @param [in] shape 4D shape information (dimensions) | |||||
| * @param [out] op Save converted shap information | |||||
| * @return SUCCESS Parsing success | |||||
| * @return FAILED Parsing failed | |||||
| */ | |||||
| static Status ParseShape(const vector<int64_t> &shape, ge::OpDescPtr op); | |||||
| private: | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Convert Input's Shape Information | |||||
| * @param [in] 4D shape information (dimensions) | |||||
| * @param [out] Save converted shap information | |||||
| */ | |||||
| static Status Init5DInputTensor(const vector<int64_t> &shape, ge::GeTensorDesc &tensorDesc); | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Convert Shape of Output | |||||
| * @param [in] shape 4D shape information (dimensions) | |||||
| * @param [out] output Save converted shap information | |||||
| * @return SUCCESS Convert success | |||||
| * @return FAILED Convert failed | |||||
| */ | |||||
| static Status Init5DOutputTensor(const vector<int64_t> &shape, ge::GeTensorDesc &output); | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief 4D shape information (dimensions)4D shape information (dimensions)4D shape information (dimensions) | |||||
| * @param [in] 4D shape information (dimensions) | |||||
| * @param [out] input Save converted shap information | |||||
| */ | |||||
| static Status InitInputTensor(const vector<int64_t> &shape, ge::GeTensorDesc &input); | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Convert Shape of Output | |||||
| * @param [in] shape 4D shape information (dimensions) | |||||
| * @param [out] output Save converted shap information | |||||
| * @return SUCCESS Convert success | |||||
| * @return FAILED Convert failed | |||||
| */ | |||||
| static Status InitOutputTensor(const vector<int64_t> &shape, ge::GeTensorDesc &output); | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Convert Shape of Output | |||||
| * @param [in] shape 4D shape information (dimensions) | |||||
| * @param [out] output Save converted shap information | |||||
| * @return SUCCESS Convert success | |||||
| * @return FAILED Convert failed | |||||
| */ | |||||
| static Status InitNDTensor(const vector<int64_t> &shape, ge::DataType data_type, ge::GeTensorDesc &desc); | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // PARSER_COMMON_DATA_OP_PARSER_H_ | |||||
| @@ -0,0 +1,155 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <sys/stat.h> | |||||
| #include <fcntl.h> | |||||
| #include "parser/common/model_saver.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/debug/log.h" | |||||
| #include "common/util/error_manager/error_manager.h" | |||||
| #include "mmpa/mmpa_api.h" | |||||
| namespace { | |||||
| const int kFileOpSuccess = 0; | |||||
| } // namespace | |||||
| namespace ge { | |||||
| namespace parser { | |||||
| const uint32_t kInteval = 2; | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelSaver::SaveJsonToFile(const char *file_path, | |||||
| const Json &model) { | |||||
| Status ret = SUCCESS; | |||||
| if (file_path == nullptr || SUCCESS != CheckPath(file_path)) { | |||||
| GELOGE(FAILED, "Check output file failed."); | |||||
| return FAILED; | |||||
| } | |||||
| std::string model_str; | |||||
| try { | |||||
| model_str = model.dump(kInteval, ' ', false, Json::error_handler_t::ignore); | |||||
| } catch (std::exception &e) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19007", {"exception"}, {e.what()}); | |||||
| GELOGE(FAILED, "Failed to convert JSON to string, reason: %s.", e.what()); | |||||
| return FAILED; | |||||
| } catch (...) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19008"); | |||||
| GELOGE(FAILED, "Failed to convert JSON to string."); | |||||
| return FAILED; | |||||
| } | |||||
| char real_path[PATH_MAX] = {0}; | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(file_path) >= PATH_MAX, return FAILED, "file path is too long!"); | |||||
| if (realpath(file_path, real_path) == nullptr) { | |||||
| GELOGI("File %s does not exit, it will be created.", file_path); | |||||
| } | |||||
| // Open file | |||||
| mode_t mode = S_IRUSR | S_IWUSR; | |||||
| int32_t fd = mmOpen2(real_path, O_RDWR | O_CREAT | O_TRUNC, mode); | |||||
| if (fd == EN_ERROR || fd == EN_INVALID_PARAM) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, {file_path, strerror(errno)}); | |||||
| GELOGE(FAILED, "Open file[%s] failed. %s", file_path, strerror(errno)); | |||||
| return FAILED; | |||||
| } | |||||
| const char *model_char = model_str.c_str(); | |||||
| uint32_t len = static_cast<uint32_t>(model_str.length()); | |||||
| // Write data to file | |||||
| mmSsize_t mmpa_ret = mmWrite(fd, const_cast<void *>((const void *)model_char), len); | |||||
| if (mmpa_ret == EN_ERROR || mmpa_ret == EN_INVALID_PARAM) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||||
| "E19004", {"file", "errmsg"}, {file_path, strerror(errno)}); | |||||
| // Need to both print the error info of mmWrite and mmClose, so return ret after mmClose | |||||
| GELOGE(FAILED, "Write to file failed. errno = %d, %s", mmpa_ret, strerror(errno)); | |||||
| ret = FAILED; | |||||
| } | |||||
| // Close file | |||||
| if (mmClose(fd) != EN_OK) { | |||||
| GELOGE(FAILED, "Close file failed."); | |||||
| ret = FAILED; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelSaver::CheckPath(const std::string &file_path) { | |||||
| // Determine file path length | |||||
| if (file_path.size() >= PATH_MAX) { | |||||
| GELOGE(FAILED, "Path is too long:%zu", file_path.size()); | |||||
| return FAILED; | |||||
| } | |||||
| // Find the last separator | |||||
| int path_split_pos = static_cast<int>(file_path.size() - 1); | |||||
| for (; path_split_pos >= 0; path_split_pos--) { | |||||
| if (file_path[path_split_pos] == '\\' || file_path[path_split_pos] == '/') { | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (path_split_pos == 0) { | |||||
| return SUCCESS; | |||||
| } | |||||
| // If there is a path before the file name, create the path | |||||
| if (path_split_pos != -1) { | |||||
| if (CreateDirectory(std::string(file_path).substr(0, static_cast<size_t>(path_split_pos))) != kFileOpSuccess) { | |||||
| GELOGE(FAILED, "CreateDirectory failed, file path:%s.", file_path.c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int ModelSaver::CreateDirectory(const std::string &directory_path) { | |||||
| GE_CHK_BOOL_EXEC(!directory_path.empty(), return -1, "directory path is empty."); | |||||
| auto dir_path_len = directory_path.length(); | |||||
| if (dir_path_len >= PATH_MAX) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||||
| "E19002", {"filepath", "size"}, {directory_path, std::to_string(PATH_MAX)}); | |||||
| GELOGW("Path[%s] len is too long, it must be less than %d", directory_path.c_str(), PATH_MAX); | |||||
| return -1; | |||||
| } | |||||
| char tmp_dir_path[PATH_MAX] = {0}; | |||||
| for (size_t i = 0; i < dir_path_len; i++) { | |||||
| tmp_dir_path[i] = directory_path[i]; | |||||
| if ((tmp_dir_path[i] == '\\') || (tmp_dir_path[i] == '/')) { | |||||
| if (access(tmp_dir_path, F_OK) != 0) { | |||||
| int32_t ret = mmMkdir(tmp_dir_path, S_IRUSR | S_IWUSR | S_IXUSR); // 700 | |||||
| if (ret != 0) { | |||||
| if (errno != EEXIST) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {directory_path}); | |||||
| GELOGW("Can not create directory %s. Make sure the directory exists and writable.", | |||||
| directory_path.c_str()); | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| int32_t ret = mmMkdir(const_cast<char *>(directory_path.c_str()), S_IRUSR | S_IWUSR | S_IXUSR); // 700 | |||||
| if (ret != 0) { | |||||
| if (errno != EEXIST) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {directory_path}); | |||||
| GELOGW("Can not create directory %s. Make sure the directory exists and writable.", directory_path.c_str()); | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| return 0; | |||||
| } | |||||
| } // namespace parser | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,55 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PARSER_COMMON_FILE_SAVER_H_ | |||||
| #define PARSER_COMMON_FILE_SAVER_H_ | |||||
| #include <string> | |||||
| #include "ge/ge_api_error_codes.h" | |||||
| #include "register/register_types.h" | |||||
| #include "nlohmann/json.hpp" | |||||
| namespace ge { | |||||
| namespace parser { | |||||
| using Json = nlohmann::json; | |||||
| using std::string; | |||||
| class ModelSaver { | |||||
| public: | |||||
| /** | |||||
| * @ingroup domi_common | |||||
| * @brief Save JSON object to file | |||||
| * @param [in] file_path File output path | |||||
| * @param [in] model json object | |||||
| * @return Status result | |||||
| */ | |||||
| static Status SaveJsonToFile(const char *file_path, const Json &model); | |||||
| private: | |||||
| /// | |||||
| /// @ingroup domi_common | |||||
| /// @brief Check validity of the file path | |||||
| /// @return Status result | |||||
| /// | |||||
| static Status CheckPath(const string &file_path); | |||||
| static int CreateDirectory(const std::string &directory_path); | |||||
| }; | |||||
| } // namespace parser | |||||
| } // namespace ge | |||||
| #endif //PARSER_COMMON_FILE_SAVER_H_ | |||||
| @@ -0,0 +1,95 @@ | |||||
| LOCAL_PATH := $(call my-dir) | |||||
| include $(CLEAR_VARS) | |||||
| LOCAL_MODULE := libparser_common | |||||
| LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||||
| LOCAL_CFLAGS += -Werror | |||||
| ifeq ($(DEBUG), 1) | |||||
| LOCAL_CFLAGS += -g -O0 | |||||
| endif | |||||
| COMMON_LOCAL_SRC_FILES := \ | |||||
| parser_factory.cc \ | |||||
| data_op_parser.cc \ | |||||
| op_parser_factory.cc \ | |||||
| pre_checker.cc \ | |||||
| register_tbe.cc \ | |||||
| parser_api.cc \ | |||||
| parser_inner_ctx.cc \ | |||||
| proto_file_parser.cc \ | |||||
| acl_graph_parser_util.cc \ | |||||
| tbe_plugin_loader.cc \ | |||||
| model_saver.cc \ | |||||
| ../tensorflow/tensorflow_custom_parser_adapter.cc \ | |||||
| ../tensorflow/tensorflow_fusion_custom_parser_adapter.cc \ | |||||
| ../tensorflow/tensorflow_fusion_op_parser.cc \ | |||||
| ../tensorflow/tensorflow_util.cc \ | |||||
| convert/pb2json.cc \ | |||||
| op_def/ir_pb_converter.cc \ | |||||
| op_def/defs.cc \ | |||||
| op_def/op_schema.cc \ | |||||
| op_def/operator.cc \ | |||||
| op_map.cc \ | |||||
| parser_types.cc \ | |||||
| pass_manager.cc \ | |||||
| parser_fp16_t.cc \ | |||||
| thread_pool.cc \ | |||||
| FMK_COMMON_SRC_FILES := \ | |||||
| # ../../common/fmk_error_codes.cc \ | |||||
| ../../common/auth/cipher.cc \ | |||||
| ../../common/context/ctx.cc \ | |||||
| ../../graph/passes/pass_manager.cc \ | |||||
| ../../graph/common/omg_util.cc \ | |||||
| ../../common/types.cc \ | |||||
| ../../common/auth/file_saver.cc \ | |||||
| ../../common/util.cc \ | |||||
| ../../common/model_saver.cc \ | |||||
| ../../common/fp16_t.cc \ | |||||
| ../../common/thread_pool.cc \ | |||||
| LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES) | |||||
| LOCAL_SRC_FILES += $(FMK_COMMON_SRC_FILES) | |||||
| LOCAL_C_INCLUDES := \ | |||||
| proto/om.proto \ | |||||
| proto/insert_op.proto \ | |||||
| proto/ge_ir.proto \ | |||||
| proto/tensorflow/graph.proto \ | |||||
| proto/tensorflow/node_def.proto \ | |||||
| proto/tensorflow/tensor_shape.proto \ | |||||
| proto/tensorflow/attr_value.proto \ | |||||
| proto/tensorflow/function.proto \ | |||||
| proto/tensorflow/op_def.proto \ | |||||
| proto/tensorflow/resource_handle.proto \ | |||||
| proto/tensorflow/tensor.proto \ | |||||
| proto/tensorflow/types.proto \ | |||||
| proto/tensorflow/versions.proto \ | |||||
| $(LOCAL_PATH) \ | |||||
| $(TOPDIR)inc \ | |||||
| $(TOPDIR)inc/external \ | |||||
| $(TOPDIR)inc/external/graph \ | |||||
| $(TOPDIR)inc/framework \ | |||||
| $(TOPDIR)inc/common/util \ | |||||
| $(TOPDIR)framework/domi \ | |||||
| $(TOPDIR)framework/domi/common \ | |||||
| $(TOPDIR)framework/domi/parser \ | |||||
| $(TOPDIR)third_party/json/include \ | |||||
| $(TOPDIR)third_party/protobuf/include \ | |||||
| libc_sec/include \ | |||||
| third_party/openssl/include/x86/include \ | |||||
| LOCAL_SHARED_LIBRARIES := \ | |||||
| libprotobuf \ | |||||
| libslog \ | |||||
| libgraph \ | |||||
| libmmpa \ | |||||
| libc_sec \ | |||||
| liberror_manager \ | |||||
| libregister \ | |||||
| LOCAL_LDFLAGS := -lrt -ldl | |||||
| include $(BUILD_HOST_SHARED_LIBRARY) | |||||
| @@ -0,0 +1,38 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "parser/common/op_def/arg_op.h" | |||||
| #include <string> | |||||
| #include "framework/common/fmk_types.h" | |||||
| namespace ge { | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ArgOpOperator::ArgOpOperator() : ParserOperator("Data") {} | |||||
| ArgOpOperator::~ArgOpOperator() {} | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ArgOpOperator &ArgOpOperator::Name(const std::string &name) { | |||||
| (void)ParserOperator::Name(name); | |||||
| return *this; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ArgOpOperator &ArgOpOperator::Index(int64_t index) { | |||||
| Attr("index", static_cast<int64_t>(index)); | |||||
| return *this; | |||||
| } | |||||
| int64_t ArgOpOperator::GetIndex() const { return GetIntAttr("index"); } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,36 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef DOMI_OP_ARG_OP_H_ | |||||
| #define DOMI_OP_ARG_OP_H_ | |||||
| #include "parser/common/op_def/operator.h" | |||||
| namespace ge { | |||||
| class ArgOpOperator : public ParserOperator { | |||||
| public: | |||||
| ArgOpOperator(); | |||||
| ~ArgOpOperator(); | |||||
| ArgOpOperator &Name(const std::string &name); | |||||
| ArgOpOperator &Index(int64_t index); | |||||
| int64_t GetIndex() const; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // DOMI_OP_ARG_OP_H_ | |||||
| @@ -0,0 +1,45 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/op_def/constant_op.h" | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "graph/debug/ge_attr_define.h" | |||||
| namespace ge { | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ConstantOperator::ConstantOperator() : ParserOperator("Constant") {} | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ConstantOperator::~ConstantOperator() {} | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ConstantOperator &ConstantOperator::Name(const std::string &name) { | |||||
| ParserOperator::Name(name); | |||||
| return *this; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ConstantOperator &ConstantOperator::VectorAttr( | |||||
| std::string key, std::vector<int64_t> &value) { | |||||
| Attr(key, value); | |||||
| return *this; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ConstantOperator &ConstantOperator::DType(ge::DataType t) { | |||||
| Attr(VAR_ATTR_DTYPE, (int64_t)t); | |||||
| return *this; | |||||
| } | |||||
| ge::DataType ConstantOperator::GetDType() const { return (ge::DataType)GetIntAttr(VAR_ATTR_DTYPE); } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,37 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| // AUTO GEN PLEASE DO NOT MODIFY IT | |||||
| #ifndef DOMI_OP_CONSTANT_OP_H_ | |||||
| #define DOMI_OP_CONSTANT_OP_H_ | |||||
| #include "parser/common/op_def/operator.h" | |||||
| #include "framework/omg/parser/parser_types.h" | |||||
| namespace ge { | |||||
| class ConstantOperator : public ParserOperator { | |||||
| public: | |||||
| ConstantOperator(); | |||||
| ~ConstantOperator(); | |||||
| ConstantOperator &Name(const std::string &name); | |||||
| ConstantOperator &VectorAttr(std::string key, std::vector<int64_t> &value); | |||||
| ConstantOperator &DType(ge::DataType t); | |||||
| ge::DataType GetDType() const; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // DOMI_OP_CONSTANT_OP_H_ AUTO GEN PLEASE DO NOT MODIFY IT | |||||
| @@ -0,0 +1,712 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/op_def/op_schema.h" | |||||
| namespace ge { | |||||
| DOMI_OP_SCHEMA(Data).Output("y"); | |||||
| DOMI_OP_SCHEMA(Const).Output("y"); | |||||
| DOMI_OP_SCHEMA(ConvolutionDepthwise) | |||||
| .Input("x") | |||||
| .Input("w") | |||||
| .Input("b", OpSchema::Optional) | |||||
| .Output("y") | |||||
| .Attr("group", AttributeType::INT, static_cast<int64_t>(1)) | |||||
| .Attr("num_output", AttributeType::INT, static_cast<int64_t>(1)) | |||||
| .Attr("pad_mode", AttributeType::INT, static_cast<int64_t>(0)) | |||||
| .Attr("mode", AttributeType::INT, static_cast<int64_t>(1)) | |||||
| .Attr("pad", AttributeType::INTLIST, IntTuple{0, 0, 0, 0}) | |||||
| .Attr("stride", AttributeType::INTLIST, IntTuple{1, 1}) | |||||
| .Attr("dilation", AttributeType::INTLIST, IntTuple{1, 1}) | |||||
| .Attr("kernel", AttributeType::INTLIST, IntTuple{0, 0}) | |||||
| .Attr("before_pad", AttributeType::INTLIST, IntTuple{0, 0, 0, 0}); | |||||
| DOMI_OP_SCHEMA(Region) | |||||
| .Input("x") | |||||
| .Output("y") | |||||
| .Attr("casses", AttributeType::INT, static_cast<int64_t>(20)) | |||||
| .Attr("coords", AttributeType::INT, static_cast<int64_t>(4)) | |||||
| .Attr("boxes", AttributeType::INT, static_cast<int64_t>(1)) | |||||
| .Attr("background", AttributeType::BOOL, static_cast<bool>(false)) | |||||
| .Attr("softmax", AttributeType::BOOL, static_cast<bool>(false)) | |||||
| .Attr("softmax_tree", AttributeType::BOOL, static_cast<bool>(false)) | |||||
| .Attr("yolo_version", AttributeType::INT, static_cast<int64_t>(0)); | |||||
| DOMI_OP_SCHEMA(Gather) | |||||
| .Input("params") | |||||
| .Input("indices") | |||||
| .Input("axis", OpSchema::Optional) | |||||
| .Output("y") | |||||
| .Attr("params_type", AttributeType::INT, static_cast<int64_t>(1)) | |||||
| .Attr("indices_type", AttributeType::INT, static_cast<int64_t>(3)) | |||||
| .Attr("validate_indices", AttributeType::BOOL, static_cast<bool>(true)); | |||||
| DOMI_OP_SCHEMA(ArgMax) | |||||
| .Input("input") | |||||
| .Output("output") | |||||
| .Attr("axis", AttributeType::INT, static_cast<int64_t>(0)) | |||||
| .Attr("keep_dims", AttributeType::BOOL, static_cast<bool>(true)) | |||||
| .Attr("axis_type", AttributeType::INT, static_cast<int64_t>(3)) | |||||
| .Attr("outmaxval", AttributeType::BOOL, static_cast<bool>(false)) | |||||
| .Attr("topk", AttributeType::UINT, static_cast<uint32_t>(1)); | |||||
| DOMI_OP_SCHEMA(Split) | |||||
| .Input("x") | |||||
| .Input("axis", OpSchema::Optional) | |||||
| .Output("y") | |||||
| .Attr("T", AttributeType::INT, static_cast<int64_t>(1)) | |||||
| .Attr("num_split", AttributeType::INT, static_cast<int64_t>(1)); | |||||
| DOMI_OP_SCHEMA(SplitV) | |||||
| .Input("x") | |||||
| .Input("axis", OpSchema::Optional) | |||||
| .Output("y") | |||||
| .Attr("T", AttributeType::INT, static_cast<int64_t>(1)) | |||||
| .Attr("Tlen", AttributeType::INT, static_cast<int64_t>(1)) | |||||
| .Attr("num_split", AttributeType::INT, static_cast<int64_t>(1)); | |||||
| DOMI_OP_SCHEMA(Fill).Input("x").Input("value").Output("y").Attr("T", AttributeType::INT, static_cast<int64_t>(1)); | |||||
| DOMI_OP_SCHEMA(Rsqrt).Input("x").Output("y"); | |||||
| DOMI_OP_SCHEMA(BiasAdd) | |||||
| .Input("x") | |||||
| .Input("bias") | |||||
| .Output("y") | |||||
| .Attr("format", AttributeType::INT, static_cast<int64_t>(1)); | |||||
| DOMI_OP_SCHEMA(Reverse) | |||||
| .Input("x") | |||||
| .Input("axis") | |||||
| .Output("y") | |||||
| .Attr("T", AttributeType::INT, static_cast<int64_t>(1)) | |||||
| .Attr("Tidx", AttributeType::INT, static_cast<int64_t>(1)); | |||||
| DOMI_OP_SCHEMA(Unpack) | |||||
| .Input("x") | |||||
| .Output("y") | |||||
| .Attr("T", AttributeType::INT, static_cast<int64_t>(1)) | |||||
| .Attr("axis", AttributeType::INT, static_cast<int64_t>(0)) | |||||
| .Attr("num", AttributeType::INT, static_cast<int64_t>(1)); | |||||
| DOMI_OP_SCHEMA(Yolo2Reorg) | |||||
| .Input("x") | |||||
| .Output("y") | |||||
| .Attr("reverse", AttributeType::BOOL, static_cast<bool>(1)) | |||||
| .Attr("stride", AttributeType::INT, static_cast<int64_t>(1)); | |||||
| DOMI_OP_SCHEMA(ReduceSum) | |||||
| .Input("x") | |||||
| .Output("y") | |||||
| .Attr("Tidx", AttributeType::INT, static_cast<int64_t>(1)) | |||||
| .Attr("keep_dims", AttributeType::BOOL, static_cast<bool>(1)); | |||||
| DOMI_OP_SCHEMA(Concat) | |||||
| .Input("x") | |||||
| .Output("y") | |||||
| .Attr("Tidx", AttributeType::INT, static_cast<int64_t>(1)) | |||||
| .Attr("N", AttributeType::INT, static_cast<int64_t>(1)); | |||||
| DOMI_OP_SCHEMA(ResizeBilinear) | |||||
| .Input("x") | |||||
| .Input("sizes") | |||||
| .Output("y") | |||||
| .Attr("output_dim_mode", AttributeType::INT, static_cast<int64_t>(1)) | |||||
| .Attr("align_corners", AttributeType::BOOL, static_cast<bool>(1)) | |||||
| .Attr("zoom_factor", AttributeType::INT, static_cast<int64_t>(1)) | |||||
| .Attr("shrink_factor", AttributeType::INT, static_cast<int64_t>(1)) | |||||
| .Attr("height", AttributeType::INT, static_cast<int64_t>(1)) | |||||
| .Attr("width", AttributeType::INT, static_cast<int64_t>(1)) | |||||
| .Attr("pad_begin", AttributeType::INT, static_cast<int64_t>(1)) | |||||
| .Attr("pad_end", AttributeType::INT, static_cast<int64_t>(1)); | |||||
| DOMI_OP_SCHEMA(LRN) | |||||
| .Input("x") | |||||
| .Output("y") | |||||
| .Attr("lrn_normregion", AttributeType::UINT, static_cast<uint32_t>(0)) | |||||
| .Attr("lrn_k", AttributeType::FLOAT, static_cast<float>(1)) | |||||
| .Attr("lrn_localsize", AttributeType::UINT, static_cast<uint32_t>(5)) | |||||
| .Attr("lrn_alpha", AttributeType::FLOAT, static_cast<float>(1)) | |||||
| .Attr("lrn_beta", AttributeType::FLOAT, static_cast<float>(0.75)); | |||||
| DOMI_OP_SCHEMA(Maximum).Input("x").Input("w").Output("y"); | |||||
| DOMI_OP_SCHEMA(Slice) | |||||
| .Input("x") | |||||
| .Output("y") | |||||
| .Attr("axis", AttributeType::INT, static_cast<int64_t>(2)) | |||||
| .AttrRequired("offsets", AttributeType::INTLIST); | |||||
| DOMI_OP_SCHEMA(Pad) | |||||
| .Input("x") | |||||
| .Input("paddings") | |||||
| .Input("constant_values", OpSchema::Optional) | |||||
| .Output("y") | |||||
| .Attr("T", AttributeType::INT, static_cast<int64_t>(1)) | |||||
| .Attr("t_paddings", AttributeType::INT, static_cast<int64_t>(1)) | |||||
| .Attr("mode", AttributeType::INT, static_cast<int64_t>(0)); | |||||
| DOMI_OP_SCHEMA(PadV2) | |||||
| .Input("input") | |||||
| .Output("output") | |||||
| .Attr("constant_values", AttributeType::INT, static_cast<int64_t>(0)) | |||||
| .AttrRequired("paddings", AttributeType::INTLIST); | |||||
| DOMI_OP_SCHEMA(MirrorPad) | |||||
| .Input("input") | |||||
| .Output("output") | |||||
| .AttrRequired("paddings", AttributeType::INTLIST) | |||||
| .Attr("mode", AttributeType::INT, static_cast<int64_t>(2)); | |||||
| DOMI_OP_SCHEMA(Upsample) | |||||
| .Input("input") | |||||
| .Input("scales") | |||||
| .Output("output") | |||||
| .Attr("mode", AttributeType::INT, static_cast<int64_t>(0)); | |||||
| DOMI_OP_SCHEMA(Cast) | |||||
| .Input("x") | |||||
| .Output("y") | |||||
| .Attr("DstT", AttributeType::INT, static_cast<int64_t>(1)) | |||||
| .Attr("SrcT", AttributeType::INT, static_cast<int64_t>(1)); | |||||
| DOMI_OP_SCHEMA(LogicalNot).Input("x").Output("y"); | |||||
| DOMI_OP_SCHEMA(LogicalAnd).Input("x1").Input("x2").Output("y"); | |||||
| DOMI_OP_SCHEMA(LogicalOr).Input("x1").Input("x2").Output("y"); | |||||
| DOMI_OP_SCHEMA(Equal).Input("x1").Input("x2").Output("y").Attr("T", AttributeType::INT, static_cast<int64_t>(1)); | |||||
| DOMI_OP_SCHEMA(MatMul) | |||||
| .Input("a") | |||||
| .Input("b") | |||||
| .Output("product") | |||||
| .Attr("transposeX", AttributeType::BOOL, static_cast<bool>(false)) | |||||
| .Attr("transposeW", AttributeType::BOOL, static_cast<bool>(false)); | |||||
| DOMI_OP_SCHEMA(RNN) | |||||
| .Input("x") | |||||
| .Input("cont") | |||||
| .Input("xstatic", OpSchema::Optional) | |||||
| .Input("w") // filter | |||||
| .Input("b") // bias | |||||
| .Input("seqlen") // T | |||||
| .Input("hx") // Hx | |||||
| .Input("cx") // cx | |||||
| .Output("y") | |||||
| .Output("cyfw") | |||||
| .Output("hyfw") | |||||
| .Output("cybw") | |||||
| .Output("hybw") | |||||
| .Attr("hidden_size", AttributeType::INT, static_cast<int64_t>(0)) | |||||
| .Attr("num_layers", AttributeType::INT, static_cast<int64_t>(1)) | |||||
| .Attr("support_cont", AttributeType::BOOL, static_cast<bool>(false)) | |||||
| .Attr("support_xstatic", AttributeType::BOOL, static_cast<bool>(false)) | |||||
| .Attr("input_mode", AttributeType::INT, static_cast<int64_t>(0)) | |||||
| .Attr("direction_mode", AttributeType::INT, static_cast<int64_t>(0)) | |||||
| .Attr("mode", AttributeType::INT, static_cast<int64_t>(0)) | |||||
| .Attr("input_data_layout", AttributeType::INT, static_cast<int64_t>(0)) | |||||
| .Attr("output_data_layout", AttributeType::INT, static_cast<int64_t>(0)); | |||||
| DOMI_OP_SCHEMA(FrameworkOp).Attr("framework_type", AttributeType::INT, static_cast<int64_t>(3)); | |||||
| DOMI_OP_SCHEMA(Multinomial) | |||||
| .Input("logits") | |||||
| .Output("output") | |||||
| .Attr("num_samples", AttributeType::INT, static_cast<int64_t>(0)) | |||||
| .AttrRequired("seed", AttributeType::INT) | |||||
| .AttrRequired("seed2", AttributeType::INT); | |||||
| DOMI_OP_SCHEMA(ReverseSequence) | |||||
| .Input("input") | |||||
| .Input("seq_lengths") | |||||
| .Output("output") | |||||
| .AttrRequired("seq_dim", AttributeType::INT) | |||||
| .AttrRequired("batch_dim", AttributeType::INT); | |||||
| DOMI_OP_SCHEMA(Interp) | |||||
| .Input("x") | |||||
| .Output("y") | |||||
| .Attr("output_dim_mode", AttributeType::INT, static_cast<int64_t>(2)) | |||||
| .Attr("zoom_factor", AttributeType::INT, static_cast<int64_t>(1)) | |||||
| .Attr("shrink_factor", AttributeType::INT, static_cast<int64_t>(1)) | |||||
| .Attr("height", AttributeType::INT, static_cast<int64_t>(0)) | |||||
| .Attr("width", AttributeType::INT, static_cast<int64_t>(0)) | |||||
| .Attr("pad_begin", AttributeType::INT, static_cast<int64_t>(0)) | |||||
| .Attr("pad_end", AttributeType::INT, static_cast<int64_t>(0)); | |||||
| DOMI_OP_SCHEMA(ShuffleChannel).Input("x").Output("y").Attr("group", AttributeType::UINT, static_cast<uint32_t>(1)); | |||||
| DOMI_OP_SCHEMA(Conv2DBackpropFilter) | |||||
| .Input("x") | |||||
| .Input("w") | |||||
| .Input("b", OpSchema::Optional) | |||||
| .Output("y") | |||||
| .Attr("padding", AttributeType::INT, static_cast<int64_t>(1)) | |||||
| .Attr("pads", AttributeType::UINTLIST, UintTuple{0, 0, 0, 0}) | |||||
| .Attr("strides", AttributeType::UINTLIST, UintTuple{1, 1}) | |||||
| .Attr("dilations", AttributeType::UINTLIST, UintTuple{1, 1}); | |||||
| DOMI_OP_SCHEMA(Conv2DBackpropInput) | |||||
| .Input("input_sizes") | |||||
| .Input("filter") | |||||
| .Input("out_backprop") | |||||
| .Output("output") | |||||
| .Attr("data_format", AttributeType::STRING, static_cast<std::string>("NHWC")) | |||||
| .Attr("group", AttributeType::UINT, static_cast<uint32_t>(1)) | |||||
| .Attr("padding", AttributeType::INT, static_cast<int64_t>(0)) | |||||
| .Attr("dilations", AttributeType::UINTLIST, UintTuple{1, 1}) | |||||
| .Attr("strides", AttributeType::UINTLIST, UintTuple{1, 1}) | |||||
| .Attr("pad", AttributeType::UINTLIST, UintTuple{0, 0, 0, 0}); | |||||
| DOMI_OP_SCHEMA(BiasAddGrad).Input("dy").Output("db").Attr("format", AttributeType::INT, static_cast<int64_t>(1)); | |||||
| DOMI_OP_SCHEMA(ReluGrad).Input("dy").Input("x").Output("dx"); | |||||
| DOMI_OP_SCHEMA(MeanGrad).Input("dy").Output("dx"); | |||||
| DOMI_OP_SCHEMA(NonMaxSuppression) | |||||
| .Input("boxes") | |||||
| .Input("scores") | |||||
| .Output("selected_indices") | |||||
| .Attr("max_output_size", AttributeType::INT, static_cast<int64_t>(-1)) | |||||
| .Attr("iou_threshold", AttributeType::FLOAT, static_cast<float>(0.5)) | |||||
| .Attr("score_threshold", AttributeType::FLOAT, static_cast<float>(-1)); | |||||
| DOMI_OP_SCHEMA(CropAndResize) | |||||
| .Input("image") | |||||
| .Input("boxes") | |||||
| .Input("box_ind") | |||||
| .Output("crops") | |||||
| .Attr("method", AttributeType::INT, static_cast<int64_t>(0)) | |||||
| .Attr("extrapolation_value", AttributeType::FLOAT, static_cast<float>(0)) | |||||
| .Attr("crop_size_h", AttributeType::INT, static_cast<int64_t>(0)) | |||||
| .Attr("crop_size_w", AttributeType::INT, static_cast<int64_t>(0)); | |||||
| DOMI_OP_SCHEMA(TopKV2) | |||||
| .Input("input") | |||||
| .Input("k") | |||||
| .Output("value") | |||||
| .Output("indices") | |||||
| .AttrRequired("sorted", AttributeType::BOOL); | |||||
| DOMI_OP_SCHEMA(InvertPermutation).Input("x").Output("y"); | |||||
| DOMI_OP_SCHEMA(GatherV2) | |||||
| .Input("params") | |||||
| .Input("indices") | |||||
| .Input("axis", OpSchema::Optional) | |||||
| .Output("y") | |||||
| .Attr("Tparams", AttributeType::INT, static_cast<int64_t>(0)) // default: DT_FLOAT | |||||
| .Attr("Tindices", AttributeType::INT, static_cast<int64_t>(3)) // default: DT_INT32 | |||||
| .Attr("Taxis", AttributeType::INT, static_cast<int64_t>(3)); // default: DT_INT32 | |||||
| DOMI_OP_SCHEMA(HighWay) | |||||
| .Input("x") | |||||
| .Input("tw") // filter | |||||
| .Input("tb") // bias | |||||
| .Input("uw") // filter | |||||
| .Input("ub") // bias | |||||
| .Output("y"); | |||||
| DOMI_OP_SCHEMA(Reciprocal).Input("x").Output("y"); | |||||
| DOMI_OP_SCHEMA(Asinh).Input("input").Output("output"); | |||||
| DOMI_OP_SCHEMA(Acosh).Input("input").Output("output"); | |||||
| DOMI_OP_SCHEMA(Minimum).Input("x").Input("y").Output("output"); | |||||
| DOMI_OP_SCHEMA(Clip).Input("input").Input("min").Input("max").Output("output"); | |||||
| DOMI_OP_SCHEMA(FusedBatchNorm) | |||||
| .Input("x") | |||||
| .Input("scale") | |||||
| .Input("offset") | |||||
| .Input("mean") | |||||
| .Input("variance") | |||||
| .Output("y") | |||||
| .Output("batch_mean") | |||||
| .Output("batch_variance") | |||||
| .Output("reserve_space_1") | |||||
| .Output("reserve_space_2") | |||||
| .Attr("data_format", AttributeType::STRING, static_cast<std::string>("NHWC")) | |||||
| .Attr("epsilon", AttributeType::FLOAT, static_cast<float>(0.0001)) | |||||
| .Attr("is_training", AttributeType::BOOL, static_cast<bool>(false)); | |||||
| DOMI_OP_SCHEMA(FusedBatchNormGrad) | |||||
| .Input("dy") | |||||
| .Input("x") | |||||
| .Input("bnscale") | |||||
| .Input("save_mean") | |||||
| .Input("save_variance") | |||||
| .Output("dx") | |||||
| .Output("result_bn_scale_diff") | |||||
| .Output("result_bn_bias_diff") | |||||
| .Attr("data_format", AttributeType::STRING, static_cast<std::string>("NHWC")) | |||||
| .Attr("epsilon", AttributeType::FLOAT, static_cast<float>(0.0)) | |||||
| .Attr("is_training", AttributeType::BOOL, static_cast<bool>(true)); | |||||
| DOMI_OP_SCHEMA(MaxPoolWithArgmax) | |||||
| .Input("x") | |||||
| .Output("y") | |||||
| .Output("argmax") | |||||
| .AttrRequired("window", AttributeType::INTLIST) | |||||
| .AttrRequired("stride", AttributeType::INTLIST) | |||||
| .AttrRequired("pad_mode", AttributeType::INT) | |||||
| .AttrRequired("ceil_mode", AttributeType::BOOL) | |||||
| .AttrRequired("data_mode", AttributeType::INT); | |||||
| DOMI_OP_SCHEMA(MaxPoolGradWithArgmax) | |||||
| .Input("input") | |||||
| .Input("grad") | |||||
| .Output("output") | |||||
| .AttrRequired("window", AttributeType::INTLIST) | |||||
| .AttrRequired("stride", AttributeType::INTLIST) | |||||
| .AttrRequired("pad_mode", AttributeType::INT) | |||||
| .AttrRequired("ceil_mode", AttributeType::BOOL) | |||||
| .AttrRequired("data_mode", AttributeType::INT); | |||||
| DOMI_OP_SCHEMA(HcomBroadcast) | |||||
| .AttrRequired("root_rank", AttributeType::INT) | |||||
| .AttrRequired("group", AttributeType::STRING); | |||||
| DOMI_OP_SCHEMA(HcomAllReduce) | |||||
| .Input("x") | |||||
| .Output("y") | |||||
| .AttrRequired("reduction", AttributeType::STRING) | |||||
| .AttrRequired("group", AttributeType::STRING); | |||||
| DOMI_OP_SCHEMA(HcomAllGather) | |||||
| .Input("x") | |||||
| .Output("y") | |||||
| .AttrRequired("rank_size", AttributeType::INT) | |||||
| .AttrRequired("group", AttributeType::STRING); | |||||
| DOMI_OP_SCHEMA(SparseSoftmaxCrossEntropyWithLogits) | |||||
| .Input("features") | |||||
| .Input("labels") | |||||
| .Output("loss") | |||||
| .Output("backprop") | |||||
| .AttrRequired("T", AttributeType::INT) | |||||
| .Attr("Tlabels", AttributeType::INT, static_cast<int64_t>(9)); | |||||
| DOMI_OP_SCHEMA(Snapshot).Input("input").Output("output").AttrRequired("T", AttributeType::INT); | |||||
| DOMI_OP_SCHEMA(ReduceProd) | |||||
| .Input("bottom") | |||||
| .Output("top") | |||||
| .AttrRequired("axes", AttributeType::INTLIST) | |||||
| .Attr("keep_dims", AttributeType::BOOL, static_cast<bool>(false)); | |||||
| DOMI_OP_SCHEMA(ReduceAll) | |||||
| .Input("x") | |||||
| .Output("y") | |||||
| .AttrRequired("axes", AttributeType::INTLIST) | |||||
| .Attr("keep_dims", AttributeType::BOOL, static_cast<bool>(false)); | |||||
| DOMI_OP_SCHEMA(ReduceMax) | |||||
| .Input("x") | |||||
| .Output("y") | |||||
| .AttrRequired("axis", AttributeType::INTLIST) | |||||
| .Attr("keep_dims", AttributeType::BOOL, static_cast<bool>(false)); | |||||
| DOMI_OP_SCHEMA(AddN).Input("x").Output("y"); | |||||
| DOMI_OP_SCHEMA(ShapeN) | |||||
| .Input("x") | |||||
| .Output("y") | |||||
| .AttrRequired("N", AttributeType::INT) | |||||
| .AttrRequired("in_type", AttributeType::INT) | |||||
| .AttrRequired("dtype", AttributeType::INT); | |||||
| DOMI_OP_SCHEMA(ReduceMin) | |||||
| .Input("x") | |||||
| .Output("y") | |||||
| .AttrRequired("axis", AttributeType::INTLIST) | |||||
| .Attr("keep_dims", AttributeType::BOOL, static_cast<bool>(false)); | |||||
| DOMI_OP_SCHEMA(Sqrt).Input("x").Output("y"); | |||||
| DOMI_OP_SCHEMA(L2Loss).Input("x").Output("y"); | |||||
| DOMI_OP_SCHEMA(Multiply).Input("x").Input("y").Output("z"); | |||||
| DOMI_OP_SCHEMA(Add).Input("x").Output("y"); | |||||
| DOMI_OP_SCHEMA(Constant).Output("y"); | |||||
| DOMI_OP_SCHEMA(ApplyMomentum) | |||||
| .Input("variable") | |||||
| .Input("accumulation") | |||||
| .Input("learningRate") | |||||
| .Input("gradient") | |||||
| .Input("momuntum") | |||||
| .Input("fp16variable") | |||||
| .Attr("algo", AttributeType::INT, static_cast<int64_t>(0)); | |||||
| DOMI_OP_SCHEMA(AvgPoolGrad) | |||||
| .Input("shape") | |||||
| .Input("grad") | |||||
| .Output("output") | |||||
| .Attr("padding", AttributeType::INT, static_cast<int64_t>(0)) | |||||
| .Attr("data_format", AttributeType::STRING, static_cast<std::string>("NHWC")) | |||||
| .Attr("strides", AttributeType::UINTLIST, UintTuple{0, 0, 0, 0}) | |||||
| .Attr("ksize", AttributeType::UINTLIST, UintTuple{0, 0, 0, 0}); | |||||
| DOMI_OP_SCHEMA(Lars) | |||||
| .Input("w") | |||||
| .Input("g") | |||||
| .Input("weight_decay") | |||||
| .Output("y") | |||||
| .Attr("hyperpara", AttributeType::FLOAT, static_cast<float>(0.001)) | |||||
| .Attr("epsilon", AttributeType::FLOAT, static_cast<float>(0.00001)); | |||||
| DOMI_OP_SCHEMA(AssignSub) | |||||
| .Input("variable") | |||||
| .Input("input") | |||||
| .Input("output") | |||||
| .Attr("mode", AttributeType::INT, static_cast<int64_t>(0)); | |||||
| DOMI_OP_SCHEMA(AssignAdd) | |||||
| .Input("variable") | |||||
| .Input("input") | |||||
| .Output("output") | |||||
| .Attr("mode", AttributeType::INT, static_cast<int64_t>(0)); | |||||
| DOMI_OP_SCHEMA(SpaceToBatchND).Input("input").Input("block_shape").Input("paddings").Output("output"); | |||||
| DOMI_OP_SCHEMA(Variable) | |||||
| .Output("variable") | |||||
| .Attr("container", AttributeType::STRING, static_cast<std::string>("")) | |||||
| .Attr("shared_name", AttributeType::STRING, static_cast<std::string>("")) | |||||
| .AttrRequired("dtype", AttributeType::INT); | |||||
| DOMI_OP_SCHEMA(Assign).Input("variable").Input("value").Output("variable"); | |||||
| DOMI_OP_SCHEMA(VarIsInitializedOp).Input("variable").Output("value"); | |||||
| DOMI_OP_SCHEMA(NoOp).Attr("algo", AttributeType::INT, static_cast<int64_t>(0)); | |||||
| DOMI_OP_SCHEMA(LogTimeStamp) | |||||
| .Attr("logid", AttributeType::STRING, static_cast<std::string>("")) | |||||
| .Attr("notify", AttributeType::BOOL, static_cast<bool>(false)); | |||||
| DOMI_OP_SCHEMA(ResizeNearestNeighbor) | |||||
| .Input("images") | |||||
| .Output("resized_images") | |||||
| .Attr("align_corners", AttributeType::BOOL, static_cast<bool>(false)) | |||||
| .AttrRequired("height", AttributeType::INT) | |||||
| .AttrRequired("width", AttributeType::INT); | |||||
| DOMI_OP_SCHEMA(BatchToSpaceND).Input("input").Input("block_shape").Input("crops").Output("output"); | |||||
| DOMI_OP_SCHEMA(Assert).Input("x").Input("w").Output("y"); | |||||
| DOMI_OP_SCHEMA(Pow).Input("x").Input("y").Output("z"); | |||||
| DOMI_OP_SCHEMA(GreaterEqual).Input("x1").Input("x2").Output("y"); | |||||
| DOMI_OP_SCHEMA(SpaceToDepth) | |||||
| .Input("input") | |||||
| .Output("output") | |||||
| .Attr("block_size", AttributeType::INT, static_cast<int64_t>(0)) | |||||
| .AttrRequired("T", AttributeType::INT) | |||||
| .Attr("data_format", AttributeType::STRING, static_cast<std::string>("NHWC")); | |||||
| DOMI_OP_SCHEMA(DepthToSpace) | |||||
| .Input("input") | |||||
| .Output("output") | |||||
| .Attr("block_size", AttributeType::INT, static_cast<int64_t>(0)) | |||||
| .AttrRequired("T", AttributeType::INT) | |||||
| .Attr("data_format", AttributeType::STRING, static_cast<std::string>("NHWC")); | |||||
| DOMI_OP_SCHEMA(Rint).Input("input").Output("output").AttrRequired("T", AttributeType::INT); | |||||
| DOMI_OP_SCHEMA(ExtractImagePatches) | |||||
| .Input("images") | |||||
| .Output("y") | |||||
| .AttrRequired("ksizes", AttributeType::INTLIST) | |||||
| .AttrRequired("strides", AttributeType::INTLIST) | |||||
| .AttrRequired("rates", AttributeType::INTLIST) | |||||
| .AttrRequired("padding", AttributeType::STRING); | |||||
| DOMI_OP_SCHEMA(Atan).Input("x").Output("output"); | |||||
| DOMI_OP_SCHEMA(Atanh).Input("x").Output("output"); | |||||
| DOMI_OP_SCHEMA(Acos).Input("x").Output("y"); | |||||
| DOMI_OP_SCHEMA(Asin).Input("x").Output("y"); | |||||
| DOMI_OP_SCHEMA(Log) | |||||
| .Input("x") | |||||
| .Output("output") | |||||
| .AttrRequired("scale", AttributeType::INT) | |||||
| .AttrRequired("shift", AttributeType::INT) | |||||
| .AttrRequired("base", AttributeType::INT); | |||||
| DOMI_OP_SCHEMA(Neg).Input("input").Output("output"); | |||||
| DOMI_OP_SCHEMA(Tan).Input("x").Output("output"); | |||||
| DOMI_OP_SCHEMA(Round).Input("x").Output("output"); | |||||
| DOMI_OP_SCHEMA(Exp) | |||||
| .Input("x") | |||||
| .Output("y") | |||||
| .Attr("scale", AttributeType::FLOAT, static_cast<float>(1)) | |||||
| .Attr("shift", AttributeType::FLOAT, static_cast<float>(0)) | |||||
| .Attr("base", AttributeType::FLOAT, static_cast<float>(-1)); | |||||
| DOMI_OP_SCHEMA(Less).Input("x").Input("y").Output("output"); | |||||
| DOMI_OP_SCHEMA(LessEqual).Input("x").Input("y").Output("output"); | |||||
| DOMI_OP_SCHEMA(OneHot).Input("indices").Input("depth").Input("on_value").Input("off_value").Output("output"); | |||||
| DOMI_OP_SCHEMA(ZerosLike).Input("x").Output("y"); | |||||
| DOMI_OP_SCHEMA(Where).Input("x").Output("y"); | |||||
| DOMI_OP_SCHEMA(RefSwitch).Input("x").Output("y"); | |||||
| DOMI_OP_SCHEMA(FakeQuantWithMinMaxVars) | |||||
| .Input("x") | |||||
| .Input("min") | |||||
| .Input("max") | |||||
| .Output("y") | |||||
| .Attr("narrow_range", AttributeType::BOOL, static_cast<bool>(false)) | |||||
| .Attr("num_bits", AttributeType::INT, static_cast<int64_t>(8)); | |||||
| DOMI_OP_SCHEMA(Sinh).Input("x").Output("y"); | |||||
| DOMI_OP_SCHEMA(Cosh).Input("x").Output("y"); | |||||
| DOMI_OP_SCHEMA(Floor).Input("x").Output("output"); | |||||
| DOMI_OP_SCHEMA(RandomUniform).Input("input").Output("output"); | |||||
| DOMI_OP_SCHEMA(BatchMatMul).Input("x").Input("y").Output("output"); | |||||
| DOMI_OP_SCHEMA(FloorMod).Input("x").Input("y").Output("output"); | |||||
| DOMI_OP_SCHEMA(SquaredDifference).Input("x").Input("y").Output("output"); | |||||
| DOMI_OP_SCHEMA(LayerNorm).Input("x").Output("output").AttrRequired("Epsilon", AttributeType::FLOAT); | |||||
| DOMI_OP_SCHEMA(SSDPostProcessor) | |||||
| .Input("trueImgShape") | |||||
| .Input("boxEncoding") | |||||
| .Input("anchors") | |||||
| .Input("clsPred") | |||||
| .Output("detectBoxes") | |||||
| .Output("detectScores") | |||||
| .Output("detectNum") | |||||
| .Output("detectClasses") | |||||
| .AttrRequired("numClasses", AttributeType::INT) | |||||
| .AttrRequired("scoreThreshold", AttributeType::FLOAT) | |||||
| .AttrRequired("iouThreshold", AttributeType::FLOAT) | |||||
| .AttrRequired("maxDetectionsPerClass", AttributeType::INT) | |||||
| .AttrRequired("maxTotalDetections", AttributeType::INT) | |||||
| .AttrRequired("boxTypeNum", AttributeType::UINT) | |||||
| .AttrRequired("scaleFactors_0", AttributeType::UINT) | |||||
| .AttrRequired("scaleFactors_1", AttributeType::UINT) | |||||
| .AttrRequired("scaleFactors_2", AttributeType::UINT) | |||||
| .AttrRequired("scaleFactors_3", AttributeType::UINT) | |||||
| .AttrRequired("imgH", AttributeType::INT) | |||||
| .AttrRequired("imgW", AttributeType::INT) | |||||
| .AttrRequired("useStaticShape", AttributeType::BOOL) | |||||
| .AttrRequired("convertScoresMode", AttributeType::INT); | |||||
| DOMI_OP_SCHEMA(RetinaPostProcessor) | |||||
| .Input("anchors") | |||||
| .Input("regression") | |||||
| .Input("classification") | |||||
| .Output("detectBoxes") | |||||
| .Output("detectScores") | |||||
| .Output("detectLabels") | |||||
| .Output("detectNum") | |||||
| .AttrRequired("numClasses", AttributeType::INT) | |||||
| .AttrRequired("maxDetections", AttributeType::INT) | |||||
| .AttrRequired("nmsThreshold", AttributeType::FLOAT) | |||||
| .AttrRequired("scoreThreshold", AttributeType::FLOAT) | |||||
| .AttrRequired("imgH", AttributeType::INT) | |||||
| .AttrRequired("imgW", AttributeType::INT) | |||||
| .AttrRequired("boxTypeNum", AttributeType::UINT) | |||||
| .AttrRequired("means", AttributeType::FLOATLIST) | |||||
| .AttrRequired("stds", AttributeType::FLOATLIST); | |||||
| DOMI_OP_SCHEMA(ROIInterPooling) | |||||
| .Input("input") | |||||
| .Input("input_1") | |||||
| .Output("maxPool") | |||||
| .AttrRequired("hStride", AttributeType::INT) | |||||
| .AttrRequired("wStride", AttributeType::INT) | |||||
| .AttrRequired("hKernel", AttributeType::INT) | |||||
| .AttrRequired("wKernel", AttributeType::INT) | |||||
| .AttrRequired("hResize", AttributeType::INT) | |||||
| .AttrRequired("wResize", AttributeType::INT) | |||||
| .AttrRequired("hFeatureMap", AttributeType::INT) | |||||
| .AttrRequired("wFeatureMap", AttributeType::INT); | |||||
| DOMI_OP_SCHEMA(FirstStageProcessor) | |||||
| .Input("anchors") | |||||
| .Input("boxEncoding") | |||||
| .Input("clsPred") | |||||
| .Input("trueImgShape") | |||||
| .Output("detectBoxes") | |||||
| .Output("detectScores") | |||||
| .Output("detectLables") | |||||
| .Output("detectNum") | |||||
| .AttrRequired("scaleFactorsNum", AttributeType::INT) | |||||
| .AttrRequired("iouThreshold", AttributeType::FLOAT) | |||||
| .AttrRequired("scoreThreshold", AttributeType::FLOAT) | |||||
| .AttrRequired("maxSizePerClass", AttributeType::INT) | |||||
| .AttrRequired("maxTotalSize", AttributeType::INT) | |||||
| .AttrRequired("imgH", AttributeType::INT) | |||||
| .AttrRequired("imgW", AttributeType::INT) | |||||
| .AttrRequired("boxTypeNum", AttributeType::UINT) | |||||
| .AttrRequired("scaleFactors_0", AttributeType::UINT) | |||||
| .AttrRequired("scaleFactors_1", AttributeType::UINT) | |||||
| .AttrRequired("scaleFactors_2", AttributeType::UINT) | |||||
| .AttrRequired("scaleFactors_3", AttributeType::UINT); | |||||
| DOMI_OP_SCHEMA(SecondStageProcessor) | |||||
| .Input("anchors") | |||||
| .Input("boxEncoding") | |||||
| .Input("clsPred") | |||||
| .Input("validBoxNum") | |||||
| .Input("trueImgShape") | |||||
| .Output("detectBoxes") | |||||
| .Output("detectScores") | |||||
| .Output("detectLables") | |||||
| .Output("detectNum") | |||||
| .AttrRequired("scaleFactorsNum", AttributeType::INT) | |||||
| .AttrRequired("iouThreshold", AttributeType::FLOAT) | |||||
| .AttrRequired("scoreThreshold", AttributeType::FLOAT) | |||||
| .AttrRequired("maxSizePerClass", AttributeType::INT) | |||||
| .AttrRequired("maxTotalSize", AttributeType::INT) | |||||
| .AttrRequired("numClasses", AttributeType::INT) | |||||
| .AttrRequired("scaleFactors_0", AttributeType::UINT) | |||||
| .AttrRequired("scaleFactors_1", AttributeType::UINT) | |||||
| .AttrRequired("scaleFactors_2", AttributeType::UINT) | |||||
| .AttrRequired("scaleFactors_3", AttributeType::UINT); | |||||
| DOMI_OP_SCHEMA(StreamSwitch) | |||||
| .Input("loopIndex") | |||||
| .Input("itersPerLoop") | |||||
| .AttrRequired("switch_condition", AttributeType::UINT) | |||||
| .AttrRequired("true_branch_stream", AttributeType::INT); | |||||
| DOMI_OP_SCHEMA(StreamActive).AttrRequired("active_stream_list", AttributeType::INTLIST); | |||||
| DOMI_OP_SCHEMA(MemcpyAsync).Input("in").Output("out"); | |||||
| DOMI_OP_SCHEMA(CleanAddr) | |||||
| .AttrRequired("automic_add_addr_start", AttributeType::INT) | |||||
| .AttrRequired("automic_add_mem_size", AttributeType::INT); | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,45 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/op_def/fill_op.h" | |||||
| #include "framework/common/fmk_types.h" | |||||
| namespace ge { | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY FillOperator::FillOperator() : ParserOperator("Fill") {} | |||||
| FMK_FUNC_DEV_VISIBILITY FillOperator::~FillOperator() {} | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY FillOperator &FillOperator::DataType(int64_t dataType) { | |||||
| Attr("T", static_cast<int64_t>(dataType)); | |||||
| return *this; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY FillOperator &FillOperator::Alpha(float alpha) { | |||||
| Attr("alpha", static_cast<float>(alpha)); | |||||
| return *this; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY FillOperator &FillOperator::Beta(float beta) { | |||||
| Attr("beta", static_cast<float>(beta)); | |||||
| return *this; | |||||
| } | |||||
| int64_t FillOperator::GetDataType() const { return GetIntAttr("T"); } | |||||
| float FillOperator::GetAlpha() const { return GetFloatAttr("alpha"); } | |||||
| float FillOperator::GetBeta() const { return GetFloatAttr("beta"); } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,42 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef DOMI_OP_FILL_OP_H_ | |||||
| #define DOMI_OP_FILL_OP_H_ | |||||
| #include "parser/common/op_def/operator.h" | |||||
| namespace ge { | |||||
| class FillOperator : public ParserOperator { | |||||
| public: | |||||
| FillOperator(); | |||||
| ~FillOperator(); | |||||
| FillOperator &DataType(int64_t dataType); | |||||
| FillOperator &Alpha(float alpha); | |||||
| FillOperator &Beta(float beta); | |||||
| int64_t GetDataType() const; | |||||
| float GetAlpha() const; | |||||
| float GetBeta() const; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // DOMI_OP_FILL_OP_H_ | |||||
| @@ -0,0 +1,74 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/op_def/frameworkop_op.h" | |||||
| #include <string> | |||||
| #include "framework/common/fmk_types.h" | |||||
| namespace ge { | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY FrameworkOpOperator::FrameworkOpOperator() | |||||
| : ParserOperator("FrameworkOp") {} | |||||
| FrameworkOpOperator::~FrameworkOpOperator() {} | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY FrameworkOpOperator &FrameworkOpOperator::Name( | |||||
| const std::string &name) { | |||||
| ParserOperator::Name(name); | |||||
| return *this; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY FrameworkOpOperator &FrameworkOpOperator::Index(int64_t index) { | |||||
| Attr(RETVAL_ATTR_NAME_INDEX, static_cast<int64_t>(index)); | |||||
| return *this; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY FrameworkOpOperator &FrameworkOpOperator::NodeDefPkg( | |||||
| const std::string &nodedef_pkg) { | |||||
| Attr_bt(ATTR_NAME_FRAMEWORK_NODE_DEF, nodedef_pkg); | |||||
| return *this; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY FrameworkOpOperator &FrameworkOpOperator::Frameworktype( | |||||
| int64_t framework_type) { | |||||
| Attr(ATTR_NAME_FRAMEWORK_FWK_TYPE, static_cast<int64_t>(framework_type)); | |||||
| return *this; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY FrameworkOpOperator &FrameworkOpOperator::TfOpDef( | |||||
| const std::string &opdef_string) { | |||||
| Attr(ATTR_NAME_FRAMEWORK_OP_DEF, opdef_string); | |||||
| return *this; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY FrameworkOpOperator &FrameworkOpOperator::OriginalType( | |||||
| const std::string &type) { | |||||
| Attr(ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type); | |||||
| return *this; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FrameworkOpOperator &FrameworkOpOperator::FuncDefPkg(const std::string &func_string) { | |||||
| Attr_bt(ATTR_NAME_FRAMEWORK_FUNC_DEF, func_string); | |||||
| return *this; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY int64_t FrameworkOpOperator::GetFrameworkType() const { | |||||
| return GetIntAttr(ATTR_NAME_FRAMEWORK_FWK_TYPE); | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY std::string FrameworkOpOperator::GetNodeDefPkg() const { | |||||
| return GetStringAttr(ATTR_NAME_FRAMEWORK_NODE_DEF); | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,49 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef DOMI_OP_FRAMEWORKOP_OP_OPERATOR_H_ | |||||
| #define DOMI_OP_FRAMEWORKOP_OP_OPERATOR_H_ | |||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "parser/common/op_def/operator.h" | |||||
| namespace ge { | |||||
| class FrameworkOpOperator : public ParserOperator { | |||||
| public: | |||||
| FrameworkOpOperator(); | |||||
| ~FrameworkOpOperator(); | |||||
| FrameworkOpOperator &Name(const std::string &name); | |||||
| FrameworkOpOperator &OriginalType(const std::string &type); | |||||
| FrameworkOpOperator &NodeDefPkg(const std::string &nodedef_pkg); | |||||
| FrameworkOpOperator &Frameworktype(int64_t framework_type); | |||||
| FrameworkOpOperator &TfOpDef(const std::string &opdef_string); | |||||
| FrameworkOpOperator &Index(int64_t index); | |||||
| FrameworkOpOperator &FuncDefPkg(const std::string &func_string); | |||||
| int64_t GetFrameworkType() const; | |||||
| std::string GetNodeDefPkg() const; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // DOMI_OP_FRAMEWORKOP_OP_OPERATOR_H_ | |||||
| @@ -0,0 +1,205 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "parser/common/op_def/ir_pb_converter.h" | |||||
| #include <map> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "google/protobuf/map.h" | |||||
| #include "graph/ge_tensor.h" | |||||
| #include "graph/buffer.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "framework/common/util.h" | |||||
| namespace ge { | |||||
| static void ConvertList(const std::pair<std::string, OpAttribute> &op_attr_pair, ge::OpDescPtr op_def) { | |||||
| domi::AttrDef_ListValue a_list = op_attr_pair.second.value_.list(); | |||||
| vector<int64_t> v_i; | |||||
| for (int32_t i = 0; i < a_list.i_size(); i++) { | |||||
| v_i.push_back((int64_t)a_list.i(i)); | |||||
| } | |||||
| if (v_i.size() > 0) { | |||||
| (void)ge::AttrUtils::SetListInt(op_def, op_attr_pair.first, v_i); | |||||
| return; | |||||
| } | |||||
| vector<float> v_f; | |||||
| for (int32_t i = 0; i < a_list.f_size(); i++) { | |||||
| v_f.push_back(a_list.f(i)); | |||||
| } | |||||
| if (v_f.size() > 0) { | |||||
| (void)ge::AttrUtils::SetListFloat(op_def, op_attr_pair.first, v_f); | |||||
| return; | |||||
| } | |||||
| vector<bool> v_b; | |||||
| for (int32_t i = 0; i < a_list.b_size(); i++) { | |||||
| v_b.push_back(a_list.b(i)); | |||||
| } | |||||
| if (v_b.size() > 0) { | |||||
| (void)ge::AttrUtils::SetListBool(op_def, op_attr_pair.first, v_b); | |||||
| return; | |||||
| } | |||||
| vector<int32_t> v_u; | |||||
| for (int32_t i = 0; i < a_list.u_size(); i++) { | |||||
| v_u.push_back((int32_t)a_list.u(i)); | |||||
| } | |||||
| if (v_u.size() > 0) { | |||||
| (void)ge::AttrUtils::SetListInt(op_def, op_attr_pair.first, v_u); | |||||
| return; | |||||
| } | |||||
| // set for empty list | |||||
| (void)ge::AttrUtils::SetListInt(op_def, op_attr_pair.first, v_i); | |||||
| GELOGI("set empty list for node %s attr %s", op_def->GetName().c_str(), op_attr_pair.first.c_str()); | |||||
| } | |||||
| static void UpdateTensorForOpDesc(const ParserOperator &op, ge::OpDescPtr op_def) { | |||||
| if (op_def == nullptr) { | |||||
| return; | |||||
| } | |||||
| uint32_t in_index = 0; | |||||
| for (const ge::GeTensorDesc &input_desc : op.GetInputTensorDesc()) { | |||||
| if (in_index < op_def->GetInputsSize()) { | |||||
| (void)op_def->UpdateInputDesc(in_index++, input_desc); | |||||
| } else { | |||||
| (void)op_def->AddInputDesc(input_desc); | |||||
| in_index++; | |||||
| } | |||||
| } | |||||
| uint32_t out_index = 0; | |||||
| for (const ge::GeTensorDesc &output_desc : op.GetOutputTensorDesc()) { | |||||
| if (out_index < op_def->GetOutputsSize()) { | |||||
| op_def->UpdateOutputDesc(out_index++, output_desc); | |||||
| } else { | |||||
| op_def->AddOutputDesc(output_desc); | |||||
| out_index++; | |||||
| } | |||||
| } | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status ConvertToOpDesc(const ParserOperator &op, | |||||
| ge::OpDescPtr op_def) { | |||||
| GE_RETURN_WITH_LOG_IF_TRUE(op_def == nullptr, "parameter is null."); | |||||
| GE_CHK_BOOL_RET_STATUS(op.GetSchema(), domi::PARAM_INVALID, "Op schema is null, op type: %s", op.GetType().c_str()); | |||||
| op_def->SetName(op.GetName()); | |||||
| op_def->SetType(op.GetType()); | |||||
| GE_IF_BOOL_EXEC(op.GetType() == ge::parser::YOLO, op_def->SetType(ge::parser::REGION)); | |||||
| UpdateTensorForOpDesc(op, op_def); | |||||
| GELOGD("Convert to op desc: name:%s, input size: %zu, output size:%zu", op_def->GetName().c_str(), | |||||
| op_def->GetInputsSize(), op_def->GetOutputsSize()); | |||||
| for (const auto &op_attr_pair : op.GetOpAttrs()) { | |||||
| if (op_attr_pair.second.value_.has_list()) { | |||||
| ConvertList(op_attr_pair, op_def); | |||||
| } else { | |||||
| if (op_attr_pair.second.value_.value_case() == domi::AttrDef::kBt) { | |||||
| auto &buffer = op_attr_pair.second.value_.bt(); | |||||
| (void)ge::AttrUtils::SetZeroCopyBytes(op_def, op_attr_pair.first, | |||||
| ge::Buffer::CopyFrom(reinterpret_cast<uint8_t *>(const_cast<char *>(buffer.data())), buffer.size())); | |||||
| } | |||||
| if (op_attr_pair.second.value_.value_case() == domi::AttrDef::kS) { | |||||
| (void)ge::AttrUtils::SetStr(op_def, op_attr_pair.first, op_attr_pair.second.value_.s()); | |||||
| } | |||||
| if (op_attr_pair.second.value_.value_case() == domi::AttrDef::kI) { | |||||
| (void)ge::AttrUtils::SetInt(op_def, op_attr_pair.first, op_attr_pair.second.value_.i()); | |||||
| } | |||||
| if (op_attr_pair.second.value_.value_case() == domi::AttrDef::kF) { | |||||
| (void)ge::AttrUtils::SetFloat(op_def, op_attr_pair.first, op_attr_pair.second.value_.f()); | |||||
| } | |||||
| if (op_attr_pair.second.value_.value_case() == domi::AttrDef::kB) { | |||||
| (void)ge::AttrUtils::SetBool(op_def, op_attr_pair.first, op_attr_pair.second.value_.b()); | |||||
| } | |||||
| if (op_attr_pair.second.value_.value_case() == domi::AttrDef::kU) { | |||||
| (void)ge::AttrUtils::SetInt(op_def, op_attr_pair.first, op_attr_pair.second.value_.u()); | |||||
| } | |||||
| } | |||||
| } | |||||
| GE_CHK_BOOL_RET_STATUS(op.GetSchema()->Verify(op_def), domi::PARAM_INVALID, "Op schema verify failed, op name: %s", | |||||
| op.GetName().c_str()); | |||||
| return domi::SUCCESS; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status ConvertFromOpDesc(const ge::OpDescPtr op_def, | |||||
| ParserOperator &op) { | |||||
| GE_RETURN_WITH_LOG_IF_TRUE(op_def == nullptr, "parameter is null."); | |||||
| op.Name(op_def->GetName()); | |||||
| map<string, ge::GeAttrValue> allattrs = op_def->GetAllAttrs(); | |||||
| for (const auto &attr : allattrs) { | |||||
| ge::GeAttrValue::ValueType v_t = attr.second.GetValueType(); | |||||
| switch (v_t) { | |||||
| case ge::GeAttrValue::ValueType::VT_LIST_STRING: { | |||||
| std::vector<string> vec; | |||||
| (void)ge::AttrUtils::GetListStr(op_def, attr.first, vec); | |||||
| op.Attr(attr.first, vec); | |||||
| break; | |||||
| } | |||||
| case ge::GeAttrValue::ValueType::VT_LIST_FLOAT: { | |||||
| std::vector<float> vec; | |||||
| (void)ge::AttrUtils::GetListFloat(op_def, attr.first, vec); | |||||
| op.Attr(attr.first, vec); | |||||
| break; | |||||
| } | |||||
| case ge::GeAttrValue::ValueType::VT_LIST_BOOL: { | |||||
| std::vector<bool> vec; | |||||
| (void)ge::AttrUtils::GetListBool(op_def, attr.first, vec); | |||||
| op.Attr(attr.first, vec); | |||||
| break; | |||||
| } | |||||
| case ge::GeAttrValue::ValueType::VT_LIST_INT: { | |||||
| std::vector<int64_t> vec; | |||||
| (void)ge::AttrUtils::GetListInt(op_def, attr.first, vec); | |||||
| op.Attr(attr.first, vec); | |||||
| break; | |||||
| } | |||||
| case ge::GeAttrValue::ValueType::VT_STRING: { | |||||
| string s = ""; | |||||
| (void)ge::AttrUtils::GetStr(op_def, attr.first, s); | |||||
| op.Attr(attr.first, s); | |||||
| break; | |||||
| } | |||||
| case ge::GeAttrValue::ValueType::VT_FLOAT: { | |||||
| float f = 0.0; | |||||
| (void)ge::AttrUtils::GetFloat(op_def, attr.first, f); | |||||
| op.Attr(attr.first, f); | |||||
| break; | |||||
| } | |||||
| case ge::GeAttrValue::ValueType::VT_BOOL: { | |||||
| bool b = false; | |||||
| (void)ge::AttrUtils::GetBool(op_def, attr.first, b); | |||||
| op.Attr(attr.first, b); | |||||
| break; | |||||
| } | |||||
| case ge::GeAttrValue::ValueType::VT_INT: { | |||||
| int64_t i = 0; | |||||
| (void)ge::AttrUtils::GetInt(op_def, attr.first, i); | |||||
| op.Attr(attr.first, i); | |||||
| break; | |||||
| } | |||||
| default: | |||||
| break; | |||||
| } | |||||
| } | |||||
| return domi::SUCCESS; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,36 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef DOMI_COMMON_OP_DEF_IR_PB_CONVERTER_H | |||||
| #define DOMI_COMMON_OP_DEF_IR_PB_CONVERTER_H | |||||
| #include "framework/common/fmk_error_codes.h" | |||||
| #include "common/op_def/op_schema.h" | |||||
| #include "parser/common/op_def/operator.h" | |||||
| #include "graph/ge_attr_value.h" | |||||
| #include "graph/ge_tensor.h" | |||||
| #include "graph/op_desc.h" | |||||
| #include "graph/utils/op_desc_utils.h" | |||||
| #include "graph/utils/tensor_utils.h" | |||||
| #include "proto/om.pb.h" | |||||
| namespace ge { | |||||
| domi::Status ConvertToOpDesc(const ParserOperator &op, ge::OpDescPtr op_def); | |||||
| domi::Status ConvertFromOpDesc(const ge::OpDescPtr op_def, ParserOperator &op); | |||||
| } // namespace ge | |||||
| #endif // DOMI_COMMON_OP_DEF_IR_PB_CONVERTER_H | |||||
| @@ -0,0 +1,30 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| // AUTO GEN PLEASE DO NOT MODIFY IT | |||||
| #include "common/op_def/no_op_op.h" | |||||
| #include <string> | |||||
| namespace ge { | |||||
| FMK_FUNC_HOST_VISIBILITY NoOpOperator::NoOpOperator() : ParserOperator("NoOp") {} | |||||
| FMK_FUNC_HOST_VISIBILITY NoOpOperator::~NoOpOperator() {} | |||||
| FMK_FUNC_HOST_VISIBILITY NoOpOperator &NoOpOperator::Name(const std::string &name) { | |||||
| ParserOperator::Name(name); | |||||
| return *this; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,33 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| // AUTO GEN PLEASE DO NOT MODIFY IT | |||||
| #ifndef DOMI_OP_NO_OP_OP_H_ | |||||
| #define DOMI_OP_NO_OP_OP_H_ | |||||
| #include "parser/common/op_def/operator.h" | |||||
| #include "framework/omg/parser/parser_types.h" | |||||
| namespace ge { | |||||
| class NoOpOperator : public ParserOperator { | |||||
| public: | |||||
| NoOpOperator(); | |||||
| ~NoOpOperator(); | |||||
| NoOpOperator &Name(const std::string &name); | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // DOMI_OP_NO_OP_H_ AUTO GEN PLEASE DO NOT MODIFY IT | |||||
| @@ -0,0 +1,215 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/op_def/op_schema.h" | |||||
| #include <iostream> | |||||
| #include <utility> | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/debug/log.h" | |||||
| namespace ge { | |||||
| OpSchema::FormalParameter::FormalParameter(const std::string &name, FormalParameterOption param_option) | |||||
| : name_(name), param_option_(param_option) {} | |||||
| OpSchema::FormalParameter::~FormalParameter() {} | |||||
| const std::string &OpSchema::FormalParameter::Name() const { return name_; } | |||||
| OpSchema::FormalParameterOption OpSchema::FormalParameter::Option() const { return param_option_; } | |||||
| OpSchema::OpSchema(const std::string &name) : name_(name) {} | |||||
| OpSchema::~OpSchema() {} | |||||
| OpSchema &OpSchema::Input(const std::string &name, FormalParameterOption param_option) { | |||||
| inputs_.emplace_back(FormalParameter(name, param_option)); | |||||
| return *this; | |||||
| } | |||||
| OpSchema &OpSchema::Output(const std::string &name, FormalParameterOption param_option) { | |||||
| outputs_.emplace_back(FormalParameter(name, param_option)); | |||||
| return *this; | |||||
| } | |||||
| OpSchema &OpSchema::Attr(const Attribute &attr) { | |||||
| (void)attributes_.insert(std::make_pair(attr.name_, attr)); | |||||
| return *this; | |||||
| } | |||||
| #if defined(CFG_BUILD_DEBUG) | |||||
| #define ATTR_SETTER_WITH_SINGLE_VALUE(Type, field, attrtype) \ | |||||
| OpSchema &OpSchema::Attr(const std::string &name, AttributeType attr_type, const Type &default_value) { \ | |||||
| if (attrtype != attr_type) { \ | |||||
| GELOGE(FAILED, "Attribute specification param_type mismatch, input attr type %u, required attr type %u.", \ | |||||
| (uint32_t)attr_type, (uint32_t)attrtype); \ | |||||
| return *this; \ | |||||
| } \ | |||||
| \ | |||||
| domi::AttrDef a; \ | |||||
| a.set_##field(default_value); \ | |||||
| Attr(Attribute(name, attr_type, a)); \ | |||||
| return *this; \ | |||||
| } | |||||
| #else | |||||
| #define ATTR_SETTER_WITH_SINGLE_VALUE(Type, field, attrtype) \ | |||||
| OpSchema &OpSchema::Attr(const std::string &name, AttributeType attr_type, const Type &default_value) { \ | |||||
| if (attrtype != attr_type) { \ | |||||
| return *this; \ | |||||
| } \ | |||||
| domi::AttrDef a; \ | |||||
| a.set_##field(default_value); \ | |||||
| Attr(Attribute(name, attr_type, a)); \ | |||||
| return *this; \ | |||||
| } | |||||
| #endif | |||||
| #if defined(CFG_BUILD_DEBUG) | |||||
| #define ATTR_SETTER_WITH_LIST_VALUE(Type, field, attrtype) \ | |||||
| OpSchema &OpSchema::Attr(const std::string &name, AttributeType attr_type, const std::vector<Type> &default_value) { \ | |||||
| if (attrtype != attr_type) { \ | |||||
| GELOGE(FAILED, "Attribute specification vector param_type mismatch, input attr type %u, required attr type %u.", \ | |||||
| (uint32_t)attr_type, (uint32_t)attrtype); \ | |||||
| return *this; \ | |||||
| } \ | |||||
| domi::AttrDef vec_a; \ | |||||
| for (const auto &v : default_value) { \ | |||||
| vec_a.mutable_list()->add_##field(v); \ | |||||
| } \ | |||||
| Attr(Attribute(name, attr_type, vec_a)); \ | |||||
| return *this; \ | |||||
| } \ | |||||
| OpSchema &OpSchema::Attr(const std::string &name, AttributeType attr_type, const Tuple<Type> &default_value) { \ | |||||
| if (attrtype != attr_type) { \ | |||||
| GELOGE(FAILED, "Attribute specification vector param_type mismatch, input attr type %u, required attr type %u.", \ | |||||
| (uint32_t)attr_type, (uint32_t)attrtype); \ | |||||
| return *this; \ | |||||
| } \ | |||||
| domi::AttrDef tuple_a; \ | |||||
| for (const auto &v : default_value) { \ | |||||
| tuple_a.mutable_list()->add_##field(v); \ | |||||
| } \ | |||||
| Attr(Attribute(name, attr_type, tuple_a)); \ | |||||
| return *this; \ | |||||
| } | |||||
| #else | |||||
| #define ATTR_SETTER_WITH_LIST_VALUE(Type, field, attrtype) \ | |||||
| OpSchema &OpSchema::Attr(const std::string &name, AttributeType attr_type, const std::vector<Type> &default_value) { \ | |||||
| if (attrtype != attr_type) { \ | |||||
| return *this; \ | |||||
| } \ | |||||
| domi::AttrDef vec_a; \ | |||||
| for (const auto &v : default_value) { \ | |||||
| vec_a.mutable_list()->add_##field(v); \ | |||||
| } \ | |||||
| Attr(Attribute(name, attr_type, vec_a)); \ | |||||
| return *this; \ | |||||
| } \ | |||||
| OpSchema &OpSchema::Attr(const std::string &name, AttributeType attr_type, const Tuple<Type> &default_value) { \ | |||||
| if (attrtype != attr_type) { \ | |||||
| return *this; \ | |||||
| } \ | |||||
| domi::AttrDef tuple_a; \ | |||||
| for (const auto &v : default_value) { \ | |||||
| tuple_a.mutable_list()->add_##field(v); \ | |||||
| } \ | |||||
| Attr(Attribute(name, attr_type, tuple_a)); \ | |||||
| return *this; \ | |||||
| } | |||||
| #endif | |||||
| ATTR_SETTER_WITH_SINGLE_VALUE(uint32_t, u, AttributeType::UINT) | |||||
| ATTR_SETTER_WITH_SINGLE_VALUE(int64_t, i, AttributeType::INT) | |||||
| ATTR_SETTER_WITH_SINGLE_VALUE(bool, b, AttributeType::BOOL) | |||||
| ATTR_SETTER_WITH_SINGLE_VALUE(float, f, AttributeType::FLOAT) | |||||
| ATTR_SETTER_WITH_SINGLE_VALUE(std::string, s, AttributeType::STRING) | |||||
| ATTR_SETTER_WITH_LIST_VALUE(uint32_t, u, AttributeType::UINTLIST) | |||||
| ATTR_SETTER_WITH_LIST_VALUE(int64_t, i, AttributeType::INTLIST) | |||||
| ATTR_SETTER_WITH_LIST_VALUE(bool, b, AttributeType::BOOLLIST) | |||||
| ATTR_SETTER_WITH_LIST_VALUE(float, f, AttributeType::FLOATLIST) | |||||
| ATTR_SETTER_WITH_LIST_VALUE(std::string, s, AttributeType::STRINGLIST) | |||||
| OpSchema &OpSchema::AttrRequired(const std::string &name, AttributeType attr_type) { | |||||
| Attr(Attribute(name, attr_type, true)); | |||||
| return *this; | |||||
| } | |||||
| bool OpSchema::HasDefaultAttr(const std::string &name) const { | |||||
| auto it = attributes_.find(name); | |||||
| if (it == attributes_.end()) { | |||||
| return false; | |||||
| } | |||||
| // required does not need a default value | |||||
| return !it->second.required_; | |||||
| } | |||||
| const domi::AttrDef &OpSchema::GetDefaultAttr(const std::string &name) const { | |||||
| auto it = attributes_.find(name); | |||||
| if (it == attributes_.end()) { | |||||
| const static domi::AttrDef attr_def; | |||||
| return attr_def; | |||||
| } | |||||
| return it->second.default_value_; | |||||
| } | |||||
| bool OpSchema::Verify(const ge::OpDescPtr op_def) const { | |||||
| if (op_def->GetType() != name_) { | |||||
| GELOGE(FAILED, "Name not math, op schema name: %s, opdef type: %s.", name_.c_str(), op_def->GetType().c_str()); | |||||
| return false; | |||||
| } | |||||
| // Required field verification | |||||
| for (const auto &pair : attributes_) { | |||||
| const auto &attr = pair.second; | |||||
| if (!attr.required_) { | |||||
| continue; | |||||
| } | |||||
| if (!op_def->HasAttr(attr.name_)) { | |||||
| GELOGE(FAILED, "Required attribute: %s of op: %s is missing.", attr.name_.c_str(), op_def->GetName().c_str()); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| OpSchemaFactory &OpSchemaFactory::Instance() { | |||||
| static OpSchemaFactory instance; | |||||
| return instance; | |||||
| } | |||||
| const OpSchema *OpSchemaFactory::Get(const std::string &op) const { | |||||
| auto it = op_schema_map_.find(op); | |||||
| if (it == op_schema_map_.end()) { | |||||
| return nullptr; | |||||
| } | |||||
| return &it->second; | |||||
| } | |||||
| OpSchemaRegistry::OpSchemaRegistry(OpSchema &op_schema) { | |||||
| OpSchemaFactory &op_factory = OpSchemaFactory::Instance(); | |||||
| // save op_schema to the map | |||||
| if (op_factory.op_schema_map_.count(op_schema.name_)) { | |||||
| GELOGD("Failed to register op schema: %s., reason: already exist!", op_schema.name_.c_str()); | |||||
| return; | |||||
| } | |||||
| (void)op_factory.op_schema_map_.emplace(std::make_pair(op_schema.name_, op_schema)); | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,175 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef DOMI_COMMON_OP_SCHEMA_H | |||||
| #define DOMI_COMMON_OP_SCHEMA_H | |||||
| #include <string> | |||||
| #include <unordered_map> | |||||
| #include <vector> | |||||
| #include "common/tuple.h" | |||||
| #include "graph/op_desc.h" | |||||
| #include "proto/om.pb.h" | |||||
| #include "framework/common/fmk_types.h" | |||||
| namespace ge { | |||||
| enum class AttributeType { | |||||
| UNDEFINED, | |||||
| INT, | |||||
| UINT, | |||||
| BOOL, | |||||
| FLOAT, | |||||
| STRING, | |||||
| BYTES, | |||||
| INTLIST, | |||||
| UINTLIST, | |||||
| BOOLLIST, | |||||
| FLOATLIST, | |||||
| STRINGLIST | |||||
| }; | |||||
| class OpSchema; | |||||
| class OpSchemaRegistry; | |||||
| class FMK_FUNC_HOST_VISIBILITY OpSchema { | |||||
| public: | |||||
| // Formal parameter options. | |||||
| enum FormalParameterOption { | |||||
| // The input formal parameter is single and not optional. | |||||
| // Number of this input is 1. | |||||
| Single = 0, | |||||
| // The input formal parameter is single and optional. | |||||
| // Number of this input is 0 or 1. | |||||
| Optional = 1, | |||||
| // The input formal parameter is variadic. | |||||
| // Number of this input is [1, n]. | |||||
| Variadic = 2, | |||||
| }; | |||||
| // Formal parameter represenation, including input/output name, typeStr, | |||||
| // description, and type constraints. | |||||
| class FormalParameter { | |||||
| public: | |||||
| // Constructor. | |||||
| FormalParameter() = default; | |||||
| explicit FormalParameter(const std::string &name, FormalParameterOption param_option = Single); | |||||
| ~FormalParameter(); | |||||
| // Get formal parameter name. | |||||
| const std::string &Name() const; | |||||
| // Get the parameter option, it could be Single, Optional or Variadic. | |||||
| FormalParameterOption Option() const; | |||||
| private: | |||||
| friend class OpSchema; | |||||
| // Formal parameter name. | |||||
| std::string name_; | |||||
| // Formal parameter option. | |||||
| FormalParameterOption param_option_; | |||||
| }; | |||||
| explicit OpSchema(const std::string &name); | |||||
| ~OpSchema(); | |||||
| OpSchema &Input(const std::string &name, FormalParameterOption param_option = Single); | |||||
| OpSchema &Output(const std::string &name, FormalParameterOption param_option = Single); | |||||
| struct Attribute { | |||||
| Attribute(const std::string &name, AttributeType type, bool required) | |||||
| : name_(name), type_(type), required_(required) {} | |||||
| Attribute(const std::string &name, AttributeType type, domi::AttrDef default_value) | |||||
| : name_(name), type_(type), required_(false), default_value_(default_value) {} | |||||
| const std::string name_; | |||||
| AttributeType type_; | |||||
| bool required_; | |||||
| domi::AttrDef default_value_; | |||||
| }; | |||||
| OpSchema &Attr(const Attribute &attr); | |||||
| // Register "optional" attribute with default value. | |||||
| #define ATTR_SETTER_WITH_DEFAULT_VALUE(TypeName) \ | |||||
| OpSchema &Attr(const std::string &name, AttributeType type, const TypeName &default_value); \ | |||||
| OpSchema &Attr(const std::string &name, AttributeType type, const std::vector<TypeName> &default_value); \ | |||||
| OpSchema &Attr(const std::string &name, AttributeType type, const Tuple<TypeName> &default_value); | |||||
| ATTR_SETTER_WITH_DEFAULT_VALUE(uint32_t) | |||||
| ATTR_SETTER_WITH_DEFAULT_VALUE(int64_t) | |||||
| ATTR_SETTER_WITH_DEFAULT_VALUE(bool) | |||||
| ATTR_SETTER_WITH_DEFAULT_VALUE(float) | |||||
| ATTR_SETTER_WITH_DEFAULT_VALUE(std::string) | |||||
| // Register "required" attribute without default value. | |||||
| OpSchema &AttrRequired(const std::string &name, AttributeType type); | |||||
| bool HasDefaultAttr(const std::string &name) const; | |||||
| const domi::AttrDef &GetDefaultAttr(const std::string &name) const; | |||||
| // verify op_def | |||||
| bool Verify(const ge::OpDescPtr op_def) const; | |||||
| private: | |||||
| friend class OpSchemaRegistry; | |||||
| std::string name_; | |||||
| std::vector<FormalParameter> inputs_; | |||||
| std::vector<FormalParameter> outputs_; | |||||
| std::unordered_map<std::string, Attribute> attributes_; | |||||
| }; | |||||
| class OpSchemaFactory { | |||||
| public: | |||||
| // this is a singleton object | |||||
| static OpSchemaFactory &Instance(); | |||||
| const OpSchema *Get(const std::string &op) const; | |||||
| private: | |||||
| OpSchemaFactory() = default; | |||||
| ~OpSchemaFactory() = default; | |||||
| friend class OpSchemaRegistry; | |||||
| // the op schema map | |||||
| std::unordered_map<std::string, OpSchema> op_schema_map_; | |||||
| }; | |||||
| class FMK_FUNC_HOST_VISIBILITY OpSchemaRegistry { | |||||
| public: | |||||
| OpSchemaRegistry(OpSchema &op_schema); | |||||
| ~OpSchemaRegistry() = default; | |||||
| }; | |||||
| #define DOMI_OP_SCHEMA(name) DOMI_OP_SCHEMA_UNIQ_HELPER(__COUNTER__, name) | |||||
| #define DOMI_OP_SCHEMA_UNIQ_HELPER(ctr, name) DOMI_OP_SCHEMA_UNIQ(ctr, name) | |||||
| #define DOMI_OP_SCHEMA_UNIQ(ctr, name) \ | |||||
| static OpSchemaRegistry op_schema_registry##ctr __attribute__((unused)) = OpSchema(#name) | |||||
| } // namespace ge | |||||
| #endif // DOMI_COMMON_OP_SCHEMA_H | |||||
| @@ -0,0 +1,200 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "operator.h" | |||||
| #include <utility> | |||||
| #include "framework/common/fmk_types.h" | |||||
| #include "framework/common/util.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| using ge::BoolTuple; | |||||
| using ge::FloatTuple; | |||||
| using ge::IntTuple; | |||||
| using ge::StringTuple; | |||||
| using ge::UintTuple; | |||||
| namespace ge { | |||||
| ParserOperator::ParserOperator(const std::string &type) { | |||||
| type_ = type; | |||||
| op_schema_ = ge::OpSchemaFactory::Instance().Get(type); | |||||
| if (op_schema_ == nullptr) { | |||||
| GELOGW("Cannot find op schema of op type: %s", type.c_str()); | |||||
| } | |||||
| } | |||||
| ParserOperator &ParserOperator::Input(const ParserOperator &in_op, uint32_t index) { | |||||
| if (index == 0) { | |||||
| inputs_.push_back(in_op.GetName()); | |||||
| } else { | |||||
| inputs_.push_back(in_op.GetName() + ":" + std::to_string(index)); | |||||
| } | |||||
| return *this; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ParserOperator &ParserOperator::Name(const std::string &name) { | |||||
| name_ = name; | |||||
| return *this; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ParserOperator &ParserOperator::Type(const std::string &type) { | |||||
| type_ = type; | |||||
| return *this; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ParserOperator &ParserOperator::InputTensorDesc( | |||||
| const ge::GeTensorDesc &input_tensordesc) { | |||||
| input_descs_.push_back(input_tensordesc); | |||||
| return *this; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ParserOperator &ParserOperator::OutputTensorDesc( | |||||
| const ge::GeTensorDesc &output_tensordesc) { | |||||
| output_descs_.push_back(output_tensordesc); | |||||
| return *this; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ParserOperator &ParserOperator::AttrVector( | |||||
| std::string key, | |||||
| std::vector<int32_t> &value) { | |||||
| domi::AttrDef out; | |||||
| auto it = op_attrs_.find(key); | |||||
| if (it != op_attrs_.end()) { | |||||
| out = it->second.value_; | |||||
| } | |||||
| for (auto &v : value) { | |||||
| out.mutable_list()->add_i(v); | |||||
| } | |||||
| (void)op_attrs_.erase(key); | |||||
| (void)op_attrs_.insert(std::make_pair(key, OpAttribute(key, out))); | |||||
| return *this; | |||||
| } | |||||
| FMK_FUNC_DEV_VISIBILITY FMK_FUNC_DEV_VISIBILITY ParserOperator &ParserOperator::AttrVector( | |||||
| std::string key, | |||||
| std::vector<int64_t> &value) { | |||||
| domi::AttrDef out; | |||||
| auto it = op_attrs_.find(key); | |||||
| if (it != op_attrs_.end()) { | |||||
| out = it->second.value_; | |||||
| } | |||||
| for (auto &v : value) { | |||||
| out.mutable_list()->add_i(v); | |||||
| } | |||||
| (void)op_attrs_.erase(key); | |||||
| (void)op_attrs_.insert(std::make_pair(key, OpAttribute(key, out))); | |||||
| return *this; | |||||
| } | |||||
| ParserOperator &ParserOperator::Attr(const OpAttribute &attr) { | |||||
| auto it = op_attrs_.find(attr.name_); | |||||
| if (it != op_attrs_.end()) { | |||||
| (void)op_attrs_.erase(it); | |||||
| } | |||||
| (void)op_attrs_.insert(std::make_pair(attr.name_, attr)); | |||||
| return *this; | |||||
| } | |||||
| ParserOperator &ParserOperator::Attr_bt(const std::string &name, const std::string &value) { | |||||
| domi::AttrDef a; | |||||
| a.set_bt(value); | |||||
| Attr(OpAttribute(name, a)); | |||||
| return *this; | |||||
| } | |||||
| #define ATTR_SETTER_WITH_SINGLE_VALUE(type, field) \ | |||||
| ParserOperator &ParserOperator::Attr(const std::string &name, const type &value) { \ | |||||
| domi::AttrDef a; \ | |||||
| a.set_##field(value); \ | |||||
| Attr(OpAttribute(name, a)); \ | |||||
| return *this; \ | |||||
| } | |||||
| #define ATTR_SETTER_WITH_LIST_VALUE(type, field) \ | |||||
| ParserOperator &ParserOperator::Attr(const std::string &name, const std::vector<type> &value) { \ | |||||
| domi::AttrDef a; \ | |||||
| auto attr_list = a.mutable_list(); \ | |||||
| for (size_t i = 0; i < value.size(); ++i) { \ | |||||
| attr_list->add_##field(value[i]); \ | |||||
| } \ | |||||
| Attr(OpAttribute(name, a)); \ | |||||
| return *this; \ | |||||
| } \ | |||||
| ParserOperator &ParserOperator::Attr(const std::string &name, const ge::Tuple<type> &value) { \ | |||||
| domi::AttrDef a; \ | |||||
| auto attr_list = a.mutable_list(); \ | |||||
| for (uint32_t i = 0; i < value.ndim(); ++i) { \ | |||||
| attr_list->add_##field(value[i]); \ | |||||
| } \ | |||||
| Attr(OpAttribute(name, a)); \ | |||||
| return *this; \ | |||||
| } | |||||
| ATTR_SETTER_WITH_SINGLE_VALUE(int64_t, i) | |||||
| ATTR_SETTER_WITH_SINGLE_VALUE(bool, b) | |||||
| ATTR_SETTER_WITH_SINGLE_VALUE(float, f) | |||||
| ATTR_SETTER_WITH_SINGLE_VALUE(std::string, s) | |||||
| ATTR_SETTER_WITH_SINGLE_VALUE(uint32_t, i) | |||||
| ATTR_SETTER_WITH_LIST_VALUE(int64_t, i) | |||||
| ATTR_SETTER_WITH_LIST_VALUE(bool, b) | |||||
| ATTR_SETTER_WITH_LIST_VALUE(float, f) | |||||
| ATTR_SETTER_WITH_LIST_VALUE(std::string, s) | |||||
| ATTR_SETTER_WITH_LIST_VALUE(uint32_t, i) | |||||
| #define ATTR_GET_SINGLE_VALUE(type, field, type_name) \ | |||||
| type ParserOperator::Get##type_name##Attr(const std::string &name) const { \ | |||||
| domi::AttrDef single_val; \ | |||||
| auto it = op_attrs_.find(name); \ | |||||
| if (it != op_attrs_.end()) { \ | |||||
| single_val = it->second.value_; \ | |||||
| } else { \ | |||||
| if (op_schema_ && op_schema_->HasDefaultAttr(name)) { \ | |||||
| single_val = op_schema_->GetDefaultAttr(name); \ | |||||
| } \ | |||||
| } \ | |||||
| return single_val.field(); \ | |||||
| } | |||||
| ATTR_GET_SINGLE_VALUE(uint32_t, i, Uint) | |||||
| ATTR_GET_SINGLE_VALUE(int64_t, i, Int) | |||||
| ATTR_GET_SINGLE_VALUE(float, f, Float) | |||||
| ATTR_GET_SINGLE_VALUE(bool, b, Bool) | |||||
| ATTR_GET_SINGLE_VALUE(std::string, s, String) | |||||
| #define ATTR_GET_TUPLE_VALUE(type, field, tuple_type_name) \ | |||||
| tuple_type_name ParserOperator::Get##tuple_type_name##Attr(const std::string &name) const { \ | |||||
| domi::AttrDef value; \ | |||||
| auto it = op_attrs_.find(name); \ | |||||
| if (it != op_attrs_.end()) { \ | |||||
| value = it->second.value_; \ | |||||
| } else { \ | |||||
| if (op_schema_ && op_schema_->HasDefaultAttr(name)) { \ | |||||
| value = op_schema_->GetDefaultAttr(name); \ | |||||
| } \ | |||||
| } \ | |||||
| const auto attr_def = value.list(); \ | |||||
| std::size_t n = attr_def.field##_size(); \ | |||||
| std::vector<type> vec(n); \ | |||||
| for (std::size_t i = 0; i < n; i++) { \ | |||||
| vec[i] = attr_def.field(i); \ | |||||
| } \ | |||||
| return tuple_type_name(vec); \ | |||||
| } | |||||
| ATTR_GET_TUPLE_VALUE(uint32_t, i, UintTuple) | |||||
| ATTR_GET_TUPLE_VALUE(int64_t, i, IntTuple) | |||||
| ATTR_GET_TUPLE_VALUE(float, f, FloatTuple) | |||||
| ATTR_GET_TUPLE_VALUE(bool, b, BoolTuple) | |||||
| ATTR_GET_TUPLE_VALUE(std::string, s, StringTuple) | |||||
| } // namespace domi | |||||
| @@ -0,0 +1,117 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef DOMI_COMMON_OP_OPERATOR_H | |||||
| #define DOMI_COMMON_OP_OPERATOR_H | |||||
| #include <string> | |||||
| #include <unordered_map> | |||||
| #include <vector> | |||||
| #include "framework/common/fmk_types.h" | |||||
| #include "common/op_def/op_schema.h" | |||||
| #include "common/tuple.h" | |||||
| #include "graph/ge_tensor.h" | |||||
| #include "proto/om.pb.h" | |||||
| namespace ge { | |||||
| struct OpAttribute { | |||||
| OpAttribute(const std::string &name, const domi::AttrDef &value) : name_(name), value_(value) {} | |||||
| const std::string name_; | |||||
| domi::AttrDef value_; | |||||
| }; | |||||
| class FMK_FUNC_HOST_VISIBILITY ParserOperator { | |||||
| public: | |||||
| explicit ParserOperator(const std::string &type); | |||||
| ParserOperator() { op_schema_ = nullptr; } | |||||
| virtual ~ParserOperator() { op_schema_ = nullptr; } | |||||
| ParserOperator &Input(const ParserOperator &in_op, uint32_t index = 0); | |||||
| ParserOperator &Attr(const OpAttribute &op_attr); | |||||
| ParserOperator &AttrVector(std::string key, std::vector<int32_t> &value); | |||||
| ParserOperator &AttrVector(std::string key, std::vector<int64_t> &value); | |||||
| ParserOperator &Name(const std::string &name); | |||||
| ParserOperator &Type(const std::string &type); | |||||
| ParserOperator &InputTensorDesc(const ge::GeTensorDesc &input_tensordesc); | |||||
| ParserOperator &OutputTensorDesc(const ge::GeTensorDesc &output_tensordesc); | |||||
| ParserOperator &Attr_bt(const std::string &name, const std::string &value); | |||||
| // Register "optional" attribute with default value. | |||||
| #define ATTR_SETTER_WITH_VALUE(TypeName) \ | |||||
| ParserOperator &Attr(const std::string &name, const TypeName &value); \ | |||||
| ParserOperator &Attr(const std::string &name, const std::vector<TypeName> &value); \ | |||||
| ParserOperator &Attr(const std::string &name, const ge::Tuple<TypeName> &value) | |||||
| ATTR_SETTER_WITH_VALUE(uint32_t); | |||||
| ATTR_SETTER_WITH_VALUE(int64_t); | |||||
| ATTR_SETTER_WITH_VALUE(bool); | |||||
| ATTR_SETTER_WITH_VALUE(float); | |||||
| ATTR_SETTER_WITH_VALUE(std::string); | |||||
| const std::string &GetName() const { return name_; } | |||||
| const std::string &GetType() const { return type_; } | |||||
| const std::vector<std::string> &GetInputs() const { return inputs_; } | |||||
| const std::vector<ge::GeTensorDesc> &GetInputTensorDesc() const { return input_descs_; } | |||||
| const std::vector<ge::GeTensorDesc> &GetOutputTensorDesc() const { return output_descs_; } | |||||
| const std::unordered_map<std::string, OpAttribute> GetOpAttrs() const { return op_attrs_; } | |||||
| bool HasAttr(const std::string &name) const { return op_attrs_.find(name) != op_attrs_.end(); } | |||||
| const ge::OpSchema *GetSchema() const { return op_schema_; } | |||||
| int64_t GetIntAttr(const std::string &name) const; | |||||
| uint32_t GetUintAttr(const std::string &name) const; | |||||
| float GetFloatAttr(const std::string &name) const; | |||||
| bool GetBoolAttr(const std::string &name) const; | |||||
| std::string GetStringAttr(const std::string &name) const; | |||||
| ge::IntTuple GetIntTupleAttr(const std::string &name) const; | |||||
| ge::UintTuple GetUintTupleAttr(const std::string &name) const; | |||||
| ge::FloatTuple GetFloatTupleAttr(const std::string &name) const; | |||||
| ge::BoolTuple GetBoolTupleAttr(const std::string &name) const; | |||||
| ge::StringTuple GetStringTupleAttr(const std::string &name) const; | |||||
| private: | |||||
| const ge::OpSchema *op_schema_; | |||||
| std::string name_; | |||||
| std::string type_; | |||||
| std::vector<std::string> inputs_; | |||||
| std::unordered_map<std::string, OpAttribute> op_attrs_; | |||||
| std::vector<ge::GeTensorDesc> input_descs_; | |||||
| std::vector<ge::GeTensorDesc> output_descs_; | |||||
| }; | |||||
| } // namespace domi | |||||
| #endif // DOMI_COMMON_OP_OPERATOR_H | |||||
| @@ -0,0 +1,34 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| // AUTO GEN PLEASE DO NOT MODIFY IT | |||||
| #include "common/op_def/ref_switch_op.h" | |||||
| namespace ge { | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY RefSwitchOperator::RefSwitchOperator() : ParserOperator("RefSwitch") {} | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY RefSwitchOperator::~RefSwitchOperator() {} | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY RefSwitchOperator &RefSwitchOperator::Name(const std::string &name) { | |||||
| ParserOperator::Name(name); | |||||
| return *this; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY RefSwitchOperator &RefSwitchOperator::T(ge::DataType t) { | |||||
| Attr("T", (int64_t)t); | |||||
| return *this; | |||||
| } | |||||
| } // namespace ge AUTO GEN PLEASE DO NOT MODIFY IT | |||||
| @@ -0,0 +1,34 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| // AUTO GEN PLEASE DO NOT MODIFY IT | |||||
| #ifndef DOMI_OP_REF_SWITCH_H_ | |||||
| #define DOMI_OP_REF_SWITCH_H_ | |||||
| #include "parser/common/op_def/operator.h" | |||||
| #include "framework/omg/parser/parser_types.h" | |||||
| namespace ge { | |||||
| class RefSwitchOperator : public ParserOperator { | |||||
| public: | |||||
| RefSwitchOperator(); | |||||
| ~RefSwitchOperator(); | |||||
| RefSwitchOperator &Name(const std::string &name); | |||||
| RefSwitchOperator &T(ge::DataType t); | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // DOMI_OP_REF_SWITCH_H_ AUTO GEN PLEASE DO NOT MODIFY IT | |||||
| @@ -0,0 +1,56 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| // AUTO GEN PLEASE DO NOT MODIFY IT | |||||
| #include "common/op_def/shape_n_op.h" | |||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "framework/omg/parser/parser_types.h" | |||||
| namespace ge { | |||||
| FMK_FUNC_HOST_VISIBILITY ShapeNOperator::ShapeNOperator() : ParserOperator("ShapeN") {} | |||||
| FMK_FUNC_HOST_VISIBILITY ShapeNOperator::~ShapeNOperator() {} | |||||
| FMK_FUNC_HOST_VISIBILITY ShapeNOperator &ShapeNOperator::Name(const std::string &name) { | |||||
| ParserOperator::Name(name); | |||||
| return *this; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY ShapeNOperator &ShapeNOperator::N(int64_t n) { | |||||
| Attr(SHAPEN_ATTR_N, n); | |||||
| return *this; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY int64_t ShapeNOperator::GetN() const { return GetIntAttr(SHAPEN_ATTR_N); } | |||||
| FMK_FUNC_HOST_VISIBILITY ShapeNOperator &ShapeNOperator::InType(ge::DataType t) { | |||||
| Attr(SHAPEN_ATTR_IN_TYPE, (int64_t)t); | |||||
| return *this; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY ge::DataType ShapeNOperator::GetInType() const { | |||||
| return (ge::DataType)GetIntAttr(SHAPEN_ATTR_IN_TYPE); | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY ShapeNOperator &ShapeNOperator::OutType(ge::DataType t) { | |||||
| Attr(SHAPEN_ATTR_OUT_TYPE, (int64_t)t); | |||||
| return *this; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY ge::DataType ShapeNOperator::GetOutType() const { | |||||
| return (ge::DataType)GetIntAttr(SHAPEN_ATTR_OUT_TYPE); | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,40 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| // AUTO GEN PLEASE DO NOT MODIFY IT | |||||
| #ifndef DOMI_OP_SHAPE_N_OP_H_ | |||||
| #define DOMI_OP_SHAPE_N_OP_H_ | |||||
| #include "parser/common/op_def/operator.h" | |||||
| #include "framework/omg/parser/parser_types.h" | |||||
| namespace ge { | |||||
| class ShapeNOperator : public ParserOperator { | |||||
| public: | |||||
| ShapeNOperator(); | |||||
| ~ShapeNOperator(); | |||||
| ShapeNOperator &Name(const std::string &name); | |||||
| ShapeNOperator &N(int64_t n); | |||||
| int64_t GetN() const; | |||||
| ShapeNOperator &InType(ge::DataType t); | |||||
| ge::DataType GetInType() const; | |||||
| ShapeNOperator &OutType(ge::DataType t); | |||||
| ge::DataType GetOutType() const; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // DOMI_OP_SHAPE_N_OP_H_ AUTO GEN PLEASE DO NOT MODIFY IT | |||||
| @@ -0,0 +1,37 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| // AUTO GEN PLEASE DO NOT MODIFY IT | |||||
| #include "common/op_def/var_is_initialized_op_op.h" | |||||
| #include <string> | |||||
| #include <vector> | |||||
| namespace ge { | |||||
| VarIsInitializedOpOperator::VarIsInitializedOpOperator() : ParserOperator(ge::parser::VARISINITIALIZEDOP) {} | |||||
| VarIsInitializedOpOperator::~VarIsInitializedOpOperator() {} | |||||
| VarIsInitializedOpOperator &VarIsInitializedOpOperator::Name(const std::string &name) { | |||||
| ParserOperator::Name(name); | |||||
| return *this; | |||||
| } | |||||
| VarIsInitializedOpOperator &VarIsInitializedOpOperator::VectorAttr(const std::string &key, | |||||
| std::vector<int64_t> &value) { | |||||
| Attr(key, value); | |||||
| return *this; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,34 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| // AUTO GEN PLEASE DO NOT MODIFY IT | |||||
| #ifndef DOMI_OP_VARISINITIALIZEDOP_H_ | |||||
| #define DOMI_OP_VARISINITIALIZEDOP_H_ | |||||
| #include "parser/common/op_def/operator.h" | |||||
| #include "framework/omg/parser/parser_types.h" | |||||
| namespace ge { | |||||
| class VarIsInitializedOpOperator : public ParserOperator { | |||||
| public: | |||||
| VarIsInitializedOpOperator(); | |||||
| ~VarIsInitializedOpOperator(); | |||||
| VarIsInitializedOpOperator &Name(const std::string &name); | |||||
| VarIsInitializedOpOperator &VectorAttr(const std::string &key, std::vector<int64_t> &value); | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // DOMI_OP_VARISINITIALIZEDOP_H_ AUTO GEN PLEASE DO NOT MODIFY IT | |||||
| @@ -0,0 +1,57 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "parser/common/op_def/variable_op.h" | |||||
| #include "graph/debug/ge_attr_define.h" | |||||
| namespace ge { | |||||
| VariableOperator::VariableOperator() : ParserOperator(ge::parser::VARIABLE) {} | |||||
| VariableOperator::~VariableOperator() {} | |||||
| VariableOperator &VariableOperator::Name(const std::string &name) { | |||||
| ParserOperator::Name(name); | |||||
| return *this; | |||||
| } | |||||
| VariableOperator &VariableOperator::Container(const std::string &container) { | |||||
| Attr(VAR_ATTR_CONTAINER, container); | |||||
| return *this; | |||||
| } | |||||
| VariableOperator &VariableOperator::SharedName(const std::string &sharedname) { | |||||
| Attr(VAR_ATTR_SHARED_NAME, sharedname); | |||||
| return *this; | |||||
| } | |||||
| VariableOperator &VariableOperator::Placement(const std::string &placement) { | |||||
| Attr(ATTR_VARIABLE_PLACEMENT, placement); | |||||
| return *this; | |||||
| } | |||||
| VariableOperator &VariableOperator::SrcType(const int64_t &dtype) { | |||||
| Attr(VAR_ATTR_DTYPE, dtype); | |||||
| return *this; | |||||
| } | |||||
| VariableOperator &VariableOperator::VarShape(const std::vector<int64_t> &shape_value) { | |||||
| Attr(VAR_ATTR_SHAPE, shape_value); | |||||
| return *this; | |||||
| } | |||||
| int64_t VariableOperator::GetVarSrcType() const { return GetIntAttr(VAR_ATTR_DTYPE); } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,46 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| // AUTO GEN PLEASE DO NOT MODIFY IT | |||||
| #ifndef DOMI_OP_VARIABLE_H_ | |||||
| #define DOMI_OP_VARIABLE_H_ | |||||
| #include <vector> | |||||
| #include "parser/common/op_def/operator.h" | |||||
| #include "framework/omg/parser/parser_types.h" | |||||
| namespace ge { | |||||
| class VariableOperator : public ParserOperator { | |||||
| public: | |||||
| VariableOperator(); | |||||
| ~VariableOperator(); | |||||
| VariableOperator &Name(const std::string &name); | |||||
| VariableOperator &Container(const std::string &container); | |||||
| VariableOperator &SharedName(const std::string &sharedname); | |||||
| VariableOperator &Placement(const std::string &placement); | |||||
| VariableOperator &SrcType(const int64_t &dtype); | |||||
| VariableOperator &VarShape(const std::vector<int64_t> &shape_value); | |||||
| int64_t GetVarSrcType() const; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // DOMI_OP_VAR_H_ AUTO GEN PLEASE DO NOT MODIFY IT | |||||
| @@ -0,0 +1,159 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/op_map.h" | |||||
| #include <map> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "register/op_registry.h" | |||||
| using std::map; | |||||
| using std::string; | |||||
| using std::vector; | |||||
| using namespace ge::parser; | |||||
| namespace ge { | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::map<std::string, std::string> caffe_op_map = { | |||||
| {"Input", DATA}, | |||||
| {"DummyData", DATA}, | |||||
| {"Reshape", RESHAPE}, | |||||
| {"Dropout", DROPOUT}, | |||||
| {"NetOutput", NETOUTPUT}, | |||||
| }; | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::map<std::string, std::string> tensorflow_op_map = { | |||||
| {"BroadcastGradientArgs", BROADCASTGRADIENTARGS}, | |||||
| {"StopGradient", STOPGRADIENT}, | |||||
| {"ExpandDims", EXPANDDIMS}, | |||||
| {"DestroyTemporaryVariable", DESTROYTEMPORARYVARIABLE}, | |||||
| {"GuaranteeConst", GUARANTEECONST}, | |||||
| {"BroadcastArgs", BROADCASTARGS}, | |||||
| {"PreventGradient", PREVENTGRADIENT}, | |||||
| {"Empty", EMPTY}, | |||||
| {"Placeholder", DATA}, | |||||
| {"ControlTrigger", CONTROLTRIGGER}, | |||||
| {"_ParallelConcatStart", PARALLELCONCATSTART}, | |||||
| {"Const", CONSTANT}, | |||||
| {"FrameworkOp", FRAMEWORKOP}, | |||||
| {"Reshape", RESHAPE}, | |||||
| {"Squeeze", SQUEEZE}, | |||||
| {"Enter", ENTER}, | |||||
| {"RefEnter", REFENTER}, | |||||
| {"Exit", EXIT}, | |||||
| {"RefExit", REFEXIT}, | |||||
| {"LoopCond", LOOPCOND}, | |||||
| {"NextIteration", NEXTITERATION}, | |||||
| {"RefNextIteration", REFNEXTITERATION}, | |||||
| {"Identity", IDENTITY}, | |||||
| {"IdentityN", IDENTITYN}, | |||||
| {"PlaceholderWithDefault", PLACEHOLDERWITHDEFAULT}, | |||||
| {"Size", SIZE}, | |||||
| {"Shape", SHAPE}, | |||||
| {"ShapeN", SHAPEN}, | |||||
| {"Fill", FILL}, | |||||
| {"Rank", RANK}, | |||||
| {"Merge", MERGE}, | |||||
| {"RefMerge", REFMERGE}, | |||||
| {"Switch", SWITCH}, | |||||
| {"RefSwitch", REFSWITCH}, | |||||
| {"LayerNorm", LAYERNORM}, | |||||
| {"RNN", RNN}, | |||||
| {"_Arg", ARG}, | |||||
| {"_Retval", FRAMEWORKOP}, | |||||
| {"Bitcast", BITCAST}, | |||||
| {"Snapshot", SNAPSHOT}, | |||||
| }; | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY map<string, string> tensorflow_train_op_map = { | |||||
| {"BroadcastGradientArgs", BROADCASTGRADIENTARGS}, | |||||
| {"StopGradient", STOPGRADIENT}, | |||||
| {"ExpandDims", EXPANDDIMS}, | |||||
| {"DestroyTemporaryVariable", DESTROYTEMPORARYVARIABLE}, | |||||
| {"TemporaryVariable", TEMPORARYVARIABLE}, | |||||
| {"GuaranteeConst", GUARANTEECONST}, | |||||
| {"BroadcastArgs", BROADCASTARGS}, | |||||
| {"PreventGradient", PREVENTGRADIENT}, | |||||
| {"Empty", EMPTY}, | |||||
| {"ControlTrigger", CONTROLTRIGGER}, | |||||
| {"_Arg", ARG}, | |||||
| {"_ParallelConcatStart", PARALLELCONCATSTART}, | |||||
| {"Const", CONSTANTOP}, | |||||
| {"VariableV2", VARIABLE}, | |||||
| {"VarHandleOp", VARHANDLEOP}, | |||||
| {"VarIsInitializedOp", VARISINITIALIZEDOP}, | |||||
| {"IsVariableInitialized", ISVARIABLEINITIALIZED}, | |||||
| {"ReadVariableOp", READVARIABLEOP}, | |||||
| {"Reshape", RESHAPE}, | |||||
| {"Squeeze", SQUEEZE}, | |||||
| {"NoOp", NOOP}, | |||||
| {"Enter", ENTER}, | |||||
| {"RefEnter", REFENTER}, | |||||
| {"Exit", EXIT}, | |||||
| {"RefExit", REFEXIT}, | |||||
| {"LoopCond", LOOPCOND}, | |||||
| {"NextIteration", NEXTITERATION}, | |||||
| {"RefNextIteration", REFNEXTITERATION}, | |||||
| {"Identity", IDENTITY}, | |||||
| {"IdentityN", IDENTITYN}, | |||||
| {"PlaceholderWithDefault", PLACEHOLDERWITHDEFAULT}, | |||||
| {"Size", SIZE}, | |||||
| {"Shape", SHAPE}, | |||||
| {"ShapeN", SHAPEN}, | |||||
| {"Rank", RANK}, | |||||
| {"Merge", MERGE}, | |||||
| {"Switch", SWITCH}, | |||||
| {"LayerNorm", LAYERNORM}, | |||||
| {"LayerNormGrad", LAYERNORMGRAD}, | |||||
| {"Dropout", DROPOUT}, | |||||
| {"Bitcast", BITCAST}, | |||||
| {"Snapshot", SNAPSHOT}, | |||||
| }; | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY map<string, int32_t> op_output_tensor_num = { | |||||
| {SSDDETECTIONOUTPUT, 3}, | |||||
| {REFINEDETDETECTIONOUTPUT, 3}, | |||||
| {FSRDETECTIONOUTPUT, 2}, | |||||
| {FASTERRCNNFIRSTSTAGEPOSTPROCESSOR, 4}, | |||||
| {FASTERRCNNSECONDSTAGEPOSTPROCESSOR, 4}, | |||||
| {YOLODETECTIONOUTPUT, 2}, | |||||
| {FASTRCNNPREDICTIONS, 4}, | |||||
| {RPNPROPOSALS, 3}, | |||||
| {MAXPOOLWITHARGMAX, 2}, | |||||
| {REGION, 3}, | |||||
| {TOPKV2, 2}, | |||||
| {LogTimeStamp, 0}, | |||||
| /* training op */ | |||||
| {MAXPOOLWITHARGMAX, 2}, | |||||
| {FUSEDBATCHNORM, 5}, | |||||
| {FUSEDBATCHNORMGRAD, 3}, | |||||
| {SHAPEN, 0}, | |||||
| {SSDPOSTPROCESSOR, 4}, | |||||
| {LAYERNORM, 3}, | |||||
| {LAYERNORMGRAD, 3}, | |||||
| {SPARSESOFTMAXCROSSENTROPYWITHLOGITS, 2}, | |||||
| }; | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY vector<string> local_framework_op_vec = { | |||||
| "TensorDataset", "QueueDataset", "DeviceQueueDataset", "ParallelMapDataset", "BatchDatasetV2", | |||||
| "IteratorV2", "MakeIterator", "IteratorGetNext", "FilterDataset", "MapAndBatchDatasetV2"}; | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY vector<string> is_dataset_op_vec = { | |||||
| "TensorDataset", "QueueDataset", "DeviceQueueDataset", "ParallelMapDataset", "BatchDatasetV2", | |||||
| "IteratorV2", "MakeIterator", "IteratorGetNext", "FilterDataset", "MapAndBatchDatasetV2"}; | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,45 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_COMMON_OP_MAP_H_ | |||||
| #define GE_COMMON_OP_MAP_H_ | |||||
| #include <map> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| /*lint -e1073*/ | |||||
| namespace ge { | |||||
| // the operator type mapping table of caffe and mindspore | |||||
| extern std::map<std::string, std::string> caffe_op_map; | |||||
| // the operator type mapping table of TensorFlow and mindspore | |||||
| extern std::map<std::string, std::string> tensorflow_op_map; | |||||
| // the network training operator type mapping table of TensorFlow and mindspore | |||||
| extern std::map<std::string, std::string> tensorflow_train_op_map; | |||||
| // local framework op vec | |||||
| extern std::vector<std::string> local_framework_op_vec; | |||||
| // dataset op vec | |||||
| extern std::vector<std::string> is_dataset_op_vec; | |||||
| // output tensor num | |||||
| extern std::map<std::string, int32_t> op_output_tensor_num; | |||||
| } // namespace ge | |||||
| /*lint +e1073*/ | |||||
| #endif // GE_COMMON_OP_MAP_H_ | |||||
| @@ -0,0 +1,117 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "parser/common/op_parser_factory.h" | |||||
| #include "common/debug/log.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "graph/utils/type_utils.h" | |||||
| namespace ge { | |||||
| FMK_FUNC_HOST_VISIBILITY CustomParserAdapterRegistry *CustomParserAdapterRegistry::Instance() { | |||||
| static CustomParserAdapterRegistry instance; | |||||
| return &instance; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY void CustomParserAdapterRegistry::Register(const domi::FrameworkType framework, | |||||
| CustomParserAdapterRegistry::CREATOR_FUN fun) { | |||||
| if (funcs_.find(framework) != funcs_.end()) { | |||||
| GELOGW("Framework type %s has already registed.", TypeUtils::FmkTypeToSerialString(framework).c_str()); | |||||
| return; | |||||
| } | |||||
| funcs_[framework] = fun; | |||||
| GELOGI("Register %s custom parser adapter success.", TypeUtils::FmkTypeToSerialString(framework).c_str()); | |||||
| return; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY CustomParserAdapterRegistry::CREATOR_FUN | |||||
| CustomParserAdapterRegistry::GetCreateFunc(const domi::FrameworkType framework) { | |||||
| if (funcs_.find(framework) == funcs_.end()) { | |||||
| GELOGW("Framework type %s has not registed.", TypeUtils::FmkTypeToSerialString(framework).c_str()); | |||||
| return nullptr; | |||||
| } | |||||
| return funcs_[framework]; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY std::shared_ptr<OpParserFactory> OpParserFactory::Instance( | |||||
| const domi::FrameworkType framework) { | |||||
| // Each framework corresponds to one op parser factory, | |||||
| // If instances are static data members of opparserfactory, the order of their construction is uncertain. | |||||
| // Instances cannot be a member of a class because they may be used before initialization, resulting in a run error. | |||||
| static std::map<domi::FrameworkType, std::shared_ptr<OpParserFactory>> instances; | |||||
| auto iter = instances.find(framework); | |||||
| if (iter == instances.end()) { | |||||
| std::shared_ptr<OpParserFactory> instance(new (std::nothrow) OpParserFactory()); | |||||
| if (instance == nullptr) { | |||||
| GELOGE(INTERNAL_ERROR, "Create op parser factory failed."); | |||||
| return nullptr; | |||||
| } | |||||
| instances[framework] = instance; | |||||
| return instance; | |||||
| } | |||||
| return iter->second; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY std::shared_ptr<OpParser> OpParserFactory::CreateOpParser(const std::string &op_type) { | |||||
| // First look for CREATOR_FUN based on OpType, then call CREATOR_FUN to create OpParser. | |||||
| auto iter = op_parser_creator_map_.find(op_type); | |||||
| if (iter != op_parser_creator_map_.end()) { | |||||
| return iter->second(); | |||||
| } | |||||
| GELOGE(FAILED, "OpParserFactory::CreateOpParser: Not supported type: %s", op_type.c_str()); | |||||
| return nullptr; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY std::shared_ptr<OpParser> OpParserFactory::CreateFusionOpParser(const std::string &op_type) { | |||||
| // First look for CREATOR_FUN based on OpType, then call CREATOR_FUN to create OpParser. | |||||
| auto iter = fusion_op_parser_creator_map_.find(op_type); | |||||
| if (iter != fusion_op_parser_creator_map_.end()) { | |||||
| return iter->second(); | |||||
| } | |||||
| GELOGE(FAILED, "OpParserFactory::CreateOpParser: Not supported fusion op type: %s", op_type.c_str()); | |||||
| return nullptr; | |||||
| } | |||||
| // This function is only called within the constructor of the global opparserregisterar object, | |||||
| // and does not involve concurrency, so there is no need to lock it | |||||
| FMK_FUNC_HOST_VISIBILITY void OpParserFactory::RegisterCreator(const std::string &type, CREATOR_FUN fun, | |||||
| bool is_fusion_op) { | |||||
| std::map<std::string, CREATOR_FUN> *op_parser_creator_map = &op_parser_creator_map_; | |||||
| if (is_fusion_op) { | |||||
| op_parser_creator_map = &fusion_op_parser_creator_map_; | |||||
| } | |||||
| GELOGD("OpParserFactory::RegisterCreator: op type:%s, is_fusion_op:%d.", type.c_str(), is_fusion_op); | |||||
| (*op_parser_creator_map)[type] = fun; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY bool OpParserFactory::OpParserIsRegistered(const std::string &op_type, bool is_fusion_op) { | |||||
| if (is_fusion_op) { | |||||
| auto iter = fusion_op_parser_creator_map_.find(op_type); | |||||
| if (iter != fusion_op_parser_creator_map_.end()) { | |||||
| return true; | |||||
| } | |||||
| } else { | |||||
| auto iter = op_parser_creator_map_.find(op_type); | |||||
| if (iter != op_parser_creator_map_.end()) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,198 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PARSER_COMMON_OP_PARSER_FACTORY_H_ | |||||
| #define PARSER_COMMON_OP_PARSER_FACTORY_H_ | |||||
| #include <functional> | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <mutex> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "common/ge/ge_util.h" | |||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "omg/omg_inner_types.h" | |||||
| #include "external/register/register.h" | |||||
| using domi::CAFFE; | |||||
| namespace ge { | |||||
| class OpParser; | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Used to create OpParser | |||||
| * | |||||
| */ | |||||
| class OpParserFactory { | |||||
| public: | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Returns the OpParserFactory instance corresponding to the Framework | |||||
| * @return OpParserFactory object | |||||
| */ | |||||
| static std::shared_ptr<OpParserFactory> Instance(const domi::FrameworkType framework); | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Create OpParser based on input type | |||||
| * @param [in] op_type Op type | |||||
| * @return Created OpParser | |||||
| */ | |||||
| std::shared_ptr<OpParser> CreateOpParser(const std::string &op_type); | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Create fusion OpParser based on input type | |||||
| * @param [in] op_type Op type | |||||
| * @return Created OpParser | |||||
| */ | |||||
| std::shared_ptr<OpParser> CreateFusionOpParser(const std::string &op_type); | |||||
| // The Factory instance is automatically released by shared_ptr. | |||||
| // The shared_ptr internally calls the destructor indirectly. | |||||
| // If the destructor is not public, it will generate a compilation error. | |||||
| // Another solution is to specify the deleter for shared_ptr, and set the deleter as a friend of the current class. | |||||
| // But this method is more complicated to implement. | |||||
| ~OpParserFactory() {} | |||||
| bool OpParserIsRegistered(const std::string &op_type, bool is_fusion_op = false); | |||||
| protected: | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief OpParser creation function | |||||
| * @return Created OpParser | |||||
| */ | |||||
| // typedef shared_ptr<OpParser> (*CREATOR_FUN)(void); | |||||
| using CREATOR_FUN = std::function<std::shared_ptr<OpParser>(void)>; | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Factory instances can only be created automatically, not new methods, so the constructor is not public. | |||||
| */ | |||||
| OpParserFactory() {} | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Register creation function | |||||
| * @param [in] type Op type | |||||
| * @param [in] fun OpParser creation function | |||||
| */ | |||||
| void RegisterCreator(const std::string &type, CREATOR_FUN fun, bool is_fusion_op = false); | |||||
| private: | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Each Op corresponds to a Creator function | |||||
| */ | |||||
| std::map<std::string, CREATOR_FUN> op_parser_creator_map_; // lint !e1073 | |||||
| std::map<std::string, CREATOR_FUN> fusion_op_parser_creator_map_; | |||||
| friend class OpParserRegisterar; | |||||
| friend class domi::OpRegistrationData; | |||||
| friend class OpRegistrationTbe; | |||||
| }; | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief For registering Creator functions for different types of Op | |||||
| * | |||||
| */ | |||||
| class OpParserRegisterar { | |||||
| public: | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Constructor | |||||
| * @param [in] framework Framework type | |||||
| * @param [in] op_type Op type | |||||
| * @param [in] fun Creator function corresponding to Op | |||||
| */ | |||||
| OpParserRegisterar(const domi::FrameworkType framework, const std::string &op_type, OpParserFactory::CREATOR_FUN fun, | |||||
| bool is_fusion_op = false) { | |||||
| OpParserFactory::Instance(framework)->RegisterCreator(op_type, fun, is_fusion_op); | |||||
| } | |||||
| ~OpParserRegisterar() {} | |||||
| }; | |||||
| // Used to save the functions created by the xxxCustomParserAdapter class | |||||
| class CustomParserAdapterRegistry { | |||||
| public: | |||||
| static CustomParserAdapterRegistry *Instance(); | |||||
| using CREATOR_FUN = std::function<std::shared_ptr<OpParser>(void)>; | |||||
| void Register(const domi::FrameworkType framework, CREATOR_FUN fun); | |||||
| CREATOR_FUN GetCreateFunc(const domi::FrameworkType framework); | |||||
| private: | |||||
| map<domi::FrameworkType, CREATOR_FUN> funcs_; | |||||
| friend class CustomParserAdapterRegistrar; | |||||
| }; | |||||
| // Register Creator function for the custom custom operator ParserAdapter | |||||
| class CustomParserAdapterRegistrar { | |||||
| public: | |||||
| CustomParserAdapterRegistrar(const domi::FrameworkType framework, CustomParserAdapterRegistry::CREATOR_FUN fun) { | |||||
| CustomParserAdapterRegistry::Instance()->Register(framework, fun); | |||||
| } | |||||
| ~CustomParserAdapterRegistrar() {} | |||||
| }; | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief OpParser Registration Macro | |||||
| * @param [in] framework Framework type | |||||
| * @param [in] op_type Op type | |||||
| * @param [in] clazz OpParser implementation class | |||||
| */ | |||||
| #define REGISTER_OP_PARSER_CREATOR(framework, op_type, clazz) \ | |||||
| std::shared_ptr<OpParser> Creator_##framework##_##op_type##_Op_Parser() { \ | |||||
| std::shared_ptr<clazz> ptr = ge::MakeShared<clazz>(); \ | |||||
| if (ptr == nullptr) { \ | |||||
| GELOGW("MakeShared failed, result is nullptr."); \ | |||||
| } \ | |||||
| return std::shared_ptr<OpParser>(ptr); \ | |||||
| } \ | |||||
| ge::OpParserRegisterar g_##framework##_##op_type##_Op_Parser_Creator(framework, op_type, \ | |||||
| Creator_##framework##_##op_type##_Op_Parser) | |||||
| #define REGISTER_FUSION_OP_PARSER_CREATOR(framework, op_type, clazz) \ | |||||
| std::shared_ptr<OpParser> Creator_##framework##_##op_type##_Fusion_Op_Parser() { \ | |||||
| std::shared_ptr<clazz> ptr = ge::MakeShared<clazz>(); \ | |||||
| if (ptr == nullptr) { \ | |||||
| GELOGW("MakeShared failed, result is nullptr."); \ | |||||
| } \ | |||||
| return std::shared_ptr<OpParser>(ptr); \ | |||||
| } \ | |||||
| OpParserRegisterar g_##framework##_##op_type##_Fusion_Op_Parser_Creator( \ | |||||
| framework, op_type, Creator_##framework##_##op_type##_Fusion_Op_Parser, true) | |||||
| /// @brief xxxCustomParserAdapter Registration Macro | |||||
| /// @param [in] framework Framework type | |||||
| /// @param [in] clazz CaffeCustomParserAdapter adaptation class | |||||
| #define REGISTER_CUSTOM_PARSER_ADAPTER_CREATOR(framework, clazz) \ | |||||
| std::shared_ptr<OpParser> Creator_##framework##_Op_Parser_Adapter() { \ | |||||
| std::shared_ptr<clazz> ptr = ge::MakeShared<clazz>(); \ | |||||
| if (ptr == nullptr) { \ | |||||
| GELOGW("MakeShared failed, result is nullptr."); \ | |||||
| } \ | |||||
| return std::shared_ptr<OpParser>(ptr); \ | |||||
| } \ | |||||
| CustomParserAdapterRegistrar g_##framework##_Op_Parser_Creator(framework, Creator_##framework##_Op_Parser_Adapter) | |||||
| } // namespace ge | |||||
| #endif // PARSER_COMMON_OP_PARSER_FACTORY_H_ | |||||
| @@ -0,0 +1,76 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "framework/omg/parser/parser_api.h" | |||||
| #include "common/debug/log.h" | |||||
| #include "tbe_plugin_loader.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "parser/common/register_tbe.h" | |||||
| #include "framework/omg/parser/parser_inner_ctx.h" | |||||
| #include "external/ge/ge_api_types.h" | |||||
| namespace ge { | |||||
| static bool parser_initialized = false; | |||||
| // Initialize PARSER, load custom op plugin | |||||
| // options will be used later for parser decoupling | |||||
| Status ParserInitialize(const std::map<std::string, std::string> &options) { | |||||
| GELOGT(TRACE_INIT, "ParserInitialize start"); | |||||
| // check init status | |||||
| if (parser_initialized) { | |||||
| GELOGW("ParserInitialize is called more than once"); | |||||
| return SUCCESS; | |||||
| } | |||||
| // load custom op plugin | |||||
| TBEPluginLoader::Instance().LoadPluginSo(options); | |||||
| std::vector<OpRegistrationData> registrationDatas = domi::OpRegistry::Instance()->registrationDatas; | |||||
| GELOGI("The size of registrationDatas in parser is: %zu", registrationDatas.size()); | |||||
| for (OpRegistrationData ®_data : registrationDatas) { | |||||
| (void)OpRegistrationTbe::Instance()->Finalize(reg_data, true); | |||||
| } | |||||
| auto iter = options.find(ge::OPTION_EXEC_ENABLE_SCOPE_FUSION_PASSES); | |||||
| if (iter != options.end()) { | |||||
| ge::GetParserContext().enable_scope_fusion_passes = iter->second; | |||||
| } | |||||
| // set init status | |||||
| if (!parser_initialized) { | |||||
| // Initialize success, first time calling initialize | |||||
| parser_initialized = true; | |||||
| } | |||||
| GELOGT(TRACE_STOP, "ParserInitialize finished"); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ParserFinalize() { | |||||
| GELOGT(TRACE_INIT, "ParserFinalize start"); | |||||
| // check init status | |||||
| if (!parser_initialized) { | |||||
| GELOGW("ParserFinalize is called before ParserInitialize"); | |||||
| return SUCCESS; | |||||
| } | |||||
| GE_CHK_STATUS(TBEPluginLoader::Instance().Finalize()); | |||||
| if (parser_initialized) { | |||||
| parser_initialized = false; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,81 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "omg/parser/parser_factory.h" | |||||
| #include "common/debug/log.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| namespace domi { | |||||
| FMK_FUNC_HOST_VISIBILITY WeightsParserFactory *WeightsParserFactory::Instance() { | |||||
| static WeightsParserFactory instance; | |||||
| return &instance; | |||||
| } | |||||
| std::shared_ptr<WeightsParser> WeightsParserFactory::CreateWeightsParser(const domi::FrameworkType type) { | |||||
| std::map<domi::FrameworkType, WEIGHTS_PARSER_CREATOR_FUN>::iterator iter = creator_map_.find(type); | |||||
| if (iter != creator_map_.end()) { | |||||
| return iter->second(); | |||||
| } | |||||
| GELOGE(FAILED, "WeightsParserFactory::CreateWeightsParser: Not supported Type: %d", type); | |||||
| return nullptr; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY void WeightsParserFactory::RegisterCreator(const domi::FrameworkType type, | |||||
| WEIGHTS_PARSER_CREATOR_FUN fun) { | |||||
| std::map<domi::FrameworkType, WEIGHTS_PARSER_CREATOR_FUN>::iterator iter = creator_map_.find(type); | |||||
| if (iter != creator_map_.end()) { | |||||
| GELOGW("WeightsParserFactory::RegisterCreator: %d creator already exist", type); | |||||
| return; | |||||
| } | |||||
| creator_map_[type] = fun; | |||||
| } | |||||
| WeightsParserFactory::~WeightsParserFactory() { | |||||
| creator_map_.clear(); | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY ModelParserFactory *ModelParserFactory::Instance() { | |||||
| static ModelParserFactory instance; | |||||
| return &instance; | |||||
| } | |||||
| std::shared_ptr<ModelParser> ModelParserFactory::CreateModelParser(const domi::FrameworkType type) { | |||||
| std::map<domi::FrameworkType, MODEL_PARSER_CREATOR_FUN>::iterator iter = creator_map_.find(type); | |||||
| if (iter != creator_map_.end()) { | |||||
| return iter->second(); | |||||
| } | |||||
| GELOGE(FAILED, "ModelParserFactory::CreateModelParser: Not supported Type: %d", type); | |||||
| return nullptr; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY void ModelParserFactory::RegisterCreator(const domi::FrameworkType type, | |||||
| MODEL_PARSER_CREATOR_FUN fun) { | |||||
| std::map<domi::FrameworkType, MODEL_PARSER_CREATOR_FUN>::iterator iter = creator_map_.find(type); | |||||
| if (iter != creator_map_.end()) { | |||||
| GELOGW("ModelParserFactory::RegisterCreator: %d creator already exist", type); | |||||
| return; | |||||
| } | |||||
| creator_map_[type] = fun; | |||||
| } | |||||
| ModelParserFactory::~ModelParserFactory() { | |||||
| creator_map_.clear(); | |||||
| } | |||||
| } // namespace domi | |||||
| @@ -0,0 +1,653 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PARSER_COMMON_FP16_T_H_ | |||||
| #define PARSER_COMMON_FP16_T_H_ | |||||
| #include <algorithm> | |||||
| #include <cmath> | |||||
| #include <cstdint> | |||||
| namespace ge { | |||||
| namespace parser { | |||||
| using DimIndex = enum { | |||||
| kDim0 = 0, | |||||
| kDim1, | |||||
| kDim2, | |||||
| kDim3, | |||||
| kDim4, | |||||
| kDim5, | |||||
| kDim6, | |||||
| kDim7, | |||||
| kDim8, | |||||
| kDim9, | |||||
| kDim10, | |||||
| kDim11, | |||||
| kDim12, | |||||
| kDim13, | |||||
| kDim14, | |||||
| kDim15, | |||||
| kDim16, | |||||
| }; | |||||
| using BitShift = enum { | |||||
| kBitShift2 = 2, | |||||
| kBitShift3 = 3, | |||||
| kBitShift4 = 4, | |||||
| kBitShift5 = 5, | |||||
| kBitShift6 = 6, | |||||
| kBitShift7 = 7, | |||||
| kBitShift8 = 8, | |||||
| kBitShift9 = 9, | |||||
| kBitShift10 = 10, | |||||
| kBitShift11 = 11, | |||||
| kBitShift12 = 12, | |||||
| kBitShift13 = 13, | |||||
| kBitShift14 = 14, | |||||
| kBitShift15 = 15, | |||||
| kBitShift16 = 16, | |||||
| kBitShift20 = 20, | |||||
| kBitShift24 = 24, | |||||
| kBitShift27 = 27, | |||||
| kBitShift28 = 28, | |||||
| kBitShift31 = 31, | |||||
| kBitShift32 = 32, | |||||
| kBitShift36 = 36, | |||||
| kBitShift40 = 40, | |||||
| kBitShift44 = 44, | |||||
| kBitShift48 = 48, | |||||
| kBitShift52 = 52, | |||||
| kBitShift56 = 56, | |||||
| kBitShift59 = 59, | |||||
| kBitShift60 = 60, | |||||
| kBitShift63 = 63, | |||||
| kBitShift64 = 64, | |||||
| kBitShift128 = 128, | |||||
| kBitShift255 = 255, | |||||
| kBitShift256 = 256, | |||||
| kBitShift512 = 512, | |||||
| kBitShift768 = 768, | |||||
| kBitShift784 = 784, | |||||
| kBitShift1020 = 1020, | |||||
| kBitShift1024 = 1024, | |||||
| kBitShift3136 = 3136, | |||||
| kBitShift4096 = 4096, | |||||
| kBitShift6144 = 6144, | |||||
| kBitShift10240 = 10240, | |||||
| kBitShift65536 = 65536 | |||||
| }; | |||||
| /// @ingroup fp16 basic parameter | |||||
| /// @brief fp16 exponent bias | |||||
| constexpr uint16_t kFp16ExpBias = 15; | |||||
| /// @ingroup fp16 basic parameter | |||||
| /// @brief the exponent bit length of fp16 is 5 | |||||
| constexpr uint16_t kFp16ExpLen = 5; | |||||
| /// @ingroup fp16 basic parameter | |||||
| /// @brief the mantissa bit length of fp16 is 10 | |||||
| constexpr uint16_t kFp16ManLen = 10; | |||||
| /// @ingroup fp16 basic parameter | |||||
| /// @brief bit index of sign in fp16 | |||||
| constexpr uint16_t kFp16SignIndex = 15; | |||||
| /// @ingroup fp16 basic parameter | |||||
| /// @brief sign mask of fp16 (1 00000 00000 00000) | |||||
| constexpr uint16_t kFp16SignMask = 0x8000; | |||||
| /// @ingroup fp16 basic parameter | |||||
| /// @brief exponent mask of fp16 ( 11111 00000 00000) | |||||
| constexpr uint16_t kFp16ExpMask = 0x7C00; | |||||
| /// @ingroup fp16 basic parameter | |||||
| /// @brief mantissa mask of fp16 ( 11111 11111) | |||||
| constexpr uint16_t kFp16ManMask = 0x03FF; | |||||
| /// @ingroup fp16 basic parameter | |||||
| /// @brief hide bit of mantissa of fp16( 1 00000 00000) | |||||
| constexpr uint16_t kFp16ManHideBit = 0x0400; | |||||
| /// @ingroup fp16 basic parameter | |||||
| /// @brief maximum value (0111 1011 1111 1111) | |||||
| constexpr uint16_t kFp16Max = 0x7BFF; | |||||
| /// @ingroup fp16 basic parameter | |||||
| /// @brief minimum value (1111 1011 1111 1111) | |||||
| constexpr uint16_t kFp16Min = 0xFBFF; | |||||
| /// @ingroup fp16 basic parameter | |||||
| /// @brief absolute maximum value (0111 1111 1111 1111) | |||||
| constexpr uint16_t kFp16AbsMax = 0x7FFF; | |||||
| /// @ingroup fp16 basic parameter | |||||
| /// @brief maximum exponent value of fp16 is 15(11111) | |||||
| constexpr uint16_t kFp16MaxExp = 0x001F; | |||||
| /// @ingroup fp16 basic parameter | |||||
| /// @brief maximum valid exponent value of fp16 is 14(11110) | |||||
| constexpr uint16_t kFp16MaxValidExp = 0x001E; | |||||
| /// @ingroup fp16 basic parameter | |||||
| /// @brief maximum mantissa value of fp16(11111 11111) | |||||
| constexpr uint16_t kFp16MaxMan = 0x03FF; | |||||
| /// @ingroup fp16 basic parameter | |||||
| /// @brief absolute minimum normal value of fp16 | |||||
| /// (E=1,M=0 D=2^(-14)=0.00006103515625) | |||||
| constexpr uint16_t kFp16MinNormal = 1.0f / (2 << 14); | |||||
| /// @ingroup fp16 basic operator | |||||
| /// @brief get sign of fp16 | |||||
| #define FP16_EXTRAC_SIGN(x) (((x) >> 15) & 1) | |||||
| /// @ingroup fp16 basic operator | |||||
| /// @brief get exponent of fp16 | |||||
| #define FP16_EXTRAC_EXP(x) (((x) >> 10) & kFp16MaxExp) | |||||
| /// @ingroup fp16 basic operator | |||||
| /// @brief get mantissa of fp16 | |||||
| #define FP16_EXTRAC_MAN(x) ((((x) >> 0) & 0x3FF) | (((((x) >> 10) & 0x1F) > 0 ? 1 : 0) * 0x400)) | |||||
| /// @ingroup fp16 basic operator | |||||
| /// @brief constructor of fp16 from sign exponent and mantissa | |||||
| #define FP16_CONSTRUCTOR(s, e, m) (((s) << kFp16SignIndex) | ((e) << kFp16ManLen) | ((m)&kFp16MaxMan)) | |||||
| /// @ingroup fp16 special value judgment | |||||
| /// @brief whether a fp16 is zero | |||||
| #define FP16_IS_ZERO(x) (((x)&kFp16AbsMax) == 0) | |||||
| /// @ingroup fp16 special value judgment | |||||
| /// @brief whether a fp16 is a denormalized value | |||||
| #define FP16_IS_DENORM(x) ((((x)&kFp16ExpMask) == 0)) | |||||
| /// @ingroup fp16 special value judgment | |||||
| /// @brief whether a fp16 is infinite | |||||
| #define FP16_IS_INF(x) (((x)&kFp16AbsMax) == kFp16ExpMask) | |||||
| /// @ingroup fp16 special value judgment | |||||
| /// @brief whether a fp16 is NaN | |||||
| #define FP16_IS_NAN(x) (((x & kFp16ExpMask) == kFp16ExpMask) && (x & kFp16ManMask)) | |||||
| /// @ingroup fp16 special value judgment | |||||
| /// @brief whether a fp16 is invalid | |||||
| #define FP16_IS_INVALID(x) ((x & kFp16ExpMask) == kFp16ExpMask) | |||||
| /// @ingroup fp32 basic parameter | |||||
| /// @brief fp32 exponent bias | |||||
| constexpr uint16_t kFp32ExpBias = 127; | |||||
| /// @ingroup fp32 basic parameter | |||||
| /// @brief the exponent bit length of float/fp32 is 8 | |||||
| constexpr uint16_t kFp32ExpLen = 8; | |||||
| /// @ingroup fp32 basic parameter | |||||
| /// @brief the mantissa bit length of float/fp32 is 23 | |||||
| constexpr uint16_t kFp32ManLen = 23; | |||||
| /// @ingroup fp32 basic parameter | |||||
| /// @brief bit index of sign in float/fp32 | |||||
| constexpr uint16_t kFp32SignIndex = 31; | |||||
| /// @ingroup fp32 basic parameter | |||||
| /// @brief sign mask of fp32 (1 0000 0000 0000 0000 0000 0000 000) | |||||
| constexpr uint32_t kFp32SignMask = 0x80000000u; | |||||
| /// @ingroup fp32 basic parameter | |||||
| /// @brief exponent mask of fp32 ( 1111 1111 0000 0000 0000 0000 000) | |||||
| constexpr uint32_t kFp32ExpMask = 0x7F800000u; | |||||
| /// @ingroup fp32 basic parameter | |||||
| /// @brief mantissa mask of fp32 ( 1111 1111 1111 1111 111) | |||||
| constexpr uint32_t kFp32ManMask = 0x007FFFFFu; | |||||
| /// @ingroup fp32 basic parameter | |||||
| /// @brief hide bit of mantissa of fp32 ( 1 0000 0000 0000 0000 000) | |||||
| constexpr uint32_t kFp32ManHideBit = 0x00800000u; | |||||
| /// @ingroup fp32 basic parameter | |||||
| /// @brief absolute maximum value (0 1111 1111 1111 1111 1111 1111 111) | |||||
| constexpr uint32_t kFp32AbsMax = 0x7FFFFFFFu; | |||||
| /// @ingroup fp32 basic parameter | |||||
| /// @brief maximum exponent value of fp32 is 255(1111 1111) | |||||
| constexpr uint32_t kFp32MaxExp = 0xFF; | |||||
| /// @ingroup fp32 basic parameter | |||||
| /// @brief maximum mantissa value of fp32 (1111 1111 1111 1111 1111 111) | |||||
| constexpr uint32_t kFp32MaxMan = 0x7FFFFF; | |||||
| /// @ingroup fp32 special value judgment | |||||
| /// @brief whether a fp32 is NaN | |||||
| #define FP32_IS_NAN(x) (((x & kFp32ExpMask) == kFp32ExpMask) && (x & kFp32ManMask)) | |||||
| /// @ingroup fp32 special value judgment | |||||
| /// @brief whether a fp32 is infinite | |||||
| #define FP32_IS_INF(x) (((x & kFp32ExpMask) == kFp32ExpMask) && (!(x & kFp32ManMask))) | |||||
| /// @ingroup fp32 special value judgment | |||||
| /// @brief whether a fp32 is a denormalized value | |||||
| #define FP32_IS_DENORM(x) ((((x)&kFp32ExpMask) == 0)) | |||||
| /// @ingroup fp32 basic operator | |||||
| /// @brief get sign of fp32 | |||||
| #define FP32_EXTRAC_SIGN(x) (((x) >> kFp32SignIndex) & 1) | |||||
| /// @ingroup fp32 basic operator | |||||
| /// @brief get exponent of fp16 | |||||
| #define FP32_EXTRAC_EXP(x) (((x)&kFp32ExpMask) >> kFp32ManLen) | |||||
| /// @ingroup fp32 basic operator | |||||
| /// @brief get mantissa of fp16 | |||||
| #define FP32_EXTRAC_MAN(x) (((x)&kFp32ManMask) | (((((x) >> kFp32ManLen) & kFp32MaxExp) > 0 ? 1 : 0) * kFp32ManHideBit)) | |||||
| /// @ingroup fp32 basic operator | |||||
| /// @brief constructor of fp32 from sign exponent and mantissa | |||||
| #define FP32_CONSTRUCTOR(s, e, m) (((s) << kFp32SignIndex) | ((e) << kFp32ManLen) | ((m)&kFp32MaxMan)) | |||||
| /// @ingroup fp64 basic parameter | |||||
| /// @brief fp64 exponent bias | |||||
| constexpr uint16_t kFp64ExpBias = 1023; | |||||
| /// @ingroup fp64 basic parameter | |||||
| /// @brief the exponent bit length of double/fp64 is 11 | |||||
| constexpr uint16_t kFp64ExpLen = 11; | |||||
| /// @ingroup fp64 basic parameter | |||||
| /// @brief the mantissa bit length of double/fp64 is 52 | |||||
| constexpr uint16_t kFp64ManLen = 52; | |||||
| /// @ingroup fp64 basic parameter | |||||
| /// @brief bit index of sign in double/fp64 is 63 | |||||
| constexpr uint16_t kFp64SignIndex = 63; | |||||
| /// @ingroup fp64 basic parameter | |||||
| /// @brief sign mask of fp64 (1 000 (total 63bits 0)) | |||||
| constexpr uint64_t kFp64SignMask = 0x8000000000000000LLu; | |||||
| /// @ingroup fp64 basic parameter | |||||
| /// @brief exponent mask of fp64 (0 1 11111 11111 0000?-?-(total 52bits 0)) | |||||
| constexpr uint64_t kFp64ExpMask = 0x7FF0000000000000LLu; | |||||
| /// @ingroup fp64 basic parameter | |||||
| /// @brief mantissa mask of fp64 ( 1111?-?-(total 52bits 1)) | |||||
| constexpr uint64_t kFp64ManMask = 0x000FFFFFFFFFFFFFLLu; | |||||
| /// @ingroup fp64 basic parameter | |||||
| /// @brief hide bit of mantissa of fp64 ( 1 0000?-?-(total 52bits 0)) | |||||
| constexpr uint64_t kFp64ManHideBit = 0x0010000000000000LLu; | |||||
| /// @ingroup fp64 basic parameter | |||||
| /// @brief absolute maximum value (0 111?-?-(total 63bits 1)) | |||||
| constexpr uint64_t kFp64AbsMax = 0x7FFFFFFFFFFFFFFFLLu; | |||||
| /// @ingroup fp64 basic parameter | |||||
| /// @brief maximum exponent value of fp64 is 2047(1 11111 11111) | |||||
| constexpr uint64_t kFp64MaxExp = 0x07FF; | |||||
| /// @ingroup fp64 basic parameter | |||||
| /// @brief maximum mantissa value of fp64 (111?-?-(total 52bits 1)) | |||||
| constexpr uint64_t kFp64MaxMan = 0xFFFFFFFFFFFLLu; | |||||
| /// @ingroup fp64 special value judgment | |||||
| /// @brief whether a fp64 is NaN | |||||
| #define FP64_IS_NAN(x) (((x & kFp64ExpMask) == kFp64ExpMask) && (x & kFp64ManMask)) | |||||
| /// @ingroup fp64 special value judgment | |||||
| /// @brief whether a fp64 is infinite | |||||
| #define FP64_IS_INF(x) (((x & kFp64ExpMask) == kFp64ExpMask) && (!(x & kFp64ManMask))) | |||||
| /// @ingroup integer special value judgment | |||||
| /// @brief maximum positive value of int8_t (0111 1111) | |||||
| constexpr int8_t kInt8Max = 0x7F; | |||||
| /// @ingroup integer special value judgment | |||||
| /// @brief maximum value of a data with 8 bits length (1111 111) | |||||
| constexpr uint8_t kBitLen8Max = 0xFF; | |||||
| /// @ingroup integer special value judgment | |||||
| /// @brief maximum positive value of int16_t (0111 1111 1111 1111) | |||||
| constexpr int16_t kInt16Max = 0x7FFF; | |||||
| /// @ingroup integer special value judgment | |||||
| /// @brief maximum value of a data with 16 bits length (1111 1111 1111 1111) | |||||
| constexpr uint16_t kBitLen16Max = 0xFFFF; | |||||
| /// @ingroup integer special value judgment | |||||
| /// @brief maximum positive value of int32_t (0111 1111 1111 1111 1111 1111 1111 1111) | |||||
| constexpr int32_t kInt32Max = 0x7FFFFFFFu; | |||||
| /// @ingroup integer special value judgment | |||||
| /// @brief maximum value of a data with 32 bits length (1111 1111 1111 1111 1111 1111 1111 1111) | |||||
| constexpr uint32_t kBitLen32Max = 0xFFFFFFFFu; | |||||
| /// @ingroup integer special value judgment | |||||
| /// @brief maximum positive value of int64_t | |||||
| /// (0111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111) | |||||
| constexpr int64_t kInt64Max = 0x7FFFFFFFFFFFFFFFu; | |||||
| /// @ingroup integer special value judgment | |||||
| /// @brief maximum value of a data with 64 bits length | |||||
| /// (1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111) | |||||
| constexpr uint64_t kBitLen64Max = 0xFFFFFFFFFFFFFFFFu; | |||||
| /// @ingroup fp16_t enum | |||||
| /// @brief round mode of last valid digital | |||||
| enum TagFp16RoundMode { | |||||
| kRoundToNearest = 0, // < round to nearest even | |||||
| kRoundByTruncated, // < round by truncated | |||||
| kRoundModeReserved, | |||||
| }; | |||||
| /// @ingroup fp16_t | |||||
| /// @brief Half precision float | |||||
| /// bit15: 1 bit SIGN +---+-----+------------+ | |||||
| /// bit14-10: 5 bit EXP | S |EEEEE|MM MMMM MMMM| | |||||
| /// bit0-9: 10bit MAN +---+-----+------------+ | |||||
| using fp16_t = struct TagFp16 { | |||||
| uint16_t val; | |||||
| public: | |||||
| /// @ingroup fp16_t constructor | |||||
| /// @brief Constructor without any param(default constructor) | |||||
| TagFp16(void) { val = 0x0u; } | |||||
| /// @ingroup fp16_t constructor | |||||
| /// @brief Constructor with an uint16_t value | |||||
| TagFp16(const uint16_t &ui_val) : val(ui_val) {} | |||||
| /// @ingroup fp16_t constructor | |||||
| /// @brief Constructor with a fp16_t object(copy constructor) | |||||
| TagFp16(const TagFp16 &fp) : val(fp.val) {} | |||||
| /// @ingroup fp16_t math operator | |||||
| /// @param [in] fp fp16_t object to be added | |||||
| /// @brief Override addition operator to performing fp16_t addition | |||||
| /// @return Return fp16_t result of adding this and fp | |||||
| TagFp16 operator+(const TagFp16 fp); | |||||
| /// @ingroup fp16_t math operator | |||||
| /// @param [in] fp fp16_t object to be subtracted | |||||
| /// @brief Override addition operator to performing fp16_t subtraction | |||||
| /// @return Return fp16_t result of subtraction fp from this | |||||
| TagFp16 operator-(const TagFp16 fp); | |||||
| /// @ingroup fp16_t math operator | |||||
| /// @param [in] fp fp16_t object to be multiplied | |||||
| /// @brief Override multiplication operator to performing fp16_t multiplication | |||||
| /// @return Return fp16_t result of multiplying this and fp | |||||
| TagFp16 operator*(const TagFp16 fp); | |||||
| /// @ingroup fp16_t math operator divided | |||||
| /// @param [in] fp fp16_t object to be divided | |||||
| /// @brief Override division operator to performing fp16_t division | |||||
| /// @return Return fp16_t result of division this by fp | |||||
| TagFp16 operator/(const TagFp16 fp); | |||||
| /// @ingroup fp16_t math operator | |||||
| /// @param [in] fp fp16_t object to be added | |||||
| /// @brief Override addition operator to performing fp16_t addition | |||||
| /// @return Return fp16_t result of adding this and fp | |||||
| TagFp16 operator+=(const TagFp16 fp); | |||||
| /// @ingroup fp16_t math operator | |||||
| /// @param [in] fp fp16_t object to be subtracted | |||||
| /// @brief Override addition operator to performing fp16_t subtraction | |||||
| /// @return Return fp16_t result of subtraction fp from this | |||||
| TagFp16 operator-=(const TagFp16 fp); | |||||
| /// @ingroup fp16_t math operator | |||||
| /// @param [in] fp fp16_t object to be multiplied | |||||
| /// @brief Override multiplication operator to performing fp16_t multiplication | |||||
| /// @return Return fp16_t result of multiplying this and fp | |||||
| TagFp16 operator*=(const TagFp16 fp); | |||||
| /// @ingroup fp16_t math operator divided | |||||
| /// @param [in] fp fp16_t object to be divided | |||||
| /// @brief Override division operator to performing fp16_t division | |||||
| /// @return Return fp16_t result of division this by fp | |||||
| TagFp16 operator/=(const TagFp16 fp); | |||||
| /// @ingroup fp16_t math compare operator | |||||
| /// @param [in] fp fp16_t object to be compared | |||||
| /// @brief Override basic comparison operator to performing fp16_t if-equal comparison | |||||
| /// @return Return boolean result of if-equal comparison of this and fp. | |||||
| bool operator==(const TagFp16 &fp) const; | |||||
| /// @ingroup fp16_t math compare operator | |||||
| /// @param [in] fp fp16_t object to be compared | |||||
| /// @brief Override basic comparison operator to performing fp16_t not-equal comparison | |||||
| /// @return Return boolean result of not-equal comparison of this and fp. | |||||
| bool operator!=(const TagFp16 &fp) const; | |||||
| /// @ingroup fp16_t math compare operator | |||||
| /// @param [in] fp fp16_t object to be compared | |||||
| /// @brief Override basic comparison operator to performing fp16_t greater-than comparison | |||||
| /// @return Return boolean result of greater-than comparison of this and fp. | |||||
| bool operator>(const TagFp16 &fp) const; | |||||
| /// @ingroup fp16_t math compare operator | |||||
| /// @param [in] fp fp16_t object to be compared | |||||
| /// @brief Override basic comparison operator to performing fp16_t greater-equal comparison | |||||
| /// @return Return boolean result of greater-equal comparison of this and fp. | |||||
| bool operator>=(const TagFp16 &fp) const; | |||||
| /// @ingroup fp16_t math compare operator | |||||
| /// @param [in] fp fp16_t object to be compared | |||||
| /// @brief Override basic comparison operator to performing fp16_t less-than comparison | |||||
| /// @return Return boolean result of less-than comparison of this and fp. | |||||
| bool operator<(const TagFp16 &fp) const; | |||||
| /// @ingroup fp16_t math compare operator | |||||
| /// @param [in] fp fp16_t object to be compared | |||||
| /// @brief Override basic comparison operator to performing fp16_t less-equal comparison | |||||
| /// @return Return boolean result of less-equal comparison of this and fp. | |||||
| bool operator<=(const TagFp16 &fp) const; | |||||
| /// @ingroup fp16_t math evaluation operator | |||||
| /// @param [in] fp fp16_t object to be copy to fp16_t | |||||
| /// @brief Override basic evaluation operator to copy fp16_t to a new fp16_t | |||||
| /// @return Return fp16_t result from fp | |||||
| TagFp16 &operator=(const TagFp16 &fp); | |||||
| /// @ingroup fp16_t math evaluation operator | |||||
| /// @param [in] f_val float object to be converted to fp16_t | |||||
| /// @brief Override basic evaluation operator to convert float to fp16_t | |||||
| /// @return Return fp16_t result from f_val | |||||
| TagFp16 &operator=(const float &f_val); | |||||
| /// @ingroup fp16_t math evaluation operator | |||||
| /// @param [in] d_val double object to be converted to fp16_t | |||||
| /// @brief Override basic evaluation operator to convert double to fp16_t | |||||
| /// @return Return fp16_t result from d_val | |||||
| TagFp16 &operator=(const double &d_val); | |||||
| /// @ingroup fp16_t math evaluation operator | |||||
| /// @param [in] i_val float object to be converted to fp16_t | |||||
| /// @brief Override basic evaluation operator to convert float to fp16_t | |||||
| /// @return Return fp16_t result from i_val | |||||
| TagFp16 &operator=(const int8_t &i_val); | |||||
| /// @ingroup fp16_t math evaluation operator | |||||
| /// @param [in] ui_val uint8_t object to be converted to fp16_t | |||||
| /// @brief Override basic evaluation operator to convert uint8_t to fp16_t | |||||
| /// @return Return fp16_t result from ui_val | |||||
| TagFp16 &operator=(const uint8_t &ui_val); | |||||
| /// @ingroup fp16_t math evaluation operator | |||||
| /// @param [in] i_val int16_t object to be converted to fp16_t | |||||
| /// @brief Override basic evaluation operator to convert int16_t to fp16_t | |||||
| /// @return Return fp16_t result from i_val | |||||
| TagFp16 &operator=(const int16_t &i_val); | |||||
| /// @ingroup fp16_t math evaluation operator | |||||
| /// @param [in] ui_val uint16_t object to be converted to fp16_t | |||||
| /// @brief Override basic evaluation operator to convert uint16_t to fp16_t | |||||
| /// @return Return fp16_t result from ui_val | |||||
| TagFp16 &operator=(const uint16_t &ui_val); | |||||
| /// @ingroup fp16_t math evaluation operator | |||||
| /// @param [in] i_val int32_t object to be converted to fp16_t | |||||
| /// @brief Override basic evaluation operator to convert int32_t to fp16_t | |||||
| /// @return Return fp16_t result from i_val | |||||
| TagFp16 &operator=(const int32_t &i_val); | |||||
| /// @ingroup fp16_t math evaluation operator | |||||
| /// @param [in] ui_val uint32_t object to be converted to fp16_t | |||||
| /// @brief Override basic evaluation operator to convert uint32_t to fp16_t | |||||
| /// @return Return fp16_t result from ui_val | |||||
| TagFp16 &operator=(const uint32_t &ui_val); | |||||
| /// @ingroup fp16_t math conversion | |||||
| /// @brief Override convert operator to convert fp16_t to float/fp32 | |||||
| /// @return Return float/fp32 value of fp16_t | |||||
| operator float() const; | |||||
| /// @ingroup fp16_t math conversion | |||||
| /// @brief Override convert operator to convert fp16_t to double/fp64 | |||||
| /// @return Return double/fp64 value of fp16_t | |||||
| operator double() const; | |||||
| /// @ingroup fp16_t math conversion | |||||
| /// @brief Override convert operator to convert fp16_t to int8_t | |||||
| /// @return Return int8_t value of fp16_t | |||||
| operator int8_t() const; | |||||
| /// @ingroup fp16_t math conversion | |||||
| /// @brief Override convert operator to convert fp16_t to uint8_t | |||||
| /// @return Return uint8_t value of fp16_t | |||||
| operator uint8_t() const; | |||||
| /// @ingroup fp16_t conversion | |||||
| /// @brief Override convert operator to convert fp16_t to int16_t | |||||
| /// @return Return int16_t value of fp16_t | |||||
| operator int16_t() const; | |||||
| /// @ingroup fp16_t math conversion | |||||
| /// @brief Override convert operator to convert fp16_t to uint16_t | |||||
| /// @return Return uint16_t value of fp16_t | |||||
| operator uint16_t() const; | |||||
| /// @ingroup fp16_t math conversion | |||||
| /// @brief Override convert operator to convert fp16_t to int32_t | |||||
| /// @return Return int32_t value of fp16_t | |||||
| operator int32_t() const; | |||||
| /// @ingroup fp16_t math conversion | |||||
| /// @brief Override convert operator to convert fp16_t to uint32_t | |||||
| /// @return Return uint32_t value of fp16_t | |||||
| operator uint32_t() const; | |||||
| /// @ingroup fp16_t math conversion | |||||
| /// @brief Override convert operator to convert fp16_t to int64_t | |||||
| /// @return Return int64_t value of fp16_t | |||||
| operator int64_t() const; | |||||
| /// @ingroup fp16_t math conversion | |||||
| /// @brief Override convert operator to convert fp16_t to uint64_t | |||||
| /// @return Return uint64_t value of fp16_t | |||||
| operator uint64_t() const; | |||||
| /// @ingroup fp16_t judgment method | |||||
| /// @param [in] fp fp16_t object to be judgement | |||||
| /// @brief whether a fp16_t is inifinite | |||||
| /// @return Returns 1:+INF -1:-INF 0:not INF | |||||
| int IsInf(); | |||||
| /// @ingroup fp16_t math conversion | |||||
| /// @brief Convert fp16_t to float/fp32 | |||||
| /// @return Return float/fp32 value of fp16_t | |||||
| float ToFloat() const; | |||||
| /// @ingroup fp16_t math conversion | |||||
| /// @brief Convert fp16_t to double/fp64 | |||||
| /// @return Return double/fp64 value of fp16_t | |||||
| double ToDouble() const; | |||||
| /// @ingroup fp16_t math conversion | |||||
| /// @brief Convert fp16_t to int8_t | |||||
| /// @return Return int8_t value of fp16_t | |||||
| int8_t ToInt8() const; | |||||
| /// @ingroup fp16_t math conversion | |||||
| /// @brief Convert fp16_t to uint8_t | |||||
| /// @return Return uint8_t value of fp16_t | |||||
| uint8_t ToUInt8() const; | |||||
| /// @ingroup fp16_t conversion | |||||
| /// @brief Convert fp16_t to int16_t | |||||
| /// @return Return int16_t value of fp16_t | |||||
| int16_t ToInt16() const; | |||||
| /// @ingroup fp16_t math conversion | |||||
| /// @brief Convert fp16_t to uint16_t | |||||
| /// @return Return uint16_t value of fp16_t | |||||
| uint16_t ToUInt16() const; | |||||
| /// @ingroup fp16_t math conversion | |||||
| /// @brief Convert fp16_t to int32_t | |||||
| /// @return Return int32_t value of fp16_t | |||||
| int32_t ToInt32() const; | |||||
| /// @ingroup fp16_t math conversion | |||||
| /// @brief Convert fp16_t to uint32_t | |||||
| /// @return Return uint32_t value of fp16_t | |||||
| uint32_t ToUInt32() const; | |||||
| }; | |||||
| /// @ingroup fp16_t public method | |||||
| /// @param [in] val signature is negative | |||||
| /// @param [in|out] s sign of fp16_t object | |||||
| /// @param [in|out] e exponent of fp16_t object | |||||
| /// @param [in|out] m mantissa of fp16_t object | |||||
| /// @brief Extract the sign, exponent and mantissa of a fp16_t object | |||||
| void ExtractFp16(const uint16_t &val, uint16_t &s, int16_t &e, uint16_t &m); | |||||
| /// @ingroup fp16_t public method | |||||
| /// @param [in] negative sign is negative | |||||
| /// @param [in|out] man mantissa to be reverse | |||||
| /// @brief Calculate a mantissa's complement (add ont to it's radix-minus-one complement) | |||||
| /// @return Return complement of man | |||||
| template<typename T> | |||||
| void ReverseMan(bool negative, T &man) { | |||||
| if (negative) { | |||||
| man = (~(man)) + 1; | |||||
| } | |||||
| } | |||||
| /// @ingroup fp16_t public method | |||||
| /// @param [in] e_a exponent of one fp16_t/float number | |||||
| /// @param [in] m_a mantissa of one fp16_t/float number | |||||
| /// @param [in] e_b exponent of another fp16_t/float number | |||||
| /// @param [in] m_b mantissa of another fp16_t/float number | |||||
| /// @brief choose mantissa to be shift right whoes exponent is less than another one | |||||
| /// @return Return mantissawhoes exponent is less than another one | |||||
| template<typename T> | |||||
| T MinMan(const int16_t &e_a, T &m_a, const int16_t &e_b, T &m_b) { | |||||
| return (e_a > e_b) ? m_b : m_a; | |||||
| } | |||||
| /// @ingroup fp16_t public method | |||||
| /// @param [in] man mantissa to be operate | |||||
| /// @param [in] shift right shift bits | |||||
| /// @brief right shift a mantissa | |||||
| /// @return Return right-shift mantissa | |||||
| template<typename T> | |||||
| T RightShift(T man, int16_t shift) { | |||||
| int bits = sizeof(T) * 8; // one byte have 8 bits | |||||
| T mask = (((T) 1u) << ((unsigned int) (bits - 1))); | |||||
| for (int i = 0; i < shift; i++) { | |||||
| man = ((man & mask) | (man >> 1)); | |||||
| } | |||||
| return man; | |||||
| } | |||||
| /// @ingroup fp16_t public method | |||||
| /// @param [in] e_a exponent of one temp fp16_t number | |||||
| /// @param [in] m_a mantissa of one temp fp16_t number | |||||
| /// @param [in] e_b exponent of another temp fp16_t number | |||||
| /// @param [in] m_b mantissa of another temp fp16_t number | |||||
| /// @brief Get mantissa sum of two temp fp16_t numbers, T support types: uint16_t/uint32_t/uint64_t | |||||
| /// @return Return mantissa sum | |||||
| template<typename T> | |||||
| T GetManSum(int16_t e_a, const T &m_a, int16_t e_b, const T &m_b) { | |||||
| T sum = 0; | |||||
| if (e_a != e_b) { | |||||
| T m_tmp = 0; | |||||
| int16_t e_tmp = std::abs(e_a - e_b); | |||||
| if (e_a > e_b) { | |||||
| m_tmp = m_b; | |||||
| m_tmp = RightShift(m_tmp, e_tmp); | |||||
| sum = m_a + m_tmp; | |||||
| } else { | |||||
| m_tmp = m_a; | |||||
| m_tmp = RightShift(m_tmp, e_tmp); | |||||
| sum = m_tmp + m_b; | |||||
| } | |||||
| } else { | |||||
| sum = m_a + m_b; | |||||
| } | |||||
| return sum; | |||||
| } | |||||
| /// @ingroup fp16_t public method | |||||
| /// @param [in] bit0 whether the last preserved bit is 1 before round | |||||
| /// @param [in] bit1 whether the abbreviation's highest bit is 1 | |||||
| /// @param [in] bitLeft whether the abbreviation's bits which not contain highest bit grater than 0 | |||||
| /// @param [in] man mantissa of a fp16_t or float number, support types: uint16_t/uint32_t/uint64_t | |||||
| /// @param [in] shift abbreviation bits | |||||
| /// @brief Round fp16_t or float mantissa to nearest value | |||||
| /// @return Returns true if round 1,otherwise false; | |||||
| template<typename T> | |||||
| T ManRoundToNearest(bool bit0, bool bit1, bool bitLeft, T man, uint16_t shift = 0) { | |||||
| man = (man >> shift) + ((bit1 && (bitLeft || bit0)) ? 1 : 0); | |||||
| return man; | |||||
| } | |||||
| /// @ingroup fp16_t public method | |||||
| /// @param [in] man mantissa of a float number, support types: uint16_t/uint32_t/uint64_t | |||||
| /// @brief Get bit length of a uint32_t number | |||||
| /// @return Return bit length of man | |||||
| template<typename T> | |||||
| int16_t GetManBitLength(T man) { | |||||
| int16_t len = 0; | |||||
| while (man) { | |||||
| man >>= 1; | |||||
| len++; | |||||
| } | |||||
| return len; | |||||
| } | |||||
| } // namespace parser | |||||
| } // namespace ge | |||||
| #endif // GE_PARSER_COMMON_FP16_T_H_ | |||||
| @@ -0,0 +1,24 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "framework/omg/parser/parser_inner_ctx.h" | |||||
| namespace ge { | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ParserContext &GetParserContext() { | |||||
| static ParserContext context; | |||||
| return context; | |||||
| } | |||||
| } // namespace domi | |||||
| @@ -0,0 +1,494 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "framework/omg/parser/parser_types.h" | |||||
| namespace ge{ | |||||
| namespace parser { | |||||
| const char *DATA = "Data"; | |||||
| const char *AIPPDATA = "AippData"; | |||||
| const char *CONVOLUTION = "Convolution"; | |||||
| const char *CORRELATION = "Correlation"; | |||||
| const char *CORRELATIONV2 = "Correlation_V2"; | |||||
| const char *DECONVOLUTION = "Deconvolution"; | |||||
| const char *POOLING = "Pooling"; | |||||
| const char *ELTWISE = "Eltwise"; | |||||
| const char *RELU = "ReLU"; | |||||
| const char *RELU6 = "ReLU6"; | |||||
| const char *SIGMOID = "Sigmoid"; | |||||
| const char *ABSVAL = "AbsVal"; | |||||
| const char *TANH = "TanH"; | |||||
| const char *PRELU = "PReLU"; | |||||
| const char *BATCHNORM = "BatchNorm"; | |||||
| const char *FUSIONBATCHNORM = "FusionBatchNorm"; | |||||
| const char *SCALE = "Scale"; | |||||
| const char *FULL_CONNECTION = "FullConnection"; | |||||
| const char *SOFTMAX = "Softmax"; | |||||
| const char *PLUS = "Plus"; | |||||
| const char *ACTIVATION = "Activation"; | |||||
| const char *FLATTEN = "Flatten"; | |||||
| const char *ADD = "Add"; | |||||
| const char *SUB = "Sub"; | |||||
| const char *MUL = "Mul"; | |||||
| const char *MATMUL = "MatMul"; | |||||
| const char *RSQRT = "Rsqrt"; | |||||
| const char *BIASADD = "BiasAdd"; | |||||
| const char *RESHAPE = "Reshape"; | |||||
| const char *REFORMAT = "ReFormat"; | |||||
| const char *DEPCONVOLUTION = "ConvolutionDepthwise"; | |||||
| const char *DROPOUT = "Dropout"; | |||||
| const char *DROPOUTGENMASK = "DropOutGenMask"; | |||||
| const char *DROPOUTDOMASK = "DropOutDoMask"; | |||||
| const char *CONCAT = "Concat"; | |||||
| const char *ROIPOOLING = "ROIPooling"; | |||||
| const char *PROPOSAL = "Proposal"; | |||||
| const char *FSRDETECTIONOUTPUT = "FSRDetectionOutput"; | |||||
| const char *DETECTIONPOSTPROCESS = "Detectpostprocess"; | |||||
| const char *LRN = "LRN"; | |||||
| const char *TRANSDATA = "TransData"; | |||||
| const char *PERMUTE = "Permute"; | |||||
| const char *SSDNORMALIZE = "SSDNormalize"; | |||||
| const char *SSDPRIORBOX = "SSDPriorBox"; | |||||
| const char *NETOUTPUT = "NetOutput"; | |||||
| const char *SSDDETECTIONOUTPUT = "SSDDetectionOutput"; | |||||
| const char *REFINEDETDETECTIONOUTPUT = "RefinedetDetectionOutput"; | |||||
| const char *CHANNELAXPY = "ChannelAxpy"; | |||||
| const char *PSROIPOOLING = "PSROIPooling"; | |||||
| const char *POWER = "Power"; | |||||
| const char *POW = "Pow"; | |||||
| const char *ROIALIGN = "ROIAlign"; | |||||
| const char *PYTHON = "Python"; | |||||
| const char *FREESPACEEXTRACT = "FreespaceExtract"; | |||||
| const char *SPATIALTF = "SpatialTransform"; | |||||
| const char *SHAPE = "Shape"; | |||||
| const char *SHAPEN = "ShapeN"; | |||||
| const char *ARGMAX = "ArgMax"; | |||||
| const char *GATHERND = "GatherNd"; | |||||
| const char *GATHER = "Gather"; | |||||
| const char *REALDIV = "RealDiv"; | |||||
| const char *PACK = "Pack"; | |||||
| const char *SLICE = "Slice"; | |||||
| const char *SLICED = "SliceD"; | |||||
| const char *FLOORDIV = "FloorDiv"; | |||||
| const char *SQUEEZE = "Squeeze"; | |||||
| const char *UNSQUEEZE = "Unsqueeze"; | |||||
| const char *STRIDEDSLICE = "StridedSlice"; | |||||
| const char *RANGE = "Range"; | |||||
| const char *RPNPROPOSALS = "RpnProposals"; | |||||
| const char *DECODEBBOX = "DecodeBbox"; | |||||
| const char *PAD = "Pad"; | |||||
| const char *PADV2 = "PadV2"; | |||||
| const char *MIRRORPAD = "MirrorPad"; | |||||
| const char *TILE = "Tile"; | |||||
| const char *SIZE = "Size"; | |||||
| const char *CLIPBOXES = "ClipBoxes"; | |||||
| const char *FASTRCNNPREDICTIONS = "FastrcnnPredictions"; | |||||
| const char *SPLIT = "Split"; | |||||
| const char *SPLITV = "SplitV"; | |||||
| const char *EXPANDDIMS = "ExpandDims"; | |||||
| const char *EMPTY = "Empty"; | |||||
| const char *MEAN = "Mean"; | |||||
| const char *GREATER = "Greater"; | |||||
| const char *SWITCH = "Switch"; | |||||
| const char *SWITCHN = "SwitchN"; | |||||
| const char *MERGE = "Merge"; | |||||
| const char *SYMBOLICGRADIENT = "SymbolicGradient"; | |||||
| const char *REMOTECALL = "RemoteCall"; | |||||
| const char *_IF = "_If"; | |||||
| const char *STATELESSIF = "StatelessIf"; | |||||
| const char *IF = "If"; | |||||
| const char *CASE = "Case"; | |||||
| const char *_WHILE = "_While"; | |||||
| const char *WHILE = "While"; | |||||
| const char *STATELESSWHILE = "StatelessWhile"; | |||||
| const char *FOR = "For"; | |||||
| const char *PARTITIONEDCALL = "PartitionedCall"; | |||||
| const char *STATEFULPARTITIONEDCALL = "StatefulPartitionedCall"; | |||||
| const char *FAKEPARAM = "FakeParam"; | |||||
| const char *TRANSPOSE = "Transpose"; | |||||
| const char *TRANSPOSED = "TransposeD"; | |||||
| const char *CAST = "Cast"; | |||||
| const char *REGION = "Region"; | |||||
| const char *YOLO = "Yolo"; | |||||
| const char *YOLODETECTIONOUTPUT = "YoloDetectionOutput"; | |||||
| const char *FILL = "Fill"; | |||||
| const char *REVERSE = "Reverse"; | |||||
| const char *UNPACK = "Unpack"; | |||||
| const char *YOLO2REORG = "Yolo2Reorg"; | |||||
| const char *REDUCESUM = "ReduceSum"; | |||||
| const char *SUM = "Sum"; | |||||
| const char *CONSTANT = "Const"; | |||||
| const char *RESIZEBILINEAR = "ResizeBilinear"; | |||||
| const char *RESIZEBILINEARGRAD = "ResizeBilinearGrad"; | |||||
| const char *MAXIMUM = "Maximum"; | |||||
| const char *FRAMEWORKOP = "FrameworkOp"; | |||||
| const char *ARG = "_Arg"; | |||||
| const char *FUSEDBATCHNORMGRAD = "FusedBatchNormGrad"; | |||||
| const char *LSTM = "LSTM"; | |||||
| const char *HIGHWAY = "HighWay"; | |||||
| const char *RNN = "RNN"; | |||||
| const char *ATTENTIONDECODER = "AttentionDecoder"; | |||||
| const char *LOGICAL_NOT = "LogicalNot"; | |||||
| const char *LOGICAL_AND = "LogicalAnd"; | |||||
| const char *LOGICAL_OR = "LogicalOr"; | |||||
| const char *EQUAL = "Equal"; | |||||
| const char *NOTEQUAL = "NotEqual"; | |||||
| const char *INTERP = "Interp"; | |||||
| const char *SHUFFLECHANNEL = "ShuffleChannel"; | |||||
| const char *AIPP = "Aipp"; | |||||
| const char *MULTISHAPE = "MultiShape"; | |||||
| const char *RECIPROCAL = "Reciprocal"; | |||||
| const char *SELU = "Selu"; | |||||
| const char *ELU = "Elu"; | |||||
| const char *ACOSH = "Acosh"; | |||||
| const char *ASINH = "Asinh"; | |||||
| const char *MINIMUM = "Minimum"; | |||||
| const char *CLIP = "Clip"; | |||||
| const char *L2NORMALIZE = "L2Normalize"; | |||||
| const char *CROPANDRESIZE = "CropAndResize"; | |||||
| const char *UNUSEDCONST = "UnusedConst"; | |||||
| const char *SPARSETODENSE = "SparseToDense"; | |||||
| const char *NONMAXSUPPRESSION = "NonMaxSuppression"; | |||||
| const char *TOPKV2 = "TopKV2"; | |||||
| const char *INVERTPERMUTATION = "InvertPermutation"; | |||||
| const char *MULTINOMIAL = "Multinomial"; | |||||
| const char *REVERSESEQUENCE = "ReverseSequence"; | |||||
| const char *REDUCEPROD = "ReduceProd"; | |||||
| const char *REDUCEMAX = "ReduceMax"; | |||||
| const char *REDUCEMIN = "ReduceMin"; | |||||
| const char *EXTRACTIMAGEPATCHES = "ExtractImagePatches"; | |||||
| const char *SQRT = "Sqrt"; | |||||
| const char *REDUCEALL = "ReduceAll"; | |||||
| const char *RESIZENEARESTNEIGHBOR = "ResizeNearestNeighbor"; | |||||
| const char *SPACETOBATCHND = "SpaceToBatchND"; | |||||
| const char *BATCHTOSPACEND = "BatchToSpaceND"; | |||||
| const char *ASSERT = "Assert"; | |||||
| const char *GREATEREQUAL = "GreaterEqual"; | |||||
| const char *FLOOR = "Floor"; | |||||
| const char *RANDOMUNIFORM = "RandomUniform"; | |||||
| const char *BATCHMATMUL = "BatchMatMul"; | |||||
| const char *SPACETODEPTH = "SpaceToDepth"; | |||||
| const char *DEPTHTOSPACE = "DepthToSpace"; | |||||
| const char *RINT = "Rint"; | |||||
| const char *ATAN = "Atan"; | |||||
| const char *ATAN2 = "Atan2"; | |||||
| const char *ATANH = "Atanh"; | |||||
| const char *ACOS = "Acos"; | |||||
| const char *ASIN = "Asin"; | |||||
| const char *NEG = "Neg"; | |||||
| const char *LOG = "Log"; | |||||
| const char *TAN = "Tan"; | |||||
| const char *ROUND = "Round"; | |||||
| const char *UPSAMPLE = "Upsample"; | |||||
| const char *FLOORMOD = "FloorMod"; | |||||
| const char *LESS = "Less"; | |||||
| const char *LESSEQUAL = "LessEqual"; | |||||
| const char *ONEHOT = "OneHot"; | |||||
| const char *REFSWITCH = "RefSwitch"; | |||||
| const char *REFMERGE = "RefMerge"; | |||||
| const char *ENTER = "Enter"; | |||||
| const char *REFENTER = "RefEnter"; | |||||
| const char *LOOPCOND = "LoopCond"; | |||||
| const char *NEXTITERATION = "NextIteration"; | |||||
| const char *REFNEXTITERATION = "RefNextIteration"; | |||||
| const char *EXIT = "Exit"; | |||||
| const char *REFEXIT = "RefExit"; | |||||
| const char *CONTROLTRIGGER = "ControlTrigger"; | |||||
| const char *ZEROSLIKE = "ZerosLike"; | |||||
| const char *EXP = "Exp"; | |||||
| const char *WHERE = "Where"; | |||||
| const char *FAKEQUANTWITHMINMAXVARS = "FakeQuantWithMinMaxVars"; | |||||
| const char *SOFTPLUS = "Softplus"; | |||||
| const char *SOFTSIGN = "Softsign"; | |||||
| const char *COSH = "Cosh"; | |||||
| const char *SINH = "Sinh"; | |||||
| const char *SQUAREDDIFFERENCE = "SquaredDifference"; | |||||
| const char *REQUIREDSPACETOBATCHPADDINGS = "RequiredSpaceToBatchPaddings"; // for retinanet scope fusion | |||||
| const char *SSDPOSTPROCESSOR = "SSDPostProcessor"; | |||||
| const char *RETINANETBOXES = "RetinanetBoxes"; | |||||
| const char *RETINAMULTIANCHORS = "RetinaMultiAnchor"; | |||||
| const char *RETINANETCLIPPEDBOXES = "RetinanetClippedBoxes"; | |||||
| const char *RETINANETFILTEREDDETECTIONS = "RetinanetFilteredDetections"; | |||||
| const char *RETINANETPOSTPROCESSOR = "RetinanetPostProcessor"; | |||||
| const char *RETINANETANCHORS = "RetinanetAnchors"; | |||||
| const char *FASTERRCNNMAP = "FasterRCNNMap"; | |||||
| const char *FASTERRCNNMAP1 = "FasterRCNNMap1"; | |||||
| const char *FASTERRCNNSECONDSTAGEPOSTPROCESSOR = "FasterRCNNSecondStagePostprocessor"; | |||||
| const char *FASTERRCNNROIINTERPOOLING = "FasterRCNNROIInterPooling"; | |||||
| const char *FASTERRCNNFIRSTSTAGEPOSTPROCESSOR = "FasterRCNNFirstStagePostprocessor"; | |||||
| const char *FASTERRCNNGRIDANCHORGENERATOR = "FasterRCNNGridAnchorGenerator"; | |||||
| const char *ROIINTERPOOLING = "ROIInterPooling"; | |||||
| const char *FASTERRCNNCLIPTOWINDOW = "FasterRCNNClipToWindow"; | |||||
| const char *EMBEDLOOKUP = "EmbedLookup"; | |||||
| const char *HASHLOOKUP = "HashLookup"; | |||||
| const char *LSH_PROJ = "LshProject"; | |||||
| const char *SVDF = "SVDF"; | |||||
| const char *SSDANCHORGENERATOR = "SSDAnchorGenerator"; | |||||
| const char *IDENTITY = "Identity"; | |||||
| const char *IDENTITYN = "IdentityN"; | |||||
| const char *PLACEHOLDERWITHDEFAULT = "PlaceholderWithDefault"; | |||||
| const char *SELECT = "Select"; | |||||
| const char *GETSPAN = "GetSpan"; | |||||
| const char *STOPGRADIENT = "StopGradient"; | |||||
| const char *PREVENTGRADIENT = "PreventGradient"; | |||||
| const char *GUARANTEECONST = "GuaranteeConst"; | |||||
| const char *BROADCASTGRADIENTARGS = "BroadcastGradientArgs"; | |||||
| const char *BROADCASTARGS = "BroadcastArgs"; | |||||
| const char *CONFUSIONMATRIX = "ConfusionMatrix"; | |||||
| const char *RANK = "Rank"; | |||||
| const char *PLACEHOLDER = "PlaceHolder"; | |||||
| const char *END = "End"; | |||||
| const char *BASICLSTMCELL = "BasicLSTMCell"; | |||||
| const char *GETNEXT = "GetNext"; | |||||
| const char *INITDATA = "InitData"; | |||||
| const char *REFIDENTITY = "RefIdentity"; | |||||
| const char *BITCAST = "Bitcast"; | |||||
| /***************Ann special operator*************************/ | |||||
| const char *ANN_MEAN = "AnnMean"; | |||||
| const char *ANN_CONVOLUTION = "AnnConvolution"; | |||||
| const char *ANN_DEPCONVOLUTION = "AnnDepthConv"; | |||||
| const char *ANN_FULLCONNECTION = "AnnFullConnection"; | |||||
| const char *ANN_NETOUTPUT = "AnnNetOutput"; | |||||
| const char *ANN_DATA = "AnnData"; | |||||
| const char *ANN_RESHAPE = "AnnReshape"; | |||||
| const char *ANN_ADD = "AnnAdd"; | |||||
| const char *ANN_MUL = "AnnMul"; | |||||
| const char *ANN_SUB = "AnnSub"; | |||||
| const char *ANN_DIV = "AnnDiv"; | |||||
| const char *ANN_DEQUANTIZE = "AnnDequant"; | |||||
| const char *ANN_QUANTIZE = "AnnQuant"; | |||||
| const char *ANN_PAD = "AnnPad"; | |||||
| const char *ANN_RESIZE_BILINEAR = "AnnResizeBilinear"; | |||||
| /***************************************************/ | |||||
| /******************Training operator*************************/ | |||||
| const char *GATHERV2 = "GatherV2"; | |||||
| const char *CONVGRADFILTER = "Conv2DBackpropFilter"; | |||||
| const char *CONV2D = "Conv2D"; | |||||
| const char *CONV2DBACKPROPINPUT = "Conv2DBackpropInput"; | |||||
| const char *FUSEDBATCHNORM = "FusedBatchNorm"; | |||||
| const char *BIASADDGRAD = "BiasAddGrad"; | |||||
| const char *ACTIVATIONGRAD = "ReluGrad"; | |||||
| const char *MAXPOOLWITHARGMAX = "MaxPoolWithArgmax"; | |||||
| const char *MAXPOOLGRADWITHARGMAX = "MaxPoolGradWithArgmax"; | |||||
| const char *SPARSESOFTMAXCROSSENTROPYWITHLOGITS = "SparseSoftmaxCrossEntropyWithLogits"; | |||||
| const char *SNAPSHOT = "Snapshot"; | |||||
| const char *VAR = "Var"; | |||||
| const char *MEANGRAD = "MeanGrad"; | |||||
| const char *TRANSLATE = "Translate"; | |||||
| const char *ADDN = "AddN"; | |||||
| const char *L2LOSS = "L2Loss"; | |||||
| const char *MULTIPLY = "Multiply"; | |||||
| const char *HUBERLOSSGRAD = "HuberLossGrad"; | |||||
| const char *HUBERLOSS = "HuberLoss"; | |||||
| const char *NEGATIVE = "Negative"; | |||||
| const char *SSDCAST = "SSDCast"; | |||||
| const char *SPARSESOFTMAXCROSSENTROPY = "SsdSparseSoftmaxCrossEntropy"; | |||||
| const char *SPARSESOFTMAXCROSSENTROPYGRAD = "SsdSparseSoftmaxCrossEntropyGrad"; | |||||
| const char *SSDSQUEEZEFUSION = "SsdSqueezeFusion"; | |||||
| const char *CONCATFOUR2FIVE = "ConcatFour2Five"; | |||||
| const char *CONCATFIVE2FOUR = "ConcatFive2Four"; | |||||
| const char *SSDREALDIVTILEMUL = "SSDRealdivTileMul"; | |||||
| const char *SSDSUMMULREALDIVMEAN = "SSDSumMulRealdivMean"; | |||||
| const char *VARIABLEV2 = "VariableV2"; | |||||
| const char *VARHANDLEOP = "VarHandleOp"; | |||||
| const char *TEMPORARYVARIABLE = "TemporaryVariable"; | |||||
| const char *DESTROYTEMPORARYVARIABLE = "DestroyTemporaryVariable"; | |||||
| const char *VARIABLE = "Variable"; | |||||
| const char *ASSIGN = "Assign"; | |||||
| const char *ASSIGNVARIABLEOP = "AssignVariableOp"; | |||||
| const char *ASSIGNADD = "AssignAdd"; | |||||
| const char *ASSIGNADDVARIABLEOP = "AssignAddVariableOp"; | |||||
| const char *ASSIGNSUB = "AssignSub"; | |||||
| const char *ASSIGNSUBVARIABLEOP = "AssignSubVariableOp"; | |||||
| const char *APPLYMOMENTUM = "ApplyMomentum"; | |||||
| const char *RESOURCEAPPLYMOMENTUM = "ResourceApplyMomentum"; | |||||
| const char *SGD = "SGD"; | |||||
| const char *NOOP = "NoOp"; | |||||
| const char *READVARIABLEOP = "ReadVariableOp"; | |||||
| const char *PARALLELCONCATSTART = "_ParallelConcatStart"; | |||||
| const char *CONSTANTOP = "Constant"; | |||||
| const char *DEPTHWISECONV2DBACKPROPFILTER = "DepthwiseConv2dNativeBackpropFilter"; | |||||
| const char *DEPTHWISECONV2DBACKPORPINPUT = "DepthwiseConv2dNativeBackpropInput"; | |||||
| const char *DEPTHWISECONV2DFORWARDNATIVE = "DepthwiseConv2dNative"; | |||||
| const char *DROPOUTGRAD = "DropOutGrad"; | |||||
| const char *APPLYRMSPROPMIXEDPRECISION = "apply_rms_prop_mixed_precision"; | |||||
| const char *APPLYRMSPROP = "ApplyRMSProp"; | |||||
| const char *RELU6GRAD = "Relu6Grad"; | |||||
| const char *AVGPOOLGRAD = "AvgPoolGrad"; | |||||
| const char *CONCATV2 = "ConcatV2"; | |||||
| const char *CONCATOFFSET = "ConcatOffset"; | |||||
| const char *LAYERNORMGRAD = "LayerNormGrad"; | |||||
| const char *LAYERNORM = "LayerNorm"; | |||||
| const char *LARS = "Lars"; | |||||
| const char *DYNAMICSTITCH = "DynamicStitch"; | |||||
| /***************************************************/ | |||||
| const char *SQUARE = "Square"; | |||||
| const char *HCOMBROADCAST = "HcomBroadcast"; | |||||
| const char *HCOMALLGATHER = "HcomAllGather"; | |||||
| const char *HCOMALLREDUCE = "HcomAllReduce"; | |||||
| const char *HCOMREDUCESCATTER = "HcomReduceScatter"; | |||||
| const char *HCOMSEND = "HcomSend"; | |||||
| const char *HCOMRECEIVE = "HcomReceive"; | |||||
| const char *HCOMREMOTEREAD = "HcomRemoteRead"; | |||||
| const char *HCOMREMOTEWRITE = "HcomRemoteWrite"; | |||||
| const char *VARASSIGN = "VarAssign"; | |||||
| const char *VARISINITIALIZEDOP = "VarIsInitializedOp"; | |||||
| const char *LogTimeStamp = "LogTimeStamp"; | |||||
| const char *ISVARIABLEINITIALIZED = "IsVariableInitialized"; | |||||
| const char *STREAMSWITCH = "StreamSwitch"; | |||||
| const char *STREAMSWITCHN = "StreamSwitchN"; | |||||
| const char *STREAMACTIVE = "StreamActive"; | |||||
| const char *MEMCPYASYNC = "MemcpyAsync"; | |||||
| const char *MEMCPYADDRASYNC = "MemcpyAddrAsync"; | |||||
| const char *STREAMMERGE = "StreamMerge"; | |||||
| const char *ENDGRAPH = "EndGraph"; | |||||
| const char *SEND = "Send"; | |||||
| const char *RECV = "Recv"; | |||||
| const char *ENDOFSEQUENCE = "EndOfSequence"; | |||||
| const char *LABELSET = "LabelSet"; | |||||
| const char *LABELGOTO = "LabelGoto"; | |||||
| const char *LABELGOTOEX = "LabelGotoEx"; | |||||
| const char *LABELSWITCH = "LabelSwitch"; | |||||
| const char *LABELSWITCHBYINDEX = "LabelSwitchByIndex"; | |||||
| const char *ATOMICADDRCLEAN = "AtomicAddrClean"; | |||||
| const char *ABS_GRAD = "AbsGrad"; | |||||
| const char *ACCUMULATE_N_V2 = "AccumulateNV2"; | |||||
| const char *ACOS_GRAD = "AcosGrad"; | |||||
| const char *ACOSH_GRAD = "AcoshGrad"; | |||||
| const char *ANY = "Any"; | |||||
| const char *APPROXIMATE_EQUAL = "ApproximateEqual"; | |||||
| const char *ASIN_GRAD = "AsinGrad"; | |||||
| const char *ASINH_GRAD = "AsinhGrad"; | |||||
| const char *ATAN_GRAD = "AtanGrad"; | |||||
| const char *BROADCAST_TO = "BroadcastTo"; | |||||
| const char *ELU_GRAD = "EluGrad"; | |||||
| const char *ADD_V2 = "AddV2"; | |||||
| const char *DATAFORMATDIMMAP = "DataFormatDimMap"; | |||||
| const char *DATAFORMATVECPERMUTE = "DataFormatVecPermute"; | |||||
| const char *BESSELI0E = "BesselI0e"; | |||||
| const char *BESSELI1E = "BesselI1e"; | |||||
| const char *APPLYADADELTA = "ApplyAdadelta"; | |||||
| const char *APPLYADAGRAD = "ApplyAdagrad"; | |||||
| const char *APPLYADAGRADDA = "ApplyAdagradDA"; | |||||
| const char *APPLYADAM = "ApplyAdam"; | |||||
| const char *APPLYADAMAX = "ApplyAdaMax"; | |||||
| const char *APPLYADDSIGN = "ApplyAddSign"; | |||||
| const char *APPLYCENTEREDRMSPROP = "ApplyCenteredRMSProp"; | |||||
| const char *APPLYFTRL = "ApplyFtrl"; | |||||
| const char *APPLYFTRLV2 = "ApplyFtrlV2"; | |||||
| const char *APPLYGRADIENTDESCENT = "ApplyGradientDescent"; | |||||
| const char *APPLYPOWERSIGN = "ApplyPowerSign"; | |||||
| const char *APPLYPROXIMALADAGRAD = "ApplyProximalAdagrad"; | |||||
| const char *APPLYPROXIMALGRADIENTDESCENT = "ApplyProximalGradientDescent"; | |||||
| const char *DEQUANTIZE = "Dequantize"; | |||||
| const char *FOCAL_LOSS = "FocalLoss"; | |||||
| const char *FOCAL_LOSS_GRAD = "FocalLossGrad"; | |||||
| const char *SMOOTHL1_LOSS = "SmoothL1Loss"; | |||||
| const char *SMOOTHL1_LOSS_grad = "SmoothL1LossGrad"; | |||||
| const char *REDUCEMEAN = "ReduceMean"; | |||||
| const char *CONCAT_V2 = "ConcatV2"; | |||||
| const char *ONEHOT_V2 = "OneHotV2"; | |||||
| const char *SLICE_V2 = "SliceV2"; | |||||
| const char *TILE_V2 = "TileV2"; | |||||
| const char *SUM_V2 = "SumV2"; | |||||
| // Common type when the operator has the same name | |||||
| const char *DETECTIONOUTPUT = "DetectionOutput"; | |||||
| // Custom operator | |||||
| const char *CUSTOMOP = "CustomOp"; | |||||
| const char *CUSTOMOP_NCHW = "CustomOpNchw"; | |||||
| const char *CUSTOMOP_NHWC = "CustomOpNhwc"; | |||||
| const char *CUSTOMOP_NC1HWC0 = "CustomOpNc1hwc0"; | |||||
| // Depthwise 4d_2_6d,6d_2_4d | |||||
| const char *DEPTHWISEWEIGHT4D26D = "depthwise_weight_4d_2_6d"; | |||||
| const char *DEPTHWISEWEIGHT6D24D = "depthwise_weight_6d_2_4d"; | |||||
| const char *SQRTGRAD = "SqrtGrad"; | |||||
| const char *SIGMOIDGRAD = "SigmoidGrad"; | |||||
| const char *TRANSSHAPE = "TransShape"; | |||||
| // Horovod operator | |||||
| const char *HVDCALLBACKALLREDUCE = "HorovodAllreduce"; | |||||
| const char *HVDCALLBACKALLGATHER = "HorovodAllgather"; | |||||
| const char *HVDCALLBACKBROADCAST = "HorovodBroadcast"; | |||||
| const char *HVDWAIT = "HorovodWait"; | |||||
| /// | |||||
| /// @brief Magic number of model file | |||||
| /// | |||||
| const uint32_t MODEL_FILE_MAGIC_NUM = 0x444F4D49; // magic number | |||||
| /// | |||||
| /// @brief Model head length | |||||
| /// | |||||
| const uint32_t MODEL_FILE_HEAD_LEN = 256; | |||||
| const uint32_t MODEL_VERSION = 0x10000000; ///< Model version 1.0/// | |||||
| /// | |||||
| /// @ingroup domi_omg | |||||
| /// @brief alpha default value | |||||
| /// | |||||
| const float ALPHA_DEFAULT_VALUE = 1.0; | |||||
| /// | |||||
| /// @ingroup domi_omg | |||||
| /// @brief beta default value | |||||
| /// | |||||
| const float BETA_DEFAULT_VALUE = 0.0; | |||||
| /// | |||||
| /// @ingroup domi_omg | |||||
| /// @brief Input node type | |||||
| /// | |||||
| const std::string INPUT_TYPE = "Input"; | |||||
| const std::string DUMMY_DATA = "DummyData"; | |||||
| // for fusion op plugin | |||||
| const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE = "_fusionop_original_type"; | |||||
| const std::string ATTR_NAME_INPUT_TENSOR_DESC = "input_tensor_desc"; | |||||
| const std::string ATTR_NAME_OUTPUT_TENSOR_DESC = "output_tensor_desc"; | |||||
| /// | |||||
| /// @ingroup domi_omg | |||||
| /// @brief DATA node type | |||||
| /// | |||||
| const std::string DATA_TYPE = "Data"; | |||||
| /// | |||||
| /// @ingroup domi_omg | |||||
| /// @brief Frame operator type | |||||
| /// | |||||
| const std::string FRAMEWORK_OP_TYPE = "FrameworkOp"; | |||||
| /// | |||||
| /// @ingroup domi_omg | |||||
| /// @brief Convolution node type | |||||
| /// | |||||
| const std::string NODE_NAME_NET_OUTPUT = "Node_Output"; | |||||
| } // namespace parser | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,83 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "parser/common/pass_manager.h" | |||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "parser/common/acl_graph_parser_util.h" | |||||
| #include "common/debug/log.h" | |||||
| #include "graph/utils/node_utils.h" | |||||
| #include "omg/omg_inner_types.h" | |||||
| namespace ge { | |||||
| namespace parser { | |||||
| const vector<std::pair<std::string, GraphPass *>> &PassManager::GraphPasses() const { return names_to_graph_passes_; } | |||||
| Status PassManager::AddPass(const string &pass_name, GraphPass *pass) { | |||||
| GE_CHECK_NOTNULL(pass); | |||||
| names_to_graph_passes_.emplace_back(pass_name, pass); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status PassManager::Run(const ComputeGraphPtr &graph) { | |||||
| GE_CHECK_NOTNULL(graph); | |||||
| return Run(graph, names_to_graph_passes_); | |||||
| } | |||||
| Status PassManager::Run(const ComputeGraphPtr &graph, vector<std::pair<std::string, GraphPass *>> &names_to_passes) { | |||||
| GE_CHECK_NOTNULL(graph); | |||||
| bool not_changed = true; | |||||
| for (auto &pass_pair : names_to_passes) { | |||||
| const auto &pass = pass_pair.second; | |||||
| const auto &pass_name = pass_pair.first; | |||||
| GE_CHECK_NOTNULL(pass); | |||||
| PARSER_TIMESTAMP_START(PassRun); | |||||
| Status status = pass->Run(graph); | |||||
| if (status == SUCCESS) { | |||||
| not_changed = false; | |||||
| } else if (status != NOT_CHANGED) { | |||||
| GELOGE(status, "Pass Run failed on graph %s", graph->GetName().c_str()); | |||||
| return status; | |||||
| } | |||||
| for (const auto &subgraph :graph->GetAllSubgraphs()) { | |||||
| GE_CHECK_NOTNULL(subgraph); | |||||
| GE_CHK_STATUS_RET(pass->ClearStatus(), "pass clear status failed for subgraph %s", subgraph->GetName().c_str()); | |||||
| string subgraph_pass_name = pass_name + "::" + graph->GetName(); | |||||
| PARSER_TIMESTAMP_START(PassRunSubgraph); | |||||
| status = pass->Run(subgraph); | |||||
| PARSER_TIMESTAMP_END(PassRunSubgraph, subgraph_pass_name.c_str()); | |||||
| if (status == SUCCESS) { | |||||
| not_changed = false; | |||||
| } else if (status != NOT_CHANGED) { | |||||
| GELOGE(status, "Pass Run failed on subgraph %s", subgraph->GetName().c_str()); | |||||
| return status; | |||||
| } | |||||
| } | |||||
| PARSER_TIMESTAMP_END(PassRun, pass_name.c_str()); | |||||
| } | |||||
| return not_changed ? NOT_CHANGED : SUCCESS; | |||||
| } | |||||
| PassManager::~PassManager() { | |||||
| for (auto &pass_pair : names_to_graph_passes_) { | |||||
| auto &pass = pass_pair.second; | |||||
| GE_DELETE_NEW_SINGLE(pass); | |||||
| } | |||||
| } | |||||
| } // namespace parser | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,76 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PARSER_COMMON_PASS_MANAGER_H_ | |||||
| #define PARSER_COMMON_PASS_MANAGER_H_ | |||||
| #include <vector> | |||||
| #include "inc/graph_pass.h" | |||||
| using std::vector; | |||||
| namespace ge { | |||||
| namespace parser { | |||||
| /// | |||||
| /// @ingroup domi_omg | |||||
| /// @brief pass manager | |||||
| /// @author | |||||
| /// | |||||
| class PassManager { | |||||
| public: | |||||
| /// | |||||
| /// get graph passes | |||||
| /// @author | |||||
| /// | |||||
| const vector<std::pair<std::string, GraphPass *>> &GraphPasses() const; | |||||
| /// | |||||
| /// Add graph pass | |||||
| /// @param [in] pass Pass to be added, it will be destroyed when pass manager destroys. | |||||
| /// @author | |||||
| /// | |||||
| Status AddPass(const string &pass_name, GraphPass *pass); | |||||
| /// | |||||
| /// Optimize graph with added pass | |||||
| /// @param [inout] graph graph to be optimized | |||||
| /// @return SUCCESS optimize successfully | |||||
| /// @return NOT_CHANGED not optimized | |||||
| /// @return others optimize failed | |||||
| /// @author | |||||
| /// | |||||
| Status Run(const ge::ComputeGraphPtr &graph); | |||||
| /// | |||||
| /// Optimize graph with specified pass | |||||
| /// @param [inout] graph graph to be optimized | |||||
| /// @param [in] passes passes to be used | |||||
| /// @return SUCCESS optimize successfully | |||||
| /// @return NOT_CHANGED not optimized | |||||
| /// @return others optimized failed | |||||
| /// @author | |||||
| /// | |||||
| static Status Run(const ge::ComputeGraphPtr &graph, vector<std::pair<std::string, GraphPass *>> &passes); | |||||
| ~PassManager(); | |||||
| private: | |||||
| vector<std::pair<std::string, GraphPass *>> names_to_graph_passes_; | |||||
| }; | |||||
| } // namespace parser | |||||
| } // namespace ge | |||||
| #endif // PARSER_COMMON_PASS_MANAGER_H_ | |||||
| @@ -0,0 +1,287 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "parser/common/pre_checker.h" | |||||
| #include <nlohmann/json.hpp> | |||||
| #include "common/model_saver.h" | |||||
| #include "common/op_map.h" | |||||
| #include "common/util.h" | |||||
| #include "common/util/error_manager/error_manager.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "omg/omg.h" | |||||
| #include "parser/common/op_parser_factory.h" | |||||
| #include "parser/common/model_saver.h" | |||||
| #include "register/op_registry.h" | |||||
| namespace ge { | |||||
| // Keys in JSON file | |||||
| namespace { | |||||
| const char *const kKeyName = "name"; | |||||
| const char *const kKeyResult = "result"; | |||||
| const char *const kKeyTotal = "total"; | |||||
| const char *const kKeyPass = "pass"; | |||||
| const char *const kKeyFail = "fail"; | |||||
| const char *const kKeyOp = "op"; | |||||
| const char *const kKeyOpName = "name"; | |||||
| const char *const kKeyOpType = "type"; | |||||
| const char *const kKeyOpResult = "result"; | |||||
| const char *const kKeyCause = "cause"; | |||||
| const char *const kKeyCauseCode = "code"; | |||||
| const char *const kKeyCauseMessage = "message"; | |||||
| // Checking result and support warning later | |||||
| const char *const kResultSuccess = "success"; | |||||
| const char *const kResultFailed = "failed"; | |||||
| } // namespace | |||||
| PreChecker::PreChecker() : fmk_op_types_(nullptr) { Init(); } | |||||
| void PreChecker::Init() { | |||||
| model_name_.clear(); | |||||
| op_map_.clear(); | |||||
| ops_.clear(); | |||||
| fmk_op_types_ = nullptr; | |||||
| // Currently only Caffe and tensorflow are supported | |||||
| domi::FrameworkType fmk_type = GetParserContext().type; | |||||
| if (fmk_type == domi::CAFFE) | |||||
| fmk_op_types_ = &caffe_op_map; | |||||
| else if (fmk_type == domi::TENSORFLOW) | |||||
| fmk_op_types_ = &tensorflow_op_map; | |||||
| else | |||||
| return; | |||||
| } | |||||
| PreChecker::~PreChecker() {} | |||||
| FMK_FUNC_HOST_VISIBILITY PreChecker &PreChecker::Instance() { | |||||
| static PreChecker instance; | |||||
| return instance; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY void PreChecker::SetModelName(const string &name) { model_name_ = name; } | |||||
| FMK_FUNC_HOST_VISIBILITY Status PreChecker::AddOp(OpId id, const string &name, const string &type) { | |||||
| GE_RETURN_WITH_LOG_IF_TRUE(op_map_.find(id) != op_map_.end(), "Id already exists."); | |||||
| Info info; | |||||
| info.id = id; | |||||
| info.name = name; | |||||
| info.type = type; | |||||
| op_map_[id] = info; | |||||
| ops_.push_back(id); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status PreChecker::CheckName(OpId id) { | |||||
| auto iter = op_map_.find(id); | |||||
| GE_RETURN_WITH_LOG_IF_TRUE(iter == op_map_.end(), "Id does not exist."); | |||||
| Info &info = iter->second; | |||||
| for (auto &v : op_map_) { | |||||
| // If the name is duplicate, an error is logged | |||||
| if (id != v.first && info.name == v.second.name) { | |||||
| Cause cause; | |||||
| cause.code = NAME_REPEATED; | |||||
| cause.message = "The name is repeated."; | |||||
| GELOGI("Name %s repeated.", info.name.c_str()); | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19009", {"opname"}, {info.name}); | |||||
| GE_RETURN_WITH_LOG_IF_ERROR(AddCause(id, cause), "Add cause failed."); | |||||
| GE_RETURN_WITH_LOG_IF_ERROR(AddCause(v.first, cause), "Add cause failed."); | |||||
| break; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY Status PreChecker::CheckType(OpId id, bool is_tensorflow) { | |||||
| auto iter = op_map_.find(id); | |||||
| GE_RETURN_WITH_LOG_IF_TRUE(iter == op_map_.end(), "Id does not exist."); | |||||
| Info &info = iter->second; | |||||
| string type = info.type; | |||||
| // If the user explicitly specifies the mapping relationship of the operator type through | |||||
| // the -- OP_name_map parameter, the type specified by the user is used. | |||||
| auto op_map_iter = GetParserContext().op_conf_map.find(type); | |||||
| if (op_map_iter != GetParserContext().op_conf_map.end()) { | |||||
| type = op_map_iter->second; | |||||
| } | |||||
| // Judge whether the type is supported | |||||
| GE_RETURN_WITH_LOG_IF_ERROR( | |||||
| CheckTypeSupported(info.id, type, info.name, is_tensorflow), "Check type supported failed."); | |||||
| return SUCCESS; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY Status PreChecker::AddCause(OpId id, ErrorCode code, const string &msg) { | |||||
| Cause cause; | |||||
| cause.code = code; | |||||
| cause.message = msg; | |||||
| return AddCause(id, cause); | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY void PreChecker::RefreshErrorMessageByName(const string &op_name, ErrorCode code, | |||||
| const string &msg) { | |||||
| for (const auto &op : op_map_) { | |||||
| if (op.second.name == op_name) { | |||||
| AddCause(op.second.id, code, msg); | |||||
| return; | |||||
| } | |||||
| } | |||||
| GELOGW("Node [%s] not founded in prechecking list.", op_name.c_str()); | |||||
| } | |||||
| Status PreChecker::AddCause(OpId id, const Cause &cause) { | |||||
| auto iter = op_map_.find(id); | |||||
| GE_RETURN_WITH_LOG_IF_TRUE(iter == op_map_.end(), "Id does not exist."); | |||||
| Info &info = iter->second; | |||||
| // Avoid adding repeatedly | |||||
| for (Cause &c : info.causes) { | |||||
| if (c.code == cause.code && c.message == cause.message) { | |||||
| return SUCCESS; | |||||
| } | |||||
| } | |||||
| info.causes.push_back(cause); | |||||
| return SUCCESS; | |||||
| } | |||||
| void PreChecker::Clear() { Init(); } | |||||
| Status PreChecker::Clear(OpId id, const string &message) { | |||||
| auto iter = op_map_.find(id); | |||||
| GE_RETURN_WITH_LOG_IF_TRUE(iter == op_map_.end(), "Id does not exist."); | |||||
| Info &info = iter->second; | |||||
| info.causes.clear(); | |||||
| // Set additional information | |||||
| if (message != "") { | |||||
| Cause cause; | |||||
| cause.code = ErrorCode::OK; | |||||
| cause.message = message; | |||||
| GE_RETURN_WITH_LOG_IF_ERROR(AddCause(id, cause), "Add cause failed."); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY bool PreChecker::HasError() { | |||||
| for (auto id : ops_) { | |||||
| if (HasError(id)) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| Status PreChecker::Save(string file) { | |||||
| uint32_t fail_num = 0; | |||||
| for (auto id : ops_) { | |||||
| if (HasError(id)) { | |||||
| fail_num++; | |||||
| } | |||||
| } | |||||
| // Initialization model related JSON information | |||||
| nlohmann::json model; | |||||
| model[kKeyName] = model_name_; | |||||
| model[kKeyResult] = HasError() ? kResultFailed : kResultSuccess; | |||||
| model[kKeyTotal] = ops_.size(); | |||||
| model[kKeyPass] = ops_.size() - fail_num; | |||||
| model[kKeyFail] = fail_num; | |||||
| // Constructing JSON information of operators in order of network | |||||
| for (auto id : ops_) { | |||||
| auto iter = op_map_.find(id); | |||||
| GE_CHK_BOOL_RET_STATUS(iter != op_map_.end(), FAILED, "don't find this op."); | |||||
| Info &info = iter->second; | |||||
| // Initialization operator general information | |||||
| nlohmann::json op = {{kKeyOpName, info.name}, {kKeyOpType, info.type}}; | |||||
| op[kKeyOpResult] = HasError(id) ? kResultFailed : kResultSuccess; | |||||
| // handle causes | |||||
| for (const Cause &cause : info.causes) { | |||||
| nlohmann::json cause_j = {{kKeyCauseCode, cause.code}, {kKeyCauseMessage, cause.message}}; | |||||
| op[kKeyCause].push_back(cause_j); | |||||
| } | |||||
| model[kKeyOp].push_back(op); | |||||
| } | |||||
| // Save JSON data to a file | |||||
| GE_RETURN_WITH_LOG_IF_ERROR(ge::parser::ModelSaver::SaveJsonToFile(file.c_str(), model), "Save failed."); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status PreChecker::CheckTypeSupported(OpId id, const string &type, const string &name, bool is_tensorflow) { | |||||
| // Currently only partial framework type checking is supported | |||||
| if (fmk_op_types_ == nullptr) { | |||||
| std::string op_type; | |||||
| if (!domi::OpRegistry::Instance()->GetOmTypeByOriOpType(type, op_type)) { | |||||
| Cause cause; | |||||
| cause.code = TYPE_UNSUPPORTED; | |||||
| cause.message = "The type is not supported."; | |||||
| GELOGI("Check op[%s]'s type[%s] failed, it is not supported.", name.c_str(), type.c_str()); | |||||
| if (!is_tensorflow) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19010", {"opname", "optype"}, {name, type}); | |||||
| } | |||||
| GE_RETURN_WITH_LOG_IF_ERROR(AddCause(id, cause), "Add cause failed."); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| // Log error if type not found | |||||
| if (fmk_op_types_->find(type) == fmk_op_types_->end()) { | |||||
| Cause cause; | |||||
| cause.code = TYPE_UNSUPPORTED; | |||||
| cause.message = "The type is not supported."; | |||||
| GELOGI("Check op[%s]'s type[%s] failed, it is not supported.", name.c_str(), type.c_str()); | |||||
| if (!is_tensorflow) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19010", {"opname", "optype"}, {name, type}); | |||||
| } | |||||
| GE_RETURN_WITH_LOG_IF_ERROR(AddCause(id, cause), "Add cause failed."); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| bool PreChecker::HasError(OpId id) { | |||||
| auto iter = op_map_.find(id); | |||||
| GE_RETURN_WITH_LOG_IF_TRUE(iter == op_map_.end(), "Id does not exist."); | |||||
| Info &info = iter->second; | |||||
| for (const Cause &cause : info.causes) { | |||||
| if (cause.code != ErrorCode::OK) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,194 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PARSER_COMMON_PRE_CHECKER_H_ | |||||
| #define PARSER_COMMON_PRE_CHECKER_H_ | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "omg/omg_inner_types.h" | |||||
| namespace ge { | |||||
| using std::map; | |||||
| using std::string; | |||||
| using std::vector; | |||||
| using Status = domi::Status; | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief pre_check | |||||
| * @author | |||||
| */ | |||||
| class PreChecker { | |||||
| public: | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Operator unique identification | |||||
| */ | |||||
| using OpId = const void *; | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief error code, 1~99:Error, 100~199:Waring。 | |||||
| */ | |||||
| enum ErrorCode { | |||||
| // no error | |||||
| OK = 0, | |||||
| // type unsupported | |||||
| TYPE_UNSUPPORTED = 1, | |||||
| // param invalid | |||||
| PARAM_INVALID = 2, | |||||
| // type ambiguous | |||||
| TYPE_AMBIGUOUS = 8, | |||||
| // name repeated | |||||
| NAME_REPEATED = 9 | |||||
| }; | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Operator error description | |||||
| */ | |||||
| struct Cause { | |||||
| // error code | |||||
| ErrorCode code; | |||||
| // error message | |||||
| string message; | |||||
| }; | |||||
| public: | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief instance interface | |||||
| */ | |||||
| static PreChecker &Instance(); | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief set model name | |||||
| */ | |||||
| void SetModelName(const string &name); | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief add op information | |||||
| */ | |||||
| Status AddOp(OpId id, const string &name, const string &type); | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Judge whether the operator name is duplicate | |||||
| */ | |||||
| Status CheckName(OpId id); | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief check operation type | |||||
| * 1、Check whether the operator type supports according to the global frameworktype | |||||
| * 2、Check if the operator type is ambiguous | |||||
| */ | |||||
| Status CheckType(OpId id, bool is_tensorflow = false); | |||||
| void RefreshErrorMessageByName(const string &op_name, ErrorCode code, const string& msg); | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Add custom error description | |||||
| */ | |||||
| Status AddCause(OpId id, ErrorCode code, const string &msg); | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Add custom error description | |||||
| */ | |||||
| Status AddCause(OpId id, const Cause &cause); | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Clear all operator information | |||||
| */ | |||||
| void Clear(); | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Clear the error information of the specified operator | |||||
| */ | |||||
| Status Clear(OpId id, const string &message = ""); | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Determine if an error has been detected | |||||
| */ | |||||
| bool HasError(); | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Save inspection results(JSON) | |||||
| */ | |||||
| Status Save(string file); | |||||
| private: | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief operation information | |||||
| */ | |||||
| struct Info { | |||||
| // Operator identifier | |||||
| OpId id; | |||||
| // Operator name | |||||
| string name; | |||||
| // Operator type | |||||
| string type; | |||||
| // Error description, which may contain multiple (for example, both name and type are illegal) | |||||
| vector<Cause> causes; | |||||
| }; | |||||
| PreChecker(); | |||||
| ~PreChecker(); | |||||
| PreChecker(const PreChecker &); | |||||
| PreChecker &operator=(const PreChecker &); | |||||
| // Initialize internal data | |||||
| void Init(); | |||||
| // Judge whether the type is supported | |||||
| Status CheckTypeSupported(OpId id, const string &type, const string &name, bool is_tensorflow); | |||||
| // Determine if an error has been detected | |||||
| bool HasError(OpId id); | |||||
| private: | |||||
| // model name | |||||
| string model_name_; | |||||
| // Save operator check results | |||||
| map<OpId, Info> op_map_; | |||||
| // Save operator list in original order | |||||
| vector<OpId> ops_; | |||||
| // save frame related operator types | |||||
| map<string, string> *fmk_op_types_; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // PARSER_COMMON_PRE_CHECKER_H_ | |||||
| @@ -0,0 +1,190 @@ | |||||
| syntax = "proto3"; | |||||
| package ge.proto; | |||||
| enum DataType | |||||
| { | |||||
| DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set. | |||||
| DT_FLOAT = 1; // float type | |||||
| DT_FLOAT16 = 2; // fp16 type | |||||
| DT_INT8 = 3; // int8 type | |||||
| DT_UINT8 = 4; // uint8 type | |||||
| DT_INT16 = 5; // int16 type | |||||
| DT_UINT16 = 6; // uint16 type | |||||
| DT_INT32 = 7; // | |||||
| DT_INT64 = 8; // int64 type | |||||
| DT_UINT32 = 9; // unsigned int32 | |||||
| DT_UINT64 = 10; // unsigned int64 | |||||
| DT_BOOL = 11; // bool type | |||||
| DT_DOUBLE = 12; // double type | |||||
| DT_STRING = 13; // string type | |||||
| DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */ | |||||
| DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */ | |||||
| DT_COMPLEX64 = 16; // complex64 type | |||||
| DT_COMPLEX128 = 17; // complex128 type | |||||
| DT_QINT8 = 18; // qint8 type | |||||
| DT_QINT16 = 19; // qint16 type | |||||
| DT_QINT32 = 20; // qint32 type | |||||
| DT_QUINT8 = 21; // quint8 type | |||||
| DT_QUINT16 = 22; // quint16 type | |||||
| DT_RESOURCE = 23; // resource type | |||||
| DT_STRING_REF = 24; // string_ref type | |||||
| DT_DUAL = 25; /**< dual output type */ | |||||
| } | |||||
| message AttrDef | |||||
| { | |||||
| message ListValue | |||||
| { | |||||
| enum ListValueType{ | |||||
| VT_LIST_NONE = 0; | |||||
| VT_LIST_STRING = 1; | |||||
| VT_LIST_INT = 2; | |||||
| VT_LIST_FLOAT = 3; | |||||
| VT_LIST_BOOL = 4; | |||||
| VT_LIST_BYTES = 5; | |||||
| VT_LIST_TENSOR_DESC = 6; | |||||
| VT_LIST_TENSOR = 7; | |||||
| VT_LIST_GRAPH = 8; | |||||
| VT_LIST_NAMED_ATTRS = 9; | |||||
| VT_LIST_DATA_TYPE = 10; | |||||
| } | |||||
| repeated bytes s = 2; // "list(string)" | |||||
| repeated int64 i = 3; // "list(int)" | |||||
| repeated float f = 4; // "list(float)" | |||||
| repeated bool b = 5; // "list(bool)" | |||||
| repeated bytes bt = 7; | |||||
| repeated TensorDescriptor td = 8; | |||||
| repeated TensorDef t = 9; | |||||
| repeated GraphDef g = 10; | |||||
| repeated NamedAttrs na = 11; | |||||
| repeated int64 dt = 12; // list ge::DataType | |||||
| ListValueType val_type = 20; | |||||
| } | |||||
| message ListListInt{ | |||||
| message ListInt{ | |||||
| repeated int64 list_i = 1; // list int | |||||
| } | |||||
| repeated ListInt list_list_i = 1; // list list int | |||||
| } | |||||
| oneof value | |||||
| { | |||||
| bytes s = 2; // "string" | |||||
| int64 i = 3; // "int" | |||||
| float f = 4; // "float" | |||||
| bool b = 5; // "bool" | |||||
| bytes bt = 7; | |||||
| ListValue list = 1; // any "list(...)" | |||||
| NamedAttrs func = 10; // Used to support attr nesting | |||||
| TensorDescriptor td = 11; // GeTensorDesc type | |||||
| TensorDef t = 12; // GeTensor type | |||||
| GraphDef g = 13; // Graph type | |||||
| ListListInt list_list_int = 14; // List List Int type | |||||
| int64 dt = 15; // ge::DataType | |||||
| } | |||||
| } | |||||
| // A list of attr names and their values. The whole list is attached | |||||
| // with a string name. E.g., MatMul[T=float]. | |||||
| message NamedAttrs | |||||
| { | |||||
| string name = 1; | |||||
| map<string, AttrDef> attr = 2; | |||||
| } | |||||
| // Shape / dimension description, using row-major order | |||||
| message ShapeDef | |||||
| { | |||||
| repeated int64 dim = 1; // Size of each dimension | |||||
| } | |||||
| // Multidimensional data description | |||||
| message TensorDescriptor | |||||
| { | |||||
| string name = 1; // Optional parameter, tensor name | |||||
| DataType dtype = 2; // tensor datatype | |||||
| ShapeDef shape = 3; // Shape / dimension | |||||
| string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND" | |||||
| bool has_out_attr = 9; | |||||
| int64 size = 10; | |||||
| int64 weight_size = 11; | |||||
| bool reuse_input = 12; | |||||
| bool output_tensor = 13; | |||||
| string device_type = 14; | |||||
| bool input_tensor =15; | |||||
| int64 real_dim_cnt = 16; | |||||
| int64 reuse_input_index = 17; | |||||
| int64 data_offset = 18; | |||||
| int64 cmps_size = 19; | |||||
| string cmps_tab = 20; | |||||
| int64 cmps_tab_offset = 21; | |||||
| map<string, AttrDef> attr = 5; // Set of extra parameter fields | |||||
| } | |||||
| // GeTensor definition | |||||
| message TensorDef | |||||
| { | |||||
| TensorDescriptor desc = 1; // Tensor description | |||||
| bytes data = 2; // Tensor data | |||||
| } | |||||
| // Operator description | |||||
| message OpDef | |||||
| { | |||||
| string name = 1; // name | |||||
| string type = 2; // type | |||||
| repeated string input = 5; // input original op name + outgoing index. op_name:index | |||||
| map<string, AttrDef> attr = 10; // Set of operator parameter fields | |||||
| bool has_out_attr = 20; | |||||
| int64 id = 21; | |||||
| int64 stream_id =22; | |||||
| repeated string input_name = 23; | |||||
| repeated string src_name = 24; | |||||
| repeated int64 src_index = 25; | |||||
| repeated string dst_name = 26; | |||||
| repeated int64 dst_index = 27; | |||||
| repeated int64 input_i = 28; | |||||
| repeated int64 output_i = 29; | |||||
| repeated int64 workspace = 30; | |||||
| repeated int64 workspace_bytes = 31; | |||||
| repeated bool is_input_const = 32; | |||||
| repeated TensorDescriptor input_desc = 33; | |||||
| repeated TensorDescriptor output_desc = 34; | |||||
| repeated string subgraph_name = 35; | |||||
| } | |||||
| // Graph definition | |||||
| message GraphDef | |||||
| { | |||||
| string name = 1; // name | |||||
| repeated string input = 4; // Graph input | |||||
| repeated string output = 5; // Graph output | |||||
| repeated OpDef op = 6; // List of operators | |||||
| map<string, AttrDef> attr = 11; // Extended field | |||||
| } | |||||
| // model definition | |||||
| message ModelDef | |||||
| { | |||||
| string name = 1; // name | |||||
| uint32 version = 2; // IR Proto verion | |||||
| string custom_version = 3; // User model version number, passed in by user | |||||
| repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef | |||||
| map<string, AttrDef> attr = 11; // Extended field | |||||
| } | |||||
| @@ -0,0 +1,136 @@ | |||||
| syntax = "proto3"; | |||||
| package domi; | |||||
| message InsertNewOps { | |||||
| repeated AippOpParams aipp_op = 1; | |||||
| repeated MultiShapeOpParams multi_shape_op = 2; | |||||
| } | |||||
| message AippOpParams { | |||||
| enum InputFormat { | |||||
| UNDEFINED = 0; | |||||
| YUV420SP_U8 = 1; | |||||
| XRGB8888_U8 = 2; | |||||
| RGB888_U8 = 3; | |||||
| YUV400_U8 = 4; | |||||
| NC1HWC0DI_FP16 = 5; | |||||
| NC1HWC0DI_S8 = 6; | |||||
| ARGB8888_U8 = 7; | |||||
| YUYV_U8 = 8; | |||||
| YUV422SP_U8 = 9; | |||||
| AYUV444_U8 = 10; | |||||
| RAW10 = 11; | |||||
| RAW12 = 12; | |||||
| RAW16 = 13; | |||||
| RAW24 = 14; | |||||
| RGB16 = 15; | |||||
| RGB20 = 16; | |||||
| RGB24 = 17; | |||||
| RGB8_IR = 18; | |||||
| RGB16_IR = 19; | |||||
| RGB24_IR = 20; | |||||
| } | |||||
| enum AippMode { | |||||
| undefined = 0; | |||||
| static = 1; | |||||
| dynamic = 2; | |||||
| } | |||||
| // AIPP模式,区分静态AIPP和动态AIPP | |||||
| AippMode aipp_mode = 1; | |||||
| // related_input_rank参数为必填,类型为整型,配置范围>=0, <=输入Data算子的个数,默认值为0。 | |||||
| // 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。 | |||||
| uint32 related_input_rank = 2; | |||||
| // input_edge_idx参数为可选,类型为整型,配置范围为>=0。 | |||||
| // 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。 | |||||
| // 配置值 <= Data算子输出边的个数。 | |||||
| repeated uint32 input_edge_idx = 3; | |||||
| // [Begin] 动态AIPP参数,配置静态AIPP时无效 | |||||
| uint32 max_src_image_size = 4; | |||||
| // 是否支持旋转。默认不支持,开启支持旋转时,会有额外的空间和性能损失 | |||||
| bool support_rotation = 5; | |||||
| // [End] 动态AIPP参数 | |||||
| // [Begin] 静态AIPP参数,配置动态AIPP时无效 | |||||
| InputFormat input_format = 51; | |||||
| bool csc_switch = 52; | |||||
| float cpadding_value = 53; | |||||
| bool rbuv_swap_switch = 54; | |||||
| bool ax_swap_switch = 55; | |||||
| bool single_line_mode = 56; | |||||
| int32 src_image_size_w = 57; | |||||
| int32 src_image_size_h = 58; | |||||
| bool crop = 59; | |||||
| int32 load_start_pos_w = 60; | |||||
| int32 load_start_pos_h = 61; | |||||
| int32 crop_size_w = 62; | |||||
| int32 crop_size_h = 63; | |||||
| bool resize = 64; | |||||
| int32 resize_output_w = 65; | |||||
| int32 resize_output_h = 66; | |||||
| bool padding = 67; | |||||
| int32 left_padding_size = 68; | |||||
| int32 right_padding_size = 69; | |||||
| int32 top_padding_size = 70; | |||||
| int32 bottom_padding_size = 71; | |||||
| int32 mean_chn_0 = 10; | |||||
| int32 mean_chn_1 = 11; | |||||
| int32 mean_chn_2 = 12; | |||||
| int32 mean_chn_3 = 19; | |||||
| float min_chn_0 = 13; | |||||
| float min_chn_1 = 14; | |||||
| float min_chn_2 = 15; | |||||
| float min_chn_3 = 20; | |||||
| repeated float var_reci_chn_0 = 16; | |||||
| repeated float var_reci_chn_1 = 17; | |||||
| repeated float var_reci_chn_2 = 18; | |||||
| repeated float var_reci_chn_3 = 21; | |||||
| repeated int32 matrix_r0c0 = 30; | |||||
| repeated int32 matrix_r0c1 = 31; | |||||
| repeated int32 matrix_r0c2 = 32; | |||||
| repeated int32 matrix_r1c0 = 33; | |||||
| repeated int32 matrix_r1c1 = 34; | |||||
| repeated int32 matrix_r1c2 = 35; | |||||
| repeated int32 matrix_r2c0 = 36; | |||||
| repeated int32 matrix_r2c1 = 37; | |||||
| repeated int32 matrix_r2c2 = 38; | |||||
| repeated int32 output_bias_0 = 39; | |||||
| repeated int32 output_bias_1 = 40; | |||||
| repeated int32 output_bias_2 = 41; | |||||
| repeated int32 input_bias_0 = 42; | |||||
| repeated int32 input_bias_1 = 43; | |||||
| repeated int32 input_bias_2 = 44; | |||||
| // [End] 静态AIPP参数 | |||||
| // The n number that is used for raw/rgbir data into f16 transformation. | |||||
| // The transformation equation is x/(2^n). If set to 0, no transform is performed. | |||||
| uint32 raw_rgbir_to_f16_n = 45; | |||||
| } | |||||
| message MultiShapeOpParams { | |||||
| enum MultiShapeMode { | |||||
| batch = 0; //动态batch | |||||
| resolution = 1; //动态分辨率,扩展用 | |||||
| } | |||||
| MultiShapeMode mode = 1; //算子模式 | |||||
| uint32 related_input_rank = 2; //新增算子插入到哪个输入 | |||||
| repeated uint32 batch_list = 11; //batch_list值,batch_list的个数是2到8之间 | |||||
| } | |||||
| @@ -0,0 +1,396 @@ | |||||
| /* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. | |||||
| * | |||||
| * This program is free software; you can redistribute it and/or modify | |||||
| * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. | |||||
| * | |||||
| * This program is distributed in the hope that it will be useful, | |||||
| * but WITHOUT ANY WARRANTY; without even the implied warranty of | |||||
| * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |||||
| * Apache License for more details at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| */ | |||||
| syntax = "proto3"; | |||||
| package domi; | |||||
| enum TargetType | |||||
| { | |||||
| MINI = 0; | |||||
| TINY = 1; | |||||
| LITE = 2; | |||||
| } | |||||
| // offline model | |||||
| message ModelDef { | |||||
| string name = 1; | |||||
| uint32 version = 2; | |||||
| uint64 memory_size = 10; | |||||
| uint32 stream_num = 11; | |||||
| uint32 event_num = 12; | |||||
| uint64 weight_size = 13; | |||||
| uint32 label_num = 15; | |||||
| repeated OpDef op = 20; | |||||
| TargetType target_type = 23; | |||||
| map<string, AttrDef> attr = 30; | |||||
| }; | |||||
| // operator define | |||||
| message OpDef { | |||||
| string name = 1; | |||||
| string type = 2; | |||||
| uint32 id = 3; | |||||
| uint32 stream_id = 4; | |||||
| repeated string input_name = 5; | |||||
| repeated string src_name = 8; | |||||
| repeated int32 src_index = 9; | |||||
| repeated int64 input = 10; | |||||
| repeated int64 output = 11; | |||||
| repeated TensorDescriptor input_desc = 12; | |||||
| repeated TensorDescriptor output_desc = 13; | |||||
| repeated WeightDef weights = 14; | |||||
| repeated string dst_name = 15; | |||||
| repeated int32 dst_index = 16; | |||||
| repeated int64 workspace = 20; | |||||
| repeated uint32 workspace_bytes = 21; | |||||
| repeated string weight_name = 22; | |||||
| repeated bool is_input_const = 23; | |||||
| map<string, AttrDef> attr = 30; | |||||
| QuantizeFactorParams quantize_factor = 31; | |||||
| oneof op_params { | |||||
| // start at 100 here | |||||
| SendOpParams sender_param = 100; | |||||
| RecvOpParams receiver_param = 200; | |||||
| ConvolutionOpParams convolution_param = 300; | |||||
| PoolingOpParams pooling_param = 400; | |||||
| EltwiseOpParams eltwise_param = 500; | |||||
| BatchNormOpParams batchnorm_param = 600; | |||||
| ScaleOpParams scale_param = 700; | |||||
| FullConnectionOpParams full_connection_param = 800; | |||||
| SoftmaxOpParams softmax_param = 900; | |||||
| ActivationOpParams activation_param = 1000; | |||||
| ReshapeOpParams reshape_param = 1100; | |||||
| } | |||||
| }; | |||||
| message SendOpParams { | |||||
| uint32 event_id = 1; | |||||
| }; | |||||
| message RecvOpParams { | |||||
| uint32 event_id = 1; | |||||
| }; | |||||
| enum QuantizeScaleType | |||||
| { | |||||
| VECTOR_SCALE = 0; | |||||
| SCALAR_SCALE = 1; | |||||
| } | |||||
| enum QuantizeScaleMode | |||||
| { | |||||
| NORMAL_MODE = 0; | |||||
| SQRT_MODE = 1; | |||||
| } | |||||
| enum QuantizeAlgorithm | |||||
| { | |||||
| NON_OFFSET_ALGO = 0; | |||||
| HALF_OFFSET_ALGO = 1; | |||||
| ALL_OFFSET_ALGO = 2; | |||||
| } | |||||
| message QuantizeFactor | |||||
| { | |||||
| QuantizeScaleMode scale_mode = 1; | |||||
| bytes scale_value = 2; | |||||
| int64 scale_offset = 3; | |||||
| bytes offset_data_value = 4; | |||||
| int64 offset_data_offset = 5; | |||||
| bytes offset_weight_value = 6; | |||||
| int64 offset_weight_offset = 7; | |||||
| bytes offset_pad_value = 8; | |||||
| int64 offset_pad_offset = 9; | |||||
| }; | |||||
| message QuantizeCalcFactor | |||||
| { | |||||
| bytes offsetw = 1; | |||||
| int64 offsetw_offset = 2; | |||||
| bytes offsetd = 3; | |||||
| int64 offsetd_offset = 4; | |||||
| bytes scalereq = 5; | |||||
| int64 scaledreq_offset = 6; | |||||
| bytes offsetdnext = 7; | |||||
| int64 offsetdnext_offset = 8; | |||||
| } | |||||
| message QuantizeFactorParams | |||||
| { | |||||
| QuantizeAlgorithm quantize_algo = 1; | |||||
| QuantizeScaleType scale_type = 2; | |||||
| QuantizeFactor quantize_param = 3; | |||||
| QuantizeFactor dequantize_param = 4; | |||||
| QuantizeFactor requantize_param = 5; | |||||
| QuantizeCalcFactor quantizecalc_param = 6; | |||||
| }; | |||||
| message ConvolutionOpParams { | |||||
| int32 mode = 1; | |||||
| int32 algo = 2; | |||||
| int32 pad_mode = 3; | |||||
| uint32 group = 4; | |||||
| uint32 num_output = 5; | |||||
| repeated uint32 pad = 10; | |||||
| repeated uint32 stride = 11; | |||||
| repeated uint32 dilation = 12; | |||||
| repeated uint32 kernel = 13; | |||||
| float alpha = 20; | |||||
| float beta = 21; | |||||
| WeightDef filter = 40; | |||||
| WeightDef bias = 41; | |||||
| bool relu_flag = 62; | |||||
| repeated uint32 adj = 70; | |||||
| repeated uint32 target_shape = 71; | |||||
| repeated uint32 before_pad = 72; | |||||
| }; | |||||
| message PoolingOpParams { | |||||
| int32 mode = 1; | |||||
| int32 nan_opt = 2; | |||||
| int32 pad_mode = 3; | |||||
| bool global_pooling = 4; | |||||
| repeated uint32 window = 10; | |||||
| repeated uint32 pad = 11; | |||||
| repeated uint32 stride = 12; | |||||
| bool ceil_mode = 13; | |||||
| int32 data_mode = 14; | |||||
| float alpha = 20; | |||||
| float beta = 21; | |||||
| repeated uint32 before_pad = 22; | |||||
| }; | |||||
| message EltwiseOpParams { | |||||
| int32 mode = 1; | |||||
| repeated float coeff = 2; | |||||
| float alpha = 3; | |||||
| float beta = 4; | |||||
| repeated WeightDef weight = 5; | |||||
| bool relu_flag = 6; | |||||
| }; | |||||
| message ActivationOpParams { | |||||
| int32 mode = 1; | |||||
| float coef = 2; | |||||
| float alpha = 3; | |||||
| float beta = 4; | |||||
| }; | |||||
| message BatchNormOpParams { | |||||
| int32 mode = 1; | |||||
| float alpha = 2; | |||||
| float beta = 3; | |||||
| double epsilon = 4;//optinal,[default = 1e-5] | |||||
| bool use_global_stats = 5; //optinal,by default true,testing mode | |||||
| float moving_average_fraction = 6; //optinal,[default = .999]; | |||||
| WeightDef estimated_mean = 7; | |||||
| WeightDef estimated_variance = 8; | |||||
| WeightDef scale = 9; | |||||
| WeightDef bias = 10; | |||||
| }; | |||||
| message ScaleOpParams { | |||||
| WeightDef scale = 1; | |||||
| WeightDef bias = 2; | |||||
| }; | |||||
| message ReshapeOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| ShapeDef shape = 3; | |||||
| int32 axis = 4; | |||||
| int32 num_axes = 5; | |||||
| int32 format = 6; | |||||
| }; | |||||
| message SoftmaxOpParams { | |||||
| int32 algo = 1; | |||||
| int32 mode = 2; | |||||
| float alpha = 3; | |||||
| float beta = 4; | |||||
| }; | |||||
| message FullConnectionOpParams { | |||||
| WeightDef filter = 1; | |||||
| WeightDef bias = 2; | |||||
| uint32 num_output = 3; | |||||
| bool relu_flag = 12; | |||||
| }; | |||||
| message FlattenOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| int32 start_axis = 3; | |||||
| int32 end_axis = 4; | |||||
| } | |||||
| message AddLimitedOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| int32 axis = 3; | |||||
| bool broadcast = 4; | |||||
| repeated WeightDef weight = 10; | |||||
| }; | |||||
| message MulLimitedOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| int32 axis = 3; | |||||
| bool broadcast = 4; | |||||
| repeated WeightDef weight = 10; | |||||
| }; | |||||
| message AddOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| repeated WeightDef weight = 10; | |||||
| }; | |||||
| message MulOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| repeated WeightDef weight = 10; | |||||
| }; | |||||
| message SubOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| repeated WeightDef weight = 10; | |||||
| }; | |||||
| message BiasAddOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| WeightDef bias = 10; | |||||
| }; | |||||
| message MatMulOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| bool transposeX = 3; | |||||
| bool transposeW = 4; | |||||
| WeightDef filter = 10; | |||||
| WeightDef bias = 12; | |||||
| }; | |||||
| message RsqrtOpParams { | |||||
| float alpha = 1; | |||||
| float beta = 2; | |||||
| }; | |||||
| message WeightDef { | |||||
| int32 format = 1; | |||||
| int32 data_type = 2; | |||||
| ShapeDef shape = 3; | |||||
| bytes data = 4; | |||||
| int64 data_offset = 5; | |||||
| uint32 cmps_size = 6; | |||||
| bytes cmps_tab = 7; | |||||
| int64 cmps_tab_offset = 10; | |||||
| CompressInfo cmps_info = 8; | |||||
| AllOffsetQuantizeInfo alloffset_quantize_info = 11; | |||||
| } | |||||
| message ShapeDef { | |||||
| repeated int64 dim = 1; | |||||
| } | |||||
| enum DeviceType { | |||||
| NPU = 0; // In default, we will use NPU. | |||||
| CPU = 1; // CPU | |||||
| } | |||||
| message AllOffsetQuantizeInfo { | |||||
| float scale = 1; | |||||
| int32 offset = 2; | |||||
| } | |||||
| message TensorDescriptor { | |||||
| int32 format = 1; | |||||
| int32 data_type = 2; | |||||
| repeated int64 dim = 3; | |||||
| uint32 size = 4; | |||||
| bool reuse_input = 5; | |||||
| bool output_tensor = 7; | |||||
| DeviceType device_type = 8; | |||||
| bool input_tensor = 9; | |||||
| uint32 real_dim_cnt = 10; | |||||
| uint32 reuse_input_index = 11; | |||||
| AllOffsetQuantizeInfo alloffset_quantize_info = 12; | |||||
| } | |||||
| message CompressInfo { | |||||
| int32 blockRow = 1; // block row | |||||
| int32 blockCol = 2; // block col | |||||
| int32 fractalK = 3; // fractal K | |||||
| int32 fractalN = 4; // fractal N | |||||
| int32 lastFractalK = 5; // K of last fractal | |||||
| int32 lastFractalN = 6; // N of last fractal | |||||
| int32 cubeSize = 7; // cube's length | |||||
| int32 loadDir = 8; // data load directtiono 0:col load 1:row load | |||||
| } | |||||
| message AttrDef { | |||||
| message ListValue { | |||||
| repeated string s = 2; // "list(string)" | |||||
| repeated int64 i = 3 [packed = true]; // "list(int)" | |||||
| repeated float f = 4 [packed = true]; // "list(float)" | |||||
| repeated bool b = 5 [packed = true]; // "list(bool)" | |||||
| repeated uint32 u = 6 [packed = true]; // "list(uint)" | |||||
| repeated bytes bt = 7; | |||||
| } | |||||
| oneof value { | |||||
| string s = 2; // "string" | |||||
| int64 i = 3; // "int" | |||||
| float f = 4; // "float" | |||||
| bool b = 5; // "bool" | |||||
| uint32 u = 6; // "uint32" | |||||
| bytes bt = 7; | |||||
| ListValue list = 1; // any "list(...)" | |||||
| NamedAttrs func = 10; | |||||
| } | |||||
| } | |||||
| // A list of attr names and their values. The whole list is attached | |||||
| // with a string name. E.g., MatMul[T=float]. | |||||
| message NamedAttrs { | |||||
| string name = 1; | |||||
| map<string, AttrDef> attr = 2; | |||||
| } | |||||
| @@ -0,0 +1,62 @@ | |||||
| syntax = "proto3"; | |||||
| package domi.tensorflow; | |||||
| option cc_enable_arenas = true; | |||||
| option java_outer_classname = "AttrValueProtos"; | |||||
| option java_multiple_files = true; | |||||
| option java_package = "org.tensorflow.framework"; | |||||
| import "tensor.proto"; | |||||
| import "tensor_shape.proto"; | |||||
| import "types.proto"; | |||||
| // Protocol buffer representing the value for an attr used to configure an Op. | |||||
| // Comment indicates the corresponding attr type. Only the field matching the | |||||
| // attr type may be filled. | |||||
| message AttrValue { | |||||
| // LINT.IfChange | |||||
| message ListValue { | |||||
| repeated bytes s = 2; // "list(string)" | |||||
| repeated int64 i = 3 [packed = true]; // "list(int)" | |||||
| repeated float f = 4 [packed = true]; // "list(float)" | |||||
| repeated bool b = 5 [packed = true]; // "list(bool)" | |||||
| repeated DataType type = 6 [packed = true]; // "list(type)" | |||||
| repeated TensorShapeProto shape = 7; // "list(shape)" | |||||
| repeated TensorProto tensor = 8; // "list(tensor)" | |||||
| repeated NameAttrList func = 9; // "list(attr)" | |||||
| } | |||||
| // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc) | |||||
| oneof value { | |||||
| bytes s = 2; // "string" | |||||
| int64 i = 3; // "int" | |||||
| float f = 4; // "float" | |||||
| bool b = 5; // "bool" | |||||
| DataType type = 6; // "type" | |||||
| TensorShapeProto shape = 7; // "shape" | |||||
| TensorProto tensor = 8; // "tensor" | |||||
| ListValue list = 1; // any "list(...)" | |||||
| // "func" represents a function. func.name is a function's name or | |||||
| // a primitive op's name. func.attr.first is the name of an attr | |||||
| // defined for that function. func.attr.second is the value for | |||||
| // that attr in the instantiation. | |||||
| NameAttrList func = 10; | |||||
| // This is a placeholder only used in nodes defined inside a | |||||
| // function. It indicates the attr value will be supplied when | |||||
| // the function is instantiated. For example, let us suppose a | |||||
| // node "N" in function "FN". "N" has an attr "A" with value | |||||
| // placeholder = "foo". When FN is instantiated with attr "foo" | |||||
| // set to "bar", the instantiated node N's attr A will have been | |||||
| // given the value "bar". | |||||
| string placeholder = 9; | |||||
| } | |||||
| } | |||||
| // A list of attr names and their values. The whole list is attached | |||||
| // with a string name. E.g., MatMul[T=float]. | |||||
| message NameAttrList { | |||||
| string name = 1; | |||||
| map<string, AttrValue> attr = 2; | |||||
| } | |||||
| @@ -0,0 +1,100 @@ | |||||
| syntax = "proto3"; | |||||
| package domi.tensorflow; | |||||
| option cc_enable_arenas = true; | |||||
| option java_outer_classname = "FunctionProtos"; | |||||
| option java_multiple_files = true; | |||||
| option java_package = "org.tensorflow.framework"; | |||||
| import "attr_value.proto"; | |||||
| import "node_def.proto"; | |||||
| import "op_def.proto"; | |||||
| // A library is a set of named functions. | |||||
| message FunctionDefLibrary { | |||||
| repeated FunctionDef function = 1; | |||||
| repeated GradientDef gradient = 2; | |||||
| } | |||||
| // A function can be instantiated when the runtime can bind every attr | |||||
| // with a value. When a GraphDef has a call to a function, it must | |||||
| // have binding for every attr defined in the signature. | |||||
| // * device spec, etc. | |||||
| message FunctionDef { | |||||
| // The definition of the function's name, arguments, return values, | |||||
| // attrs etc. | |||||
| OpDef signature = 1; | |||||
| // Attributes specific to this function definition. | |||||
| map<string, AttrValue> attr = 5; | |||||
| // NOTE: field id 2 deleted on Jan 11, 2017, GraphDef version 21. | |||||
| reserved 2; | |||||
| // In both of the following fields, there is the need to specify an | |||||
| // output that is used as either the input to another node (in | |||||
| // `node_def`) or as a return value of the function (in `ret`). | |||||
| // Unlike the NodeDefs in GraphDef, we need to be able to specify a | |||||
| // list in some cases (instead of just single outputs). Also, we | |||||
| // need to be able to deal with lists of unknown length (so the | |||||
| // output index may not be known at function definition time). So | |||||
| // we use the following format instead: | |||||
| // * "fun_in" where "fun_in" is the name of a function input arg in | |||||
| // the `signature` field above. This represents that input, whether | |||||
| // it is a single tensor or a list. | |||||
| // * "fun_in:0" gives the first element of a function input arg (a | |||||
| // non-list input is considered a list of length 1 for these | |||||
| // purposes). | |||||
| // * "node:out" where "node" is the name of a node in `node_def` and | |||||
| // "out" is the name one of its op's output arguments (the name | |||||
| // comes from the OpDef of the node's op). This represents that | |||||
| // node's output, whether it is a single tensor or a list. | |||||
| // Note: We enforce that an op's output arguments are never | |||||
| // renamed in the backwards-compatibility test. | |||||
| // * "node:out:0" gives the first element of a node output arg (a | |||||
| // non-list output is considered a list of length 1 for these | |||||
| // purposes). | |||||
| // | |||||
| // NOT CURRENTLY SUPPORTED (but may be in the future): | |||||
| // * "node:out:-1" gives last element in a node output list | |||||
| // * "node:out:1:" gives a list with all but the first element in a | |||||
| // node output list | |||||
| // * "node:out::-1" gives a list with all but the last element in a | |||||
| // node output list | |||||
| // The body of the function. Unlike the NodeDefs in a GraphDef, attrs | |||||
| // may have values of type `placeholder` and the `input` field uses | |||||
| // the "output" format above. | |||||
| // By convention, "op" in node_def is resolved by consulting with a | |||||
| // user-defined library first. If not resolved, "func" is assumed to | |||||
| // be a builtin op. | |||||
| repeated NodeDef node_def = 3; | |||||
| // A mapping from the output arg names from `signature` to the | |||||
| // outputs from `node_def` that should be returned by the function. | |||||
| map<string, string> ret = 4; | |||||
| } | |||||
| // GradientDef defines the gradient function of a function defined in | |||||
| // a function library. | |||||
| // | |||||
| // A gradient function g (specified by gradient_func) for a function f | |||||
| // (specified by function_name) must follow the following: | |||||
| // | |||||
| // The function 'f' must be a numerical function which takes N inputs | |||||
| // and produces M outputs. Its gradient function 'g', which is a | |||||
| // function taking N + M inputs and produces N outputs. | |||||
| // | |||||
| // I.e. if we have | |||||
| // (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), | |||||
| // then, g is | |||||
| // (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, | |||||
| // dL/dy1, dL/dy2, ..., dL/dy_M), | |||||
| // where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the | |||||
| // loss function). dL/dx_i is the partial derivative of L with respect | |||||
| // to x_i. | |||||
| message GradientDef { | |||||
| string function_name = 1; // The function name. | |||||
| string gradient_func = 2; // The gradient function's name. | |||||
| } | |||||
| @@ -0,0 +1,56 @@ | |||||
| syntax = "proto3"; | |||||
| package domi.tensorflow; | |||||
| option cc_enable_arenas = true; | |||||
| option java_outer_classname = "GraphProtos"; | |||||
| option java_multiple_files = true; | |||||
| option java_package = "org.tensorflow.framework"; | |||||
| import "node_def.proto"; | |||||
| import "function.proto"; | |||||
| import "versions.proto"; | |||||
| // Represents the graph of operations | |||||
| message GraphDef { | |||||
| repeated NodeDef node = 1; | |||||
| // Compatibility versions of the graph. See core/public/version.h for version | |||||
| // history. The GraphDef version is distinct from the TensorFlow version, and | |||||
| // each release of TensorFlow will support a range of GraphDef versions. | |||||
| VersionDef versions = 4; | |||||
| // Deprecated single version field; use versions above instead. Since all | |||||
| // GraphDef changes before "versions" was introduced were forward | |||||
| // compatible, this field is entirely ignored. | |||||
| int32 version = 3 [deprecated = true]; | |||||
| // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. | |||||
| // | |||||
| // "library" provides user-defined functions. | |||||
| // | |||||
| // Naming: | |||||
| // * library.function.name are in a flat namespace. | |||||
| // NOTE: We may need to change it to be hierarchical to support | |||||
| // different orgs. E.g., | |||||
| // { "/google/nn", { ... }}, | |||||
| // { "/google/vision", { ... }} | |||||
| // { "/org_foo/module_bar", { ... }} | |||||
| // map<string, FunctionDefLib> named_lib; | |||||
| // * If node[i].op is the name of one function in "library", | |||||
| // node[i] is deemed as a function call. Otherwise, node[i].op | |||||
| // must be a primitive operation supported by the runtime. | |||||
| // | |||||
| // | |||||
| // Function call semantics: | |||||
| // | |||||
| // * The callee may start execution as soon as some of its inputs | |||||
| // are ready. The caller may want to use Tuple() mechanism to | |||||
| // ensure all inputs are ready in the same time. | |||||
| // | |||||
| // * The consumer of return values may start executing as soon as | |||||
| // the return values the consumer depends on are ready. The | |||||
| // consumer may want to use Tuple() mechanism to ensure the | |||||
| // consumer does not start until all return values of the callee | |||||
| // function are ready. | |||||
| FunctionDefLibrary library = 2; | |||||
| }; | |||||
| @@ -0,0 +1,63 @@ | |||||
| syntax = "proto3"; | |||||
| package domi.tensorflow; | |||||
| option cc_enable_arenas = true; | |||||
| option java_outer_classname = "NodeProto"; | |||||
| option java_multiple_files = true; | |||||
| option java_package = "org.tensorflow.framework"; | |||||
| import "attr_value.proto"; | |||||
| message NodeDef { | |||||
| // The name given to this operator. Used for naming inputs, | |||||
| // logging, visualization, etc. Unique within a single GraphDef. | |||||
| // Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*". | |||||
| string name = 1; | |||||
| // The operation name. There may be custom parameters in attrs. | |||||
| // Op names starting with an underscore are reserved for internal use. | |||||
| string op = 2; | |||||
| // Each input is "node:src_output" with "node" being a string name and | |||||
| // "src_output" indicating which output tensor to use from "node". If | |||||
| // "src_output" is 0 the ":0" suffix can be omitted. Regular inputs | |||||
| // may optionally be followed by control inputs that have the format | |||||
| // "^node". | |||||
| repeated string input = 3; | |||||
| // A (possibly partial) specification for the device on which this | |||||
| // node should be placed. | |||||
| // The expected syntax for this string is as follows: | |||||
| // | |||||
| // DEVICE_SPEC ::= PARTIAL_SPEC | |||||
| // | |||||
| // PARTIAL_SPEC ::= ("/" CONSTRAINT) * | |||||
| // CONSTRAINT ::= ("job:" JOB_NAME) | |||||
| // | ("replica:" [1-9][0-9]*) | |||||
| // | ("task:" [1-9][0-9]*) | |||||
| // | ("device:" [A-Za-z]* ":" ([1-9][0-9]* | "*") ) | |||||
| // | |||||
| // Valid values for this string include: | |||||
| // * "/job:worker/replica:0/task:1/device:GPU:3" (full specification) | |||||
| // * "/job:worker/device:GPU:3" (partial specification) | |||||
| // * "" (no specification) | |||||
| // | |||||
| // If the constraints do not resolve to a single device (or if this | |||||
| // field is empty or not present), the runtime will attempt to | |||||
| // choose a device automatically. | |||||
| string device = 4; | |||||
| // Operation-specific graph-construction-time configuration. | |||||
| // Note that this should include all attrs defined in the | |||||
| // corresponding OpDef, including those with a value matching | |||||
| // the default -- this allows the default to change and makes | |||||
| // NodeDefs easier to interpret on their own. However, if | |||||
| // an attr with a default is not specified in this list, the | |||||
| // default will be used. | |||||
| // The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and | |||||
| // one of the names from the corresponding OpDef's attr field). | |||||
| // The values must have a type matching the corresponding OpDef | |||||
| // attr's type field. | |||||
| // Add some examples here showing best practices. | |||||
| map<string, AttrValue> attr = 5; | |||||
| }; | |||||
| @@ -0,0 +1,164 @@ | |||||
| syntax = "proto3"; | |||||
| package domi.tensorflow; | |||||
| option cc_enable_arenas = true; | |||||
| option java_outer_classname = "OpDefProtos"; | |||||
| option java_multiple_files = true; | |||||
| option java_package = "org.tensorflow.framework"; | |||||
| import "attr_value.proto"; | |||||
| import "types.proto"; | |||||
| // Defines an operation. A NodeDef in a GraphDef specifies an Op by | |||||
| // using the "op" field which should match the name of a OpDef. | |||||
| // LINT.IfChange | |||||
| message OpDef { | |||||
| // Op names starting with an underscore are reserved for internal use. | |||||
| // Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*". | |||||
| string name = 1; | |||||
| // For describing inputs and outputs. | |||||
| message ArgDef { | |||||
| // Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*". | |||||
| string name = 1; | |||||
| // Human readable description. | |||||
| string description = 2; | |||||
| // Describes the type of one or more tensors that are accepted/produced | |||||
| // by this input/output arg. The only legal combinations are: | |||||
| // * For a single tensor: either the "type" field is set or the | |||||
| // "type_attr" field is set to the name of an attr with type "type". | |||||
| // * For a sequence of tensors with the same type: the "number_attr" | |||||
| // field will be set to the name of an attr with type "int", and | |||||
| // either the "type" or "type_attr" field will be set as for | |||||
| // single tensors. | |||||
| // * For a sequence of tensors, the "type_list_attr" field will be set | |||||
| // to the name of an attr with type "list(type)". | |||||
| DataType type = 3; | |||||
| string type_attr = 4; // if specified, attr must have type "type" | |||||
| string number_attr = 5; // if specified, attr must have type "int" | |||||
| // If specified, attr must have type "list(type)", and none of | |||||
| // type, type_attr, and number_attr may be specified. | |||||
| string type_list_attr = 6; | |||||
| // For inputs: if true, the inputs are required to be refs. | |||||
| // By default, inputs can be either refs or non-refs. | |||||
| // For outputs: if true, outputs are refs, otherwise they are not. | |||||
| bool is_ref = 16; | |||||
| }; | |||||
| // Description of the input(s). | |||||
| repeated ArgDef input_arg = 2; | |||||
| // Description of the output(s). | |||||
| repeated ArgDef output_arg = 3; | |||||
| // Description of the graph-construction-time configuration of this | |||||
| // Op. That is to say, this describes the attr fields that will | |||||
| // be specified in the NodeDef. | |||||
| message AttrDef { | |||||
| // A descriptive name for the argument. May be used, e.g. by the | |||||
| // Python client, as a keyword argument name, and so should match | |||||
| // the regexp "[a-z][a-z0-9_]+". | |||||
| string name = 1; | |||||
| // One of the type names from attr_value.proto ("string", "list(string)", | |||||
| // "int", etc.). | |||||
| string type = 2; | |||||
| // A reasonable default for this attribute if the user does not supply | |||||
| // a value. If not specified, the user must supply a value. | |||||
| AttrValue default_value = 3; | |||||
| // Human-readable description. | |||||
| string description = 4; | |||||
| // --- Constraints --- | |||||
| // These constraints are only in effect if specified. Default is no | |||||
| // constraints. | |||||
| // For type == "int", this is a minimum value. For "list(___)" | |||||
| // types, this is the minimum length. | |||||
| bool has_minimum = 5; | |||||
| int64 minimum = 6; | |||||
| // The set of allowed values. Has type that is the "list" version | |||||
| // of the "type" field above (uses the "list" field of AttrValue). | |||||
| // If type == "type" or "list(type)" above, then the "type" field | |||||
| // of "allowed_values.list" has the set of allowed DataTypes. | |||||
| // If type == "string" or "list(string)", then the "s" field of | |||||
| // "allowed_values.list" has the set of allowed strings. | |||||
| AttrValue allowed_values = 7; | |||||
| } | |||||
| repeated AttrDef attr = 4; | |||||
| // Optional deprecation based on GraphDef versions. | |||||
| OpDeprecation deprecation = 8; | |||||
| // One-line human-readable description of what the Op does. | |||||
| string summary = 5; | |||||
| // Additional, longer human-readable description of what the Op does. | |||||
| string description = 6; | |||||
| // ------------------------------------------------------------------------- | |||||
| // Which optimizations this operation can participate in. | |||||
| // True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs) | |||||
| bool is_commutative = 18; | |||||
| // If is_aggregate is true, then this operation accepts N >= 2 | |||||
| // inputs and produces 1 output all of the same type. Should be | |||||
| // associative and commutative, and produce output with the same | |||||
| // shape as the input. The optimizer may replace an aggregate op | |||||
| // taking input from multiple devices with a tree of aggregate ops | |||||
| // that aggregate locally within each device (and possibly within | |||||
| // groups of nearby devices) before communicating. | |||||
| bool is_aggregate = 16; // for things like add | |||||
| // Other optimizations go here, like | |||||
| // can_alias_input, rewrite_when_output_unused, partitioning_strategy, etc. | |||||
| // ------------------------------------------------------------------------- | |||||
| // Optimization constraints. | |||||
| // Ops are marked as stateful if their behavior depends on some state beyond | |||||
| // their input tensors (e.g. variable reading op) or if they have | |||||
| // a side-effect (e.g. printing or asserting ops). Equivalently, stateless ops | |||||
| // must always produce the same output for the same input and have | |||||
| // no side-effects. | |||||
| // | |||||
| // By default Ops may be moved between devices. Stateful ops should | |||||
| // either not be moved, or should only be moved if that state can also | |||||
| // be moved (e.g. via some sort of save / restore). | |||||
| // Stateful ops are guaranteed to never be optimized away by Common | |||||
| // Subexpression Elimination (CSE). | |||||
| bool is_stateful = 17; // for things like variables, queue | |||||
| // ------------------------------------------------------------------------- | |||||
| // Non-standard options. | |||||
| // By default, all inputs to an Op must be initialized Tensors. Ops | |||||
| // that may initialize tensors for the first time should set this | |||||
| // field to true, to allow the Op to take an uninitialized Tensor as | |||||
| // input. | |||||
| bool allows_uninitialized_input = 19; // for Assign, etc. | |||||
| }; | |||||
| // LINT.ThenChange( | |||||
| // https://www.tensorflow.org/code/tensorflow/core/framework/op_def_util.cc) | |||||
| // Information about version-dependent deprecation of an op | |||||
| message OpDeprecation { | |||||
| // First GraphDef version at which the op is disallowed. | |||||
| int32 version = 1; | |||||
| // Explanation of why it was deprecated and what to use instead. | |||||
| string explanation = 2; | |||||
| }; | |||||
| // A collection of OpDefs | |||||
| message OpList { | |||||
| repeated OpDef op = 1; | |||||
| }; | |||||
| @@ -0,0 +1,29 @@ | |||||
| syntax = "proto3"; | |||||
| package domi.tensorflow; | |||||
| option cc_enable_arenas = true; | |||||
| option java_outer_classname = "ResourceHandle"; | |||||
| option java_multiple_files = true; | |||||
| option java_package = "org.tensorflow.framework"; | |||||
| // Protocol buffer representing a handle to a tensorflow resource. Handles are | |||||
| // not valid across executions, but can be serialized back and forth from within | |||||
| // a single run. | |||||
| message ResourceHandleProto { | |||||
| // Unique name for the device containing the resource. | |||||
| string device = 1; | |||||
| // Container in which this resource is placed. | |||||
| string container = 2; | |||||
| // Unique name of this resource. | |||||
| string name = 3; | |||||
| // Hash code for the type of the resource. Is only valid in the same device | |||||
| // and in the same execution. | |||||
| uint64 hash_code = 4; | |||||
| // For debug-only, the name of the type pointed to by this handle, if | |||||
| // available. | |||||
| string maybe_type_name = 5; | |||||
| }; | |||||
| @@ -0,0 +1,94 @@ | |||||
| syntax = "proto3"; | |||||
| package domi.tensorflow; | |||||
| option cc_enable_arenas = true; | |||||
| option java_outer_classname = "TensorProtos"; | |||||
| option java_multiple_files = true; | |||||
| option java_package = "org.tensorflow.framework"; | |||||
| import "resource_handle.proto"; | |||||
| import "tensor_shape.proto"; | |||||
| import "types.proto"; | |||||
| // Protocol buffer representing a tensor. | |||||
| message TensorProto { | |||||
| DataType dtype = 1; | |||||
| // Shape of the tensor. | |||||
| TensorShapeProto tensor_shape = 2; | |||||
| // Only one of the representations below is set, one of "tensor_contents" and | |||||
| // the "xxx_val" attributes. We are not using oneof because as oneofs cannot | |||||
| // contain repeated fields it would require another extra set of messages. | |||||
| // Version number. | |||||
| // | |||||
| // In version 0, if the "repeated xxx" representations contain only one | |||||
| // element, that element is repeated to fill the shape. This makes it easy | |||||
| // to represent a constant Tensor with a single value. | |||||
| int32 version_number = 3; | |||||
| // Serialized raw tensor content from either Tensor::AsProtoTensorContent or | |||||
| // memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation | |||||
| // can be used for all tensor types. The purpose of this representation is to | |||||
| // reduce serialization overhead during RPC call by avoiding serialization of | |||||
| // many repeated small items. | |||||
| bytes tensor_content = 4; | |||||
| // Type specific representations that make it easy to create tensor protos in | |||||
| // all languages. Only the representation corresponding to "dtype" can | |||||
| // be set. The values hold the flattened representation of the tensor in | |||||
| // row major order. | |||||
| // DT_HALF, DT_BFLOAT16. Note that since protobuf has no int16 type, we'll | |||||
| // have some pointless zero padding for each value here. | |||||
| repeated int32 half_val = 13 [packed = true]; | |||||
| // DT_FLOAT. | |||||
| repeated float float_val = 5 [packed = true]; | |||||
| // DT_DOUBLE. | |||||
| repeated double double_val = 6 [packed = true]; | |||||
| // DT_INT32, DT_INT16, DT_INT8, DT_UINT8. | |||||
| repeated int32 int_val = 7 [packed = true]; | |||||
| // DT_STRING | |||||
| repeated bytes string_val = 8; | |||||
| // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real | |||||
| // and imaginary parts of i-th single precision complex. | |||||
| repeated float scomplex_val = 9 [packed = true]; | |||||
| // DT_INT64 | |||||
| repeated int64 int64_val = 10 [packed = true]; | |||||
| // DT_BOOL | |||||
| repeated bool bool_val = 11 [packed = true]; | |||||
| // DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real | |||||
| // and imaginary parts of i-th double precision complex. | |||||
| repeated double dcomplex_val = 12 [packed = true]; | |||||
| // DT_RESOURCE | |||||
| repeated ResourceHandleProto resource_handle_val = 14; | |||||
| // DT_VARIANT | |||||
| repeated VariantTensorDataProto variant_val = 15; | |||||
| // DT_UINT32 | |||||
| repeated uint32 uint32_val = 16 [packed = true]; | |||||
| // DT_UINT64 | |||||
| repeated uint64 uint64_val = 17 [packed = true]; | |||||
| }; | |||||
| // Protocol buffer representing the serialization format of DT_VARIANT tensors. | |||||
| message VariantTensorDataProto { | |||||
| // Name of the type of objects being serialized. | |||||
| string type_name = 1; | |||||
| // Portions of the object that are not Tensors. | |||||
| bytes metadata = 2; | |||||
| // Tensors contained within objects being serialized. | |||||
| repeated TensorProto tensors = 3; | |||||
| } | |||||
| @@ -0,0 +1,45 @@ | |||||
| // Protocol buffer representing the shape of tensors. | |||||
| syntax = "proto3"; | |||||
| option cc_enable_arenas = true; | |||||
| option java_outer_classname = "TensorShapeProtos"; | |||||
| option java_multiple_files = true; | |||||
| option java_package = "org.tensorflow.framework"; | |||||
| package domi.tensorflow; | |||||
| // Dimensions of a tensor. | |||||
| message TensorShapeProto { | |||||
| // One dimension of the tensor. | |||||
| message Dim { | |||||
| // Size of the tensor in that dimension. | |||||
| // This value must be >= -1, but values of -1 are reserved for "unknown" | |||||
| // shapes (values of -1 mean "unknown" dimension). Certain wrappers | |||||
| // that work with TensorShapeProto may fail at runtime when deserializing | |||||
| // a TensorShapeProto containing a dim value of -1. | |||||
| int64 size = 1; | |||||
| // Optional name of the tensor dimension. | |||||
| string name = 2; | |||||
| }; | |||||
| // Dimensions of the tensor, such as {"input", 30}, {"output", 40} | |||||
| // for a 30 x 40 2D tensor. If an entry has size -1, this | |||||
| // corresponds to a dimension of unknown size. The names are | |||||
| // optional. | |||||
| // | |||||
| // The order of entries in "dim" matters: It indicates the layout of the | |||||
| // values in the tensor in-memory representation. | |||||
| // | |||||
| // The first entry in "dim" is the outermost dimension used to layout the | |||||
| // values, the last entry is the innermost dimension. This matches the | |||||
| // in-memory layout of RowMajor Eigen tensors. | |||||
| // | |||||
| // If "dim.size()" > 0, "unknown_rank" must be false. | |||||
| repeated Dim dim = 2; | |||||
| // If true, the number of dimensions in the shape is unknown. | |||||
| // | |||||
| // If true, "dim.size()" must be 0. | |||||
| bool unknown_rank = 3; | |||||
| }; | |||||
| @@ -0,0 +1,74 @@ | |||||
| syntax = "proto3"; | |||||
| package domi.tensorflow; | |||||
| option cc_enable_arenas = true; | |||||
| option java_outer_classname = "TypesProtos"; | |||||
| option java_multiple_files = true; | |||||
| option java_package = "org.tensorflow.framework"; | |||||
| // LINT.IfChange | |||||
| enum DataType { | |||||
| // Not a legal value for DataType. Used to indicate a DataType field | |||||
| // has not been set. | |||||
| DT_INVALID = 0; | |||||
| // Data types that all computation devices are expected to be | |||||
| // capable to support. | |||||
| DT_FLOAT = 1; | |||||
| DT_DOUBLE = 2; | |||||
| DT_INT32 = 3; | |||||
| DT_UINT8 = 4; | |||||
| DT_INT16 = 5; | |||||
| DT_INT8 = 6; | |||||
| DT_STRING = 7; | |||||
| DT_COMPLEX64 = 8; // Single-precision complex | |||||
| DT_INT64 = 9; | |||||
| DT_BOOL = 10; | |||||
| DT_QINT8 = 11; // Quantized int8 | |||||
| DT_QUINT8 = 12; // Quantized uint8 | |||||
| DT_QINT32 = 13; // Quantized int32 | |||||
| DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops. | |||||
| DT_QINT16 = 15; // Quantized int16 | |||||
| DT_QUINT16 = 16; // Quantized uint16 | |||||
| DT_UINT16 = 17; | |||||
| DT_COMPLEX128 = 18; // Double-precision complex | |||||
| DT_HALF = 19; | |||||
| DT_RESOURCE = 20; | |||||
| DT_VARIANT = 21; // Arbitrary C++ data types | |||||
| DT_UINT32 = 22; | |||||
| DT_UINT64 = 23; | |||||
| // Do not use! These are only for parameters. Every enum above | |||||
| // should have a corresponding value below (verified by types_test). | |||||
| DT_FLOAT_REF = 101; | |||||
| DT_DOUBLE_REF = 102; | |||||
| DT_INT32_REF = 103; | |||||
| DT_UINT8_REF = 104; | |||||
| DT_INT16_REF = 105; | |||||
| DT_INT8_REF = 106; | |||||
| DT_STRING_REF = 107; | |||||
| DT_COMPLEX64_REF = 108; | |||||
| DT_INT64_REF = 109; | |||||
| DT_BOOL_REF = 110; | |||||
| DT_QINT8_REF = 111; | |||||
| DT_QUINT8_REF = 112; | |||||
| DT_QINT32_REF = 113; | |||||
| DT_BFLOAT16_REF = 114; | |||||
| DT_QINT16_REF = 115; | |||||
| DT_QUINT16_REF = 116; | |||||
| DT_UINT16_REF = 117; | |||||
| DT_COMPLEX128_REF = 118; | |||||
| DT_HALF_REF = 119; | |||||
| DT_RESOURCE_REF = 120; | |||||
| DT_VARIANT_REF = 121; | |||||
| DT_UINT32_REF = 122; | |||||
| DT_UINT64_REF = 123; | |||||
| } | |||||
| // LINT.ThenChange( | |||||
| // https://www.tensorflow.org/code/tensorflow/c/c_api.h, | |||||
| // https://www.tensorflow.org/code/tensorflow/go/tensor.go, | |||||
| // https://www.tensorflow.org/code/tensorflow/core/framework/tensor.cc, | |||||
| // https://www.tensorflow.org/code/tensorflow/core/framework/types.h, | |||||
| // https://www.tensorflow.org/code/tensorflow/core/framework/types.cc, | |||||
| // https://www.tensorflow.org/code/tensorflow/python/framework/dtypes.py, | |||||
| // https://www.tensorflow.org/code/tensorflow/python/framework/function.py) | |||||
| @@ -0,0 +1,31 @@ | |||||
| syntax = "proto3"; | |||||
| package domi.tensorflow; | |||||
| option cc_enable_arenas = true; | |||||
| option java_outer_classname = "VersionsProtos"; | |||||
| option java_multiple_files = true; | |||||
| option java_package = "org.tensorflow.framework"; | |||||
| // Version information for a piece of serialized data | |||||
| // | |||||
| // There are different types of versions for each type of data | |||||
| // (GraphDef, etc.), but they all have the same common shape | |||||
| // described here. | |||||
| // | |||||
| // Each consumer has "consumer" and "min_producer" versions (specified | |||||
| // elsewhere). A consumer is allowed to consume this data if | |||||
| // | |||||
| // producer >= min_producer | |||||
| // consumer >= min_consumer | |||||
| // consumer not in bad_consumers | |||||
| // | |||||
| message VersionDef { | |||||
| // The version of the code that produced this data. | |||||
| int32 producer = 1; | |||||
| // Any consumer below this version is not allowed to consume this data. | |||||
| int32 min_consumer = 2; | |||||
| // Specific consumer versions which are disallowed (e.g. due to bugs). | |||||
| repeated int32 bad_consumers = 3; | |||||
| }; | |||||
| @@ -0,0 +1,528 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "parser/common/proto_file_parser.h" | |||||
| #include <iostream> | |||||
| #include <fstream> | |||||
| #include <sstream> | |||||
| #include <vector> | |||||
| #include <random> | |||||
| #include <sys/types.h> | |||||
| #include <unistd.h> | |||||
| #include "common/string_util.h" | |||||
| #include "common/types.h" | |||||
| #include "common/util.h" | |||||
| #include "common/debug/log.h" | |||||
| #include "parser/common/acl_graph_parser_util.h" | |||||
| #include "ge/ge_api_types.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| using std::ifstream; | |||||
| using std::vector; | |||||
| using std::string; | |||||
| namespace { | |||||
| const char kMinNum = '0'; | |||||
| const char kMaxNum = '9'; | |||||
| const int kMinLineWordSize = 3; | |||||
| const int kMinMessageLineWords = 2; | |||||
| const int kMaxIdentifier = 536870912; // 2^29 - 1 | |||||
| const int kTmpFileNameLen = 16; | |||||
| const int kMinRandomNum = 0; | |||||
| const int kMaxRandomNum = 9; | |||||
| const int kDecimalMulti = 10; | |||||
| const int kOpenRetValue = 0; | |||||
| const int kMessageNameIndex = 2; | |||||
| const char *const kTmpPath = "/tmp"; | |||||
| const char *const kMessage = "message"; | |||||
| const char *const kLayerParameter = "LayerParameter"; | |||||
| const char *const kNetParameter = "NetParameter"; | |||||
| const char *const kStartBrace = "{"; | |||||
| const char *const kCloseBrace = "}"; | |||||
| const char *const kOptional = "optional"; | |||||
| const char *const kRepeated = "repeated"; | |||||
| const char *const kRequired = "required"; | |||||
| bool GetIdentifier(const std::string &line, int &identifier) { | |||||
| int size = line.size(); | |||||
| auto pos = line.find("="); | |||||
| if (pos == std::string::npos) { | |||||
| return false; | |||||
| } | |||||
| for (int i = pos + 1; i < size; i++) { | |||||
| if (line[i] == ';') { | |||||
| break; | |||||
| } | |||||
| if (line[i] >= kMinNum && line[i] <= kMaxNum) { | |||||
| identifier = identifier * kDecimalMulti + line[i] - kMinNum; | |||||
| } | |||||
| if (identifier > kMaxIdentifier || identifier < 0) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| if (identifier == 0) { | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| void GetName(const std::string &op_info, string &op_name) { | |||||
| op_name.assign(op_info); | |||||
| auto pos = op_name.find("="); | |||||
| if (pos != string::npos) { | |||||
| op_name = op_name.substr(0, pos); | |||||
| } | |||||
| } | |||||
| void GetOpParamInfo(const std::string &line, std::vector<std::string> &op_param_info) { | |||||
| std::istringstream string_stream(line); | |||||
| std::string temp; | |||||
| while (std::getline(string_stream, temp, ' ')) { | |||||
| if (temp.empty()) { | |||||
| continue; | |||||
| } | |||||
| op_param_info.emplace_back(std::move(temp)); | |||||
| } | |||||
| } | |||||
| string GetMessageName(const std::string &line) { | |||||
| std::vector<std::string> op_param_info; | |||||
| GetOpParamInfo(line, op_param_info); | |||||
| string message_name; | |||||
| if (op_param_info.size() < kMinMessageLineWords) { | |||||
| message_name = ""; | |||||
| return message_name; | |||||
| } | |||||
| message_name = op_param_info[1]; | |||||
| auto pos = message_name.find(kStartBrace); | |||||
| if (pos != string::npos) { | |||||
| message_name = message_name.substr(0, pos); | |||||
| } | |||||
| return message_name; | |||||
| } | |||||
| string CreatTmpName(int len) { | |||||
| std::uniform_int_distribution<int> u(kMinRandomNum, kMaxRandomNum); | |||||
| std::default_random_engine e; | |||||
| e.seed(time(0)); | |||||
| string tmp_name = ""; | |||||
| for (int i = 0; i < len; i++) { | |||||
| tmp_name += std::to_string(u(e)); | |||||
| } | |||||
| return tmp_name; | |||||
| } | |||||
| bool SaveIdentifierOpMapInfo(const string &line, std::map<int, std::pair<string, string>> &identifier_op_map, | |||||
| std::map<std::string, std::pair<int, string>> &op_identifier_map) { | |||||
| std::vector<std::string> op_param_info; | |||||
| GetOpParamInfo(line, op_param_info); | |||||
| int info_size = op_param_info.size(); | |||||
| if (info_size < kMinLineWordSize) { | |||||
| GELOGE(ge::FAILED, "Words size of line[%s] is less than kMinLineWordSize[%d].", line.c_str(), kMinLineWordSize); | |||||
| return false; | |||||
| } | |||||
| if (op_param_info[0] != kOptional && op_param_info[0] != kRepeated && op_param_info[0] != kRequired) { | |||||
| GELOGE(ge::FAILED, "Split line[%s] failed.", line.c_str()); | |||||
| return false; | |||||
| } | |||||
| // get identifier | |||||
| int identifier = 0; | |||||
| bool ret = GetIdentifier(line, identifier); | |||||
| if (!ret) { | |||||
| GELOGE(ge::FAILED, "Get identifier of line[%s] failed.", line.c_str()); | |||||
| return false; | |||||
| } | |||||
| // get op_name | |||||
| string name; | |||||
| GetName(op_param_info[kMessageNameIndex], name); | |||||
| identifier_op_map[identifier] = std::make_pair(op_param_info[1], name); | |||||
| op_identifier_map[name] = std::make_pair(identifier, op_param_info[1]); | |||||
| return true; | |||||
| } | |||||
| bool CheckRealPath(const char *file_path) { | |||||
| string dest_path = ge::parser::RealPath(file_path); | |||||
| if (dest_path.empty()) { | |||||
| GELOGW("Path [%s] is not real existed.", file_path); | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| } // namespace | |||||
| namespace ge { | |||||
| ProtoFileParser::~ProtoFileParser() { | |||||
| if (!fusion_proto_path.empty() && CheckRealPath(fusion_proto_path.c_str())) { | |||||
| (void)remove(fusion_proto_path.c_str()); | |||||
| } | |||||
| } | |||||
| std::string ProtoFileParser::GetFusionProtoFile() { | |||||
| return fusion_proto_path; | |||||
| } | |||||
| Status ProtoFileParser::CreatProtoFile() { | |||||
| if (fusion_proto_path.empty()) { | |||||
| fusion_proto_path.assign(kTmpPath); | |||||
| fusion_proto_path += "/" + CreatTmpName(kTmpFileNameLen); | |||||
| } | |||||
| int fd = open(fusion_proto_path.c_str(), O_RDWR | O_CREAT | O_TRUNC, S_IRUSR | S_IWUSR | S_IRGRP); | |||||
| if (fd < kOpenRetValue) { | |||||
| GELOGE(FAILED, "creat tmp proto file[%s] failed.", fusion_proto_path.c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| close(fd); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ProtoFileParser::ParseProtoFile(const string &proto_file, | |||||
| std::map<int, std::pair<string, string>> &identifier_op_map, | |||||
| std::map<std::string, std::pair<int, string>> &op_identifier_map) { | |||||
| ifstream read_file; | |||||
| read_file.open(proto_file, std::ios::in); | |||||
| if (read_file.fail()) { | |||||
| GELOGE(FAILED, "ifsream open proto file[%s] failed.", proto_file.c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| std::string line; | |||||
| bool save_flag = false; | |||||
| while (std::getline(read_file, line)) { | |||||
| if (line.find(kMessage) != std::string::npos && line.find(kLayerParameter) != std::string::npos) { | |||||
| save_flag = true; | |||||
| continue; | |||||
| } | |||||
| if (save_flag && line.find(kCloseBrace) != std::string::npos) { | |||||
| save_flag = false; | |||||
| break; | |||||
| } | |||||
| if (save_flag) { | |||||
| if (line.find(kRepeated) == std::string::npos && line.find(kOptional) == std::string::npos && | |||||
| line.find(kRequired) == std::string::npos) { | |||||
| continue; | |||||
| } | |||||
| bool ret = SaveIdentifierOpMapInfo(line, identifier_op_map, op_identifier_map); | |||||
| if (!ret) { | |||||
| read_file.close(); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| read_file.close(); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ProtoFileParser::AddCustomAndConflictLayer(const char *custom_proto_file, std::ofstream &write_tmp) { | |||||
| ifstream read_custom; | |||||
| read_custom.open(custom_proto_file, std::ios::in); | |||||
| if (read_custom.fail()) { | |||||
| GELOGE(FAILED, "ifsream open custom proto file[%s] failed.", custom_proto_file); | |||||
| return FAILED; | |||||
| } | |||||
| std::string line_custom; | |||||
| bool custom_in_layer = false; | |||||
| while (std::getline(read_custom, line_custom)) { | |||||
| if (line_custom.find(kMessage) != std::string::npos && line_custom.find(kLayerParameter) != std::string::npos) { | |||||
| custom_in_layer = true; | |||||
| continue; | |||||
| } | |||||
| if (!custom_in_layer) { | |||||
| continue; | |||||
| } | |||||
| if (line_custom.find(kCloseBrace) != std::string::npos) { | |||||
| custom_in_layer = false; | |||||
| break; | |||||
| } | |||||
| // exclude remark lines | |||||
| if (line_custom.find(kRepeated) == std::string::npos && line_custom.find(kOptional) == std::string::npos && | |||||
| line_custom.find(kRequired) == std::string::npos) { | |||||
| continue; | |||||
| } | |||||
| // exclude repeated lines | |||||
| if (custom_repeat_line_map_.count(line_custom) == 0) { | |||||
| write_tmp << line_custom << '\n'; | |||||
| } | |||||
| } | |||||
| read_custom.close(); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ProtoFileParser::AddCustomAndConflictMessage(const char *custom_proto_file, std::ofstream &write_tmp) { | |||||
| ifstream read_custom; | |||||
| read_custom.open(custom_proto_file, std::ios::in); | |||||
| if (read_custom.fail()) { | |||||
| GELOGE(FAILED, "ifsream open custom proto file[%s] failed.", custom_proto_file); | |||||
| return FAILED; | |||||
| } | |||||
| std::string line_custom; | |||||
| bool custom_in_message = false; | |||||
| while (std::getline(read_custom, line_custom)) { | |||||
| if (line_custom.find(kMessage) != std::string::npos) { | |||||
| std::string message_name = GetMessageName(line_custom); | |||||
| if (message_name != kLayerParameter && message_name != kNetParameter) { | |||||
| custom_in_message = true; | |||||
| write_tmp << line_custom << '\n'; | |||||
| } else { | |||||
| custom_in_message = false; | |||||
| } | |||||
| continue; | |||||
| } | |||||
| // exclude repeated messages | |||||
| if (custom_in_message) { | |||||
| write_tmp << line_custom << '\n'; | |||||
| } | |||||
| } | |||||
| read_custom.close(); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ProtoFileParser::WriteCaffeProtoFile(const char *custom_proto_file, | |||||
| std::ifstream &read_caffe, | |||||
| std::ofstream &write_tmp) { | |||||
| std::string line_caffe; | |||||
| bool caffe_in_layer = false; | |||||
| bool caffe_in_unrepeated_message = true; | |||||
| string tmp_message_name; | |||||
| while (std::getline(read_caffe, line_caffe)) { | |||||
| if (line_caffe.find(kMessage) != std::string::npos) { | |||||
| tmp_message_name.assign(GetMessageName(line_caffe)); | |||||
| if (custom_repeat_message_map_.count(tmp_message_name) > 0) { | |||||
| caffe_in_unrepeated_message = false; | |||||
| } else { | |||||
| caffe_in_unrepeated_message = true; | |||||
| if (tmp_message_name == kLayerParameter) { | |||||
| caffe_in_layer = true; | |||||
| } | |||||
| } | |||||
| } | |||||
| if (!caffe_in_unrepeated_message) { | |||||
| continue; | |||||
| } | |||||
| if (caffe_in_layer && line_caffe.find(kCloseBrace) != std::string::npos) { | |||||
| if (AddCustomAndConflictLayer(custom_proto_file, write_tmp) != SUCCESS) { | |||||
| GELOGE(FAILED, "Add conflict and new layer line from custom proto to dest proto failed."); | |||||
| return FAILED; | |||||
| } | |||||
| caffe_in_layer = false; | |||||
| } | |||||
| // exclude conflict lines | |||||
| if (caffe_in_layer && caffe_conflict_line_map_.count(line_caffe) > 0) { | |||||
| GELOGD("pass line: %s", line_caffe.c_str()); | |||||
| continue; | |||||
| } | |||||
| write_tmp << line_caffe << '\n'; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ProtoFileParser::WriteProtoFile(const char *caffe_proto_file, | |||||
| const char *custom_proto_file) { | |||||
| std::ifstream read_caffe; | |||||
| std::ofstream write_tmp; | |||||
| read_caffe.open(caffe_proto_file, std::ios::in); | |||||
| if (read_caffe.fail()) { | |||||
| GELOGE(FAILED, "ifsream open proto file[%s] failed.", caffe_proto_file); | |||||
| return FAILED; | |||||
| } | |||||
| write_tmp.open(fusion_proto_path, std::ios::out); | |||||
| if (write_tmp.fail()) { | |||||
| GELOGE(FAILED, "ofstream open proto file[%s] failed.", fusion_proto_path.c_str()); | |||||
| read_caffe.close(); | |||||
| return FAILED; | |||||
| } | |||||
| if (WriteCaffeProtoFile(custom_proto_file, read_caffe, write_tmp) != SUCCESS) { | |||||
| read_caffe.close(); | |||||
| write_tmp.close(); | |||||
| return FAILED; | |||||
| } | |||||
| if (AddCustomAndConflictMessage(custom_proto_file, write_tmp) != SUCCESS) { | |||||
| GELOGE(FAILED, "Add conflict and new message from custom proto to dest proto failed."); | |||||
| read_caffe.close(); | |||||
| write_tmp.close(); | |||||
| return FAILED; | |||||
| } | |||||
| read_caffe.close(); | |||||
| write_tmp.close(); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ProtoFileParser::FindConflictLine(const char *proto_file, int identifier, | |||||
| std::string &dest_line) { | |||||
| ifstream read_file; | |||||
| read_file.open(proto_file, std::ios::in); | |||||
| if (read_file.fail()) { | |||||
| GELOGE(FAILED, "open file[%s] failed.", proto_file); | |||||
| return FAILED; | |||||
| } | |||||
| std::string line; | |||||
| bool save_flag = false; | |||||
| while (std::getline(read_file, line)) { | |||||
| if (line.find(kMessage) != std::string::npos && line.find(kLayerParameter) != std::string::npos) { | |||||
| save_flag = true; | |||||
| continue; | |||||
| } | |||||
| if (save_flag && line.find(kCloseBrace) != std::string::npos) { | |||||
| save_flag = false; | |||||
| break; | |||||
| } | |||||
| int tmp_identifier = 0; | |||||
| if (save_flag && GetIdentifier(line, tmp_identifier) && tmp_identifier == identifier) { | |||||
| dest_line.assign(line); | |||||
| read_file.close(); | |||||
| return SUCCESS; | |||||
| } | |||||
| } | |||||
| read_file.close(); | |||||
| GELOGE(FAILED, "find line according to identifier[%d] failed.", identifier); | |||||
| return FAILED; | |||||
| } | |||||
| void ProtoFileParser::CheckConflictOp(const char *caffe_proto_file, const char *custom_proto_file, | |||||
| std::map<std::string, std::pair<int, string>> &caffe_op_identifier_map, | |||||
| std::map<std::string, std::pair<int, string>> &custom_op_identifier_map) { | |||||
| for (auto iter = custom_op_identifier_map.begin(); iter != custom_op_identifier_map.end(); ++iter) { | |||||
| if (caffe_op_identifier_map.count(iter->first) > 0) { | |||||
| string message_name = iter->first; | |||||
| auto caffe_pair = caffe_op_identifier_map[iter->first]; | |||||
| auto custom_pair = custom_op_identifier_map[iter->first]; | |||||
| if (caffe_pair.first != custom_pair.first || caffe_pair.second != custom_pair.second) { | |||||
| // consider conflict op and name and type; | |||||
| GELOGD("Find conflict op: caffe_identifier[%d], custom_identifier[%d], op_name[%s].", | |||||
| caffe_pair.first, custom_pair.first, message_name.c_str()); | |||||
| std::string caffe_conflict_line; | |||||
| (void)FindConflictLine(caffe_proto_file, caffe_pair.first, caffe_conflict_line); | |||||
| GELOGD("conflict: %s", caffe_conflict_line.c_str()); | |||||
| caffe_conflict_line_map_[caffe_conflict_line]++; | |||||
| } else { | |||||
| // consider repeat op and name and type; could be removed | |||||
| std::string custom_repeat_line; | |||||
| (void)FindConflictLine(custom_proto_file, caffe_pair.first, custom_repeat_line); | |||||
| custom_repeat_line_map_[custom_repeat_line]++; | |||||
| GELOGD("repeat: %s", custom_repeat_line.c_str()); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| void ProtoFileParser::CheckConflictIdentifier(const char *caffe_proto_file, const char *custom_proto_file, | |||||
| std::map<int, std::pair<string, string>> caffe_identifier_op_map, | |||||
| std::map<int, std::pair<string, string>> custom_identifier_op_map) { | |||||
| for (auto iter = custom_identifier_op_map.begin(); iter != custom_identifier_op_map.end(); ++iter) { | |||||
| if (caffe_identifier_op_map.count(iter->first) > 0) { | |||||
| int identifier = iter->first; | |||||
| auto caffe_pair = caffe_identifier_op_map[iter->first]; | |||||
| auto custom_pair = custom_identifier_op_map[iter->first]; | |||||
| if (caffe_pair.first != custom_pair.first || caffe_pair.second != custom_pair.second) { | |||||
| // consider conflict op and name and type; | |||||
| GELOGD("Find conflict op: caffe_op[%s], custom_op[%s], identifier[%d].", | |||||
| caffe_pair.first.c_str(), custom_pair.first.c_str(), | |||||
| identifier); | |||||
| std::string caffe_conflict_line; | |||||
| (void)FindConflictLine(caffe_proto_file, identifier, caffe_conflict_line); | |||||
| GELOGD("conflict: %s", caffe_conflict_line.c_str()); | |||||
| caffe_conflict_line_map_[caffe_conflict_line]++; | |||||
| } else { | |||||
| // consider repeat op and name and type; | |||||
| std::string custom_repeat_line; | |||||
| (void)FindConflictLine(custom_proto_file, identifier, custom_repeat_line); | |||||
| custom_repeat_line_map_[custom_repeat_line]++; | |||||
| GELOGD("repeat: %s", custom_repeat_line.c_str()); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| Status ProtoFileParser::RecordProtoMessage(const string &proto_file) { | |||||
| ifstream read_file; | |||||
| read_file.open(proto_file, std::ios::in); | |||||
| if (read_file.fail()) { | |||||
| GELOGE(FAILED, "ifsream open proto file[%s] failed.", proto_file.c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| std::string line; | |||||
| while (std::getline(read_file, line)) { | |||||
| if (line.find(kMessage) != std::string::npos) { | |||||
| std::string message_name = GetMessageName(line); | |||||
| if (message_name != kLayerParameter && message_name != kNetParameter) { | |||||
| custom_repeat_message_map_[message_name]++; | |||||
| } | |||||
| } | |||||
| } | |||||
| read_file.close(); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ProtoFileParser::CombineProtoFile(const char *caffe_proto_file, const char *custom_proto_file, | |||||
| std::string &dest_proto_file) { | |||||
| GE_CHECK_NOTNULL(caffe_proto_file); | |||||
| GE_CHECK_NOTNULL(custom_proto_file); | |||||
| if (!CheckRealPath(caffe_proto_file) || !CheckRealPath(custom_proto_file)) { | |||||
| GELOGE(FAILED, "caffe proto[%s] and custom proto[%s] are not all existed.", | |||||
| caffe_proto_file, custom_proto_file); | |||||
| return FAILED; | |||||
| } | |||||
| GELOGI("Start fusion custom and caffe proto to file."); | |||||
| std::map<int, std::pair<string, string>> caffe_identifier_op_map; | |||||
| std::map<int, std::pair<string, string>> custom_identifier_op_map; | |||||
| std::map<std::string, std::pair<int, string>> caffe_op_identifier_map; | |||||
| std::map<std::string, std::pair<int, string>> custom_op_identifier_map; | |||||
| (void)ParseProtoFile(caffe_proto_file, caffe_identifier_op_map, caffe_op_identifier_map); | |||||
| (void)ParseProtoFile(custom_proto_file, custom_identifier_op_map, custom_op_identifier_map); | |||||
| (void)RecordProtoMessage(custom_proto_file); | |||||
| // check identifier or op_type is same | |||||
| CheckConflictIdentifier(caffe_proto_file, custom_proto_file, | |||||
| caffe_identifier_op_map, custom_identifier_op_map); | |||||
| CheckConflictOp(caffe_proto_file, custom_proto_file, | |||||
| caffe_op_identifier_map, custom_op_identifier_map); | |||||
| if (CreatProtoFile() != SUCCESS) { | |||||
| return FAILED; | |||||
| } | |||||
| if (WriteProtoFile(caffe_proto_file, custom_proto_file) != SUCCESS) { | |||||
| GELOGE(FAILED, "Combine caffe proto and custom proto to dest proto file failed."); | |||||
| return FAILED; | |||||
| } | |||||
| dest_proto_file.assign(fusion_proto_path); | |||||
| GELOGI("Fusion custom and caffe proto to file[%s] success.", dest_proto_file.c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,63 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PROTO_FILE_PARSE_UTIL_ | |||||
| #define PROTO_FILE_PARSE_UTIL_ | |||||
| #include <map> | |||||
| #include <string> | |||||
| #include "common/types.h" | |||||
| #include "ge/ge_api_types.h" | |||||
| namespace ge { | |||||
| class ProtoFileParser { | |||||
| public: | |||||
| ProtoFileParser(){}; | |||||
| ProtoFileParser(const char *dest_path){ | |||||
| fusion_proto_path = dest_path; | |||||
| } | |||||
| ~ProtoFileParser(); | |||||
| Status CombineProtoFile(const char *caffe_proto_file, const char *custom_proto_file, | |||||
| std::string &dest_proto_file); | |||||
| std::string GetFusionProtoFile(); | |||||
| private: | |||||
| Status CreatProtoFile(); | |||||
| Status ParseProtoFile(const std::string &proto_file, | |||||
| std::map<int, std::pair<std::string, std::string> > &identifier_op_map, | |||||
| std::map<std::string, std::pair<int, std::string> > &op_identifier_map); | |||||
| Status WriteCaffeProtoFile(const char *custom_proto_file, | |||||
| std::ifstream &read_caffe, | |||||
| std::ofstream &write_tmp); | |||||
| Status WriteProtoFile(const char *caffe_proto_file, const char *custom_proto_file); | |||||
| Status FindConflictLine(const char *proto_file, int identifier, | |||||
| std::string &dest_line); | |||||
| Status AddCustomAndConflictLayer(const char *custom_proto_file, std::ofstream &write_tmp); | |||||
| Status AddCustomAndConflictMessage(const char *custom_proto_file, std::ofstream &write_tmp); | |||||
| void CheckConflictOp(const char *caffe_proto_file, const char *custom_proto_file, | |||||
| std::map<std::string, std::pair<int, std::string>> &caffe_op_identifier_map, | |||||
| std::map<std::string, std::pair<int, std::string>> &custom_op_identifier_map); | |||||
| void CheckConflictIdentifier(const char *caffe_proto_file, const char *custom_proto_file, | |||||
| std::map<int, std::pair<std::string, std::string>> caffe_identifier_op_map, | |||||
| std::map<int, std::pair<std::string, std::string>> custom_identifier_op_map); | |||||
| Status RecordProtoMessage(const std::string &proto_file); | |||||
| std::map<std::string, int> caffe_conflict_line_map_; | |||||
| std::map<std::string, int> custom_repeat_line_map_; | |||||
| std::map<std::string, int> custom_repeat_message_map_; | |||||
| std::string fusion_proto_path; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // PROTO_FILE_PARSE_UTIL_ | |||||
| @@ -0,0 +1,132 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "parser/common/register_tbe.h" | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "common/debug/log.h" | |||||
| #include "common/ge/ge_util.h" | |||||
| #include "common/op/ge_op_utils.h" | |||||
| #include "common/op_map.h" | |||||
| #include "common/util.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "graph/utils/type_utils.h" | |||||
| #include "parser/common/op_parser_factory.h" | |||||
| #include "parser/tensorflow/tensorflow_custom_parser_adapter.h" | |||||
| #include "parser/tensorflow/tensorflow_fusion_custom_parser_adapter.h" | |||||
| namespace ge { | |||||
| using PARSER_CREATOR_FN = std::function<std::shared_ptr<OpParser>(void)>; | |||||
| FMK_FUNC_HOST_VISIBILITY OpRegistrationTbe *OpRegistrationTbe::Instance() { | |||||
| static OpRegistrationTbe instance; | |||||
| return &instance; | |||||
| } | |||||
| bool OpRegistrationTbe::Finalize(const OpRegistrationData ®_data, bool is_train) { | |||||
| static std::map<domi::FrameworkType, std::map<std::string, std::string> *> op_map = {{CAFFE, &caffe_op_map}}; | |||||
| if (is_train) { | |||||
| op_map[domi::TENSORFLOW] = &tensorflow_train_op_map; | |||||
| } else { | |||||
| op_map[domi::TENSORFLOW] = &tensorflow_op_map; | |||||
| } | |||||
| if (op_map.find(reg_data.GetFrameworkType()) != op_map.end()) { | |||||
| std::map<std::string, std::string> *fmk_op_map = op_map[reg_data.GetFrameworkType()]; | |||||
| auto ori_optype_set = reg_data.GetOriginOpTypeSet(); | |||||
| for (auto &tmp : ori_optype_set) { | |||||
| if ((*fmk_op_map).find(tmp) != (*fmk_op_map).end()) { | |||||
| GELOGW("Op type does not need to be changed, om_optype:%s, orignal type:%s.", (*fmk_op_map)[tmp].c_str(), | |||||
| tmp.c_str()); | |||||
| continue; | |||||
| } else { | |||||
| (*fmk_op_map)[tmp] = reg_data.GetOmOptype(); | |||||
| GELOGD("First register in parser initialize, original type: %s, om_optype: %s, imply type: %s.", tmp.c_str(), | |||||
| reg_data.GetOmOptype().c_str(), TypeUtils::ImplyTypeToSerialString(reg_data.GetImplyType()).c_str()); | |||||
| } | |||||
| } | |||||
| } | |||||
| bool ret = RegisterParser(reg_data); | |||||
| return ret; | |||||
| } | |||||
| bool OpRegistrationTbe::RegisterParser(const OpRegistrationData ®_data) { | |||||
| if (reg_data.GetFrameworkType() == domi::TENSORFLOW) { | |||||
| std::shared_ptr<OpParserFactory> factory = OpParserFactory::Instance(domi::TENSORFLOW); | |||||
| if (factory == nullptr) { | |||||
| GELOGE(INTERNAL_ERROR, "Get op parser factory for tf failed."); | |||||
| return false; | |||||
| } | |||||
| if (reg_data.GetParseParamFn() != nullptr || reg_data.GetParseParamByOperatorFn() != nullptr) { | |||||
| bool is_registed = factory->OpParserIsRegistered(reg_data.GetOmOptype()); | |||||
| if (is_registed) { | |||||
| GELOGW("Parse param func has already register for op:%s.", reg_data.GetOmOptype().c_str()); | |||||
| return false; | |||||
| } | |||||
| std::shared_ptr<TensorFlowCustomParserAdapter> tf_parser_adapter = | |||||
| ge::MakeShared<TensorFlowCustomParserAdapter>(); | |||||
| if (tf_parser_adapter == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "Create tf parser adapter failed."); | |||||
| return false; | |||||
| } | |||||
| OpParserRegisterar registerar __attribute__((unused)) = OpParserRegisterar( | |||||
| domi::TENSORFLOW, reg_data.GetOmOptype(), [=]() -> std::shared_ptr<OpParser> { return tf_parser_adapter; }); | |||||
| } | |||||
| if (reg_data.GetFusionParseParamFn() != nullptr || reg_data.GetFusionParseParamByOpFn() != nullptr) { | |||||
| bool is_registed = factory->OpParserIsRegistered(reg_data.GetOmOptype(), true); | |||||
| if (is_registed) { | |||||
| GELOGW("Parse param func has already register for fusion op:%s.", reg_data.GetOmOptype().c_str()); | |||||
| return false; | |||||
| } | |||||
| GELOGI("Register fusion custom op parser: %s", reg_data.GetOmOptype().c_str()); | |||||
| std::shared_ptr<TensorFlowFusionCustomParserAdapter> tf_fusion_parser_adapter = | |||||
| ge::MakeShared<TensorFlowFusionCustomParserAdapter>(); | |||||
| if (tf_fusion_parser_adapter == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "Create tf fusion parser adapter failed."); | |||||
| return false; | |||||
| } | |||||
| OpParserRegisterar registerar __attribute__((unused)) = OpParserRegisterar( | |||||
| domi::TENSORFLOW, reg_data.GetOmOptype(), | |||||
| [=]() -> std::shared_ptr<OpParser> { return tf_fusion_parser_adapter; }, true); | |||||
| } | |||||
| } else { | |||||
| std::shared_ptr<OpParserFactory> factory = OpParserFactory::Instance(reg_data.GetFrameworkType()); | |||||
| if (factory == nullptr) { | |||||
| GELOGE(INTERNAL_ERROR, "Get op parser factory for %s failed.", | |||||
| TypeUtils::FmkTypeToSerialString(reg_data.GetFrameworkType()).c_str()); | |||||
| return false; | |||||
| } | |||||
| bool is_registed = factory->OpParserIsRegistered(reg_data.GetOmOptype()); | |||||
| if (is_registed) { | |||||
| GELOGW("Parse param func has already register for op:%s.", reg_data.GetOmOptype().c_str()); | |||||
| return false; | |||||
| } | |||||
| PARSER_CREATOR_FN func = CustomParserAdapterRegistry::Instance()->GetCreateFunc(reg_data.GetFrameworkType()); | |||||
| if (func == nullptr) { | |||||
| GELOGE(INTERNAL_ERROR, "Get custom parser adapter failed for fmk type %s.", | |||||
| TypeUtils::FmkTypeToSerialString(reg_data.GetFrameworkType()).c_str()); | |||||
| return false; | |||||
| } | |||||
| OpParserFactory::Instance(reg_data.GetFrameworkType())->RegisterCreator(reg_data.GetOmOptype(), func); | |||||
| GELOGD("Register custom parser adapter for op %s of fmk type %s success.", reg_data.GetOmOptype().c_str(), | |||||
| TypeUtils::FmkTypeToSerialString(reg_data.GetFrameworkType()).c_str()); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,34 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PARSER_COMMON_REGISTER_TBE_H_ | |||||
| #define PARSER_COMMON_REGISTER_TBE_H_ | |||||
| #include "register/op_registry.h" | |||||
| namespace ge { | |||||
| class OpRegistrationTbe { | |||||
| public: | |||||
| static OpRegistrationTbe *Instance(); | |||||
| bool Finalize(const OpRegistrationData ®_data, bool is_train = false); | |||||
| private: | |||||
| bool RegisterParser(const OpRegistrationData ®_data); | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // PARSER_COMMON_REGISTER_TBE_H_ | |||||
| @@ -0,0 +1,212 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tbe_plugin_loader.h" | |||||
| #include <dirent.h> | |||||
| #include <sys/stat.h> | |||||
| #include <unistd.h> | |||||
| #include <algorithm> | |||||
| #include <cstring> | |||||
| #include <fstream> | |||||
| #include <iostream> | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "common/util/error_manager/error_manager.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/string_util.h" | |||||
| #include "framework/omg/parser/parser_inner_ctx.h" | |||||
| #include "graph/utils/type_utils.h" | |||||
| #include "parser/common/acl_graph_parser_util.h" | |||||
| namespace ge { | |||||
| std::map<string, string> TBEPluginLoader::options_ = {}; | |||||
| namespace { | |||||
| const std::string FRAMEWORK_TYPE = "ge.frameworkType"; | |||||
| } | |||||
| // Get Singleton Instance | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY TBEPluginLoader &TBEPluginLoader::Instance() { | |||||
| static TBEPluginLoader instance_ptr_; | |||||
| return instance_ptr_; | |||||
| } | |||||
| Status TBEPluginLoader::ClearHandles_() { | |||||
| Status ret = SUCCESS; | |||||
| for (const auto &handle : handles_vec_) { | |||||
| if (dlclose(handle) != 0) { | |||||
| ret = FAILED; | |||||
| GELOGW("Failed to close handle: %s", dlerror()); | |||||
| } | |||||
| } | |||||
| handles_vec_.clear(); | |||||
| return ret; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status TBEPluginLoader::Finalize() { | |||||
| Status ret = ClearHandles_(); | |||||
| return ret; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void TBEPluginLoader::LoadPluginSo( | |||||
| const std::map<string, string> &options) { | |||||
| vector<string> file_list; | |||||
| string caffe_parser_path; | |||||
| std::string plugin_path; | |||||
| options_ = options; | |||||
| GetCustomOpPath(plugin_path); | |||||
| // Whether there are files in the plugin so path | |||||
| GetPluginSoFileList(plugin_path, file_list, caffe_parser_path); | |||||
| // No file | |||||
| if (file_list.empty()) { | |||||
| // Print log | |||||
| GELOGW("Can not find any plugin file in plugin_path: %s", plugin_path.c_str()); | |||||
| } | |||||
| GELOGW("The shared library will not be checked. Please ensure that the source of the shared library is trusted."); | |||||
| // Load other so files except lib_caffe_parser.so in the plugin so path | |||||
| for (auto elem : file_list) { | |||||
| StringUtils::Trim(elem); | |||||
| void *handle = dlopen(elem.c_str(), RTLD_NOW | RTLD_GLOBAL | RTLD_NODELETE); | |||||
| if (handle == nullptr) { | |||||
| GELOGW("dlopen failed, plugin name:%s. Message(%s).", elem.c_str(), dlerror()); | |||||
| } else if (find(handles_vec_.begin(), handles_vec_.end(), handle) == handles_vec_.end()) { | |||||
| // Close dl when the program exist, not close here | |||||
| GELOGI("Plugin load %s success.", elem.c_str()); | |||||
| handles_vec_.push_back(handle); | |||||
| } else { | |||||
| GELOGI("Plugin so has already been loaded, no need to load again."); | |||||
| } | |||||
| } | |||||
| } | |||||
| void TBEPluginLoader::GetCustomOpPath(std::string &customop_path) { | |||||
| GELOGI("Enter get custom op path schedule"); | |||||
| std::string fmk_type; | |||||
| domi::FrameworkType type = domi::TENSORFLOW; | |||||
| auto it = options_.find(FRAMEWORK_TYPE); | |||||
| if (it != options_.end()) { | |||||
| type = static_cast<domi::FrameworkType>(std::strtol(it->second.c_str(), nullptr, 10)); | |||||
| } | |||||
| fmk_type = ge::TypeUtils::FmkTypeToSerialString(type); | |||||
| GELOGI("Framework type is %s.", fmk_type.c_str()); | |||||
| const char *path_env = std::getenv("ASCEND_OPP_PATH"); | |||||
| if (path_env != nullptr) { | |||||
| std::string path = path_env; | |||||
| customop_path = (path + "/framework/custom" + "/:") + (path + "/framework/built-in/" + fmk_type); | |||||
| GELOGI("Get custom so path from env : %s", path_env); | |||||
| return; | |||||
| } | |||||
| std::string path_base = GetPath(); | |||||
| GELOGI("path_base is %s", path_base.c_str()); | |||||
| path_base = path_base.substr(0, path_base.rfind('/')); | |||||
| path_base = path_base.substr(0, path_base.rfind('/') + 1); | |||||
| customop_path = (path_base + "ops/framework/custom" + "/:") + (path_base + "ops/framework/built-in/" + fmk_type); | |||||
| } | |||||
| string TBEPluginLoader::GetPath() { | |||||
| Dl_info dl_info; | |||||
| if (dladdr(reinterpret_cast<void *>(&TBEPluginLoader::GetPath), &dl_info) == 0) { | |||||
| GELOGW("Failed to read so path!"); | |||||
| return string(); | |||||
| } else { | |||||
| string so_path = dl_info.dli_fname; | |||||
| char path[PATH_MAX] = {0}; | |||||
| if (so_path.length() >= PATH_MAX) { | |||||
| GELOGW("File path is too long!"); | |||||
| return string(); | |||||
| } | |||||
| if (realpath(so_path.c_str(), path) == nullptr) { | |||||
| GELOGW("Failed to get realpath of %s", so_path.c_str()); | |||||
| return string(); | |||||
| } | |||||
| so_path = path; | |||||
| so_path = so_path.substr(0, so_path.rfind('/') + 1); | |||||
| return so_path; | |||||
| } | |||||
| } | |||||
| void TBEPluginLoader::GetPluginSoFileList(const string &path, vector<string> &file_list, string &caffe_parser_path) { | |||||
| // Support to split multiple so directories by ":" | |||||
| vector<string> v_path = StringUtils::Split(path, ':'); | |||||
| for (size_t i = 0; i < v_path.size(); ++i) { | |||||
| FindParserSo(v_path[i], file_list, caffe_parser_path); | |||||
| GELOGI("CustomOpLib full name = %s", v_path[i].c_str()); | |||||
| } | |||||
| } | |||||
| void TBEPluginLoader::FindParserSo(const string &path, vector<string> &file_list, string &caffe_parser_path) { | |||||
| // Path, change to absolute path | |||||
| string real_path = ge::parser::RealPath(path.c_str()); | |||||
| // Plugin path does not exist | |||||
| if (real_path.empty()) { | |||||
| GELOGW("RealPath is empty."); | |||||
| return; | |||||
| } | |||||
| struct stat stat_buf; | |||||
| if ((stat(real_path.c_str(), &stat_buf) != 0) || (!S_ISDIR(stat_buf.st_mode))) { | |||||
| GELOGW("%s is not a dir.", real_path.c_str()); | |||||
| return; | |||||
| } | |||||
| struct dirent *dent(0); | |||||
| DIR *dir = opendir(real_path.c_str()); | |||||
| // Plugin path does not exist | |||||
| if (dir == nullptr) { | |||||
| GELOGW("Open directory %s failed.", real_path.c_str()); | |||||
| return; | |||||
| } | |||||
| while ((dent = readdir(dir)) != nullptr) { | |||||
| if (strcmp(dent->d_name, ".") == 0 || strcmp(dent->d_name, "..") == 0) continue; | |||||
| string name = dent->d_name; | |||||
| string full_name = real_path + "/" + name; | |||||
| const string so_suff = ".so"; | |||||
| const string caffe_parser_so_suff = "lib_caffe_parser.so"; | |||||
| const string aicpu_so_suff = "_aicpu.so"; | |||||
| const string aicpu_host_so_suff = "_online.so"; | |||||
| if (name.size() >= so_suff.size() && name.compare(name.size() - so_suff.size(), so_suff.size(), so_suff) == 0) { | |||||
| ProcessSoFullName(file_list, caffe_parser_path, full_name, caffe_parser_so_suff, aicpu_so_suff, | |||||
| aicpu_host_so_suff); | |||||
| } else { | |||||
| FindParserSo(full_name, file_list, caffe_parser_path); | |||||
| } | |||||
| } | |||||
| closedir(dir); | |||||
| } | |||||
| void TBEPluginLoader::ProcessSoFullName(vector<string> &file_list, string &caffe_parser_path, string &full_name, | |||||
| const string &caffe_parser_so_suff, const string &aicpu_so_suff, | |||||
| const string &aicpu_host_so_suff) { | |||||
| if (full_name.size() >= caffe_parser_so_suff.size() && | |||||
| full_name.compare(full_name.size() - caffe_parser_so_suff.size(), caffe_parser_so_suff.size(), | |||||
| caffe_parser_so_suff) == 0) { | |||||
| caffe_parser_path = full_name; | |||||
| } else { | |||||
| // Save parser so path into file_list vector | |||||
| file_list.push_back(full_name); | |||||
| } | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,62 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PARSER_COMMON_TBE_PLUGIN_LOADER_H_ | |||||
| #define PARSER_COMMON_TBE_PLUGIN_LOADER_H_ | |||||
| #include <dlfcn.h> | |||||
| #include <functional> | |||||
| #include <iostream> | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <type_traits> | |||||
| #include <typeinfo> | |||||
| #include <vector> | |||||
| #include "external/ge/ge_api_error_codes.h" | |||||
| #include "external/register/register.h" | |||||
| namespace ge { | |||||
| using SoHandlesVec = std::vector<void *>; | |||||
| class TBEPluginLoader { | |||||
| public: | |||||
| Status Finalize(); | |||||
| // Get TBEPluginManager singleton instance | |||||
| static TBEPluginLoader& Instance(); | |||||
| void LoadPluginSo(const std::map<string, string> &options); | |||||
| static string GetPath(); | |||||
| private: | |||||
| TBEPluginLoader() = default; | |||||
| ~TBEPluginLoader() = default; | |||||
| Status ClearHandles_(); | |||||
| static void ProcessSoFullName(vector<string> &file_list, string &caffe_parser_path, string &full_name, | |||||
| const string &caffe_parser_so_suff, const string &aicpu_so_suff, | |||||
| const string &aicpu_host_so_suff); | |||||
| static void GetCustomOpPath(std::string &customop_path); | |||||
| static void GetPluginSoFileList(const string &path, vector<string> &file_list, string &caffe_parser_path); | |||||
| static void FindParserSo(const string &path, vector<string> &file_list, string &caffe_parser_path); | |||||
| SoHandlesVec handles_vec_; | |||||
| static std::map<string, string> options_; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif //PARSER_COMMON_TBE_PLUGIN_LOADER_H_ | |||||
| @@ -0,0 +1,78 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/thread_pool.h" | |||||
| #include <atomic> | |||||
| #include <functional> | |||||
| #include <queue> | |||||
| #include <stdexcept> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "register/register_types.h" | |||||
| namespace ge { | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ThreadPool::ThreadPool(uint32_t size) : is_stoped_(false) { | |||||
| idle_thrd_num_ = size < 1 ? 1 : size; | |||||
| for (uint32_t i = 0; i < idle_thrd_num_; ++i) { | |||||
| pool_.emplace_back(ThreadFunc, this); | |||||
| } | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ThreadPool::~ThreadPool() { | |||||
| is_stoped_.store(true); | |||||
| { | |||||
| std::unique_lock<std::mutex> lock{m_lock_}; | |||||
| cond_var_.notify_all(); | |||||
| } | |||||
| for (std::thread &thd : pool_) { | |||||
| if (thd.joinable()) { | |||||
| try { | |||||
| thd.join(); | |||||
| } catch (const std::system_error &) { | |||||
| GELOGW("system_error"); | |||||
| } catch (...) { | |||||
| GELOGW("exception"); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| void ThreadPool::ThreadFunc(ThreadPool *thread_pool) { | |||||
| if (thread_pool == nullptr) { | |||||
| return; | |||||
| } | |||||
| while (!thread_pool->is_stoped_) { | |||||
| std::function<void()> task; | |||||
| { | |||||
| std::unique_lock<std::mutex> lock{thread_pool->m_lock_}; | |||||
| thread_pool->cond_var_.wait( | |||||
| lock, [thread_pool] { return thread_pool->is_stoped_.load() || !thread_pool->tasks_.empty(); }); | |||||
| if (thread_pool->is_stoped_ && thread_pool->tasks_.empty()) { | |||||
| return; | |||||
| } | |||||
| task = std::move(thread_pool->tasks_.front()); | |||||
| thread_pool->tasks_.pop(); | |||||
| } | |||||
| --thread_pool->idle_thrd_num_; | |||||
| task(); | |||||
| ++thread_pool->idle_thrd_num_; | |||||
| } | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,83 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PARSER_COMMON_THREAD_POOL_H_ | |||||
| #define PARSER_COMMON_THREAD_POOL_H_ | |||||
| #include <atomic> | |||||
| #include <condition_variable> | |||||
| #include <functional> | |||||
| #include <future> | |||||
| #include <memory> | |||||
| #include <queue> | |||||
| #include <stdexcept> | |||||
| #include <thread> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/ge_inner_error_codes.h" | |||||
| #include "external/ge/ge_api_error_codes.h" | |||||
| #include "graph/types.h" | |||||
| #include "common/ge/ge_util.h" | |||||
| namespace ge { | |||||
| using ThreadTask = std::function<void()>; | |||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ThreadPool { | |||||
| public: | |||||
| explicit ThreadPool(uint32_t size = 4); | |||||
| ~ThreadPool(); | |||||
| template <class Func, class... Args> | |||||
| auto commit(Func &&func, Args &&... args) -> std::future<decltype(func(args...))> { | |||||
| GELOGD("commit run task enter."); | |||||
| using retType = decltype(func(args...)); | |||||
| std::future<retType> fail_future; | |||||
| if (is_stoped_.load()) { | |||||
| GELOGE(ge::FAILED, "thread pool has been stopped."); | |||||
| return fail_future; | |||||
| } | |||||
| auto bindFunc = std::bind(std::forward<Func>(func), std::forward<Args>(args)...); | |||||
| auto task = ge::MakeShared<std::packaged_task<retType()>>(bindFunc); | |||||
| if (task == nullptr) { | |||||
| GELOGE(ge::FAILED, "Make shared failed."); | |||||
| return fail_future; | |||||
| } | |||||
| std::future<retType> future = task->get_future(); | |||||
| { | |||||
| std::lock_guard<std::mutex> lock{m_lock_}; | |||||
| tasks_.emplace([task]() { (*task)(); }); | |||||
| } | |||||
| cond_var_.notify_one(); | |||||
| GELOGD("commit run task end"); | |||||
| return future; | |||||
| } | |||||
| static void ThreadFunc(ThreadPool *thread_pool); | |||||
| private: | |||||
| std::vector<std::thread> pool_; | |||||
| std::queue<ThreadTask> tasks_; | |||||
| std::mutex m_lock_; | |||||
| std::condition_variable cond_var_; | |||||
| std::atomic<bool> is_stoped_; | |||||
| std::atomic<uint32_t> idle_thrd_num_; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // PARSER_COMMON_THREAD_POOL_H_ | |||||
| @@ -0,0 +1,307 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_COMMON_TUPLE_H_ | |||||
| #define GE_COMMON_TUPLE_H_ | |||||
| #include <algorithm> | |||||
| #include <iostream> | |||||
| #include <string> | |||||
| #include <type_traits> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "framework/common/debug/log.h" | |||||
| namespace ge { | |||||
| template <typename ValueType> | |||||
| class Tuple { | |||||
| public: | |||||
| Tuple() = default; | |||||
| inline ~Tuple() { | |||||
| delete[] data_heap_; | |||||
| data_heap_ = nullptr; | |||||
| } | |||||
| /// | |||||
| /// @brief copy constructor from another tuple | |||||
| /// @param s the source tuple | |||||
| /// | |||||
| inline Tuple(const Tuple<ValueType> &s) { this->assign(s.begin(), s.end()); } | |||||
| /// | |||||
| /// @brief constructor from initializer list | |||||
| /// @param init the initializer_list | |||||
| /// | |||||
| inline Tuple(const std::initializer_list<ValueType> &init) { this->assign(init.begin(), init.end()); } | |||||
| /// | |||||
| /// @brief constructor from vector | |||||
| /// @param init the vector | |||||
| /// | |||||
| inline Tuple(const std::vector<ValueType> &init) { // NOLINT(runtime/explicit) | |||||
| this->assign(init.begin(), init.end()); | |||||
| } | |||||
| /// | |||||
| /// @brief move constructor from Tuple | |||||
| /// @param src the source shape | |||||
| /// | |||||
| inline Tuple(Tuple<ValueType> &&src) { // NOLINT(runtime/explicit) | |||||
| this->swap(src); | |||||
| } | |||||
| /// | |||||
| /// @brief construct the Tuple from content of iterator | |||||
| /// @param begin the beginning of iterator | |||||
| /// @param end end the end of the iterator | |||||
| /// @tparam RandomAccessIterator iterator type | |||||
| /// | |||||
| template <typename RandomAccessIterator> | |||||
| inline Tuple(RandomAccessIterator begin, RandomAccessIterator end) { | |||||
| this->assign(begin, end); | |||||
| } | |||||
| /// | |||||
| /// @brief Assign content to tuple from iterator. | |||||
| /// @param begin the beginning of iterator | |||||
| /// @param end end the end of the iterator | |||||
| /// @tparam RandomAccessIterator iterator type | |||||
| /// | |||||
| template <typename RandomAccessIterator> | |||||
| inline void assign(const RandomAccessIterator &begin, const RandomAccessIterator &end) { | |||||
| this->SetDim(end - begin); | |||||
| (void)std::copy(begin, end, this->begin()); | |||||
| } | |||||
| /// | |||||
| /// @brief Swap current object with other | |||||
| /// @param other another object to be swapped. | |||||
| /// | |||||
| inline void swap(Tuple<ValueType> &other) { // NOLINT(*) | |||||
| std::swap(ndim_, other.ndim_); | |||||
| std::swap(num_heap_allocated_, other.num_heap_allocated_); | |||||
| std::swap(data_stack_, other.data_stack_); | |||||
| std::swap(data_heap_, other.data_heap_); | |||||
| } | |||||
| /// | |||||
| /// @brief assignment from another tuple. | |||||
| /// @param src source tuple | |||||
| /// @return reference of self | |||||
| /// | |||||
| inline Tuple<ValueType> &operator=(const Tuple<ValueType> &src) { | |||||
| if (&src != this) { | |||||
| this->assign(src.begin(), src.end()); | |||||
| } | |||||
| return *this; | |||||
| } | |||||
| /// | |||||
| /// @brief assignment from rvalue of another tuple. | |||||
| /// @param src source tuple | |||||
| /// @return reference of self | |||||
| /// | |||||
| inline Tuple<ValueType> &operator=(Tuple<ValueType> &&src) { | |||||
| if (&src != this) { | |||||
| Tuple<ValueType>(std::move(src)).swap(*this); | |||||
| } | |||||
| return *this; | |||||
| } | |||||
| /// | |||||
| /// @brief assignment from initializer list | |||||
| /// @param init the source initializer list | |||||
| /// @return reference of self | |||||
| /// | |||||
| inline Tuple<ValueType> &operator=(std::initializer_list<ValueType> init) { | |||||
| this->assign(init.begin(), init.end()); | |||||
| return *this; | |||||
| } | |||||
| /// | |||||
| /// @return whether two tuple equals | |||||
| /// @param s the tuple to compare against | |||||
| /// | |||||
| inline bool operator==(const Tuple<ValueType> &s) const { | |||||
| if (ndim_ != s.ndim_) return false; | |||||
| return std::equal(begin(), end(), s.begin()); | |||||
| } | |||||
| /// | |||||
| /// @return whether two tuple not equal | |||||
| /// @param s the tuple to compare against | |||||
| /// | |||||
| inline bool operator!=(const Tuple<ValueType> &s) const { return !(*this == s); } | |||||
| /// | |||||
| /// @return the begin data pointer to content of the tuple | |||||
| /// | |||||
| inline const ValueType *begin() const { return ndim_ <= STACK_CACHE_NUM ? data_stack_ : data_heap_; } | |||||
| /// | |||||
| /// @return the begin data pointer to content of the tuple | |||||
| /// | |||||
| inline ValueType *begin() { return ndim_ <= STACK_CACHE_NUM ? data_stack_ : data_heap_; } | |||||
| /// | |||||
| /// @return the data pointer to end of the tuple | |||||
| /// | |||||
| inline const ValueType *end() const { | |||||
| return ndim_ <= STACK_CACHE_NUM ? (data_stack_ + ndim_) : (data_heap_ + ndim_); | |||||
| } | |||||
| /// | |||||
| /// @return the data pointer to end the tuple | |||||
| /// | |||||
| inline ValueType *end() { return ndim_ <= STACK_CACHE_NUM ? (data_stack_ + ndim_) : (data_heap_ + ndim_); } | |||||
| /// | |||||
| /// @return number of dimension of the tuple | |||||
| /// | |||||
| inline uint32_t ndim() const { return ndim_; } | |||||
| /// | |||||
| /// @brief get corresponding index | |||||
| /// @param i dimension index | |||||
| /// @return the corresponding dimension size | |||||
| /// | |||||
| inline ValueType &operator[](size_t i) { return begin()[i]; } | |||||
| /// | |||||
| /// @brief get corresponding index | |||||
| /// @param i dimension index | |||||
| /// @return the corresponding dimension size | |||||
| /// | |||||
| inline const ValueType &operator[](size_t i) const { return begin()[i]; } | |||||
| /// | |||||
| /// @brief allow output string of tuple to ostream | |||||
| /// @param os the output stream | |||||
| /// @param t the tuple | |||||
| /// @return the ostream | |||||
| /// | |||||
| friend std::ostream &operator<<(std::ostream &os, const Tuple<ValueType> &t) { | |||||
| os << '['; | |||||
| const ValueType *begin = t.begin(); | |||||
| const ValueType *end = t.end(); | |||||
| for (const ValueType *it = begin; it != end; ++it) { | |||||
| if (it != begin) os << ','; | |||||
| os << *it; | |||||
| } | |||||
| os << ']'; | |||||
| return os; | |||||
| } | |||||
| /// | |||||
| /// @brief read tuple from the istream | |||||
| /// @param is the input stream | |||||
| /// @param t The tuple | |||||
| /// @return the istream | |||||
| /// | |||||
| friend std::istream &operator>>(std::istream &is, Tuple<ValueType> &t) { | |||||
| // get ( | |||||
| if (!HandleLeftBracket(is, t)) { | |||||
| return is; | |||||
| } | |||||
| // Handle empty tuple | |||||
| while (isspace(is.peek())) { | |||||
| (void)is.get(); | |||||
| } | |||||
| if (IsRightBracket(is.peek())) { | |||||
| (void)is.get(); | |||||
| return is; | |||||
| } | |||||
| // Handle non-empty tuple | |||||
| ValueType idx; | |||||
| std::vector<ValueType> tmp; | |||||
| while (is >> idx) { | |||||
| tmp.push_back(idx); | |||||
| char ch; | |||||
| do { | |||||
| ch = static_cast<char>(is.get()); | |||||
| } while (isspace(ch)); | |||||
| if (std::is_integral<ValueType>::value && ch == 'L') { | |||||
| ch = static_cast<char>(is.get()); | |||||
| } | |||||
| if (ch == ',') { | |||||
| while (true) { | |||||
| ch = static_cast<char>(is.peek()); | |||||
| if (isspace(ch)) { | |||||
| (void)is.get(); | |||||
| continue; | |||||
| } | |||||
| if (IsRightBracket(ch)) { | |||||
| (void)is.get(); | |||||
| break; | |||||
| } | |||||
| break; | |||||
| } | |||||
| if (IsRightBracket(ch)) break; | |||||
| } else if (IsRightBracket(ch)) { | |||||
| break; | |||||
| } else { | |||||
| is.setstate(std::ios::failbit); | |||||
| return is; | |||||
| } | |||||
| } | |||||
| t.assign(tmp.begin(), tmp.end()); | |||||
| return is; | |||||
| } | |||||
| // stack cache size | |||||
| static const uint32_t STACK_CACHE_NUM = 4; | |||||
| // in stack space used to store shape when it is small | |||||
| ValueType data_stack_[STACK_CACHE_NUM]; | |||||
| // space to store shape when dimension is big | |||||
| ValueType *data_heap_{nullptr}; | |||||
| uint32_t ndim_{0}; | |||||
| protected: | |||||
| // number of cells allocated in data_heap_ | |||||
| uint32_t num_heap_allocated_{0}; | |||||
| // internal function to change the dimension | |||||
| inline void SetDim(uint32_t ndim) { | |||||
| if (ndim > STACK_CACHE_NUM && ndim > num_heap_allocated_) { | |||||
| if (data_heap_ != nullptr) { | |||||
| delete[] data_heap_; | |||||
| data_heap_ = nullptr; | |||||
| } | |||||
| data_heap_ = new (std::nothrow) ValueType[ndim](); | |||||
| if (data_heap_ == nullptr) { | |||||
| GELOGW("data_heap_ is nullptr."); | |||||
| } | |||||
| num_heap_allocated_ = ndim; | |||||
| } | |||||
| ndim_ = ndim; | |||||
| } | |||||
| static inline bool IsLeftBracket(char ch) { return ch == '(' || ch == '['; } | |||||
| static inline bool IsRightBracket(char ch) { return ch == ')' || ch == ']'; } | |||||
| friend bool HandleLeftBracket(std::istream &is, Tuple<ValueType> &t) { | |||||
| while (true) { | |||||
| char ch = is.peek(); | |||||
| if (isdigit(ch) || (ch == '-')) { | |||||
| ValueType idx; | |||||
| if (is >> idx) { | |||||
| t.assign(&idx, &idx + 1); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| (void)is.get(); | |||||
| if (IsLeftBracket(ch)) { | |||||
| break; | |||||
| } | |||||
| if (!isspace(ch)) { | |||||
| is.setstate(std::ios::failbit); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| }; | |||||
| using UintTuple = Tuple<uint32_t>; | |||||
| using IntTuple = Tuple<int64_t>; | |||||
| using FloatTuple = Tuple<float>; | |||||
| using BoolTuple = Tuple<bool>; | |||||
| using StringTuple = Tuple<std::string>; | |||||
| } // namespace ge | |||||
| #endif // GE_COMMON_TUPLE_H_ | |||||
| @@ -0,0 +1,53 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_TYPES_MAP_H | |||||
| #define GE_TYPES_MAP_H | |||||
| #include "external/graph/types.h" | |||||
| #include "proto/tensorflow/graph.pb.h" | |||||
| namespace ge { | |||||
| // Correspondence between data_type in GE and tensorflow | |||||
| static map<int32_t, int32_t> GE_TENSORFLOW_DATA_TYPE_MAP = { | |||||
| {ge::DataType::DT_UNDEFINED, domi::tensorflow::DT_INVALID}, | |||||
| {ge::DataType::DT_FLOAT, domi::tensorflow::DT_FLOAT}, | |||||
| {ge::DataType::DT_FLOAT16, domi::tensorflow::DT_HALF}, | |||||
| {ge::DataType::DT_INT8, domi::tensorflow::DT_INT8}, | |||||
| {ge::DataType::DT_INT16, domi::tensorflow::DT_INT16}, | |||||
| {ge::DataType::DT_UINT16, domi::tensorflow::DT_UINT16}, | |||||
| {ge::DataType::DT_UINT8, domi::tensorflow::DT_UINT8}, | |||||
| {ge::DataType::DT_INT32, domi::tensorflow::DT_INT32}, | |||||
| {ge::DataType::DT_INT64, domi::tensorflow::DT_INT64}, | |||||
| {ge::DataType::DT_UINT32, domi::tensorflow::DT_UINT32}, | |||||
| {ge::DataType::DT_UINT64, domi::tensorflow::DT_UINT64}, | |||||
| {ge::DataType::DT_STRING, domi::tensorflow::DT_STRING}, | |||||
| {ge::DataType::DT_RESOURCE, domi::tensorflow::DT_RESOURCE}, | |||||
| {ge::DataType::DT_BOOL, domi::tensorflow::DT_BOOL}, | |||||
| {ge::DataType::DT_DOUBLE, domi::tensorflow::DT_DOUBLE}, | |||||
| {ge::DataType::DT_COMPLEX64, domi::tensorflow::DT_COMPLEX64}, | |||||
| {ge::DataType::DT_COMPLEX128, domi::tensorflow::DT_COMPLEX128}, | |||||
| {ge::DataType::DT_QINT8, domi::tensorflow::DT_QINT8}, | |||||
| {ge::DataType::DT_QINT16, domi::tensorflow::DT_QINT16}, | |||||
| {ge::DataType::DT_QINT32, domi::tensorflow::DT_QINT32}, | |||||
| {ge::DataType::DT_QUINT8, domi::tensorflow::DT_QUINT8}, | |||||
| {ge::DataType::DT_QUINT16, domi::tensorflow::DT_QUINT16}, | |||||
| {ge::DataType::DT_DUAL, domi::tensorflow::DT_INVALID}, | |||||
| {ge::DataType::DT_DUAL_SUB_INT8, domi::tensorflow::DT_INVALID}, | |||||
| {ge::DataType::DT_DUAL_SUB_UINT8, domi::tensorflow::DT_INVALID}, | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // GE_TYPES_MAP_H | |||||
| @@ -0,0 +1,32 @@ | |||||
| set(PROTO_LIST | |||||
| "${TOP_DIR}/inc/register/proto/tensorflow/graph.proto" | |||||
| "${TOP_DIR}/inc/register/proto/tensorflow/node_def.proto" | |||||
| "${TOP_DIR}/inc/register/proto/tensorflow/tensor_shape.proto" | |||||
| "${TOP_DIR}/inc/register/proto/tensorflow/attr_value.proto" | |||||
| "${TOP_DIR}/inc/register/proto/tensorflow/function.proto" | |||||
| "${TOP_DIR}/inc/register/proto/tensorflow/op_def.proto" | |||||
| "${TOP_DIR}/inc/register/proto/tensorflow/resource_handle.proto" | |||||
| "${TOP_DIR}/inc/register/proto/tensorflow/tensor.proto" | |||||
| "${TOP_DIR}/inc/register/proto/tensorflow/types.proto" | |||||
| "${TOP_DIR}/inc/register/proto/tensorflow/versions.proto" | |||||
| "${TOP_DIR}/inc/register/proto/tensorflow/graph_library.proto" | |||||
| ) | |||||
| protobuf_generate_py(ge PROTO_SRCS ${PROTO_LIST}) | |||||
| include_directories(${CMAKE_CURRENT_LIST_DIR}) | |||||
| ############ func2graph/util ############ | |||||
| add_custom_target(util ALL | |||||
| DEPENDS ${PROTO_SRCS} | |||||
| COMMAND mkdir -p ${CMAKE_CURRENT_BINARY_DIR}/util | |||||
| && cp -r ${PROTO_SRCS} ${CMAKE_CURRENT_BINARY_DIR}/util | |||||
| ) | |||||
| set(INSTALL_BASE_DIR "") | |||||
| set(INSTALL_LIBRARY_DIR lib) | |||||
| install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/util OPTIONAL | |||||
| DESTINATION ${INSTALL_LIBRARY_DIR}/func2graph | |||||
| ) | |||||
| @@ -0,0 +1,279 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # less required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| #!/usr/bin/env python | |||||
| # -*- coding:utf-8 -*- | |||||
| import os | |||||
| import sys | |||||
| import getopt | |||||
| from google.protobuf import text_format | |||||
| import tensorflow as tf | |||||
| from tensorflow.python.framework import function_def_to_graph | |||||
| from tensorflow.python.framework.errors_impl import NotFoundError | |||||
| from tensorflow.python.platform import gfile | |||||
| from tensorflow.core.framework import graph_pb2 | |||||
| from tensorflow.core.framework import tensor_shape_pb2 | |||||
| from tensorflow.core.framework import types_pb2 | |||||
| from tensorflow.core.framework import versions_pb2 | |||||
| from tensorflow.python.eager import context | |||||
| from tensorflow.python.framework import importer | |||||
| from tensorflow.python.framework import ops | |||||
| from tensorflow.python.framework import versions | |||||
| sys.path.append(os.path.join(os.path.split(os.path.realpath(__file__))[0], "util")) | |||||
| import graph_library_pb2 | |||||
| def _get_num_args(arg_def, node_def): | |||||
| if arg_def.number_attr: | |||||
| return node_def.attr[arg_def.number_attr].i | |||||
| elif arg_def.type_list_attr: | |||||
| return len(node_def.attr[arg_def.type_list_attr].list.type) | |||||
| elif arg_def.type_attr or arg_def.type != types_pb2.DT_INVALID: | |||||
| return 1 | |||||
| else: | |||||
| raise ValueError("Invalid arg_def:\n\n{}".format(str(arg_def))) | |||||
| def is_function(fname): | |||||
| """Checks for a function definition with `fname` in the current context.""" | |||||
| if context.executing_eagerly(): | |||||
| return context.context().has_function(fname) | |||||
| else: | |||||
| return ops.get_default_graph()._is_function(fname) | |||||
| def create_arg_for_input_nodes(fdef, graph_def, input_shapes): | |||||
| for i, arg_def in enumerate(fdef.signature.input_arg): | |||||
| node_def = graph_def.node.add() | |||||
| node_def.name = arg_def.name | |||||
| node_def.op = "_Arg" | |||||
| node_def.attr["T"].type = arg_def.type | |||||
| node_def.attr["index"].i = i | |||||
| if input_shapes and input_shapes[i] is not None: | |||||
| input_shape = input_shapes[i] | |||||
| if not isinstance(input_shape, tensor_shape_pb2.TensorShapeProto): | |||||
| input_shape = input_shape.as_proto() | |||||
| node_def.attr["shape"].shape.CopyFrom(input_shape) | |||||
| arg_attrs = fdef.arg_attr[i].attr | |||||
| for k in arg_attrs: | |||||
| # Only copy internal attributes. Normal attributes for nodes cannot be | |||||
| # applied to these Arg nodes. | |||||
| if k.startswith("_"): | |||||
| node_def.attr[k].CopyFrom(arg_attrs[k]) | |||||
| return | |||||
| def create_retval_for_output_nodes(fdef, graph_def, nested_to_flat_tensor_name): | |||||
| for i, arg_def in enumerate(fdef.signature.output_arg): | |||||
| node_def = graph_def.node.add() | |||||
| node_def.name = '{}_Retval'.format(arg_def.name) | |||||
| node_def.op = "_Retval" | |||||
| node_def.attr["T"].type = arg_def.type | |||||
| node_def.attr["index"].i = i | |||||
| node_def.attr["op_def"].s = ops.get_default_graph()._get_op_def(node_def.op).SerializeToString() | |||||
| ret_name = fdef.ret[arg_def.name] | |||||
| node_def.input.append(nested_to_flat_tensor_name[ret_name]) | |||||
| return | |||||
| def updat_input_index(node_def, op_def, nested_to_flat_tensor_name): | |||||
| flattened_index = 0 | |||||
| for arg_def in op_def.output_arg: | |||||
| num_args = _get_num_args(arg_def, node_def) | |||||
| for i in range(num_args): | |||||
| # Map tensor names from "node_name:output_arg_name:index" to | |||||
| # "node_name:flattened_index". | |||||
| nested_name = "{}:{}:{}".format(node_def.name, arg_def.name, i) | |||||
| if flattened_index == 0: | |||||
| flat_name = node_def.name | |||||
| else: | |||||
| flat_name = "{}:{}".format(node_def.name, flattened_index) | |||||
| nested_to_flat_tensor_name[nested_name] = flat_name | |||||
| flattened_index += 1 | |||||
| control_name = "^" + node_def.name | |||||
| nested_to_flat_tensor_name[control_name] = control_name | |||||
| return | |||||
| def build_tensor_name(fdef, default_graph): | |||||
| nested_to_flat_tensor_name = {} | |||||
| for arg_def in fdef.signature.input_arg: | |||||
| nested_to_flat_tensor_name[arg_def.name] = arg_def.name | |||||
| control_name = '^{}'.format(arg_def.name) | |||||
| nested_to_flat_tensor_name[control_name] = control_name | |||||
| global op_def | |||||
| for node_def in fdef.node_def: | |||||
| f = default_graph._functions.get(node_def.op, None) | |||||
| if f is not None and hasattr(f, "signature"): | |||||
| op_def = f.signature | |||||
| if node_def.op not in copied_functions: | |||||
| # Since this function is referenced as an op type, we have no choice but | |||||
| # to copy it into the GraphDef if we want downstream tools to process | |||||
| # it. | |||||
| graph_def.library.function.add().CopyFrom(f.definition) | |||||
| copied_functions.add(node_def.op) | |||||
| else: | |||||
| op_def = ops.get_default_graph()._get_op_def(node_def.op) | |||||
| for attr in op_def.attr: | |||||
| if attr.type == "func": | |||||
| fname = node_def.attr[attr.name].func.name | |||||
| if not is_function(fname): | |||||
| raise ValueError("%s function not found." % fname) | |||||
| elif attr.type == "list(func)": | |||||
| for fn in node_def.attr[attr.name].list.func: | |||||
| fname = fn.name | |||||
| if not is_function(fname): | |||||
| raise ValueError("%s function not found." % fname) | |||||
| # Iterate over output_args in op_def to build the map. | |||||
| # Index of the output tensor in the flattened list of *all* output | |||||
| # tensors of the op. | |||||
| updat_input_index(node_def, op_def, nested_to_flat_tensor_name) | |||||
| return nested_to_flat_tensor_name | |||||
| def convert_function_def_to_graph_def(fdef, input_shapes=None, copy_functions=True): | |||||
| graph_def = graph_pb2.GraphDef() | |||||
| graph_def.versions.CopyFrom( | |||||
| versions_pb2.VersionDef( | |||||
| producer=versions.GRAPH_DEF_VERSION, | |||||
| min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER)) | |||||
| default_graph = ops.get_default_graph() | |||||
| copied_functions = set() | |||||
| # Copy *all* functions from outer graph to `graph_def` so that both direct | |||||
| # and indirect references are safely handled. | |||||
| if copy_functions: | |||||
| default_graph._copy_functions_to_graph_def(graph_def, 0) | |||||
| for function_name in default_graph._functions.keys(): | |||||
| copied_functions.add(function_name) | |||||
| if input_shapes and len(input_shapes) != len(fdef.signature.input_arg): | |||||
| raise ValueError("Length of input_shapes must match the number of " + | |||||
| "input_args. len(input_shapes): {} len(input_arg): {}". | |||||
| format(len(input_shapes), len(fdef.signature.input_arg))) | |||||
| # 1. Create _Arg for input nodes. | |||||
| create_arg_for_input_nodes(fdef, graph_def, input_shapes) | |||||
| # 2. Copy all body NodeDefs to the GraphDef. | |||||
| graph_def.node.extend(fdef.node_def) | |||||
| # 3. Perform the renaming. | |||||
| # Build the tensor name mapping then flatten the tensor names. | |||||
| # See comment on `FunctionDef.node_def` on how the tensor naming in | |||||
| # FunctionDefs is different from GraphDefs. | |||||
| nested_to_flat_tensor_name = build_tensor_name(fdef, default_graph) | |||||
| # Update inputs of all nodes in graph. | |||||
| for node_def in graph_def.node: | |||||
| for i in range(len(node_def.input)): | |||||
| node_def.input[i] = nested_to_flat_tensor_name[node_def.input[i]] | |||||
| # Create _Retval for output nodes. | |||||
| create_retval_for_output_nodes(fdef, graph_def, nested_to_flat_tensor_name) | |||||
| return graph_def, nested_to_flat_tensor_name | |||||
| def convert_graphs(filename): | |||||
| try: | |||||
| with tf.io.gfile.GFile(filename, 'rb') as f: | |||||
| graph_def = tf.compat.v1.GraphDef() | |||||
| graph_def.ParseFromString(f.read()) | |||||
| tf.import_graph_def(graph_def, name='') | |||||
| if len(graph_def.library.function) == 0: | |||||
| print("INFO: The input model does not contain a functionDef and does not require conversion.") | |||||
| return | |||||
| try: | |||||
| convert_subgraphs(graph_def, filename) | |||||
| except Exception as e: | |||||
| print("ERROR: Convert subgraphs failed.", e) | |||||
| return | |||||
| print("INFO: Convert to subgraphs successfully.") | |||||
| except NotFoundError: | |||||
| print('ERROR: model file {} does not exist'.format(filename)) | |||||
| return | |||||
| def convert_subgraphs(graph_def, filename): | |||||
| graph_def_library = graph_library_pb2.GraphDefLibrary() | |||||
| for i, fdef in enumerate(graph_def.library.function): | |||||
| sub_graph, nested_to_flat_tensor_name = convert_function_def_to_graph_def(fdef, copy_functions=False) | |||||
| print("INFO: Convert FunctionDef, index:{}, name:{}".format(str(i), fdef.signature.name)) | |||||
| sub_graph_name = '{}.pb'.format(fdef.signature.name) | |||||
| result_path = '{}/results'.format(os.path.dirname(os.path.abspath(filename))) | |||||
| tf.io.write_graph(sub_graph, result_path, sub_graph_name, as_text=False) | |||||
| data = sub_graph.SerializeToString() | |||||
| ge_graph_def = graph_library_pb2.GeGraphDef() | |||||
| ge_graph_def.name = fdef.signature.name | |||||
| ge_graph_def.graph.ParseFromString(data) | |||||
| graph_def_library.graph_def.append(ge_graph_def) | |||||
| print(graph_def_library.graph_def[i]) | |||||
| # Write to prototxt | |||||
| try: | |||||
| graph_def_file = '{}/graph_def_library.pbtxt'.format(os.path.dirname(os.path.abspath(filename))) | |||||
| print("graph_def_file: ", graph_def_file) | |||||
| with open(graph_def_file, "w") as f: | |||||
| print(graph_def_library, file=f) | |||||
| except IOError: | |||||
| print("Could not open file. Creating a new one.") | |||||
| def usage(): | |||||
| print( | |||||
| ''' | |||||
| Based on tensorflow 1.15 or later, Python 3 | |||||
| Convert the tensorflow functionDefs in the input model file to single GraphDefs, | |||||
| and save the result to the "results" directory and graph_def_library.pbtxt in | |||||
| the input file directory. | |||||
| The name of the sub graph is same as the name of the corresponding functionDef. | |||||
| Usage: func2grpah.py <command> | |||||
| Available commands: | |||||
| model (-m) Input model file. | |||||
| version (-v) Prints the version of this software. | |||||
| help (-h) Prints help for commands. | |||||
| ''' | |||||
| ) | |||||
| if __name__ == '__main__': | |||||
| model = '' | |||||
| try: | |||||
| opts, args = getopt.getopt(sys.argv[1:], '-v-h-m:', ['version', 'help', 'model=']) | |||||
| for opt_name, opt_value in opts: | |||||
| if opt_name in ('-m', '--model'): | |||||
| model = opt_value | |||||
| print("INFO: Input model file is", model) | |||||
| convert_graphs(model) | |||||
| elif opt_name in ('-h', '--help'): | |||||
| usage() | |||||
| break | |||||
| elif opt_name in ('-v', '--version'): | |||||
| print("version 1.0.0") | |||||
| break | |||||
| except getopt.GetoptError: | |||||
| print("ERROR: Input parameters is invalid, use '--help' to view the help.") | |||||
| if (len(sys.argv) == 1): | |||||
| print("INFO: Please specify the input parameters, and use '--help' to view the help.") | |||||
| @@ -0,0 +1,9 @@ | |||||
| LOCAL_PATH := $(call my-dir) | |||||
| include $(CLEAR_VARS) | |||||
| LOCAL_MODULE := func2graph/util | |||||
| LOCAL_MODULE_CLASS := FOLDER | |||||
| include $(LOCAL_PATH)/proto_python_rule.mk | |||||
| @@ -0,0 +1,62 @@ | |||||
| syntax = "proto3"; | |||||
| package domi.tensorflow; | |||||
| option cc_enable_arenas = true; | |||||
| option java_outer_classname = "AttrValueProtos"; | |||||
| option java_multiple_files = true; | |||||
| option java_package = "org.tensorflow.framework"; | |||||
| import "tensor.proto"; | |||||
| import "tensor_shape.proto"; | |||||
| import "types.proto"; | |||||
| // Protocol buffer representing the value for an attr used to configure an Op. | |||||
| // Comment indicates the corresponding attr type. Only the field matching the | |||||
| // attr type may be filled. | |||||
| message AttrValue { | |||||
| // LINT.IfChange | |||||
| message ListValue { | |||||
| repeated bytes s = 2; // "list(string)" | |||||
| repeated int64 i = 3 [packed = true]; // "list(int)" | |||||
| repeated float f = 4 [packed = true]; // "list(float)" | |||||
| repeated bool b = 5 [packed = true]; // "list(bool)" | |||||
| repeated DataType type = 6 [packed = true]; // "list(type)" | |||||
| repeated TensorShapeProto shape = 7; // "list(shape)" | |||||
| repeated TensorProto tensor = 8; // "list(tensor)" | |||||
| repeated NameAttrList func = 9; // "list(attr)" | |||||
| } | |||||
| // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc) | |||||
| oneof value { | |||||
| bytes s = 2; // "string" | |||||
| int64 i = 3; // "int" | |||||
| float f = 4; // "float" | |||||
| bool b = 5; // "bool" | |||||
| DataType type = 6; // "type" | |||||
| TensorShapeProto shape = 7; // "shape" | |||||
| TensorProto tensor = 8; // "tensor" | |||||
| ListValue list = 1; // any "list(...)" | |||||
| // "func" represents a function. func.name is a function's name or | |||||
| // a primitive op's name. func.attr.first is the name of an attr | |||||
| // defined for that function. func.attr.second is the value for | |||||
| // that attr in the instantiation. | |||||
| NameAttrList func = 10; | |||||
| // This is a placeholder only used in nodes defined inside a | |||||
| // function. It indicates the attr value will be supplied when | |||||
| // the function is instantiated. For example, let us suppose a | |||||
| // node "N" in function "FN". "N" has an attr "A" with value | |||||
| // placeholder = "foo". When FN is instantiated with attr "foo" | |||||
| // set to "bar", the instantiated node N's attr A will have been | |||||
| // given the value "bar". | |||||
| string placeholder = 9; | |||||
| } | |||||
| } | |||||
| // A list of attr names and their values. The whole list is attached | |||||
| // with a string name. E.g., MatMul[T=float]. | |||||
| message NameAttrList { | |||||
| string name = 1; | |||||
| map<string, AttrValue> attr = 2; | |||||
| } | |||||
| @@ -0,0 +1,100 @@ | |||||
| syntax = "proto3"; | |||||
| package domi.tensorflow; | |||||
| option cc_enable_arenas = true; | |||||
| option java_outer_classname = "FunctionProtos"; | |||||
| option java_multiple_files = true; | |||||
| option java_package = "org.tensorflow.framework"; | |||||
| import "attr_value.proto"; | |||||
| import "node_def.proto"; | |||||
| import "op_def.proto"; | |||||
| // A library is a set of named functions. | |||||
| message FunctionDefLibrary { | |||||
| repeated FunctionDef function = 1; | |||||
| repeated GradientDef gradient = 2; | |||||
| } | |||||
| // A function can be instantiated when the runtime can bind every attr | |||||
| // with a value. When a GraphDef has a call to a function, it must | |||||
| // have binding for every attr defined in the signature. | |||||
| // * device spec, etc. | |||||
| message FunctionDef { | |||||
| // The definition of the function's name, arguments, return values, | |||||
| // attrs etc. | |||||
| OpDef signature = 1; | |||||
| // Attributes specific to this function definition. | |||||
| map<string, AttrValue> attr = 5; | |||||
| // NOTE: field id 2 deleted on Jan 11, 2017, GraphDef version 21. | |||||
| reserved 2; | |||||
| // In both of the following fields, there is the need to specify an | |||||
| // output that is used as either the input to another node (in | |||||
| // `node_def`) or as a return value of the function (in `ret`). | |||||
| // Unlike the NodeDefs in GraphDef, we need to be able to specify a | |||||
| // list in some cases (instead of just single outputs). Also, we | |||||
| // need to be able to deal with lists of unknown length (so the | |||||
| // output index may not be known at function definition time). So | |||||
| // we use the following format instead: | |||||
| // * "fun_in" where "fun_in" is the name of a function input arg in | |||||
| // the `signature` field above. This represents that input, whether | |||||
| // it is a single tensor or a list. | |||||
| // * "fun_in:0" gives the first element of a function input arg (a | |||||
| // non-list input is considered a list of length 1 for these | |||||
| // purposes). | |||||
| // * "node:out" where "node" is the name of a node in `node_def` and | |||||
| // "out" is the name one of its op's output arguments (the name | |||||
| // comes from the OpDef of the node's op). This represents that | |||||
| // node's output, whether it is a single tensor or a list. | |||||
| // Note: We enforce that an op's output arguments are never | |||||
| // renamed in the backwards-compatibility test. | |||||
| // * "node:out:0" gives the first element of a node output arg (a | |||||
| // non-list output is considered a list of length 1 for these | |||||
| // purposes). | |||||
| // | |||||
| // NOT CURRENTLY SUPPORTED (but may be in the future): | |||||
| // * "node:out:-1" gives last element in a node output list | |||||
| // * "node:out:1:" gives a list with all but the first element in a | |||||
| // node output list | |||||
| // * "node:out::-1" gives a list with all but the last element in a | |||||
| // node output list | |||||
| // The body of the function. Unlike the NodeDefs in a GraphDef, attrs | |||||
| // may have values of type `placeholder` and the `input` field uses | |||||
| // the "output" format above. | |||||
| // By convention, "op" in node_def is resolved by consulting with a | |||||
| // user-defined library first. If not resolved, "func" is assumed to | |||||
| // be a builtin op. | |||||
| repeated NodeDef node_def = 3; | |||||
| // A mapping from the output arg names from `signature` to the | |||||
| // outputs from `node_def` that should be returned by the function. | |||||
| map<string, string> ret = 4; | |||||
| } | |||||
| // GradientDef defines the gradient function of a function defined in | |||||
| // a function library. | |||||
| // | |||||
| // A gradient function g (specified by gradient_func) for a function f | |||||
| // (specified by function_name) must follow the following: | |||||
| // | |||||
| // The function 'f' must be a numerical function which takes N inputs | |||||
| // and produces M outputs. Its gradient function 'g', which is a | |||||
| // function taking N + M inputs and produces N outputs. | |||||
| // | |||||
| // I.e. if we have | |||||
| // (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), | |||||
| // then, g is | |||||
| // (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, | |||||
| // dL/dy1, dL/dy2, ..., dL/dy_M), | |||||
| // where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the | |||||
| // loss function). dL/dx_i is the partial derivative of L with respect | |||||
| // to x_i. | |||||
| message GradientDef { | |||||
| string function_name = 1; // The function name. | |||||
| string gradient_func = 2; // The gradient function's name. | |||||
| } | |||||
| @@ -0,0 +1,56 @@ | |||||
| syntax = "proto3"; | |||||
| package domi.tensorflow; | |||||
| option cc_enable_arenas = true; | |||||
| option java_outer_classname = "GraphProtos"; | |||||
| option java_multiple_files = true; | |||||
| option java_package = "org.tensorflow.framework"; | |||||
| import "node_def.proto"; | |||||
| import "function.proto"; | |||||
| import "versions.proto"; | |||||
| // Represents the graph of operations | |||||
| message GraphDef { | |||||
| repeated NodeDef node = 1; | |||||
| // Compatibility versions of the graph. See core/public/version.h for version | |||||
| // history. The GraphDef version is distinct from the TensorFlow version, and | |||||
| // each release of TensorFlow will support a range of GraphDef versions. | |||||
| VersionDef versions = 4; | |||||
| // Deprecated single version field; use versions above instead. Since all | |||||
| // GraphDef changes before "versions" was introduced were forward | |||||
| // compatible, this field is entirely ignored. | |||||
| int32 version = 3 [deprecated = true]; | |||||
| // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. | |||||
| // | |||||
| // "library" provides user-defined functions. | |||||
| // | |||||
| // Naming: | |||||
| // * library.function.name are in a flat namespace. | |||||
| // NOTE: We may need to change it to be hierarchical to support | |||||
| // different orgs. E.g., | |||||
| // { "/google/nn", { ... }}, | |||||
| // { "/google/vision", { ... }} | |||||
| // { "/org_foo/module_bar", { ... }} | |||||
| // map<string, FunctionDefLib> named_lib; | |||||
| // * If node[i].op is the name of one function in "library", | |||||
| // node[i] is deemed as a function call. Otherwise, node[i].op | |||||
| // must be a primitive operation supported by the runtime. | |||||
| // | |||||
| // | |||||
| // Function call semantics: | |||||
| // | |||||
| // * The callee may start execution as soon as some of its inputs | |||||
| // are ready. The caller may want to use Tuple() mechanism to | |||||
| // ensure all inputs are ready in the same time. | |||||
| // | |||||
| // * The consumer of return values may start executing as soon as | |||||
| // the return values the consumer depends on are ready. The | |||||
| // consumer may want to use Tuple() mechanism to ensure the | |||||
| // consumer does not start until all return values of the callee | |||||
| // function are ready. | |||||
| FunctionDefLibrary library = 2; | |||||
| }; | |||||
| @@ -0,0 +1,14 @@ | |||||
| syntax = "proto3"; | |||||
| package domi.tensorflow; | |||||
| import "graph.proto"; | |||||
| message GeGraphDef { | |||||
| string name = 1; | |||||
| GraphDef graph = 2; | |||||
| } | |||||
| message GraphDefLibrary { | |||||
| repeated GeGraphDef graph_def = 1; | |||||
| }; | |||||
| @@ -0,0 +1,63 @@ | |||||
| syntax = "proto3"; | |||||
| package domi.tensorflow; | |||||
| option cc_enable_arenas = true; | |||||
| option java_outer_classname = "NodeProto"; | |||||
| option java_multiple_files = true; | |||||
| option java_package = "org.tensorflow.framework"; | |||||
| import "attr_value.proto"; | |||||
| message NodeDef { | |||||
| // The name given to this operator. Used for naming inputs, | |||||
| // logging, visualization, etc. Unique within a single GraphDef. | |||||
| // Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*". | |||||
| string name = 1; | |||||
| // The operation name. There may be custom parameters in attrs. | |||||
| // Op names starting with an underscore are reserved for internal use. | |||||
| string op = 2; | |||||
| // Each input is "node:src_output" with "node" being a string name and | |||||
| // "src_output" indicating which output tensor to use from "node". If | |||||
| // "src_output" is 0 the ":0" suffix can be omitted. Regular inputs | |||||
| // may optionally be followed by control inputs that have the format | |||||
| // "^node". | |||||
| repeated string input = 3; | |||||
| // A (possibly partial) specification for the device on which this | |||||
| // node should be placed. | |||||
| // The expected syntax for this string is as follows: | |||||
| // | |||||
| // DEVICE_SPEC ::= PARTIAL_SPEC | |||||
| // | |||||
| // PARTIAL_SPEC ::= ("/" CONSTRAINT) * | |||||
| // CONSTRAINT ::= ("job:" JOB_NAME) | |||||
| // | ("replica:" [1-9][0-9]*) | |||||
| // | ("task:" [1-9][0-9]*) | |||||
| // | ("device:" [A-Za-z]* ":" ([1-9][0-9]* | "*") ) | |||||
| // | |||||
| // Valid values for this string include: | |||||
| // * "/job:worker/replica:0/task:1/device:GPU:3" (full specification) | |||||
| // * "/job:worker/device:GPU:3" (partial specification) | |||||
| // * "" (no specification) | |||||
| // | |||||
| // If the constraints do not resolve to a single device (or if this | |||||
| // field is empty or not present), the runtime will attempt to | |||||
| // choose a device automatically. | |||||
| string device = 4; | |||||
| // Operation-specific graph-construction-time configuration. | |||||
| // Note that this should include all attrs defined in the | |||||
| // corresponding OpDef, including those with a value matching | |||||
| // the default -- this allows the default to change and makes | |||||
| // NodeDefs easier to interpret on their own. However, if | |||||
| // an attr with a default is not specified in this list, the | |||||
| // default will be used. | |||||
| // The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and | |||||
| // one of the names from the corresponding OpDef's attr field). | |||||
| // The values must have a type matching the corresponding OpDef | |||||
| // attr's type field. | |||||
| // Add some examples here showing best practices. | |||||
| map<string, AttrValue> attr = 5; | |||||
| }; | |||||
| @@ -0,0 +1,164 @@ | |||||
| syntax = "proto3"; | |||||
| package domi.tensorflow; | |||||
| option cc_enable_arenas = true; | |||||
| option java_outer_classname = "OpDefProtos"; | |||||
| option java_multiple_files = true; | |||||
| option java_package = "org.tensorflow.framework"; | |||||
| import "attr_value.proto"; | |||||
| import "types.proto"; | |||||
| // Defines an operation. A NodeDef in a GraphDef specifies an Op by | |||||
| // using the "op" field which should match the name of a OpDef. | |||||
| // LINT.IfChange | |||||
| message OpDef { | |||||
| // Op names starting with an underscore are reserved for internal use. | |||||
| // Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*". | |||||
| string name = 1; | |||||
| // For describing inputs and outputs. | |||||
| message ArgDef { | |||||
| // Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*". | |||||
| string name = 1; | |||||
| // Human readable description. | |||||
| string description = 2; | |||||
| // Describes the type of one or more tensors that are accepted/produced | |||||
| // by this input/output arg. The only legal combinations are: | |||||
| // * For a single tensor: either the "type" field is set or the | |||||
| // "type_attr" field is set to the name of an attr with type "type". | |||||
| // * For a sequence of tensors with the same type: the "number_attr" | |||||
| // field will be set to the name of an attr with type "int", and | |||||
| // either the "type" or "type_attr" field will be set as for | |||||
| // single tensors. | |||||
| // * For a sequence of tensors, the "type_list_attr" field will be set | |||||
| // to the name of an attr with type "list(type)". | |||||
| DataType type = 3; | |||||
| string type_attr = 4; // if specified, attr must have type "type" | |||||
| string number_attr = 5; // if specified, attr must have type "int" | |||||
| // If specified, attr must have type "list(type)", and none of | |||||
| // type, type_attr, and number_attr may be specified. | |||||
| string type_list_attr = 6; | |||||
| // For inputs: if true, the inputs are required to be refs. | |||||
| // By default, inputs can be either refs or non-refs. | |||||
| // For outputs: if true, outputs are refs, otherwise they are not. | |||||
| bool is_ref = 16; | |||||
| }; | |||||
| // Description of the input(s). | |||||
| repeated ArgDef input_arg = 2; | |||||
| // Description of the output(s). | |||||
| repeated ArgDef output_arg = 3; | |||||
| // Description of the graph-construction-time configuration of this | |||||
| // Op. That is to say, this describes the attr fields that will | |||||
| // be specified in the NodeDef. | |||||
| message AttrDef { | |||||
| // A descriptive name for the argument. May be used, e.g. by the | |||||
| // Python client, as a keyword argument name, and so should match | |||||
| // the regexp "[a-z][a-z0-9_]+". | |||||
| string name = 1; | |||||
| // One of the type names from attr_value.proto ("string", "list(string)", | |||||
| // "int", etc.). | |||||
| string type = 2; | |||||
| // A reasonable default for this attribute if the user does not supply | |||||
| // a value. If not specified, the user must supply a value. | |||||
| AttrValue default_value = 3; | |||||
| // Human-readable description. | |||||
| string description = 4; | |||||
| // --- Constraints --- | |||||
| // These constraints are only in effect if specified. Default is no | |||||
| // constraints. | |||||
| // For type == "int", this is a minimum value. For "list(___)" | |||||
| // types, this is the minimum length. | |||||
| bool has_minimum = 5; | |||||
| int64 minimum = 6; | |||||
| // The set of allowed values. Has type that is the "list" version | |||||
| // of the "type" field above (uses the "list" field of AttrValue). | |||||
| // If type == "type" or "list(type)" above, then the "type" field | |||||
| // of "allowed_values.list" has the set of allowed DataTypes. | |||||
| // If type == "string" or "list(string)", then the "s" field of | |||||
| // "allowed_values.list" has the set of allowed strings. | |||||
| AttrValue allowed_values = 7; | |||||
| } | |||||
| repeated AttrDef attr = 4; | |||||
| // Optional deprecation based on GraphDef versions. | |||||
| OpDeprecation deprecation = 8; | |||||
| // One-line human-readable description of what the Op does. | |||||
| string summary = 5; | |||||
| // Additional, longer human-readable description of what the Op does. | |||||
| string description = 6; | |||||
| // ------------------------------------------------------------------------- | |||||
| // Which optimizations this operation can participate in. | |||||
| // True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs) | |||||
| bool is_commutative = 18; | |||||
| // If is_aggregate is true, then this operation accepts N >= 2 | |||||
| // inputs and produces 1 output all of the same type. Should be | |||||
| // associative and commutative, and produce output with the same | |||||
| // shape as the input. The optimizer may replace an aggregate op | |||||
| // taking input from multiple devices with a tree of aggregate ops | |||||
| // that aggregate locally within each device (and possibly within | |||||
| // groups of nearby devices) before communicating. | |||||
| bool is_aggregate = 16; // for things like add | |||||
| // Other optimizations go here, like | |||||
| // can_alias_input, rewrite_when_output_unused, partitioning_strategy, etc. | |||||
| // ------------------------------------------------------------------------- | |||||
| // Optimization constraints. | |||||
| // Ops are marked as stateful if their behavior depends on some state beyond | |||||
| // their input tensors (e.g. variable reading op) or if they have | |||||
| // a side-effect (e.g. printing or asserting ops). Equivalently, stateless ops | |||||
| // must always produce the same output for the same input and have | |||||
| // no side-effects. | |||||
| // | |||||
| // By default Ops may be moved between devices. Stateful ops should | |||||
| // either not be moved, or should only be moved if that state can also | |||||
| // be moved (e.g. via some sort of save / restore). | |||||
| // Stateful ops are guaranteed to never be optimized away by Common | |||||
| // Subexpression Elimination (CSE). | |||||
| bool is_stateful = 17; // for things like variables, queue | |||||
| // ------------------------------------------------------------------------- | |||||
| // Non-standard options. | |||||
| // By default, all inputs to an Op must be initialized Tensors. Ops | |||||
| // that may initialize tensors for the first time should set this | |||||
| // field to true, to allow the Op to take an uninitialized Tensor as | |||||
| // input. | |||||
| bool allows_uninitialized_input = 19; // for Assign, etc. | |||||
| }; | |||||
| // LINT.ThenChange( | |||||
| // https://www.tensorflow.org/code/tensorflow/core/framework/op_def_util.cc) | |||||
| // Information about version-dependent deprecation of an op | |||||
| message OpDeprecation { | |||||
| // First GraphDef version at which the op is disallowed. | |||||
| int32 version = 1; | |||||
| // Explanation of why it was deprecated and what to use instead. | |||||
| string explanation = 2; | |||||
| }; | |||||
| // A collection of OpDefs | |||||
| message OpList { | |||||
| repeated OpDef op = 1; | |||||
| }; | |||||
| @@ -0,0 +1,29 @@ | |||||
| syntax = "proto3"; | |||||
| package domi.tensorflow; | |||||
| option cc_enable_arenas = true; | |||||
| option java_outer_classname = "ResourceHandle"; | |||||
| option java_multiple_files = true; | |||||
| option java_package = "org.tensorflow.framework"; | |||||
| // Protocol buffer representing a handle to a tensorflow resource. Handles are | |||||
| // not valid across executions, but can be serialized back and forth from within | |||||
| // a single run. | |||||
| message ResourceHandleProto { | |||||
| // Unique name for the device containing the resource. | |||||
| string device = 1; | |||||
| // Container in which this resource is placed. | |||||
| string container = 2; | |||||
| // Unique name of this resource. | |||||
| string name = 3; | |||||
| // Hash code for the type of the resource. Is only valid in the same device | |||||
| // and in the same execution. | |||||
| uint64 hash_code = 4; | |||||
| // For debug-only, the name of the type pointed to by this handle, if | |||||
| // available. | |||||
| string maybe_type_name = 5; | |||||
| }; | |||||
| @@ -0,0 +1,94 @@ | |||||
| syntax = "proto3"; | |||||
| package domi.tensorflow; | |||||
| option cc_enable_arenas = true; | |||||
| option java_outer_classname = "TensorProtos"; | |||||
| option java_multiple_files = true; | |||||
| option java_package = "org.tensorflow.framework"; | |||||
| import "resource_handle.proto"; | |||||
| import "tensor_shape.proto"; | |||||
| import "types.proto"; | |||||
| // Protocol buffer representing a tensor. | |||||
| message TensorProto { | |||||
| DataType dtype = 1; | |||||
| // Shape of the tensor. | |||||
| TensorShapeProto tensor_shape = 2; | |||||
| // Only one of the representations below is set, one of "tensor_contents" and | |||||
| // the "xxx_val" attributes. We are not using oneof because as oneofs cannot | |||||
| // contain repeated fields it would require another extra set of messages. | |||||
| // Version number. | |||||
| // | |||||
| // In version 0, if the "repeated xxx" representations contain only one | |||||
| // element, that element is repeated to fill the shape. This makes it easy | |||||
| // to represent a constant Tensor with a single value. | |||||
| int32 version_number = 3; | |||||
| // Serialized raw tensor content from either Tensor::AsProtoTensorContent or | |||||
| // memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation | |||||
| // can be used for all tensor types. The purpose of this representation is to | |||||
| // reduce serialization overhead during RPC call by avoiding serialization of | |||||
| // many repeated small items. | |||||
| bytes tensor_content = 4; | |||||
| // Type specific representations that make it easy to create tensor protos in | |||||
| // all languages. Only the representation corresponding to "dtype" can | |||||
| // be set. The values hold the flattened representation of the tensor in | |||||
| // row major order. | |||||
| // DT_HALF, DT_BFLOAT16. Note that since protobuf has no int16 type, we'll | |||||
| // have some pointless zero padding for each value here. | |||||
| repeated int32 half_val = 13 [packed = true]; | |||||
| // DT_FLOAT. | |||||
| repeated float float_val = 5 [packed = true]; | |||||
| // DT_DOUBLE. | |||||
| repeated double double_val = 6 [packed = true]; | |||||
| // DT_INT32, DT_INT16, DT_INT8, DT_UINT8. | |||||
| repeated int32 int_val = 7 [packed = true]; | |||||
| // DT_STRING | |||||
| repeated bytes string_val = 8; | |||||
| // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real | |||||
| // and imaginary parts of i-th single precision complex. | |||||
| repeated float scomplex_val = 9 [packed = true]; | |||||
| // DT_INT64 | |||||
| repeated int64 int64_val = 10 [packed = true]; | |||||
| // DT_BOOL | |||||
| repeated bool bool_val = 11 [packed = true]; | |||||
| // DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real | |||||
| // and imaginary parts of i-th double precision complex. | |||||
| repeated double dcomplex_val = 12 [packed = true]; | |||||
| // DT_RESOURCE | |||||
| repeated ResourceHandleProto resource_handle_val = 14; | |||||
| // DT_VARIANT | |||||
| repeated VariantTensorDataProto variant_val = 15; | |||||
| // DT_UINT32 | |||||
| repeated uint32 uint32_val = 16 [packed = true]; | |||||
| // DT_UINT64 | |||||
| repeated uint64 uint64_val = 17 [packed = true]; | |||||
| }; | |||||
| // Protocol buffer representing the serialization format of DT_VARIANT tensors. | |||||
| message VariantTensorDataProto { | |||||
| // Name of the type of objects being serialized. | |||||
| string type_name = 1; | |||||
| // Portions of the object that are not Tensors. | |||||
| bytes metadata = 2; | |||||
| // Tensors contained within objects being serialized. | |||||
| repeated TensorProto tensors = 3; | |||||
| } | |||||