| @@ -104,6 +104,7 @@ class MS_API Buffer { | |||||
| extern MS_API const char *kDeviceTypeAscend310; | extern MS_API const char *kDeviceTypeAscend310; | ||||
| extern MS_API const char *kDeviceTypeAscend910; | extern MS_API const char *kDeviceTypeAscend910; | ||||
| extern MS_API const char *kDeviceTypeGpu; | |||||
| constexpr auto kModelOptionDumpCfgPath = "mindspore.option.dump_config_file_path"; | constexpr auto kModelOptionDumpCfgPath = "mindspore.option.dump_config_file_path"; | ||||
| constexpr auto kModelOptionInsertOpCfgPath = "mindspore.option.insert_op_config_file_path"; // aipp config file | constexpr auto kModelOptionInsertOpCfgPath = "mindspore.option.insert_op_config_file_path"; // aipp config file | ||||
| @@ -13,7 +13,7 @@ if(CMAKE_SYSTEM_NAME MATCHES "Darwin") | |||||
| endif() | endif() | ||||
| if(ENABLE_GPU) | if(ENABLE_GPU) | ||||
| file(GLOB_RECURSE _GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "gpu_session.cc") | |||||
| file(GLOB_RECURSE _GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "gpu_session.cc" "gpu_inference_session.cc") | |||||
| list(APPEND _SESSION_SRC_LIST ${_GPU_SRC_LIST}) | list(APPEND _SESSION_SRC_LIST ${_GPU_SRC_LIST}) | ||||
| endif() | endif() | ||||
| @@ -0,0 +1,217 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <algorithm> | |||||
| #include "backend/session/gpu_inference_session.h" | |||||
| #include "ir/tensor.h" | |||||
| #include "ir/anf.h" | |||||
| #include "ir/param_info.h" | |||||
| #include "runtime/device/kernel_runtime.h" | |||||
| #include "backend/session/anf_runtime_algorithm.h" | |||||
| #include "utils/ms_utils.h" | |||||
| #include "common/trans.h" | |||||
| #include "utils/config_manager.h" | |||||
| namespace mindspore { | |||||
| namespace session { | |||||
| void GpuInferenceSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, | |||||
| const std::vector<tensor::TensorPtr> &inputs_const) const { | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||||
| std::vector<tensor::TensorPtr> inputs(inputs_const); | |||||
| auto input_nodes = kernel_graph->inputs(); | |||||
| size_t no_weight_input = 0; | |||||
| for (size_t i = 0; i < input_nodes.size(); ++i) { | |||||
| tensor::TensorPtr tensor = nullptr; | |||||
| if (!input_nodes[i]->isa<Parameter>()) { | |||||
| MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter"; | |||||
| continue; | |||||
| } | |||||
| auto pk_node = input_nodes[i]->cast<ParameterPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(pk_node); | |||||
| auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); | |||||
| MS_EXCEPTION_IF_NULL(device_address); | |||||
| if (!AnfAlgo::IsParameterWeight(pk_node)) { | |||||
| tensor = inputs[no_weight_input++]; | |||||
| if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), | |||||
| LongToSize(tensor->data().nbytes()), tensor->data_type(), | |||||
| tensor->data_c())) { | |||||
| MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| GraphId GpuInferenceSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) { | |||||
| auto graph_id = GPUSession::CompileGraphImpl(func_graph); | |||||
| auto kernel_graph = GetGraph(graph_id); | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||||
| // load weight data to device | |||||
| auto input_nodes = kernel_graph->inputs(); | |||||
| for (size_t i = 0; i < input_nodes.size(); ++i) { | |||||
| if (!input_nodes[i]->isa<Parameter>()) { | |||||
| MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter"; | |||||
| continue; | |||||
| } | |||||
| auto pk_node = input_nodes[i]->cast<ParameterPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(pk_node); | |||||
| auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); | |||||
| MS_EXCEPTION_IF_NULL(device_address); | |||||
| if (AnfAlgo::IsParameterWeight(pk_node)) { | |||||
| const auto ¶m_value = pk_node->default_param(); | |||||
| MS_EXCEPTION_IF_NULL(param_value); | |||||
| auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param_value); | |||||
| MS_EXCEPTION_IF_NULL(tensor); | |||||
| if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), | |||||
| LongToSize(tensor->data().nbytes()), tensor->data_type(), | |||||
| tensor->data_c())) { | |||||
| MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; | |||||
| } | |||||
| } | |||||
| } | |||||
| return graph_id; | |||||
| } | |||||
| bool GpuInferenceSession::CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs, | |||||
| std::string *error_msg) const { | |||||
| MS_LOG(INFO) << "Start check client inputs, graph id : " << graph_id; | |||||
| auto kernel_graph = GetGraph(graph_id); | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||||
| auto kernel_graph_inputs = kernel_graph->inputs(); | |||||
| size_t no_weight_input = 0; | |||||
| vector<ParameterPtr> paras; | |||||
| // find parameters of graph inputs | |||||
| for (size_t i = 0; i < kernel_graph_inputs.size(); ++i) { | |||||
| if (!kernel_graph_inputs[i]->isa<Parameter>()) { | |||||
| MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter."; | |||||
| continue; | |||||
| } | |||||
| auto parameter = kernel_graph_inputs[i]->cast<ParameterPtr>(); | |||||
| if (!AnfAlgo::IsParameterWeight(parameter)) { | |||||
| paras.push_back(parameter); | |||||
| } | |||||
| } | |||||
| // check inputs | |||||
| for (size_t i = 0; i < paras.size(); ++i) { | |||||
| // compare input number | |||||
| if (paras.size() != inputs.size()) { | |||||
| MS_LOG(ERROR) << "Input number is inconsistent. The actual input number [" << inputs.size() | |||||
| << "] but the graph input number is [" << paras.size() << "]"; | |||||
| MS_LOG(ERROR) << "InputsInfo --" << InputsInfo(paras, inputs); | |||||
| if (error_msg != nullptr) { | |||||
| std::stringstream str_stream; | |||||
| str_stream << "Input number is inconsistent. The given input number [" << inputs.size() | |||||
| << "] but the graph input number is [" << paras.size() << "]\n"; | |||||
| str_stream << "InputsInfo --" << InputsInfo(paras, inputs); | |||||
| *error_msg = str_stream.str(); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| auto input = inputs[no_weight_input++]; | |||||
| if (!CompareInput(input, paras[i])) { | |||||
| MS_LOG(ERROR) << "Please check the input information."; | |||||
| MS_LOG(ERROR) << "InputsInfo --" << InputsInfo(paras, inputs); | |||||
| if (error_msg != nullptr) { | |||||
| std::stringstream str_stream; | |||||
| str_stream << "Please check the input information.\n"; | |||||
| str_stream << "InputsInfo --" << InputsInfo(paras, inputs); | |||||
| *error_msg = str_stream.str(); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool GpuInferenceSession::CompareInput(const tensor::TensorPtr &input, const ParameterPtr ¶meter) const { | |||||
| MS_EXCEPTION_IF_NULL(input); | |||||
| MS_EXCEPTION_IF_NULL(parameter); | |||||
| // compare dims | |||||
| auto parameter_shape = AnfAlgo::GetOutputDeviceShape(parameter, 0); | |||||
| // compare shape | |||||
| auto input_shape = input->shape(); | |||||
| vector<size_t> trans_input; | |||||
| (void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(trans_input), | |||||
| [](const int64_t dim) { return static_cast<size_t>(dim); }); | |||||
| auto is_scalar_shape = [](const vector<size_t> &shape) { | |||||
| return shape.empty() || (shape.size() == 1 && shape[0] == 1); | |||||
| }; | |||||
| if ((!is_scalar_shape(trans_input) || !is_scalar_shape(parameter_shape)) && (trans_input != parameter_shape)) { | |||||
| MS_LOG(ERROR) << "Input shape is inconsistent. The actual shape is " << PrintInputShape(trans_input) | |||||
| << ", but the parameter shape is " << PrintInputShape(parameter_shape) | |||||
| << ". parameter : " << parameter->DebugString(); | |||||
| return false; | |||||
| } | |||||
| // compare data type | |||||
| auto kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(parameter); | |||||
| if (input->data_type() != kernel_build_info->GetOutputDeviceType(0)) { | |||||
| MS_LOG(ERROR) << "Input data type is inconsistent. The actual data type is " << input->data_type() | |||||
| << ", but the parameter data type is " << kernel_build_info->GetOutputDeviceType(0) | |||||
| << ". parameter : " << parameter->DebugString(); | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| template <typename T> | |||||
| std::string GpuInferenceSession::PrintInputShape(std::vector<T> shape) const { | |||||
| string res = "["; | |||||
| for (auto dim : shape) { | |||||
| res += " " + std::to_string(dim); | |||||
| } | |||||
| return res + " ]"; | |||||
| } | |||||
| std::string GpuInferenceSession::InputsInfo(const std::vector<ParameterPtr> ¶s, | |||||
| const std::vector<tensor::TensorPtr> &inputs) const { | |||||
| const std::map<TypeId, std::string> dtype_name_map{ | |||||
| {TypeId::kNumberTypeBegin, "Unknown"}, {TypeId::kNumberTypeBool, "Bool"}, | |||||
| {TypeId::kNumberTypeFloat64, "Float64"}, {TypeId::kNumberTypeInt8, "Int8"}, | |||||
| {TypeId::kNumberTypeUInt8, "Uint8"}, {TypeId::kNumberTypeInt16, "Int16"}, | |||||
| {TypeId::kNumberTypeUInt16, "Uint16"}, {TypeId::kNumberTypeInt32, "Int32"}, | |||||
| {TypeId::kNumberTypeUInt32, "Uint32"}, {TypeId::kNumberTypeInt64, "Int64"}, | |||||
| {TypeId::kNumberTypeUInt64, "Uint64"}, {TypeId::kNumberTypeFloat16, "Float16"}, | |||||
| {TypeId::kNumberTypeFloat32, "Float32"}, | |||||
| }; | |||||
| auto data_type_to_string = [&dtype_name_map](TypeId type_id) { | |||||
| auto it = dtype_name_map.find(type_id); | |||||
| if (it == dtype_name_map.end()) { | |||||
| return std::string("Unknown"); | |||||
| } | |||||
| return it->second; | |||||
| }; | |||||
| std::string graph = "graph inputs:{ "; | |||||
| for (size_t i = 0; i < paras.size(); ++i) { | |||||
| auto ¶ = paras[i]; | |||||
| graph += std::to_string(i) + ": dims " + std::to_string(AnfAlgo::GetOutputDeviceShape(para, 0).size()) + | |||||
| ", shape " + PrintInputShape(AnfAlgo::GetOutputDeviceShape(para, 0)) + ", data type " + | |||||
| data_type_to_string(AnfAlgo::GetSelectKernelBuildInfo(para)->GetOutputDeviceType(0)) + " }"; | |||||
| } | |||||
| std::string actual = "given inputs:{ "; | |||||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||||
| actual += std::to_string(i) + ": dims " + std::to_string(inputs[i]->shape().size()) + ", shape " + | |||||
| PrintInputShape(inputs[i]->shape()) + ", data type " + data_type_to_string(inputs[i]->data_type()) + " }"; | |||||
| } | |||||
| return graph + " " + actual; | |||||
| } | |||||
| } // namespace session | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,53 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_SESSION_GPU_INFERENCE_SESSION_H | |||||
| #define MINDSPORE_CCSRC_BACKEND_SESSION_GPU_INFERENCE_SESSION_H | |||||
| #include <unordered_map> | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include <utility> | |||||
| #include <stack> | |||||
| #include <map> | |||||
| #include <tuple> | |||||
| #include <set> | |||||
| #include "backend/session/gpu_session.h" | |||||
| #include "backend/session/kernel_graph.h" | |||||
| #include "backend/kernel_compiler/kernel.h" | |||||
| #include "backend/session/session_factory.h" | |||||
| namespace mindspore { | |||||
| namespace session { | |||||
| class GpuInferenceSession : public gpu::GPUSession { | |||||
| public: | |||||
| GpuInferenceSession() = default; | |||||
| ~GpuInferenceSession() = default; | |||||
| void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, | |||||
| const std::vector<tensor::TensorPtr> &inputs_const) const; | |||||
| bool CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs, | |||||
| std::string *error_msg) const override; | |||||
| bool CompareInput(const tensor::TensorPtr &input, const ParameterPtr ¶meter) const; | |||||
| template <typename T> | |||||
| std::string PrintInputShape(std::vector<T> shape) const; | |||||
| std::string InputsInfo(const std::vector<ParameterPtr> ¶s, const std::vector<tensor::TensorPtr> &inputs) const; | |||||
| protected: | |||||
| GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) override; | |||||
| }; | |||||
| MS_REG_SESSION(kGpuInferenceDevice, GpuInferenceSession); | |||||
| } // namespace session | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_SESSION_GPU_INFERENCE_SESSION_H | |||||
| @@ -15,9 +15,11 @@ | |||||
| */ | */ | ||||
| #include "backend/session/gpu_session.h" | #include "backend/session/gpu_session.h" | ||||
| #include <string> | |||||
| #include "backend/optimizer/common/helper.h" | #include "backend/optimizer/common/helper.h" | ||||
| #include "backend/optimizer/common/optimizer.h" | #include "backend/optimizer/common/optimizer.h" | ||||
| #include "backend/optimizer/common/pass_manager.h" | #include "backend/optimizer/common/pass_manager.h" | ||||
| #include "backend/optimizer/common/common_backend_optimization.h" | |||||
| #include "backend/optimizer/gpu/adam_weight_decay_fusion.h" | #include "backend/optimizer/gpu/adam_weight_decay_fusion.h" | ||||
| #include "backend/optimizer/gpu/adam_fusion.h" | #include "backend/optimizer/gpu/adam_fusion.h" | ||||
| #include "backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h" | #include "backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h" | ||||
| @@ -298,16 +300,31 @@ void GPUSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const | |||||
| GraphId GPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { | GraphId GPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { | ||||
| // Construct graph, if successfully, graph_sum_ + 1 | // Construct graph, if successfully, graph_sum_ + 1 | ||||
| auto graph_id = graph_sum_; | |||||
| auto graph = ConstructKernelGraph(lst, outputs); | auto graph = ConstructKernelGraph(lst, outputs); | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| return CompileGraphImpl(graph); | |||||
| } | |||||
| GraphId GPUSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) { | |||||
| std::vector<KernelGraphPtr> all_graphs; | |||||
| auto root_graph = ConstructKernelGraph(func_graph, &all_graphs); | |||||
| MS_EXCEPTION_IF_NULL(root_graph); | |||||
| if (all_graphs.size() != 1) { | |||||
| MS_LOG(EXCEPTION) << "Gpu backend does not support multi-graph schedule. graph num" << all_graphs.size(); | |||||
| } | |||||
| opt::BackendCommonOptimization(root_graph); | |||||
| return CompileGraphImpl(root_graph); | |||||
| } | |||||
| GraphId GPUSession::CompileGraphImpl(KernelGraphPtr graph) { | |||||
| // Prepare ms context info for dump .pb graph | // Prepare ms context info for dump .pb graph | ||||
| auto context_ptr = MsContext::GetInstance(); | auto context_ptr = MsContext::GetInstance(); | ||||
| MS_EXCEPTION_IF_NULL(context_ptr); | MS_EXCEPTION_IF_NULL(context_ptr); | ||||
| bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG); | bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG); | ||||
| // Dump .pb graph before graph optimization | // Dump .pb graph before graph optimization | ||||
| if (save_graphs) { | if (save_graphs) { | ||||
| DumpIRProto(graph, "before_opt_" + std::to_string(graph_id)); | |||||
| DumpIRProto(graph, "before_opt_" + std::to_string(graph->graph_id())); | |||||
| } | } | ||||
| // Graph optimization irrelevant to device data format | // Graph optimization irrelevant to device data format | ||||
| Optimize(graph); | Optimize(graph); | ||||
| @@ -326,7 +343,7 @@ GraphId GPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr | |||||
| AssignStream(graph); | AssignStream(graph); | ||||
| // Dump .pb graph before remove nop nodes | // Dump .pb graph before remove nop nodes | ||||
| if (save_graphs) { | if (save_graphs) { | ||||
| DumpIRProto(graph, "before_removeNop_" + std::to_string(graph_id)); | |||||
| DumpIRProto(graph, "before_removeNop_" + std::to_string(graph->graph_id())); | |||||
| } | } | ||||
| // Update Graph Dynamic Shape Attr. | // Update Graph Dynamic Shape Attr. | ||||
| UpdateGraphDynamicShapeAttr(NOT_NULL(graph)); | UpdateGraphDynamicShapeAttr(NOT_NULL(graph)); | ||||
| @@ -343,7 +360,7 @@ GraphId GPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr | |||||
| SetSummaryNodes(graph.get()); | SetSummaryNodes(graph.get()); | ||||
| // Dump .pb graph after graph optimization | // Dump .pb graph after graph optimization | ||||
| if (save_graphs) { | if (save_graphs) { | ||||
| DumpIRProto(graph, "after_opt_" + std::to_string(graph_id)); | |||||
| DumpIRProto(graph, "after_opt_" + std::to_string(graph->graph_id())); | |||||
| } | } | ||||
| // Set graph manager. | // Set graph manager. | ||||
| MS_EXCEPTION_IF_NULL(context_); | MS_EXCEPTION_IF_NULL(context_); | ||||
| @@ -361,9 +378,8 @@ GraphId GPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr | |||||
| debugger_->LoadGraphs(graph); | debugger_->LoadGraphs(graph); | ||||
| } | } | ||||
| #endif | #endif | ||||
| MS_LOG(INFO) << "CompileGraph graph_id: " << graph_id; | |||||
| return graph_id; | |||||
| MS_LOG(INFO) << "CompileGraph graph_id: " << graph->graph_id(); | |||||
| return graph->graph_id(); | |||||
| } | } | ||||
| void GPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | void GPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | ||||
| @@ -37,6 +37,7 @@ class GPUSession : public SessionBasic { | |||||
| protected: | protected: | ||||
| void UnifyMindIR(const KernelGraphPtr &graph) override { return; } | void UnifyMindIR(const KernelGraphPtr &graph) override { return; } | ||||
| GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; | GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; | ||||
| GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) override; | |||||
| void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override; | void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override; | ||||
| void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | ||||
| const std::vector<tensor::TensorPtr> &input_tensors, | const std::vector<tensor::TensorPtr> &input_tensors, | ||||
| @@ -81,6 +82,8 @@ class GPUSession : public SessionBasic { | |||||
| void SyncValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph) const; | void SyncValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph) const; | ||||
| void CleanValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph) const; | void CleanValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph) const; | ||||
| GraphId CompileGraphImpl(KernelGraphPtr kernel_graph); | |||||
| }; | }; | ||||
| using GPUSessionPtr = std::shared_ptr<GPUSession>; | using GPUSessionPtr = std::shared_ptr<GPUSession>; | ||||
| MS_REG_SESSION(kGPUDevice, GPUSession); | MS_REG_SESSION(kGPUDevice, GPUSession); | ||||
| @@ -15,10 +15,15 @@ if(ENABLE_ACL) | |||||
| "model/model_converter_utils/*.cc" | "model/model_converter_utils/*.cc" | ||||
| "graph/acl/*.cc" | "graph/acl/*.cc" | ||||
| ) | ) | ||||
| endif() | endif() | ||||
| if(ENABLE_D) | if(ENABLE_D) | ||||
| file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR} "python_utils.cc" "model/ms/*.cc" "graph/ms/*.cc") | |||||
| file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR} | |||||
| "python_utils.cc" "model/ms/*.cc" "graph/ascend/*.cc") | |||||
| endif() | |||||
| if(ENABLE_GPU) | |||||
| file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR} "python_utils.cc" "model/ms/*.cc" "graph/gpu/*.cc") | |||||
| endif() | endif() | ||||
| set(MSLIB_SRC ${CMAKE_CURRENT_SOURCE_DIR}/types.cc | set(MSLIB_SRC ${CMAKE_CURRENT_SOURCE_DIR}/types.cc | ||||
| @@ -98,6 +103,15 @@ if(ENABLE_D) | |||||
| target_link_libraries(mindspore_shared_lib PRIVATE ${adump_server}) | target_link_libraries(mindspore_shared_lib PRIVATE ${adump_server}) | ||||
| endif() | endif() | ||||
| if(ENABLE_GPU) | |||||
| target_link_libraries(mindspore_shared_lib PRIVATE gpu_cuda_lib gpu_queue cublas | |||||
| ${CUDA_PATH}/lib64/libcurand.so | |||||
| ${CUDNN_LIBRARY_PATH} | |||||
| ${CUDA_PATH}/lib64/libcudart.so | |||||
| ${CUDA_PATH}/lib64/stubs/libcuda.so | |||||
| ${CUDA_PATH}/lib64/libcusolver.so) | |||||
| endif() | |||||
| if(CMAKE_SYSTEM_NAME MATCHES "Linux") | if(CMAKE_SYSTEM_NAME MATCHES "Linux") | ||||
| set(MINDSPORE_RPATH $ORIGIN) | set(MINDSPORE_RPATH $ORIGIN) | ||||
| if(ENABLE_D) | if(ENABLE_D) | ||||
| @@ -110,7 +124,8 @@ if(CMAKE_SYSTEM_NAME MATCHES "Linux") | |||||
| set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/add-ons) | set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/add-ons) | ||||
| set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/opp/op_impl/built-in/ai_core/tbe/op_tiling) | set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/opp/op_impl/built-in/ai_core/tbe/op_tiling) | ||||
| set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/nnae/latest/opp/op_impl/built-in/ai_core/tbe/op_tiling) | set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/nnae/latest/opp/op_impl/built-in/ai_core/tbe/op_tiling) | ||||
| set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe/op_tiling) | |||||
| set(MINDSPORE_RPATH | |||||
| ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe/op_tiling) | |||||
| elseif(ENABLE_ACL) | elseif(ENABLE_ACL) | ||||
| set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/nnae/latest/atc/lib64) | set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/nnae/latest/atc/lib64) | ||||
| set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/atc/lib64) | set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/atc/lib64) | ||||
| @@ -121,7 +136,8 @@ if(CMAKE_SYSTEM_NAME MATCHES "Linux") | |||||
| set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/add-ons) | set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/add-ons) | ||||
| set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/opp/op_impl/built-in/ai_core/tbe/op_tiling) | set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/opp/op_impl/built-in/ai_core/tbe/op_tiling) | ||||
| set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/nnae/latest/opp/op_impl/built-in/ai_core/tbe/op_tiling) | set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/nnae/latest/opp/op_impl/built-in/ai_core/tbe/op_tiling) | ||||
| set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe/op_tiling) | |||||
| set(MINDSPORE_RPATH | |||||
| ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe/op_tiling) | |||||
| endif() | endif() | ||||
| set_target_properties(mindspore_shared_lib PROPERTIES INSTALL_RPATH ${MINDSPORE_RPATH}) | set_target_properties(mindspore_shared_lib PROPERTIES INSTALL_RPATH ${MINDSPORE_RPATH}) | ||||
| @@ -13,7 +13,7 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "cxx_api/graph/ms/ms_graph_impl.h" | |||||
| #include "cxx_api/graph/ascend/ascend_graph_impl.h" | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include "include/api/context.h" | #include "include/api/context.h" | ||||
| #include "cxx_api/factory.h" | #include "cxx_api/factory.h" | ||||
| @@ -26,43 +26,9 @@ | |||||
| #include "runtime/device/kernel_runtime_manager.h" | #include "runtime/device/kernel_runtime_manager.h" | ||||
| namespace mindspore::api { | namespace mindspore::api { | ||||
| API_FACTORY_REG(GraphCell::GraphImpl, Ascend910, MsGraphImpl); | |||||
| static DataType TransTypeId2InferDataType(TypeId type_id) { | |||||
| const std::map<TypeId, api::DataType> id2type_map{ | |||||
| {TypeId::kNumberTypeBegin, api::kMsUnknown}, {TypeId::kNumberTypeBool, api::kMsBool}, | |||||
| {TypeId::kNumberTypeFloat64, api::kMsFloat64}, {TypeId::kNumberTypeInt8, api::kMsInt8}, | |||||
| {TypeId::kNumberTypeUInt8, api::kMsUint8}, {TypeId::kNumberTypeInt16, api::kMsInt16}, | |||||
| {TypeId::kNumberTypeUInt16, api::kMsUint16}, {TypeId::kNumberTypeInt32, api::kMsInt32}, | |||||
| {TypeId::kNumberTypeUInt32, api::kMsUint32}, {TypeId::kNumberTypeInt64, api::kMsInt64}, | |||||
| {TypeId::kNumberTypeUInt64, api::kMsUint64}, {TypeId::kNumberTypeFloat16, api::kMsFloat16}, | |||||
| {TypeId::kNumberTypeFloat32, api::kMsFloat32}, | |||||
| }; | |||||
| // cppcheck-suppress stlIfFind | |||||
| if (auto it = id2type_map.find(type_id); it != id2type_map.end()) { | |||||
| return it->second; | |||||
| } | |||||
| MS_LOG(WARNING) << "Unsupported data id " << type_id; | |||||
| return api::kMsUnknown; | |||||
| } | |||||
| template <class T> | |||||
| inline static void ClearIfNotNull(T *vec) { | |||||
| if (vec != nullptr) { | |||||
| vec->clear(); | |||||
| } | |||||
| } | |||||
| template <class T, class U = std::vector<T>> | |||||
| inline static void PushbackIfNotNull(U *vec, T &&item) { | |||||
| if (vec != nullptr) { | |||||
| vec->emplace_back(item); | |||||
| } | |||||
| } | |||||
| API_FACTORY_REG(GraphCell::GraphImpl, Ascend910, AscendGraphImpl); | |||||
| MsGraphImpl::MsGraphImpl() | |||||
| AscendGraphImpl::AscendGraphImpl() | |||||
| : session_impl_(nullptr), | : session_impl_(nullptr), | ||||
| graph_id_(0), | graph_id_(0), | ||||
| device_type_("Ascend"), | device_type_("Ascend"), | ||||
| @@ -75,9 +41,9 @@ MsGraphImpl::MsGraphImpl() | |||||
| init_flag_(false), | init_flag_(false), | ||||
| load_flag_(false) {} | load_flag_(false) {} | ||||
| MsGraphImpl::~MsGraphImpl() { (void)FinalizeEnv(); } | |||||
| AscendGraphImpl::~AscendGraphImpl() { (void)FinalizeEnv(); } | |||||
| Status MsGraphImpl::InitEnv() { | |||||
| Status AscendGraphImpl::InitEnv() { | |||||
| if (init_flag_) { | if (init_flag_) { | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -108,7 +74,7 @@ Status MsGraphImpl::InitEnv() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status MsGraphImpl::FinalizeEnv() { | |||||
| Status AscendGraphImpl::FinalizeEnv() { | |||||
| if (!init_flag_) { | if (!init_flag_) { | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -136,7 +102,7 @@ Status MsGraphImpl::FinalizeEnv() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status MsGraphImpl::CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr) { | |||||
| Status AscendGraphImpl::CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr) { | |||||
| MS_ASSERT(session_impl_ != nullptr); | MS_ASSERT(session_impl_ != nullptr); | ||||
| try { | try { | ||||
| graph_id_ = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr)); | graph_id_ = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr)); | ||||
| @@ -147,7 +113,7 @@ Status MsGraphImpl::CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr) | |||||
| } | } | ||||
| } | } | ||||
| std::vector<tensor::TensorPtr> MsGraphImpl::RunGraph(const std::vector<tensor::TensorPtr> &inputs) { | |||||
| std::vector<tensor::TensorPtr> AscendGraphImpl::RunGraph(const std::vector<tensor::TensorPtr> &inputs) { | |||||
| try { | try { | ||||
| VectorRef outputs; | VectorRef outputs; | ||||
| session_impl_->RunGraph(graph_id_, inputs, &outputs); | session_impl_->RunGraph(graph_id_, inputs, &outputs); | ||||
| @@ -158,7 +124,7 @@ std::vector<tensor::TensorPtr> MsGraphImpl::RunGraph(const std::vector<tensor::T | |||||
| } | } | ||||
| } | } | ||||
| Status MsGraphImpl::CheckModelInputs(const std::vector<tensor::TensorPtr> &inputs) const { | |||||
| Status AscendGraphImpl::CheckModelInputs(const std::vector<tensor::TensorPtr> &inputs) const { | |||||
| MS_ASSERT(session_impl_ != nullptr); | MS_ASSERT(session_impl_ != nullptr); | ||||
| std::string error_msg; | std::string error_msg; | ||||
| if (!session_impl_->CheckModelInputs(graph_id_, inputs, &error_msg)) { | if (!session_impl_->CheckModelInputs(graph_id_, inputs, &error_msg)) { | ||||
| @@ -167,7 +133,7 @@ Status MsGraphImpl::CheckModelInputs(const std::vector<tensor::TensorPtr> &input | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status MsGraphImpl::ExecuteModel(const std::vector<Buffer> &request, std::vector<Buffer> *reply) { | |||||
| Status AscendGraphImpl::ExecuteModel(const std::vector<Buffer> &request, std::vector<Buffer> *reply) { | |||||
| MS_EXCEPTION_IF_NULL(reply); | MS_EXCEPTION_IF_NULL(reply); | ||||
| if (context_ == nullptr) { | if (context_ == nullptr) { | ||||
| MS_LOG(ERROR) << "rtCtx is nullptr"; | MS_LOG(ERROR) << "rtCtx is nullptr"; | ||||
| @@ -206,8 +172,8 @@ Status MsGraphImpl::ExecuteModel(const std::vector<Buffer> &request, std::vector | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status MsGraphImpl::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) { | |||||
| Status AscendGraphImpl::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) { | |||||
| if (!load_flag_) { | if (!load_flag_) { | ||||
| Status ret = Load(); | Status ret = Load(); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| @@ -216,22 +182,22 @@ Status MsGraphImpl::GetInputsInfo(std::vector<std::string> *names, std::vector<s | |||||
| } | } | ||||
| } | } | ||||
| ClearIfNotNull(names); | |||||
| ClearIfNotNull(shapes); | |||||
| ClearIfNotNull(data_types); | |||||
| ClearIfNotNull(mem_sizes); | |||||
| GraphUtils::ClearIfNotNull(names); | |||||
| GraphUtils::ClearIfNotNull(shapes); | |||||
| GraphUtils::ClearIfNotNull(data_types); | |||||
| GraphUtils::ClearIfNotNull(mem_sizes); | |||||
| for (size_t i = 0; i < inputs_.size(); i++) { | for (size_t i = 0; i < inputs_.size(); i++) { | ||||
| auto &tensor = inputs_[i]; | auto &tensor = inputs_[i]; | ||||
| PushbackIfNotNull(names, input_names_[i]); | |||||
| PushbackIfNotNull(shapes, tensor->shape()); | |||||
| PushbackIfNotNull(data_types, TransTypeId2InferDataType(tensor->data_type())); | |||||
| PushbackIfNotNull(mem_sizes, tensor->Size()); | |||||
| GraphUtils::PushbackIfNotNull(names, input_names_[i]); | |||||
| GraphUtils::PushbackIfNotNull(shapes, tensor->shape()); | |||||
| GraphUtils::PushbackIfNotNull(data_types, GraphUtils::TransTypeId2InferDataType(tensor->data_type())); | |||||
| GraphUtils::PushbackIfNotNull(mem_sizes, tensor->Size()); | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status MsGraphImpl::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) { | |||||
| Status AscendGraphImpl::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) { | |||||
| if (!load_flag_) { | if (!load_flag_) { | ||||
| Status ret = Load(); | Status ret = Load(); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| @@ -240,22 +206,22 @@ Status MsGraphImpl::GetOutputsInfo(std::vector<std::string> *names, std::vector< | |||||
| } | } | ||||
| } | } | ||||
| ClearIfNotNull(names); | |||||
| ClearIfNotNull(shapes); | |||||
| ClearIfNotNull(data_types); | |||||
| ClearIfNotNull(mem_sizes); | |||||
| GraphUtils::ClearIfNotNull(names); | |||||
| GraphUtils::ClearIfNotNull(shapes); | |||||
| GraphUtils::ClearIfNotNull(data_types); | |||||
| GraphUtils::ClearIfNotNull(mem_sizes); | |||||
| for (size_t i = 0; i < outputs_.size(); i++) { | for (size_t i = 0; i < outputs_.size(); i++) { | ||||
| auto &tensor = outputs_[i]; | auto &tensor = outputs_[i]; | ||||
| PushbackIfNotNull(names, output_names_[i]); | |||||
| PushbackIfNotNull(shapes, tensor->shape()); | |||||
| PushbackIfNotNull(data_types, TransTypeId2InferDataType(tensor->data_type())); | |||||
| PushbackIfNotNull(mem_sizes, tensor->Size()); | |||||
| GraphUtils::PushbackIfNotNull(names, output_names_[i]); | |||||
| GraphUtils::PushbackIfNotNull(shapes, tensor->shape()); | |||||
| GraphUtils::PushbackIfNotNull(data_types, GraphUtils::TransTypeId2InferDataType(tensor->data_type())); | |||||
| GraphUtils::PushbackIfNotNull(mem_sizes, tensor->Size()); | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status MsGraphImpl::Load() { | |||||
| Status AscendGraphImpl::Load() { | |||||
| // check graph type | // check graph type | ||||
| if (graph_->ModelType() != ModelType::kMindIR) { | if (graph_->ModelType() != ModelType::kMindIR) { | ||||
| MS_LOG(ERROR) << "Unsupported model type " << graph_->ModelType(); | MS_LOG(ERROR) << "Unsupported model type " << graph_->ModelType(); | ||||
| @@ -311,7 +277,7 @@ Status MsGraphImpl::Load() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status MsGraphImpl::Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) { | |||||
| Status AscendGraphImpl::Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) { | |||||
| MS_EXCEPTION_IF_NULL(outputs); | MS_EXCEPTION_IF_NULL(outputs); | ||||
| if (!load_flag_) { | if (!load_flag_) { | ||||
| Status ret = Load(); | Status ret = Load(); | ||||
| @@ -13,8 +13,8 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_MS_MS_GRAPH_IMPL_H | |||||
| #define MINDSPORE_CCSRC_CXX_API_GRAPH_MS_MS_GRAPH_IMPL_H | |||||
| #ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_MS_ASCEND_GRAPH_IMPL_H | |||||
| #define MINDSPORE_CCSRC_CXX_API_GRAPH_MS_ASCEND_GRAPH_IMPL_H | |||||
| #include <functional> | #include <functional> | ||||
| #include <map> | #include <map> | ||||
| #include <string> | #include <string> | ||||
| @@ -28,12 +28,13 @@ | |||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "cxx_api/model/model_impl.h" | #include "cxx_api/model/model_impl.h" | ||||
| #include "runtime/context.h" | #include "runtime/context.h" | ||||
| #include "cxx_api/graph/graph_utils.h" | |||||
| namespace mindspore::api { | namespace mindspore::api { | ||||
| class MsGraphImpl : public GraphCell::GraphImpl { | |||||
| class AscendGraphImpl : public GraphCell::GraphImpl { | |||||
| public: | public: | ||||
| MsGraphImpl(); | |||||
| ~MsGraphImpl() override; | |||||
| AscendGraphImpl(); | |||||
| ~AscendGraphImpl() override; | |||||
| Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override; | Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override; | ||||
| Status Load() override; | Status Load() override; | ||||
| @@ -63,4 +64,4 @@ class MsGraphImpl : public GraphCell::GraphImpl { | |||||
| bool load_flag_; | bool load_flag_; | ||||
| }; | }; | ||||
| } // namespace mindspore::api | } // namespace mindspore::api | ||||
| #endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_MS_GRAPH_IMPL_H | |||||
| #endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_ASCEND_GRAPH_IMPL_H | |||||
| @@ -20,10 +20,10 @@ VMCallbackRegister &VMCallbackRegister::GetInstance() { | |||||
| return instance; | return instance; | ||||
| } | } | ||||
| bool VMCallbackRegister::Registe(Status (*pRegProfCtrlCallback)(MsprofCtrlCallback), | |||||
| Status (*pRegProfSetDeviceCallback)(MsprofSetDeviceCallback), | |||||
| Status (*pRegProfReporterCallback)(MsprofReporterCallback), | |||||
| Status (*pProfCommandHandle)(ProfCommandHandleType, void *, uint32_t)) { | |||||
| bool VMCallbackRegister::Register(Status (*pRegProfCtrlCallback)(MsprofCtrlCallback), | |||||
| Status (*pRegProfSetDeviceCallback)(MsprofSetDeviceCallback), | |||||
| Status (*pRegProfReporterCallback)(MsprofReporterCallback), | |||||
| Status (*pProfCommandHandle)(ProfCommandHandleType, void *, uint32_t)) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -0,0 +1,256 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "cxx_api/graph/gpu/gpu_graph_impl.h" | |||||
| #include <algorithm> | |||||
| #include "include/api/context.h" | |||||
| #include "cxx_api/factory.h" | |||||
| #include "utils/log_adapter.h" | |||||
| #include "mindspore/core/base/base_ref_utils.h" | |||||
| #include "backend/session/session_factory.h" | |||||
| #include "backend/session/executor_manager.h" | |||||
| #include "runtime/device/kernel_runtime_manager.h" | |||||
| namespace mindspore::api { | |||||
| API_FACTORY_REG(GraphCell::GraphImpl, GPU, GPUGraphImpl); | |||||
| GPUGraphImpl::GPUGraphImpl() | |||||
| : session_impl_(nullptr), | |||||
| graph_id_(0), | |||||
| device_id_(Context::Instance().GetDeviceID()), | |||||
| inputs_(), | |||||
| outputs_(), | |||||
| input_names_(), | |||||
| output_names_(), | |||||
| init_flag_(false), | |||||
| load_flag_(false) {} | |||||
| Status GPUGraphImpl::InitEnv() { | |||||
| if (init_flag_) { | |||||
| MS_LOG(WARNING) << "Initialized again, return success."; | |||||
| return SUCCESS; | |||||
| } | |||||
| auto ms_context = MsContext::GetInstance(); | |||||
| if (ms_context == nullptr) { | |||||
| MS_LOG(ERROR) << "Get Context failed!"; | |||||
| return FAILED; | |||||
| } | |||||
| ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode); | |||||
| ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id_); | |||||
| ms_context->set_param<std::string>(MS_CTX_DEVICE_TARGET, kGPUDevice); | |||||
| ms_context->set_param<bool>(MS_CTX_ENABLE_INFER_OPT, true); | |||||
| session_impl_ = session::SessionFactory::Get().Create(kGpuInferenceDevice); | |||||
| if (session_impl_ == nullptr) { | |||||
| MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << kGpuInferenceDevice | |||||
| << " is available."; | |||||
| return FAILED; | |||||
| } | |||||
| session_impl_->Init(device_id_); | |||||
| init_flag_ = true; | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GPUGraphImpl::FinalizeEnv() { | |||||
| if (!init_flag_) { | |||||
| MS_LOG(WARNING) << "Never initialize before, return success"; | |||||
| return SUCCESS; | |||||
| } | |||||
| MS_LOG_INFO << "Start finalize env"; | |||||
| session::ExecutorManager::Instance().Clear(); | |||||
| device::KernelRuntimeManager::Instance().ClearRuntimeResource(); | |||||
| init_flag_ = false; | |||||
| MS_LOG(INFO) << "End finalize env"; | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GPUGraphImpl::Load() { | |||||
| // check graph type | |||||
| if (graph_->ModelType() != ModelType::kMindIR) { | |||||
| MS_LOG(ERROR) << "Unsupported model type " << graph_->ModelType(); | |||||
| return INVALID_INPUTS; | |||||
| } | |||||
| const auto &graph_data = GraphImpl::MutableGraphData(); | |||||
| MS_EXCEPTION_IF_NULL(graph_data); | |||||
| auto func_graph = graph_data->GetFuncGraph(); | |||||
| // init | |||||
| Status ret = InitEnv(); | |||||
| if (ret != SUCCESS) { | |||||
| MS_LOG(ERROR) << "InitEnv failed."; | |||||
| return FAILED; | |||||
| } | |||||
| ret = CompileGraph(func_graph); | |||||
| if (ret != SUCCESS) { | |||||
| MS_LOG(ERROR) << "Compile graph model failed"; | |||||
| return FAILED; | |||||
| } | |||||
| session_impl_->GetModelInputsInfo(graph_id_, &inputs_, &input_names_); | |||||
| session_impl_->GetModelOutputsInfo(graph_id_, &outputs_, &output_names_); | |||||
| if (inputs_.empty() || inputs_.size() != input_names_.size()) { | |||||
| MS_LOG_ERROR << "Get model inputs info failed"; | |||||
| return FAILED; | |||||
| } | |||||
| if (outputs_.empty() || outputs_.size() != output_names_.size()) { | |||||
| MS_LOG_ERROR << "Get model outputs info failed"; | |||||
| return FAILED; | |||||
| } | |||||
| load_flag_ = true; | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GPUGraphImpl::CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr) { | |||||
| MS_ASSERT(session_impl_ != nullptr); | |||||
| try { | |||||
| graph_id_ = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr)); | |||||
| return SUCCESS; | |||||
| } catch (std::exception &e) { | |||||
| MS_LOG(ERROR) << "CompileGraph failed: " << e.what(); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| std::vector<tensor::TensorPtr> GPUGraphImpl::RunGraph(const std::vector<tensor::TensorPtr> &inputs) { | |||||
| try { | |||||
| VectorRef outputs; | |||||
| session_impl_->RunGraph(graph_id_, inputs, &outputs); | |||||
| return TransformVectorRefToMultiTensor(outputs); | |||||
| } catch (std::exception &e) { | |||||
| MS_LOG(ERROR) << "RunGraph failed: " << e.what(); | |||||
| return std::vector<tensor::TensorPtr>(); | |||||
| } | |||||
| } | |||||
| Status GPUGraphImpl::ExecuteModel(const std::vector<Buffer> &request, std::vector<Buffer> *reply) { | |||||
| MS_EXCEPTION_IF_NULL(reply); | |||||
| vector<tensor::TensorPtr> inputs; | |||||
| for (size_t i = 0; i < request.size(); i++) { | |||||
| auto &item = request[i]; | |||||
| auto input = inputs_[i]; | |||||
| if (input->Size() != item.DataSize()) { | |||||
| MS_LOG(ERROR) << "Input " << i << " data size " << item.DataSize() << " not match model input data size " | |||||
| << input->Size(); | |||||
| return FAILED; | |||||
| } | |||||
| auto ret = memcpy_s(input->data_c(), input->Size(), item.Data(), item.DataSize()); | |||||
| if (ret != SUCCESS) { | |||||
| MS_LOG(ERROR) << "Tensor copy failed"; | |||||
| return FAILED; | |||||
| } | |||||
| inputs.push_back(input); | |||||
| } | |||||
| vector<tensor::TensorPtr> outputs = RunGraph(inputs); | |||||
| if (outputs.empty()) { | |||||
| MS_LOG(ERROR) << "Execute Model Failed"; | |||||
| return FAILED; | |||||
| } | |||||
| reply->clear(); | |||||
| std::transform(outputs.begin(), outputs.end(), std::back_inserter(*reply), | |||||
| [](const tensor::TensorPtr &tensor) { return Buffer(tensor->data_c(), tensor->Size()); }); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GPUGraphImpl::Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) { | |||||
| MS_EXCEPTION_IF_NULL(outputs); | |||||
| if (!load_flag_) { | |||||
| Status ret = Load(); | |||||
| if (ret != SUCCESS) { | |||||
| MS_LOG(ERROR) << "PrepareModel failed."; | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| if (inputs.size() != inputs_.size()) { | |||||
| MS_LOG(ERROR) << "inputs count not match, required count " << inputs_.size() << ", given count " << inputs.size(); | |||||
| return INVALID_INPUTS; | |||||
| } | |||||
| for (size_t i = 0; i < inputs_.size(); ++i) { | |||||
| if (inputs[i].DataSize() != inputs_[i]->Size()) { | |||||
| MS_LOG(ERROR) << "input " << i << " data size not match, required size " << inputs_[i]->Size() << ", given count " | |||||
| << inputs[i].DataSize(); | |||||
| return INVALID_INPUTS; | |||||
| } | |||||
| } | |||||
| if (ExecuteModel(inputs, outputs) != SUCCESS) { | |||||
| MS_LOG(ERROR) << "Execute Model Failed"; | |||||
| return FAILED; | |||||
| } | |||||
| if (outputs_.size() != outputs->size()) { | |||||
| MS_LOG(ERROR) << "Predict output size " << outputs->size() << " not match output size got from model info " | |||||
| << outputs_.size(); | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GPUGraphImpl::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) { | |||||
| if (!load_flag_) { | |||||
| Status ret = Load(); | |||||
| if (ret != SUCCESS) { | |||||
| MS_LOG(ERROR) << "PrepareModel failed."; | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| GraphUtils::ClearIfNotNull(names); | |||||
| GraphUtils::ClearIfNotNull(shapes); | |||||
| GraphUtils::ClearIfNotNull(data_types); | |||||
| GraphUtils::ClearIfNotNull(mem_sizes); | |||||
| for (size_t i = 0; i < inputs_.size(); i++) { | |||||
| auto &tensor = inputs_[i]; | |||||
| GraphUtils::PushbackIfNotNull(names, input_names_[i]); | |||||
| GraphUtils::PushbackIfNotNull(shapes, tensor->shape()); | |||||
| GraphUtils::PushbackIfNotNull(data_types, GraphUtils::TransTypeId2InferDataType(tensor->data_type())); | |||||
| GraphUtils::PushbackIfNotNull(mem_sizes, tensor->Size()); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GPUGraphImpl::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) { | |||||
| if (!load_flag_) { | |||||
| Status ret = Load(); | |||||
| if (ret != SUCCESS) { | |||||
| MS_LOG(ERROR) << "PrepareModel failed."; | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| GraphUtils::ClearIfNotNull(names); | |||||
| GraphUtils::ClearIfNotNull(shapes); | |||||
| GraphUtils::ClearIfNotNull(data_types); | |||||
| GraphUtils::ClearIfNotNull(mem_sizes); | |||||
| for (size_t i = 0; i < outputs_.size(); i++) { | |||||
| auto &tensor = outputs_[i]; | |||||
| GraphUtils::PushbackIfNotNull(names, output_names_[i]); | |||||
| GraphUtils::PushbackIfNotNull(shapes, tensor->shape()); | |||||
| GraphUtils::PushbackIfNotNull(data_types, GraphUtils::TransTypeId2InferDataType(tensor->data_type())); | |||||
| GraphUtils::PushbackIfNotNull(mem_sizes, tensor->Size()); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace mindspore::api | |||||
| @@ -0,0 +1,67 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_MS_GPU_GRAPH_IMPL_H | |||||
| #define MINDSPORE_CCSRC_CXX_API_GRAPH_MS_GPU_GRAPH_IMPL_H | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include <utility> | |||||
| #include <memory> | |||||
| #include "include/api/status.h" | |||||
| #include "include/api/graph.h" | |||||
| #include "cxx_api/graph/graph_impl.h" | |||||
| #include "backend/session/session_basic.h" | |||||
| #include "ir/anf.h" | |||||
| #include "cxx_api/model/model_impl.h" | |||||
| #include "cxx_api/graph/graph_utils.h" | |||||
| namespace mindspore::api { | |||||
| class GPUGraphImpl : public GraphCell::GraphImpl { | |||||
| public: | |||||
| GPUGraphImpl(); | |||||
| ~GPUGraphImpl() override = default; | |||||
| Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override; | |||||
| Status Load() override; | |||||
| Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override; | |||||
| Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override; | |||||
| private: | |||||
| Status InitEnv(); | |||||
| Status FinalizeEnv(); | |||||
| Status CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr); | |||||
| Status CheckModelInputs(const std::vector<tensor::TensorPtr> &inputs) const; | |||||
| std::vector<tensor::TensorPtr> RunGraph(const std::vector<tensor::TensorPtr> &inputs); | |||||
| Status ExecuteModel(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs); | |||||
| std::shared_ptr<session::SessionBasic> session_impl_; | |||||
| uint32_t graph_id_; | |||||
| std::string device_type_; | |||||
| uint32_t device_id_; | |||||
| std::vector<tensor::TensorPtr> inputs_; | |||||
| std::vector<tensor::TensorPtr> outputs_; | |||||
| std::vector<std::string> input_names_; | |||||
| std::vector<std::string> output_names_; | |||||
| bool init_flag_; | |||||
| bool load_flag_; | |||||
| // tensor-rt | |||||
| uint32_t batch_size_; | |||||
| uint32_t workspace_size_; | |||||
| }; | |||||
| } // namespace mindspore::api | |||||
| #endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_GPU_GRAPH_IMPL_H | |||||
| @@ -0,0 +1,63 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_UTILS_H | |||||
| #define MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_UTILS_H | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "include/api/types.h" | |||||
| #include "ir/dtype/type_id.h" | |||||
| #include "utils/log_adapter.h" | |||||
| namespace mindspore::api { | |||||
| class GraphUtils { | |||||
| public: | |||||
| static DataType TransTypeId2InferDataType(TypeId type_id) { | |||||
| const std::map<TypeId, api::DataType> id2type_map{ | |||||
| {TypeId::kNumberTypeBegin, api::kMsUnknown}, {TypeId::kNumberTypeBool, api::kMsBool}, | |||||
| {TypeId::kNumberTypeFloat64, api::kMsFloat64}, {TypeId::kNumberTypeInt8, api::kMsInt8}, | |||||
| {TypeId::kNumberTypeUInt8, api::kMsUint8}, {TypeId::kNumberTypeInt16, api::kMsInt16}, | |||||
| {TypeId::kNumberTypeUInt16, api::kMsUint16}, {TypeId::kNumberTypeInt32, api::kMsInt32}, | |||||
| {TypeId::kNumberTypeUInt32, api::kMsUint32}, {TypeId::kNumberTypeInt64, api::kMsInt64}, | |||||
| {TypeId::kNumberTypeUInt64, api::kMsUint64}, {TypeId::kNumberTypeFloat16, api::kMsFloat16}, | |||||
| {TypeId::kNumberTypeFloat32, api::kMsFloat32}, | |||||
| }; | |||||
| auto it = id2type_map.find(type_id); | |||||
| if (it != id2type_map.end()) { | |||||
| return it->second; | |||||
| } | |||||
| MS_LOG(WARNING) << "Unsupported data id " << type_id; | |||||
| return api::kMsUnknown; | |||||
| } | |||||
| template <class T> | |||||
| inline static void ClearIfNotNull(T *vec) { | |||||
| if (vec != nullptr) { | |||||
| vec->clear(); | |||||
| } | |||||
| } | |||||
| template <class T, class U> | |||||
| inline static void PushbackIfNotNull(U *vec, T &&item) { | |||||
| if (vec != nullptr) { | |||||
| vec->emplace_back(item); | |||||
| } | |||||
| } | |||||
| }; | |||||
| } // namespace mindspore::api | |||||
| #endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_UTILS_H | |||||
| @@ -22,6 +22,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace api { | namespace api { | ||||
| API_FACTORY_REG(ModelImpl, Ascend910, MsModel); | API_FACTORY_REG(ModelImpl, Ascend910, MsModel); | ||||
| API_FACTORY_REG(ModelImpl, GPU, MsModel); | |||||
| Status MsModel::Build(const std::map<std::string, std::string> &) { | Status MsModel::Build(const std::map<std::string, std::string> &) { | ||||
| MS_LOG(INFO) << "Start build model."; | MS_LOG(INFO) << "Start build model."; | ||||
| @@ -21,6 +21,7 @@ | |||||
| namespace mindspore::api { | namespace mindspore::api { | ||||
| const char *kDeviceTypeAscend310 = "Ascend310"; | const char *kDeviceTypeAscend310 = "Ascend310"; | ||||
| const char *kDeviceTypeAscend910 = "Ascend910"; | const char *kDeviceTypeAscend910 = "Ascend910"; | ||||
| const char *kDeviceTypeGpu = "GPU"; | |||||
| class DataImpl { | class DataImpl { | ||||
| public: | public: | ||||
| @@ -31,7 +31,7 @@ constexpr Status PROF_FAILED = 0xFFFFFFFF; | |||||
| } // namespace | } // namespace | ||||
| Status RegProfCtrlCallback(MsprofCtrlCallback func) { | Status RegProfCtrlCallback(MsprofCtrlCallback func) { | ||||
| if (VMCallbackRegister::GetInstance().registed()) { | |||||
| if (VMCallbackRegister::GetInstance().registered()) { | |||||
| return VMCallbackRegister::GetInstance().DoRegProfCtrlCallback(func); | return VMCallbackRegister::GetInstance().DoRegProfCtrlCallback(func); | ||||
| } else { | } else { | ||||
| return PROF_SUCCESS; | return PROF_SUCCESS; | ||||
| @@ -39,7 +39,7 @@ Status RegProfCtrlCallback(MsprofCtrlCallback func) { | |||||
| } | } | ||||
| Status RegProfSetDeviceCallback(MsprofSetDeviceCallback func) { | Status RegProfSetDeviceCallback(MsprofSetDeviceCallback func) { | ||||
| if (VMCallbackRegister::GetInstance().registed()) { | |||||
| if (VMCallbackRegister::GetInstance().registered()) { | |||||
| return VMCallbackRegister::GetInstance().DoRegProfSetDeviceCallback(func); | return VMCallbackRegister::GetInstance().DoRegProfSetDeviceCallback(func); | ||||
| } else { | } else { | ||||
| return PROF_SUCCESS; | return PROF_SUCCESS; | ||||
| @@ -47,7 +47,7 @@ Status RegProfSetDeviceCallback(MsprofSetDeviceCallback func) { | |||||
| } | } | ||||
| Status RegProfReporterCallback(MsprofReporterCallback func) { | Status RegProfReporterCallback(MsprofReporterCallback func) { | ||||
| if (VMCallbackRegister::GetInstance().registed()) { | |||||
| if (VMCallbackRegister::GetInstance().registered()) { | |||||
| return VMCallbackRegister::GetInstance().DoRegProfReporterCallback(func); | return VMCallbackRegister::GetInstance().DoRegProfReporterCallback(func); | ||||
| } else { | } else { | ||||
| return PROF_SUCCESS; | return PROF_SUCCESS; | ||||
| @@ -55,7 +55,7 @@ Status RegProfReporterCallback(MsprofReporterCallback func) { | |||||
| } | } | ||||
| Status ProfCommandHandle(ProfCommandHandleType type, void *data, uint32_t len) { | Status ProfCommandHandle(ProfCommandHandleType type, void *data, uint32_t len) { | ||||
| if (VMCallbackRegister::GetInstance().registed()) { | |||||
| if (VMCallbackRegister::GetInstance().registered()) { | |||||
| return VMCallbackRegister::GetInstance().DoProfCommandHandle(type, data, len); | return VMCallbackRegister::GetInstance().DoProfCommandHandle(type, data, len); | ||||
| } else { | } else { | ||||
| return PROF_SUCCESS; | return PROF_SUCCESS; | ||||
| @@ -69,16 +69,16 @@ VMCallbackRegister &VMCallbackRegister::GetInstance() { | |||||
| return instance; | return instance; | ||||
| } | } | ||||
| bool VMCallbackRegister::Registe(Status (*pRegProfCtrlCallback)(MsprofCtrlCallback), | |||||
| Status (*pRegProfSetDeviceCallback)(MsprofSetDeviceCallback), | |||||
| Status (*pRegProfReporterCallback)(MsprofReporterCallback), | |||||
| Status (*pProfCommandHandle)(ProfCommandHandleType, void *, uint32_t)) { | |||||
| if (!registed_) { | |||||
| bool VMCallbackRegister::Register(Status (*pRegProfCtrlCallback)(MsprofCtrlCallback), | |||||
| Status (*pRegProfSetDeviceCallback)(MsprofSetDeviceCallback), | |||||
| Status (*pRegProfReporterCallback)(MsprofReporterCallback), | |||||
| Status (*pProfCommandHandle)(ProfCommandHandleType, void *, uint32_t)) { | |||||
| if (!registered_) { | |||||
| pRegProfCtrlCallback_ = pRegProfCtrlCallback; | pRegProfCtrlCallback_ = pRegProfCtrlCallback; | ||||
| pRegProfSetDeviceCallback_ = pRegProfSetDeviceCallback; | pRegProfSetDeviceCallback_ = pRegProfSetDeviceCallback; | ||||
| pRegProfReporterCallback_ = pRegProfReporterCallback; | pRegProfReporterCallback_ = pRegProfReporterCallback; | ||||
| pProfCommandHandle_ = pProfCommandHandle; | pProfCommandHandle_ = pProfCommandHandle; | ||||
| registed_ = true; | |||||
| registered_ = true; | |||||
| ForceMsprofilerInit(); | ForceMsprofilerInit(); | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -49,12 +49,12 @@ class __attribute__((visibility("default"))) VMCallbackRegister { | |||||
| static VMCallbackRegister &GetInstance(); | static VMCallbackRegister &GetInstance(); | ||||
| VMCallbackRegister(const VMCallbackRegister &) = delete; | VMCallbackRegister(const VMCallbackRegister &) = delete; | ||||
| VMCallbackRegister &operator=(const VMCallbackRegister &) = delete; | VMCallbackRegister &operator=(const VMCallbackRegister &) = delete; | ||||
| bool Registe(Status (*pRegProfCtrlCallback)(MsprofCtrlCallback), | |||||
| Status (*pRegProfSetDeviceCallback)(MsprofSetDeviceCallback), | |||||
| Status (*pRegProfReporterCallback)(MsprofReporterCallback), | |||||
| Status (*pProfCommandHandle)(ProfCommandHandleType, void *, uint32_t)); | |||||
| bool Register(Status (*pRegProfCtrlCallback)(MsprofCtrlCallback), | |||||
| Status (*pRegProfSetDeviceCallback)(MsprofSetDeviceCallback), | |||||
| Status (*pRegProfReporterCallback)(MsprofReporterCallback), | |||||
| Status (*pProfCommandHandle)(ProfCommandHandleType, void *, uint32_t)); | |||||
| void ForceMsprofilerInit(); | void ForceMsprofilerInit(); | ||||
| bool registed() { return registed_; } | |||||
| bool registered() { return registered_; } | |||||
| Status DoRegProfCtrlCallback(MsprofCtrlCallback func) { return pRegProfCtrlCallback_(func); } | Status DoRegProfCtrlCallback(MsprofCtrlCallback func) { return pRegProfCtrlCallback_(func); } | ||||
| Status DoRegProfSetDeviceCallback(MsprofSetDeviceCallback func) { return pRegProfSetDeviceCallback_(func); } | Status DoRegProfSetDeviceCallback(MsprofSetDeviceCallback func) { return pRegProfSetDeviceCallback_(func); } | ||||
| Status DoRegProfReporterCallback(MsprofReporterCallback func) { return pRegProfReporterCallback_(func); } | Status DoRegProfReporterCallback(MsprofReporterCallback func) { return pRegProfReporterCallback_(func); } | ||||
| @@ -64,7 +64,7 @@ class __attribute__((visibility("default"))) VMCallbackRegister { | |||||
| private: | private: | ||||
| VMCallbackRegister() | VMCallbackRegister() | ||||
| : registed_(false), | |||||
| : registered_(false), | |||||
| ms_profile_inited_(false), | ms_profile_inited_(false), | ||||
| pRegProfCtrlCallback_(nullptr), | pRegProfCtrlCallback_(nullptr), | ||||
| pRegProfSetDeviceCallback_(nullptr), | pRegProfSetDeviceCallback_(nullptr), | ||||
| @@ -72,7 +72,7 @@ class __attribute__((visibility("default"))) VMCallbackRegister { | |||||
| pProfCommandHandle_(nullptr) {} | pProfCommandHandle_(nullptr) {} | ||||
| ~VMCallbackRegister() = default; | ~VMCallbackRegister() = default; | ||||
| bool registed_; | |||||
| bool registered_; | |||||
| bool ms_profile_inited_; | bool ms_profile_inited_; | ||||
| Status (*pRegProfCtrlCallback_)(MsprofCtrlCallback); | Status (*pRegProfCtrlCallback_)(MsprofCtrlCallback); | ||||
| Status (*pRegProfSetDeviceCallback_)(MsprofSetDeviceCallback); | Status (*pRegProfSetDeviceCallback_)(MsprofSetDeviceCallback); | ||||
| @@ -299,8 +299,8 @@ Status ProfCommandHandle(ProfCommandHandleType type, void *data, uint32_t len) { | |||||
| bool DoRegiste() { | bool DoRegiste() { | ||||
| MS_LOG(INFO) << "VM profiling register start"; | MS_LOG(INFO) << "VM profiling register start"; | ||||
| return VMCallbackRegister::GetInstance().Registe(RegProfCtrlCallback, RegProfSetDeviceCallback, | |||||
| RegProfReporterCallback, ProfCommandHandle); | |||||
| return VMCallbackRegister::GetInstance().Register(RegProfCtrlCallback, RegProfSetDeviceCallback, | |||||
| RegProfReporterCallback, ProfCommandHandle); | |||||
| } | } | ||||
| static bool doRegiste = DoRegiste(); | static bool doRegiste = DoRegiste(); | ||||
| } // namespace ascend | } // namespace ascend | ||||
| @@ -41,6 +41,7 @@ const char kCPUDevice[] = "CPU"; | |||||
| const char kGPUDevice[] = "GPU"; | const char kGPUDevice[] = "GPU"; | ||||
| const char kAscendDevice[] = "Ascend"; | const char kAscendDevice[] = "Ascend"; | ||||
| const char kDavinciInferenceDevice[] = "AscendInference"; | const char kDavinciInferenceDevice[] = "AscendInference"; | ||||
| const char kGpuInferenceDevice[] = "GpuInference"; | |||||
| const char kDavinciDevice[] = "Davinci"; | const char kDavinciDevice[] = "Davinci"; | ||||
| const char KNpuLog[] = "_npu_log"; | const char KNpuLog[] = "_npu_log"; | ||||
| const unsigned int MAX_CALL_DEPTH_DEFAULT = 1000; | const unsigned int MAX_CALL_DEPTH_DEFAULT = 1000; | ||||
| @@ -51,7 +52,7 @@ const float kDefaultMaxDeviceMemory = 1024; | |||||
| // enum definition for MindSpore Context Parameter | // enum definition for MindSpore Context Parameter | ||||
| enum MsCtxParam : unsigned { | enum MsCtxParam : unsigned { | ||||
| // paramater of type bool | |||||
| // parameter of type bool | |||||
| MS_CTX_TYPE_BOOL_BEGIN, | MS_CTX_TYPE_BOOL_BEGIN, | ||||
| MS_CTX_ENABLE_AUTO_MIXED_PRECISION = MS_CTX_TYPE_BOOL_BEGIN, | MS_CTX_ENABLE_AUTO_MIXED_PRECISION = MS_CTX_TYPE_BOOL_BEGIN, | ||||
| MS_CTX_CHECK_BPROP_FLAG, | MS_CTX_CHECK_BPROP_FLAG, | ||||
| @@ -74,14 +75,15 @@ enum MsCtxParam : unsigned { | |||||
| MS_CTX_ENABLE_PROFILING, | MS_CTX_ENABLE_PROFILING, | ||||
| MS_CTX_SAVE_GRAPHS_FLAG, | MS_CTX_SAVE_GRAPHS_FLAG, | ||||
| MS_CTX_ENABLE_PARALLEL_SPLIT, | MS_CTX_ENABLE_PARALLEL_SPLIT, | ||||
| MS_CTX_ENABLE_INFER_OPT, | |||||
| MS_CTX_TYPE_BOOL_END, | MS_CTX_TYPE_BOOL_END, | ||||
| // paramater of type int | |||||
| // parameter of type int | |||||
| MS_CTX_TYPE_INT_BEGIN = MS_CTX_TYPE_BOOL_END, | MS_CTX_TYPE_INT_BEGIN = MS_CTX_TYPE_BOOL_END, | ||||
| MS_CTX_EXECUTION_MODE = MS_CTX_TYPE_INT_BEGIN, | MS_CTX_EXECUTION_MODE = MS_CTX_TYPE_INT_BEGIN, | ||||
| MS_CTX_TYPE_INT_END, | MS_CTX_TYPE_INT_END, | ||||
| // paramater of type uint32 | |||||
| // parameter of type uint32 | |||||
| MS_CTX_TYPE_UINT32_BEGIN = MS_CTX_TYPE_INT_END, | MS_CTX_TYPE_UINT32_BEGIN = MS_CTX_TYPE_INT_END, | ||||
| MS_CTX_DEVICE_ID = MS_CTX_TYPE_UINT32_BEGIN, | MS_CTX_DEVICE_ID = MS_CTX_TYPE_UINT32_BEGIN, | ||||
| MS_CTX_GE_REF, | MS_CTX_GE_REF, | ||||
| @@ -89,12 +91,12 @@ enum MsCtxParam : unsigned { | |||||
| MS_CTX_TSD_REF, | MS_CTX_TSD_REF, | ||||
| MS_CTX_TYPE_UINT32_END, | MS_CTX_TYPE_UINT32_END, | ||||
| // paramater of type float | |||||
| // parameter of type float | |||||
| MS_CTX_TYPE_FLOAT_BEGIN = MS_CTX_TYPE_UINT32_END, | MS_CTX_TYPE_FLOAT_BEGIN = MS_CTX_TYPE_UINT32_END, | ||||
| MS_CTX_MAX_DEVICE_MEMORY = MS_CTX_TYPE_FLOAT_BEGIN, | MS_CTX_MAX_DEVICE_MEMORY = MS_CTX_TYPE_FLOAT_BEGIN, | ||||
| MS_CTX_TYPE_FLOAT_END, | MS_CTX_TYPE_FLOAT_END, | ||||
| // paramater of type string | |||||
| // parameter of type string | |||||
| MS_CTX_TYPE_STRING_BEGIN = MS_CTX_TYPE_FLOAT_END, | MS_CTX_TYPE_STRING_BEGIN = MS_CTX_TYPE_FLOAT_END, | ||||
| MS_CTX_DEVICE_TARGET = MS_CTX_TYPE_STRING_BEGIN, | MS_CTX_DEVICE_TARGET = MS_CTX_TYPE_STRING_BEGIN, | ||||
| MS_CTX_GRAPH_MEMORY_MAX_SIZE, | MS_CTX_GRAPH_MEMORY_MAX_SIZE, | ||||