Merge pull request !2421 from yankai10/merge_master_new_fearturetags/v0.5.0-beta
| @@ -63,7 +63,7 @@ class MS_API MSTensor { | |||
| // return A pointer points to data in MSTensor. | |||
| virtual void *MutableData() const = 0; | |||
| }; | |||
| using MultiTensor = std::vector<std::vector<std::shared_ptr<inference::MSTensor>>>; | |||
| using MultiTensor = std::vector<std::shared_ptr<inference::MSTensor>>; | |||
| } // namespace inference | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_INCLUDE_MS_TENSOR_H_ | |||
| @@ -217,6 +217,7 @@ class CNode : public AnfNode { | |||
| void set_stop_gradient(bool stop_gradient) { stop_gradient_ = stop_gradient; } | |||
| std::string fullname_with_scope() override; | |||
| void set_fullname_with_scope(const std::string full_name) { fullname_with_scope_ = full_name; } | |||
| std::string DebugString(int recursive_level = 1) const override; | |||
| std::string DebugString(bool recursive) const override { return DebugString(recursive ? 1 : 0); } | |||
| @@ -23,6 +23,7 @@ if (ENABLE_D) | |||
| file(GLOB_RECURSE _D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| "ascend_session.cc" | |||
| "ascend_control_parser.cc" | |||
| "ascend_inference_session.cc" | |||
| ) | |||
| list(APPEND _SESSION_SRC_LIST ${_D_SRC_LIST}) | |||
| endif () | |||
| @@ -0,0 +1,90 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "session/ascend_inference_session.h" | |||
| #include "operator/ops.h" | |||
| #include "ir/tensor.h" | |||
| #include "ir/anf.h" | |||
| #include "ir/param_value_py.h" | |||
| #include "device/kernel_runtime.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "common/utils.h" | |||
| #include "common/trans.h" | |||
| #include "kernel/tbe/tbe_python_funcs.h" | |||
| #include "utils/config_manager.h" | |||
| #include "utils/base_ref_extends.h" | |||
| namespace mindspore { | |||
| namespace session { | |||
| void AscendInferenceSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, | |||
| const std::vector<tensor::TensorPtr> &inputs_const) const { | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| std::vector<tensor::TensorPtr> inputs(inputs_const); | |||
| auto input_nodes = kernel_graph->inputs(); | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| size_t no_weight_input = 0; | |||
| for (size_t i = 0; i < input_nodes.size(); ++i) { | |||
| tensor::TensorPtr tensor = nullptr; | |||
| if (!input_nodes[i]->isa<Parameter>()) { | |||
| MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter"; | |||
| continue; | |||
| } | |||
| auto pk_node = input_nodes[i]->cast<ParameterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(pk_node); | |||
| if (AnfAlgo::IsParameterWeight(pk_node)) { | |||
| auto param_value = std::dynamic_pointer_cast<ParamValuePy>(pk_node->default_param()); | |||
| MS_EXCEPTION_IF_NULL(param_value); | |||
| auto py_param = param_value->value(); | |||
| MS_EXCEPTION_IF_NULL(py_param); | |||
| py::array py_array = py_param.cast<py::array>(); | |||
| tensor = std::make_shared<tensor::Tensor>(py_array); | |||
| } else { | |||
| tensor = inputs[no_weight_input++]; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| if (AnfAlgo::OutputAddrExist(pk_node, 0)) { | |||
| auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); | |||
| bool need_sync = false; | |||
| if (ms_context->enable_pynative_infer()) { | |||
| if (tensor->device_address().get() == nullptr || tensor->device_address() != device_address) { | |||
| need_sync = true; | |||
| } | |||
| } else { | |||
| if (tensor->is_dirty()) { | |||
| need_sync = true; | |||
| } else if (tensor->device_address() != device_address) { | |||
| (void)tensor->data_sync(); | |||
| need_sync = true; | |||
| } | |||
| } | |||
| if (need_sync) { | |||
| if (ms_context->execution_mode() == kPynativeMode || AnfAlgo::IsParameterWeight(pk_node)) { | |||
| tensor->set_device_address(device_address); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(device_address); | |||
| if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), | |||
| LongToSize(tensor->data().nbytes()), tensor->data_type(), | |||
| tensor->data_c(false))) { | |||
| MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; | |||
| } | |||
| } | |||
| } | |||
| tensor->set_dirty(false); | |||
| } | |||
| } | |||
| } // namespace session | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,45 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_SESSION_ASCEND_INFERENCE_SESSION_H | |||
| #define MINDSPORE_CCSRC_SESSION_ASCEND_INFERENCE_SESSION_H | |||
| #include <unordered_map> | |||
| #include <string> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <utility> | |||
| #include <stack> | |||
| #include <map> | |||
| #include <tuple> | |||
| #include <set> | |||
| #include "session/ascend_session.h" | |||
| #include "session/kernel_graph.h" | |||
| #include "kernel/kernel.h" | |||
| #include "session/session_factory.h" | |||
| #include "session/ascend_control_parser.h" | |||
| namespace mindspore { | |||
| namespace session { | |||
| class AscendInferenceSession : public AscendSession { | |||
| public: | |||
| AscendInferenceSession() = default; | |||
| ~AscendInferenceSession() = default; | |||
| void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, | |||
| const std::vector<tensor::TensorPtr> &inputs_const) const; | |||
| }; | |||
| MS_REG_SESSION(kDavinciInferenceDevice, AscendInferenceSession); | |||
| } // namespace session | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_SESSION_ASCEND_INFERENCE_SESSION_H | |||
| @@ -124,7 +124,7 @@ MultiTensor Session::RunGraph(uint32_t graph_id, const std::vector<std::shared_p | |||
| }); | |||
| if (has_error) { | |||
| MS_LOG(ERROR) << "Init Tensor failed, returning empty result"; | |||
| std::vector<std::vector<std::shared_ptr<inference::MSTensor>>> multiTensor; | |||
| std::vector<std::shared_ptr<inference::MSTensor>> multiTensor; | |||
| return multiTensor; | |||
| } | |||
| VectorRef outputs; | |||
| @@ -135,6 +135,9 @@ MultiTensor Session::RunGraph(uint32_t graph_id, const std::vector<std::shared_p | |||
| int Session::Init(const std::string &device, uint32_t device_id) { | |||
| RegAllOp(); | |||
| auto ms_context = MsContext::GetInstance(); | |||
| ms_context->set_execution_mode(kGraphMode); | |||
| ms_context->set_device_target(kAscendDevice); | |||
| session_impl_ = session::SessionFactory::Get().Create(device); | |||
| if (session_impl_ == nullptr) { | |||
| MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << device << " is available."; | |||
| @@ -619,6 +619,7 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP | |||
| auto new_cnode = CreateNewCNode(cnode, graph.get()); | |||
| MS_EXCEPTION_IF_NULL(new_cnode); | |||
| new_cnode->set_abstract(cnode->abstract()); | |||
| new_cnode->set_fullname_with_scope(cnode->fullname_with_scope()); | |||
| new_cnode->set_scope(cnode->scope()); | |||
| graph->FrontBackendlMapAdd(node, new_cnode); | |||
| if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimReturn)) { | |||
| @@ -21,20 +21,27 @@ | |||
| #include "ir/tensor.h" | |||
| namespace mindspore { | |||
| std::vector<std::shared_ptr<inference::MSTensor>> TransformBaseRefToMSTensor(const BaseRef &base_ref) { | |||
| void IterateFindTensor(std::vector<std::shared_ptr<inference::MSTensor>> *msTensors, const VectorRef &ref_list) { | |||
| for (size_t i = 0; i < ref_list.size(); ++i) { | |||
| if (utils::isa<tensor::TensorPtr>(ref_list[i])) { | |||
| auto tensor_ptr = utils::cast<std::shared_ptr<tensor::Tensor>>(ref_list[i]); | |||
| MS_EXCEPTION_IF_NULL(tensor_ptr); | |||
| auto tensor = new inference::Tensor(tensor_ptr); | |||
| msTensors->emplace_back(std::shared_ptr<inference::MSTensor>(tensor)); | |||
| } else if (utils::isa<VectorRef>(ref_list[i])) { | |||
| auto ref_iter = utils::cast<VectorRef>(ref_list[i]); | |||
| IterateFindTensor(msTensors, ref_iter); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "The output is not a tensor"; | |||
| } | |||
| } | |||
| } | |||
| std::vector<std::shared_ptr<inference::MSTensor>> TransformVectorRefToMultiTensor(const VectorRef &base_ref) { | |||
| std::vector<std::shared_ptr<inference::MSTensor>> msTensors; | |||
| if (utils::isa<VectorRef>(base_ref)) { | |||
| auto ref_list = utils::cast<VectorRef>(base_ref); | |||
| for (size_t i = 0; i < ref_list.size(); ++i) { | |||
| if (utils::isa<tensor::Tensor>(ref_list[i])) { | |||
| auto tensor_ptr = utils::cast<std::shared_ptr<tensor::Tensor>>(ref_list[i]); | |||
| MS_EXCEPTION_IF_NULL(tensor_ptr); | |||
| auto tensor = new inference::Tensor(tensor_ptr); | |||
| msTensors.emplace_back(std::shared_ptr<inference::MSTensor>(tensor)); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "The output is not a tensor!"; | |||
| } | |||
| } | |||
| IterateFindTensor(&msTensors, ref_list); | |||
| } else if (utils::isa<tensor::Tensor>(base_ref)) { | |||
| auto tensor_ptr = utils::cast<std::shared_ptr<tensor::Tensor>>(base_ref); | |||
| MS_EXCEPTION_IF_NULL(tensor_ptr); | |||
| @@ -45,14 +52,4 @@ std::vector<std::shared_ptr<inference::MSTensor>> TransformBaseRefToMSTensor(con | |||
| } | |||
| return msTensors; | |||
| } | |||
| std::vector<std::vector<std::shared_ptr<inference::MSTensor>>> TransformVectorRefToMultiTensor( | |||
| const VectorRef &vector_ref) { | |||
| std::vector<std::vector<std::shared_ptr<inference::MSTensor>>> multiTensor; | |||
| for (size_t i = 0; i < vector_ref.size(); ++i) { | |||
| auto tensors = TransformBaseRefToMSTensor(vector_ref[i]); | |||
| multiTensor.emplace_back(tensors); | |||
| } | |||
| return multiTensor; | |||
| } | |||
| } // namespace mindspore | |||
| @@ -22,9 +22,6 @@ | |||
| #ifndef MINDSPORE_CCSRC_UTILS_BASE_REF_UTILS_H | |||
| #define MINDSPORE_CCSRC_UTILS_BASE_REF_UTILS_H | |||
| namespace mindspore { | |||
| std::vector<std::shared_ptr<inference::MSTensor>> TransformBaseRefToMSTensor(const BaseRef &base_ref); | |||
| std::vector<std::vector<std::shared_ptr<inference::MSTensor>>> TransformVectorRefToMultiTensor( | |||
| const VectorRef &vector_ref); | |||
| std::vector<std::shared_ptr<inference::MSTensor>> TransformVectorRefToMultiTensor(const VectorRef &base_ref); | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_UTILS_BASE_REF_UTILS_H | |||
| @@ -41,6 +41,7 @@ const int kPynativeMode = 1; | |||
| const char kCPUDevice[] = "CPU"; | |||
| const char kGPUDevice[] = "GPU"; | |||
| const char kAscendDevice[] = "Ascend"; | |||
| const char kDavinciInferenceDevice[] = "AscendInference"; | |||
| const char kDavinciDevice[] = "Davinci"; | |||
| const char KNpuLog[] = "_npu_log"; | |||
| const std::set<std::string> kTargetSet = {kCPUDevice, kGPUDevice, kAscendDevice, kDavinciDevice}; | |||
| @@ -96,8 +96,6 @@ std::shared_ptr<FuncGraph> AnfConverter::RunAnfConverter(const std::string &file | |||
| ReadOnnxFromBinary(modelFile, &model_); | |||
| MSANFModelParser model_parser; | |||
| FuncGraphPtr dstgraph_ptr = model_parser.Parse(model_); | |||
| MS_EXCEPTION_IF_NULL(dstgraph_ptr); | |||
| TestFuncGraphBuild(dstgraph_ptr); | |||
| return dstgraph_ptr; | |||
| } | |||
| @@ -111,33 +109,7 @@ std::shared_ptr<FuncGraph> AnfConverter::RunAnfConverter(const char *buf, const | |||
| } | |||
| MSANFModelParser model_parser; | |||
| FuncGraphPtr dstgraph_ptr = model_parser.Parse(model_); | |||
| MS_EXCEPTION_IF_NULL(dstgraph_ptr); | |||
| TestFuncGraphBuild(dstgraph_ptr); | |||
| return dstgraph_ptr; | |||
| } | |||
| int AnfConverter::TestFuncGraphBuild(const FuncGraphPtr &graph) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| auto node_return = graph->get_return(); | |||
| std::vector<AnfNodePtr> node_list = TopoSort(node_return); | |||
| MS_LOG(INFO) << "node_list size is : " << node_list.size(); | |||
| for (auto &node : node_list) { | |||
| if (node->isa<CNode>()) { | |||
| auto node_CN = node->cast<CNodePtr>(); | |||
| MS_LOG(INFO) << "CN node: " << node_CN->input(0)->ToString() << ", input size :" << node_CN->size(); | |||
| } else if (node->isa<Parameter>()) { | |||
| auto node_Para = node->cast<ParameterPtr>(); | |||
| if (node_Para->has_default()) { | |||
| MS_LOG(INFO) << "Parameter node: " << node_Para->name() << "has default value!"; | |||
| } else { | |||
| MS_LOG(INFO) << "Parameter node: " << node_Para->name(); | |||
| } | |||
| } else if (node->isa<ValueNode>()) { | |||
| auto node_Value = node->cast<ValueNodePtr>(); | |||
| MS_LOG(INFO) << "Value node: " << node_Value->ToString(); | |||
| } | |||
| } | |||
| return 0; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -26,7 +26,6 @@ namespace mindspore { | |||
| namespace lite { | |||
| class AnfConverter { | |||
| public: | |||
| static int TestFuncGraphBuild(const FuncGraphPtr &graph); | |||
| static std::shared_ptr<FuncGraph> RunAnfConverter(const std::string &file_path); | |||
| static std::shared_ptr<FuncGraph> RunAnfConverter(const char *buf, const size_t buf_size); | |||
| @@ -14,16 +14,17 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "utils/load_onnx/anf_model_parser.h" | |||
| #include <functional> | |||
| #include <map> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "utils/load_onnx/anf_model_parser.h" | |||
| #include "google/protobuf/io/zero_copy_stream_impl.h" | |||
| #include "ir/tensor.h" | |||
| #include "ir/param_value_py.h" | |||
| #include "operator/ops.h" | |||
| #include "pipeline/static_analysis/abstract_value.h" | |||
| #include "proto/onnx.pb.h" | |||
| #include "utils/log_adapter.h" | |||
| @@ -33,6 +34,8 @@ namespace mindspore { | |||
| namespace lite { | |||
| static constexpr char kConstantValueNode[] = "Constant"; | |||
| static constexpr char kCNodeShapeAttr[] = "shape"; | |||
| static constexpr char kCNodeShape1Attr[] = "shape1"; | |||
| static constexpr char kCNodeShape2Attr[] = "shape2"; | |||
| enum ParseForm : int { | |||
| FORM_PARSE_TYPE = 0, | |||
| FORM_PARSE_SCALAR = 1, | |||
| @@ -56,14 +59,15 @@ static std::unordered_map<int, TypeId> kDefaultValueSwitchMap{ | |||
| void ParseAttrInScalar_##type##_##valuetype(const PrimitivePtr &prim, const std::string &attr_name, \ | |||
| const onnx::TensorProto &attr_tensor) { \ | |||
| MS_EXCEPTION_IF_NULL(prim); \ | |||
| std::vector<valuetype> attr_value_vec; \ | |||
| std::vector<ValuePtr> attr_value_vec; \ | |||
| for (int i = 0; i < attr_tensor.type##_data_size(); ++i) { \ | |||
| attr_value_vec.push_back(static_cast<valuetype>(attr_tensor.type##_data(i))); \ | |||
| auto value = static_cast<valuetype>(attr_tensor.type##_data(i)); \ | |||
| attr_value_vec.push_back(MakeValue<valuetype>(value)); \ | |||
| } \ | |||
| if (attr_value_vec.size() == 1) { \ | |||
| prim->AddAttr(attr_name, MakeValue<valuetype>(attr_value_vec[0])); \ | |||
| prim->AddAttr(attr_name, attr_value_vec[0]); \ | |||
| } else { \ | |||
| prim->AddAttr(attr_name, MakeValue<std::vector<valuetype>>(attr_value_vec)); \ | |||
| prim->AddAttr(attr_name, std::make_shared<ValueList>(attr_value_vec)); \ | |||
| } \ | |||
| } | |||
| @@ -247,17 +251,12 @@ bool MSANFModelParser::ObtainValueNodeInTensorForm(const std::string &value_node | |||
| const std::string &tensor_buf = attr_tensor.raw_data(); | |||
| auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c(true)); | |||
| memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), tensor_buf.data(), tensor_buf.size()); | |||
| if (attr_tensor_type == onnx::TensorProto_DataType_FLOAT) { | |||
| auto *data_valuennode = reinterpret_cast<float *>(tensor_info->data_c()); | |||
| MS_EXCEPTION_IF_NULL(data_valuennode); | |||
| auto new_value_node = std::make_shared<ValueNode>(MakeValue(*data_valuennode)); | |||
| anfnode_build_map_[value_node_name] = new_value_node; | |||
| } else { | |||
| auto *data_valuenode = reinterpret_cast<int32 *>(tensor_info->data_c()); | |||
| MS_EXCEPTION_IF_NULL(data_valuenode); | |||
| auto new_value_node = std::make_shared<ValueNode>(MakeValue(*data_valuenode)); | |||
| anfnode_build_map_[value_node_name] = new_value_node; | |||
| } | |||
| auto new_value_node = NewValueNode(MakeValue(tensor_info)); | |||
| MS_EXCEPTION_IF_NULL(new_value_node); | |||
| auto tensor_abstract = tensor_info->ToAbstract(); | |||
| MS_EXCEPTION_IF_NULL(tensor_abstract); | |||
| new_value_node->set_abstract(tensor_abstract); | |||
| anfnode_build_map_[value_node_name] = new_value_node; | |||
| return true; | |||
| } | |||
| @@ -315,7 +314,9 @@ bool MSANFModelParser::ObtainValueNodeInTypeForm(const std::string &value_node_n | |||
| MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type; | |||
| return false; | |||
| } | |||
| auto new_value_node = std::make_shared<ValueNode>(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); | |||
| auto new_value_node = NewValueNode(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); | |||
| abstract::AbstractTypePtr abs_type = std::make_shared<abstract::AbstractType>(std::make_shared<TypeType>()); | |||
| new_value_node->set_abstract(abs_type); | |||
| anfnode_build_map_[value_node_name] = new_value_node; | |||
| return true; | |||
| } | |||
| @@ -361,31 +362,45 @@ AbstractBasePtr MSANFModelParser::GetAbstractForCNode(const onnx::AttributeProto | |||
| tensor::TensorPtr tensor_info = | |||
| std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor.data_type()], shape_vec); | |||
| MS_EXCEPTION_IF_NULL(tensor_info); | |||
| return tensor_info->ToAbstract(); | |||
| auto abstract = tensor_info->ToAbstract(); | |||
| MS_EXCEPTION_IF_NULL(abstract); | |||
| return abstract; | |||
| } | |||
| bool MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto, | |||
| const onnx::GraphProto &importProto, const bool &ret_flag) { | |||
| CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, | |||
| const onnx::NodeProto &node_proto) { | |||
| MS_EXCEPTION_IF_NULL(outputFuncGraph); | |||
| if (!node_proto.has_op_type()) { | |||
| MS_LOG(ERROR) << "Get CNode op_type failed!"; | |||
| return false; | |||
| return nullptr; | |||
| } | |||
| const std::string &node_name = node_proto.output(0); | |||
| const std::string &fullname_with_scope = node_proto.domain(); | |||
| const std::string &node_type = node_proto.op_type(); | |||
| PrimitivePtr prim = std::make_shared<Primitive>(node_type); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| prim->set_instance_name(node_type); | |||
| AbstractBasePtr abstract; | |||
| AbstractBasePtr abstract = nullptr; | |||
| AbstractBasePtr abstract_first = nullptr; | |||
| AbstractBasePtr abstract_second = nullptr; | |||
| for (int i = 0; i < node_proto.attribute_size(); ++i) { | |||
| const onnx::AttributeProto &attr_proto = node_proto.attribute(i); | |||
| if (attr_proto.name() == kCNodeShapeAttr) { | |||
| abstract = GetAbstractForCNode(attr_proto); | |||
| continue; | |||
| } | |||
| if (attr_proto.name() == kCNodeShape1Attr) { | |||
| abstract_first = GetAbstractForCNode(attr_proto); | |||
| continue; | |||
| } | |||
| if (attr_proto.name() == kCNodeShape2Attr) { | |||
| abstract_second = GetAbstractForCNode(attr_proto); | |||
| continue; | |||
| } | |||
| if (!GetAttrValueForCNode(prim, attr_proto)) { | |||
| MS_LOG(ERROR) << "Get CNode attr failed!"; | |||
| return false; | |||
| return nullptr; | |||
| } | |||
| } | |||
| @@ -396,16 +411,64 @@ bool MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGrap | |||
| const std::string &input_name = node_proto.input(i); | |||
| if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) { | |||
| MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed"; | |||
| return false; | |||
| return nullptr; | |||
| } | |||
| inputs.push_back(anfnode_build_map_[input_name]); | |||
| } | |||
| CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(cnode_ptr); | |||
| cnode_ptr->set_abstract(abstract); | |||
| if (ret_flag) { | |||
| if (node_type == "LayerNorm") { | |||
| AbstractBasePtrList elem; | |||
| elem.push_back(abstract); | |||
| elem.push_back(abstract_first); | |||
| elem.push_back(abstract_second); | |||
| cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem)); | |||
| } else if (node_type == "ArgMaxWithValue") { | |||
| AbstractBasePtrList elem; | |||
| elem.push_back(abstract); | |||
| elem.push_back(abstract_first); | |||
| cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem)); | |||
| } else if (nullptr == abstract) { | |||
| AbstractBasePtrList elem; | |||
| for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) { | |||
| elem.push_back(cnode_ptr->input(index)->abstract()); | |||
| } | |||
| cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem)); | |||
| } else { | |||
| cnode_ptr->set_abstract(abstract); | |||
| } | |||
| cnode_ptr->set_fullname_with_scope(fullname_with_scope); | |||
| anfnode_build_map_[node_name] = cnode_ptr; | |||
| return cnode_ptr; | |||
| } | |||
| bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, | |||
| const CNodePtr &cnode_ptr) { | |||
| MS_EXCEPTION_IF_NULL(outputFuncGraph); | |||
| MS_EXCEPTION_IF_NULL(cnode_ptr); | |||
| std::vector<AnfNodePtr> inputs; | |||
| if (importProto.output_size() > 1) { | |||
| inputs.clear(); | |||
| inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); | |||
| AbstractBasePtrList elem; | |||
| for (int out_size = 0; out_size < importProto.output_size(); ++out_size) { | |||
| const onnx::ValueInfoProto &output_node = importProto.output(out_size); | |||
| const std::string &out_tuple = output_node.name(); | |||
| inputs.push_back(anfnode_build_map_[out_tuple]); | |||
| elem.push_back(anfnode_build_map_[out_tuple]->abstract()); | |||
| } | |||
| auto maketuple_ptr = outputFuncGraph->NewCNode(inputs); | |||
| maketuple_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem)); | |||
| inputs.clear(); | |||
| inputs.push_back(NewValueNode(prim::kPrimReturn)); | |||
| inputs.push_back(maketuple_ptr); | |||
| auto return_node = outputFuncGraph->NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(return_node); | |||
| outputFuncGraph->set_return(return_node); | |||
| MS_LOG(INFO) << "Construct funcgraph finined, all success."; | |||
| } else { | |||
| const onnx::ValueInfoProto &output_node = importProto.output(0); | |||
| const ::onnx::TypeProto &output_typeproto = output_node.type(); | |||
| const onnx::TypeProto &output_typeproto = output_node.type(); | |||
| int output_type = output_typeproto.tensor_type().elem_type(); | |||
| std::vector<int> output_shape; | |||
| for (int i = 0; i < output_typeproto.tensor_type().shape().dim_size(); ++i) { | |||
| @@ -417,20 +480,19 @@ bool MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGrap | |||
| inputs.push_back(NewValueNode(prim::kPrimReturn)); | |||
| inputs.push_back(cnode_ptr); | |||
| auto return_node = outputFuncGraph->NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(return_node); | |||
| return_node->set_abstract(tensor_return->ToAbstract()); | |||
| outputFuncGraph->set_return(return_node); | |||
| MS_LOG(INFO) << "Construct funcgraph finined, all success!"; | |||
| } | |||
| anfnode_build_map_[node_name] = cnode_ptr; | |||
| return true; | |||
| } | |||
| bool MSANFModelParser::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) { | |||
| MS_EXCEPTION_IF_NULL(outputFuncGraph); | |||
| bool return_flag = false; | |||
| MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size(); | |||
| CNodePtr cnode_ptr = nullptr; | |||
| for (int i = 0; i < importProto.node_size(); ++i) { | |||
| return_flag = (i == importProto.node_size() - 1) ? true : return_flag; | |||
| const onnx::NodeProto &node_proto = importProto.node(i); | |||
| const std::string &node_type = node_proto.op_type(); | |||
| if (node_type == kConstantValueNode) { | |||
| @@ -440,11 +502,14 @@ bool MSANFModelParser::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, | |||
| } | |||
| continue; | |||
| } | |||
| if (!BuildCNodeForFuncGraph(outputFuncGraph, node_proto, importProto, return_flag)) { | |||
| cnode_ptr = BuildCNodeForFuncGraph(outputFuncGraph, node_proto); | |||
| if (cnode_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: : " << i; | |||
| return false; | |||
| } | |||
| } | |||
| BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr); | |||
| return true; | |||
| } | |||
| @@ -472,12 +537,12 @@ bool MSANFModelParser::MSANFParseModelConfigureInfo(const onnx::ModelProto &mode | |||
| producer_name_ = model_proto.producer_name(); | |||
| MS_LOG(INFO) << "producer_name :" << producer_name_; | |||
| if (!model_proto.has_producer_version()) { | |||
| if (!model_proto.has_model_version()) { | |||
| MS_LOG(ERROR) << "Parse model producer version from pb file failed!"; | |||
| return false; | |||
| } | |||
| producer_version_ = model_proto.producer_version(); | |||
| MS_LOG(INFO) << "producer_version : " << producer_version_; | |||
| model_version_ = model_proto.model_version(); | |||
| MS_LOG(INFO) << "producer_version : " << model_version_; | |||
| if (!model_proto.has_ir_version()) { | |||
| MS_LOG(ERROR) << "Parse model version from pb file failed!"; | |||
| @@ -485,14 +550,6 @@ bool MSANFModelParser::MSANFParseModelConfigureInfo(const onnx::ModelProto &mode | |||
| } | |||
| ir_version_ = model_proto.ir_version(); | |||
| MS_LOG(INFO) << "ir_version :" << ir_version_; | |||
| const onnx::OperatorSetIdProto &opset_proto = model_proto.opset_import(0); | |||
| if (!opset_proto.has_version()) { | |||
| MS_LOG(ERROR) << "Parse opset version from pb file failed!"; | |||
| return false; | |||
| } | |||
| opset_version_ = opset_proto.version(); | |||
| MS_LOG(INFO) << "opset_version : " << opset_version_; | |||
| return true; | |||
| } | |||
| @@ -501,7 +558,6 @@ FuncGraphPtr MSANFModelParser::Parse(const onnx::ModelProto &model_proto) { | |||
| MS_EXCEPTION_IF_NULL(dstGraph); | |||
| if (!MSANFParseModelConfigureInfo(model_proto)) { | |||
| MS_LOG(ERROR) << "Parse configuration info for pb file failed!"; | |||
| return nullptr; | |||
| } | |||
| const onnx::GraphProto &graphBuild = model_proto.graph(); | |||
| if (!BuildFuncGraph(dstGraph, graphBuild)) { | |||
| @@ -29,6 +29,7 @@ namespace lite { | |||
| using int32 = int32_t; | |||
| using int64 = int64_t; | |||
| using uint64 = uint64_t; | |||
| using float16 = Eigen::half; | |||
| class MSANFModelParser { | |||
| public: | |||
| MSANFModelParser() = default; | |||
| @@ -38,17 +39,17 @@ class MSANFModelParser { | |||
| bool MSANFParseModelConfigureInfo(const onnx::ModelProto &model_proto); | |||
| std::string GetProducerName() { return producer_name_; } | |||
| std::string GetProducerVersion() { return producer_version_; } | |||
| int GetProducerVersion() { return model_version_; } | |||
| int GetIrVersion() { return ir_version_; } | |||
| int GetOpsetVersion() { return opset_version_; } | |||
| private: | |||
| bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto); | |||
| bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto); | |||
| bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto); | |||
| bool BuildParameterForFuncGraph(const ParameterPtr &node, const onnx::ValueInfoProto &value_proto); | |||
| bool BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto, | |||
| const onnx::GraphProto &importProto, const bool &ret_flag); | |||
| CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto); | |||
| bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, | |||
| const CNodePtr &cnode_ptr); | |||
| bool GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto); | |||
| bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name, | |||
| const onnx::TensorProto &attr_tensor); | |||
| @@ -63,15 +64,13 @@ class MSANFModelParser { | |||
| bool GetAttrValueForValueNode(const string &ref_attr_name, const std::string &value_node_name, | |||
| const onnx::TensorProto &attr_tensor); | |||
| bool ObtainValueNodeInTypeForm(const string &value_node_name, const onnx::TensorProto &attr_tensor); | |||
| AbstractBasePtr GetAbstractForCNode(const onnx::AttributeProto &attr_proto); | |||
| std::string producer_name_; | |||
| std::string producer_version_; | |||
| int ir_version_{}; | |||
| int opset_version_{}; | |||
| int model_version_; | |||
| int ir_version_; | |||
| std::unordered_map<std::string, AnfNodePtr> anfnode_build_map_; | |||
| std::map<std::string, onnx::TensorProto> default_para_map_; | |||
| AbstractBasePtr GetAbstractForCNode(const onnx::AttributeProto &attr_proto); | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||