From: @wilfchen Reviewed-by: @limingqi107,@cristoval Signed-off-by:pull/14584/MERGE
| @@ -68,7 +68,7 @@ bool TrtKernel::Init(const CNodePtr &kernel_node) { | |||
| return true; | |||
| } | |||
| TrtKernel::ReleaseResource() { | |||
| void TrtKernel::ReleaseResource() { | |||
| // Make sure destroy trt object before TrtLoader destruct. | |||
| context_.reset(); | |||
| engine_.reset(); | |||
| @@ -50,6 +50,33 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { | |||
| builder.SetOutputsFormat(outputs_format); | |||
| return builder.Build(); | |||
| } | |||
| AnfNodePtr RelpaceOutputEdge(const AnfNodePtr &node, CNodePtr adam, AnfNodePtr u_input) { | |||
| // Replace the parameters of the last UpdateState to maintain | |||
| // the execution order of FusedAdam and the following operators. | |||
| // n represents the operator assign_v in {prim::kPrimDepend, next_param, assign_v} | |||
| const auto &n = node->cast<CNodePtr>()->input(2); | |||
| MS_EXCEPTION_IF_NULL(n); | |||
| const auto &fg = n->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| auto mgr = fg->manager(); | |||
| MS_EXCEPTION_IF_NULL(mgr); | |||
| auto &node_users = mgr->node_users(); | |||
| auto iter = node_users.find(n); | |||
| if (iter == node_users.end()) { | |||
| MS_LOG(EXCEPTION) << "Can not find node : " << n->DebugString(); | |||
| } | |||
| auto &users = iter->second; | |||
| for (auto &user : users) { | |||
| if (IsPrimitiveCNode(user.first, prim::kPrimUpdateState)) { | |||
| (user.first)->cast<CNodePtr>()->set_input(1, u_input); | |||
| (user.first)->cast<CNodePtr>()->set_input(2, adam); | |||
| break; | |||
| } | |||
| } | |||
| return adam; | |||
| } | |||
| } // namespace | |||
| const BaseRef AdamFusion::DefinePattern() const { | |||
| @@ -118,51 +145,19 @@ const AnfNodePtr AdamFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr | |||
| // Fused into a FusedAdam operator. | |||
| auto prim = std::make_shared<Primitive>(kFusedAdamName); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(prim), | |||
| beta1_input, | |||
| one_sub_beta1_input, | |||
| beta2_input, | |||
| one_sub_beta2_input, | |||
| eps_input, | |||
| lr_input, | |||
| param, | |||
| m_input, | |||
| v_input, | |||
| gradient_input}; | |||
| auto prim_value = NewValueNode(prim); | |||
| std::vector<AnfNodePtr> inputs = { | |||
| prim_value, beta1_input, one_sub_beta1_input, beta2_input, one_sub_beta2_input, eps_input, lr_input, param, | |||
| m_input, v_input, gradient_input}; | |||
| auto adam = graph->NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(adam); | |||
| auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; | |||
| auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)}; | |||
| AnfAlgo::SetOutputInferTypeAndShape(types, shapes, adam.get()); | |||
| adam->set_scope(node->scope()); | |||
| auto build_info = GenerateKernelBuildInfo(adam); | |||
| AnfAlgo::SetSelectKernelBuildInfo(build_info, adam.get()); | |||
| // Replace the parameters of the last UpdateState to maintain | |||
| // the execution order of FusedAdam and the following operators. | |||
| // n represents the operator assign_v in {prim::kPrimDepend, next_param, assign_v} | |||
| auto n = node->cast<CNodePtr>()->input(2); | |||
| auto fg = n->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| auto mgr = fg->manager(); | |||
| MS_EXCEPTION_IF_NULL(mgr); | |||
| auto &node_users = mgr->node_users(); | |||
| auto iter = node_users.find(n); | |||
| if (iter == node_users.end()) { | |||
| MS_LOG(EXCEPTION) << "Can not find node : " << n->DebugString(); | |||
| } | |||
| auto &users = iter->second; | |||
| for (auto &user : users) { | |||
| if (IsPrimitiveCNode(user.first, prim::kPrimUpdateState)) { | |||
| (user.first)->cast<CNodePtr>()->set_input(1, u_input); | |||
| (user.first)->cast<CNodePtr>()->set_input(2, adam); | |||
| break; | |||
| } | |||
| } | |||
| return adam; | |||
| return RelpaceOutputEdge(node, adam, u_input); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -50,6 +50,34 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { | |||
| builder.SetOutputsFormat(outputs_format); | |||
| return builder.Build(); | |||
| } | |||
| AnfNodePtr ReplaceOutputEdge(const AnfNodePtr &node, CNodePtr adam_weight_decay, AnfNodePtr u_input) { | |||
| // Replace the parameters of the last UpdateState to maintain | |||
| // the execution order of FusedAdamWeightDecay and the following operators. | |||
| // n represents the operator assign_v in {prim::kPrimDepend, next_param, assign_v} | |||
| const auto &n = node->cast<CNodePtr>()->input(2); | |||
| MS_EXCEPTION_IF_NULL(n); | |||
| const auto &fg = n->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| auto mgr = fg->manager(); | |||
| MS_EXCEPTION_IF_NULL(mgr); | |||
| auto &node_users = mgr->node_users(); | |||
| auto iter = node_users.find(n); | |||
| if (iter == node_users.end()) { | |||
| MS_LOG(EXCEPTION) << "Can not find node : " << n->DebugString(); | |||
| } | |||
| auto &users = iter->second; | |||
| for (auto &user : users) { | |||
| if (IsPrimitiveCNode(user.first, prim::kPrimUpdateState)) { | |||
| (user.first)->cast<CNodePtr>()->set_input(1, u_input); | |||
| (user.first)->cast<CNodePtr>()->set_input(2, adam_weight_decay); | |||
| break; | |||
| } | |||
| } | |||
| return adam_weight_decay; | |||
| } | |||
| } // namespace | |||
| const BaseRef AdamWeightDecayFusion::DefinePattern() const { | |||
| @@ -122,18 +150,10 @@ const AnfNodePtr AdamWeightDecayFusion::Process(const FuncGraphPtr &graph, const | |||
| // Fused into a FusedAdamWeightDecay operator. | |||
| auto prim = std::make_shared<Primitive>(kFusedAdamWeightDecayName); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(prim), | |||
| beta1_input, | |||
| one_sub_beta1_input, | |||
| beta2_input, | |||
| one_sub_beta2_input, | |||
| eps_input, | |||
| lr_input, | |||
| param, | |||
| m_input, | |||
| v_input, | |||
| gradient_input, | |||
| weight_decay_input}; | |||
| auto prim_value = NewValueNode(prim); | |||
| std::vector<AnfNodePtr> inputs = { | |||
| prim_value, beta1_input, one_sub_beta1_input, beta2_input, one_sub_beta2_input, eps_input, lr_input, param, | |||
| m_input, v_input, gradient_input, weight_decay_input}; | |||
| auto adam_weight_decay = graph->NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(adam_weight_decay); | |||
| auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; | |||
| @@ -143,31 +163,7 @@ const AnfNodePtr AdamWeightDecayFusion::Process(const FuncGraphPtr &graph, const | |||
| auto build_info = GenerateKernelBuildInfo(adam_weight_decay); | |||
| AnfAlgo::SetSelectKernelBuildInfo(build_info, adam_weight_decay.get()); | |||
| // Replace the parameters of the last UpdateState to maintain | |||
| // the execution order of FusedAdamWeightDecay and the following operators. | |||
| // n represents the operator assign_v in {prim::kPrimDepend, next_param, assign_v} | |||
| auto n = node->cast<CNodePtr>()->input(2); | |||
| auto fg = n->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| auto mgr = fg->manager(); | |||
| MS_EXCEPTION_IF_NULL(mgr); | |||
| auto &node_users = mgr->node_users(); | |||
| auto iter = node_users.find(n); | |||
| if (iter == node_users.end()) { | |||
| MS_LOG(EXCEPTION) << "Can not find node : " << n->DebugString(); | |||
| } | |||
| auto &users = iter->second; | |||
| for (auto &user : users) { | |||
| if (IsPrimitiveCNode(user.first, prim::kPrimUpdateState)) { | |||
| (user.first)->cast<CNodePtr>()->set_input(1, u_input); | |||
| (user.first)->cast<CNodePtr>()->set_input(2, adam_weight_decay); | |||
| break; | |||
| } | |||
| } | |||
| return adam_weight_decay; | |||
| return ReplaceOutputEdge(node, adam_weight_decay, u_input); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,339 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/trt_pass/trt_converter_context.h" | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include <sstream> | |||
| #include <algorithm> | |||
| #include "runtime/device/gpu/trt_loader.h" | |||
| #include "backend/optimizer/trt_pass/trt_op_factory.h" | |||
| #include "backend/kernel_compiler/gpu/trt/trt_utils.h" | |||
| #include "utils/convert_utils.h" | |||
| #include "utils/utils.h" | |||
| #include "utils/singleton.h" | |||
| namespace mindspore::opt { | |||
| namespace { | |||
| void GetRealOutputRecursively(const AnfNodePtr &node, size_t output_index, | |||
| std::vector<session::KernelWithIndex> *inputs) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (node->isa<ValueNode>() || node->isa<Parameter>()) { | |||
| return inputs->push_back(std::make_pair(node, 0)); | |||
| } | |||
| // Skip control node | |||
| if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimLoad) || | |||
| AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState)) { | |||
| return GetRealOutputRecursive(node->cast<CNodePtr>()->input(kRealInputIndexInDepend), 0, inputs); | |||
| } | |||
| // Bypass TupleGetItem | |||
| if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { | |||
| auto tuple_get_item = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(tuple_get_item); | |||
| auto input = AnfAlgo::GetTupleGetItemRealInput(tuple_get_item); | |||
| auto index = AnfAlgo::GetTupleGetItemOutIndex(tuple_get_item); | |||
| // Conceal MakeTuple + TupleGetItem pair. | |||
| if (AnfAlgo::CheckPrimitiveType(input, prim::kPrimMakeTuple)) { | |||
| auto make_tuple = input->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(make_tuple); | |||
| auto real_input = AnfAlgo::GetInputNode(make_tuple, index); | |||
| return GetRealOutputRecursive(real_input, 0, inputs); | |||
| } | |||
| // Skip TupleGetItem. | |||
| return GetRealOutputRecursive(input, index, inputs); | |||
| } | |||
| // Flatten MakeTuple inputs. | |||
| if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) { | |||
| auto make_tuple = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(make_tuple); | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(make_tuple); | |||
| for (size_t input_index = 0; input_index < input_num; ++input_index) { | |||
| auto input_node = AnfAlgo::GetInputNode(make_tuple, input_index); | |||
| GetRealOutputRecursive(input_node, 0, inputs); | |||
| } | |||
| return; | |||
| } | |||
| return inputs->push_back(std::make_pair(node, output_index)); | |||
| } | |||
| /* Get node real inputs bypass control nodes. | |||
| * Examples: | |||
| * Case 1: | |||
| * c = Conv2D(a, b) | |||
| * d = ReLU(c) | |||
| * result: d--> (c) | |||
| * | |||
| * Case 2: | |||
| * c = Conv2D(a, b) | |||
| * d = Depend(c, v) | |||
| * e = ReLU(d) | |||
| * result: d -> (c) | |||
| * | |||
| * Case 3: | |||
| * (f, g, h, i, j) = BatchNorm(a, b, c, d, e) | |||
| * k = TupleGetItem((f, g, h, i, j), 0) | |||
| * l = ReLU(k) | |||
| * result: l -> (f) | |||
| * | |||
| * Case 4: | |||
| * c = Conv2D(a, b) | |||
| * e = MakeTuple(c, d) | |||
| * f = TupleGetItem(e, 0) | |||
| * g = ReLU(k) | |||
| * result: g -> (c) | |||
| * | |||
| * Case 5: | |||
| * b = MakeTuple(a1, a2, a3) | |||
| * c = MakeTuple(b, a4) | |||
| * d = return(c) | |||
| * result d -> (a1, a2, a3, a4) | |||
| */ | |||
| void GetRealInputs(const AnfNodePtr &node, std::vector<session::KernelWithIndex> *inputs) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(node); | |||
| for (size_t input_index = 0; input_index < input_num; ++input_index) { | |||
| auto input_node = AnfAlgo::GetInputNode(node->cast<CNodePtr>(), input_index); | |||
| GetRealOutputRecursively(input_node, 0, inputs); | |||
| } | |||
| } | |||
| } // namespace | |||
| bool TrtConverterContext::Init() { | |||
| auto trt_loader = Singleton<device::gpu::TrtLoader>::Instance(); | |||
| builder_ = trt_loader.CreateInferBuilder(&Singleton<TrtLogger>::Instance()); | |||
| MS_EXCEPTION_IF_NULL(builder_); | |||
| auto batch_type = 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); | |||
| network_ = TrtPtr(builder_->createNetworkV2(batch_type)); | |||
| MS_EXCEPTION_IF_NULL(network_); | |||
| config_ = TrtPtr(builder_->createBuilderConfig()); | |||
| MS_EXCEPTION_IF_NULL(config_); | |||
| return true; | |||
| } | |||
| bool TrtConverterContext::Parser() { | |||
| InitInputTable(); | |||
| InitValueNodeTable(); | |||
| std::vector<AnfNodePtr> node_list = TopoSort(func_graph_->get_return()); | |||
| const auto &converter_factory = TrtOpFactory::GetInstance(); | |||
| for (auto node : node_list) { | |||
| if (!node->isa<CNode>()) { | |||
| continue; | |||
| } | |||
| // Mark graph outputs | |||
| std::string op_name = AnfAlgo::GetCNodePrimitive(node)->name(); | |||
| if (op_name == kReturnOpName) { | |||
| std::vector<LayerInput> inputs; | |||
| (void)LoadLayerInput(node, &inputs); | |||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||
| const auto &input = inputs[i].tensor(); | |||
| std::string name = "return_output_" + std::to_string(i); | |||
| input->setName(name.c_str()); | |||
| network_->markOutput(*input); | |||
| } | |||
| return true; | |||
| } | |||
| // Transform AnfNode To Trt layer. | |||
| // Bypass control node including Depend, Load, UpdateState, TupleGetItem, MakeTuple. | |||
| if (!AnfAlgo::IsRealKernel(node)) { | |||
| continue; | |||
| } | |||
| ConvertFunc convert_func = converter_factory.GetConvertFunc(op_name); | |||
| auto result = convert_func(node, this->shared_from_this()); | |||
| if (!result.first) { | |||
| MS_LOG(ERROR) << op_name << " converter failed."; | |||
| return false; | |||
| } | |||
| auto ret = StoreLayerOutput(node, result.second); | |||
| if (!ret) { | |||
| MS_LOG(ERROR) << op_name << " converter failed."; | |||
| return false; | |||
| } | |||
| } | |||
| MS_LOG(ERROR) << "Graph ended without return node."; | |||
| return false; | |||
| } | |||
| bool TrtConverterContext::Serialize(std::string *model) { | |||
| MS_EXCEPTION_IF_NULL(model); | |||
| builder_->setMaxBatchSize(batch_size_); | |||
| config_->setMaxWorkspaceSize(workspace_size_); | |||
| engine_ = TrtPtr(builder_->buildEngineWithConfig(*network_, *config_)); | |||
| MS_EXCEPTION_IF_NULL(engine_); | |||
| std::shared_ptr<nvinfer1::IHostMemory> model_data = TrtPtr(engine_->serialize()); | |||
| *model = string(static_cast<const char *>(model_data->data()), model_data->size()); | |||
| return true; | |||
| } | |||
| bool TrtConverterContext::InitInputTable() { | |||
| const std::vector<AnfNodePtr> graph_inputs = func_graph_->parameters(); | |||
| for (auto input_node : graph_inputs) { | |||
| if (!input_node->isa<Parameter>()) { | |||
| continue; | |||
| } | |||
| auto input = input_node->cast<ParameterPtr>(); | |||
| if (AnfAlgo::IsParameterWeight(input)) { | |||
| const auto ¶m_value = input->default_param(); | |||
| MS_EXCEPTION_IF_NULL(param_value); | |||
| auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param_value); | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| nvinfer1::Weights weight; | |||
| weight.values = tensor->data_c(); | |||
| weight.type = TrtUtils::MsDtypeToTrtDtype(tensor->data_type()); | |||
| weight.count = tensor->DataSize(); | |||
| output_map_[input_node][0] = LayerInput(weight); | |||
| } else { | |||
| nvinfer1::DataType trt_dtype = TrtUtils::MsDtypeToTrtDtype(AnfAlgo::GetOutputInferDataType(input_node, 0)); | |||
| nvinfer1::Dims trt_dims = TrtUtils::MsDimsToTrtDims(AnfAlgo::GetOutputInferShape(input_node, 0), false); | |||
| nvinfer1::ITensor *tensor = network_->addInput(input->name().c_str(), trt_dtype, trt_dims); | |||
| output_map_[input_node][0] = LayerInput(tensor); | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool TrtConverterContext::InitValueNodeTable() { | |||
| auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(func_graph_); | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| for (auto &value_node : kernel_graph->graph_value_nodes()) { | |||
| MS_EXCEPTION_IF_NULL(value_node); | |||
| auto &node_value = value_node->value(); | |||
| MS_EXCEPTION_IF_NULL(node_value); | |||
| if (node_value->isa<tensor::Tensor>() || node_value->isa<ValueTuple>()) { | |||
| std::vector<tensor::TensorPtr> tensors; | |||
| TensorValueToTensor(node_value, &tensors); | |||
| for (size_t i = 0; i < tensors.size(); i++) { | |||
| const auto &tensor = tensors[i]; | |||
| nvinfer1::Weights weight; | |||
| weight.values = tensor->data_c(); | |||
| weight.type = TrtUtils::MsDtypeToTrtDtype(tensor->data_type()); | |||
| weight.count = tensor->DataSize(); | |||
| output_map_[value_node][i] = LayerInput(weight); | |||
| } | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool TrtConverterContext::StoreLayerOutput(const AnfNodePtr &node, const std::vector<LayerInput> &nv_tensors) { | |||
| if (nv_tensors.size() != AnfAlgo::GetOutputTensorNum(node)) { | |||
| MS_LOG(INFO) << node->DebugString() << " output num not match. expect: " << AnfAlgo::GetOutputTensorNum(node) | |||
| << ", while got: " << nv_tensors.size(); | |||
| } | |||
| for (size_t tensor_index = 0; tensor_index < nv_tensors.size(); ++tensor_index) { | |||
| if (nv_tensors[tensor_index].tensor() != nullptr) { | |||
| output_map_[node][tensor_index] = nv_tensors[tensor_index]; | |||
| std::ostringstream oss; | |||
| nvinfer1::Dims dim = nv_tensors[tensor_index].tensor()->getDimensions(); | |||
| oss << node->fullname_with_scope() << ", output: " << tensor_index << ": [ "; | |||
| for (int32_t dim_index = 0; dim_index < dim.nbDims; dim_index++) { | |||
| oss << dim.d[dim_index] << " "; | |||
| } | |||
| oss << "]"; | |||
| MS_LOG(INFO) << oss.str(); | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool TrtConverterContext::LoadLayerInput(const AnfNodePtr &node, std::vector<LayerInput> *inputs) { | |||
| std::vector<session::KernelWithIndex> real_inputs; | |||
| GetRealInputs(node, &real_inputs); | |||
| for (auto item : real_inputs) { | |||
| auto node_iter = output_map_.find(item.first); | |||
| if (node_iter == output_map_.end()) { | |||
| MS_LOG(ERROR) << "node: " << node->DebugString() << " not found."; | |||
| return false; | |||
| } | |||
| auto out_iter = node_iter->second.find(item.second); | |||
| if (out_iter == node_iter->second.end()) { | |||
| MS_LOG(ERROR) << "node: " << node->DebugString() << "output index: " << item.second << " not found."; | |||
| return false; | |||
| } | |||
| inputs->push_back(out_iter->second); | |||
| } | |||
| return true; | |||
| } | |||
| std::vector<AnfNodePtr> TrtConverterContext::GetGraphInputs() { | |||
| // Get Anf-graph inputs without weights. All weights were binded to Trt-graph. | |||
| std::unordered_map<std::string, AnfNodePtr> graph_inputs; | |||
| for (const auto &input_node : func_graph_->parameters()) { | |||
| if (!input_node->isa<Parameter>()) { | |||
| continue; | |||
| } | |||
| auto input = input_node->cast<ParameterPtr>(); | |||
| if (!AnfAlgo::IsParameterWeight(input)) { | |||
| graph_inputs.insert(std::make_pair(input->name(), input_node)); | |||
| } | |||
| } | |||
| // Keep the graph inputs in order of the binding name. | |||
| std::vector<AnfNodePtr> trt_inputs; | |||
| for (int32_t i = 0; i < engine_->getNbBindings(); ++i) { | |||
| if (!engine_->bindingIsInput(i)) { | |||
| continue; | |||
| } | |||
| auto iter = graph_inputs.find(engine_->getBindingName(i)); | |||
| if (iter == graph_inputs.end()) { | |||
| MS_LOG(EXCEPTION) << "Get graph inputs failed. input name" << engine_->getBindingName(i); | |||
| } | |||
| trt_inputs.push_back(iter->second); | |||
| } | |||
| return trt_inputs; | |||
| } | |||
| std::vector<session::KernelWithIndex> TrtConverterContext::GetGraphOutputs() { | |||
| std::vector<session::KernelWithIndex> graph_outputs; | |||
| GetRealInputs(func_graph_->get_return(), &graph_outputs); | |||
| return graph_outputs; | |||
| } | |||
| std::shared_ptr<tensor::Tensor> TrtConverterContext::CreateTempWeight(const TypeId &type, | |||
| const std::vector<size_t> &shape) { | |||
| ShapeVector shape_int; | |||
| std::transform(shape.begin(), shape.end(), std::back_inserter(shape_int), SizeToLong); | |||
| auto tensor = std::make_shared<tensor::Tensor>(type, shape_int); | |||
| temp_weights_.push_back(tensor); | |||
| return tensor; | |||
| } | |||
| } // namespace mindspore::opt | |||
| @@ -0,0 +1,89 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_CONVERTER_CONTEXT_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_CONVERTER_CONTEXT_H_ | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include <NvInfer.h> | |||
| #include "base/base.h" | |||
| #include "ir/anf.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "backend/optimizer/trt_pass/layer_input.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class TrtConverterContext : public std::enable_shared_from_this<TrtConverterContext> { | |||
| public: | |||
| explicit TrtConverterContext(FuncGraphPtr fg) | |||
| : func_graph_(fg), | |||
| batch_size_(1), | |||
| workspace_size_(4UL << 30), | |||
| builder_(nullptr), | |||
| network_(nullptr), | |||
| config_(nullptr), | |||
| engine_(nullptr) {} | |||
| ~TrtConverterContext() = default; | |||
| bool Init(); | |||
| // Parser KernelGraph to trt graph | |||
| bool Parser(); | |||
| // Serialize trt models. | |||
| bool Serialize(std::string *model); | |||
| // Get trt graph inputs without weights. The inputs keep same order as binding name. | |||
| std::vector<AnfNodePtr> GetGraphInputs(); | |||
| // Get trt graph outputs. All outputs are flatten to vector with concret shape. | |||
| std::vector<session::KernelWithIndex> GetGraphOutputs(); | |||
| // Store trt layer outputs to the cache. | |||
| bool StoreLayerOutput(const AnfNodePtr &node, const std::vector<LayerInput> &inputs); | |||
| // Get trt layer inputs from the cache. | |||
| bool LoadLayerInput(const AnfNodePtr &node, std::vector<LayerInput> *inputs); | |||
| // Create and keep temporary weight, as constant folding demanding new weight excluded in graph, | |||
| // which should release until building finish. | |||
| std::shared_ptr<tensor::Tensor> CreateTempWeight(const TypeId &type, const std::vector<size_t> &shape); | |||
| std::shared_ptr<nvinfer1::INetworkDefinition> network() const { return network_; } | |||
| private: | |||
| bool InitInputTable(); | |||
| bool InitValueNodeTable(); | |||
| FuncGraphPtr func_graph_; | |||
| uint32_t batch_size_; | |||
| size_t workspace_size_; | |||
| std::shared_ptr<nvinfer1::IBuilder> builder_; | |||
| std::shared_ptr<nvinfer1::INetworkDefinition> network_; | |||
| std::shared_ptr<nvinfer1::IBuilderConfig> config_; | |||
| std::shared_ptr<nvinfer1::ICudaEngine> engine_; | |||
| // Cache (AnfNode + output_index : ILayer output). | |||
| std::unordered_map<AnfNodePtr, std::unordered_map<size_t, LayerInput>> output_map_; | |||
| std::vector<std::shared_ptr<tensor::Tensor>> temp_weights_; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_CONVERTER_HELPER_H_ | |||
| @@ -29,9 +29,9 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class LayerInput; | |||
| class TrtConverterHelper; | |||
| class TrtConverterContext; | |||
| using ConvertResult = std::pair<bool, std::vector<LayerInput>>; | |||
| using ConvertFunc = std::function<ConvertResult(AnfNodePtr, std::shared_ptr<TrtConverterHelper>)>; | |||
| using ConvertFunc = std::function<ConvertResult(AnfNodePtr, std::shared_ptr<TrtConverterContext>)>; | |||
| class TrtOpFactory { | |||
| public: | |||
| @@ -69,10 +69,10 @@ class TrtOpRegister { | |||
| }; | |||
| // Register operator converter from AnfNode to trt layer: `OPNAME` should keep the same as primitive definition. | |||
| #define MS_TRT_CONVERTER_FUNC_REG(OPNAME) \ | |||
| ConvertResult Gpu##OPNAME##TrtConverter(AnfNodePtr node, std::shared_ptr<TrtConverterHelper> context); \ | |||
| static const TrtOpRegister(Gpu##OPNAME##ConverterRegister)(#OPNAME, Gpu##OPNAME##TrtConverter); \ | |||
| ConvertResult Gpu##OPNAME##TrtConverter(AnfNodePtr node, std::shared_ptr<TrtConverterHelper> context) | |||
| #define MS_TRT_CONVERTER_FUNC_REG(OPNAME) \ | |||
| ConvertResult Gpu##OPNAME##TrtConverter(AnfNodePtr node, std::shared_ptr<TrtConverterContext> context); \ | |||
| static const TrtOpRegister(Gpu##OPNAME##ConverterRegister)(#OPNAME, Gpu##OPNAME##TrtConverter); \ | |||
| ConvertResult Gpu##OPNAME##TrtConverter(AnfNodePtr node, std::shared_ptr<TrtConverterContext> context) | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_PASS_OP_FACTORY_H_ | |||