| @@ -104,6 +104,7 @@ class MS_API Buffer { | |||
| extern MS_API const char *kDeviceTypeAscend310; | |||
| extern MS_API const char *kDeviceTypeAscend910; | |||
| extern MS_API const char *kDeviceTypeGpu; | |||
| constexpr auto kModelOptionDumpCfgPath = "mindspore.option.dump_config_file_path"; | |||
| constexpr auto kModelOptionInsertOpCfgPath = "mindspore.option.insert_op_config_file_path"; // aipp config file | |||
| @@ -13,7 +13,7 @@ if(CMAKE_SYSTEM_NAME MATCHES "Darwin") | |||
| endif() | |||
| if(ENABLE_GPU) | |||
| file(GLOB_RECURSE _GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "gpu_session.cc") | |||
| file(GLOB_RECURSE _GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "gpu_session.cc" "gpu_inference_session.cc") | |||
| list(APPEND _SESSION_SRC_LIST ${_GPU_SRC_LIST}) | |||
| endif() | |||
| @@ -0,0 +1,217 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <algorithm> | |||
| #include "backend/session/gpu_inference_session.h" | |||
| #include "ir/tensor.h" | |||
| #include "ir/anf.h" | |||
| #include "ir/param_info.h" | |||
| #include "runtime/device/kernel_runtime.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "utils/ms_utils.h" | |||
| #include "common/trans.h" | |||
| #include "utils/config_manager.h" | |||
| namespace mindspore { | |||
| namespace session { | |||
| void GpuInferenceSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, | |||
| const std::vector<tensor::TensorPtr> &inputs_const) const { | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| std::vector<tensor::TensorPtr> inputs(inputs_const); | |||
| auto input_nodes = kernel_graph->inputs(); | |||
| size_t no_weight_input = 0; | |||
| for (size_t i = 0; i < input_nodes.size(); ++i) { | |||
| tensor::TensorPtr tensor = nullptr; | |||
| if (!input_nodes[i]->isa<Parameter>()) { | |||
| MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter"; | |||
| continue; | |||
| } | |||
| auto pk_node = input_nodes[i]->cast<ParameterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(pk_node); | |||
| auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); | |||
| MS_EXCEPTION_IF_NULL(device_address); | |||
| if (!AnfAlgo::IsParameterWeight(pk_node)) { | |||
| tensor = inputs[no_weight_input++]; | |||
| if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), | |||
| LongToSize(tensor->data().nbytes()), tensor->data_type(), | |||
| tensor->data_c())) { | |||
| MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| GraphId GpuInferenceSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) { | |||
| auto graph_id = GPUSession::CompileGraphImpl(func_graph); | |||
| auto kernel_graph = GetGraph(graph_id); | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| // load weight data to device | |||
| auto input_nodes = kernel_graph->inputs(); | |||
| for (size_t i = 0; i < input_nodes.size(); ++i) { | |||
| if (!input_nodes[i]->isa<Parameter>()) { | |||
| MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter"; | |||
| continue; | |||
| } | |||
| auto pk_node = input_nodes[i]->cast<ParameterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(pk_node); | |||
| auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); | |||
| MS_EXCEPTION_IF_NULL(device_address); | |||
| if (AnfAlgo::IsParameterWeight(pk_node)) { | |||
| const auto ¶m_value = pk_node->default_param(); | |||
| MS_EXCEPTION_IF_NULL(param_value); | |||
| auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param_value); | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), | |||
| LongToSize(tensor->data().nbytes()), tensor->data_type(), | |||
| tensor->data_c())) { | |||
| MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; | |||
| } | |||
| } | |||
| } | |||
| return graph_id; | |||
| } | |||
| bool GpuInferenceSession::CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs, | |||
| std::string *error_msg) const { | |||
| MS_LOG(INFO) << "Start check client inputs, graph id : " << graph_id; | |||
| auto kernel_graph = GetGraph(graph_id); | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| auto kernel_graph_inputs = kernel_graph->inputs(); | |||
| size_t no_weight_input = 0; | |||
| vector<ParameterPtr> paras; | |||
| // find parameters of graph inputs | |||
| for (size_t i = 0; i < kernel_graph_inputs.size(); ++i) { | |||
| if (!kernel_graph_inputs[i]->isa<Parameter>()) { | |||
| MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter."; | |||
| continue; | |||
| } | |||
| auto parameter = kernel_graph_inputs[i]->cast<ParameterPtr>(); | |||
| if (!AnfAlgo::IsParameterWeight(parameter)) { | |||
| paras.push_back(parameter); | |||
| } | |||
| } | |||
| // check inputs | |||
| for (size_t i = 0; i < paras.size(); ++i) { | |||
| // compare input number | |||
| if (paras.size() != inputs.size()) { | |||
| MS_LOG(ERROR) << "Input number is inconsistent. The actual input number [" << inputs.size() | |||
| << "] but the graph input number is [" << paras.size() << "]"; | |||
| MS_LOG(ERROR) << "InputsInfo --" << InputsInfo(paras, inputs); | |||
| if (error_msg != nullptr) { | |||
| std::stringstream str_stream; | |||
| str_stream << "Input number is inconsistent. The given input number [" << inputs.size() | |||
| << "] but the graph input number is [" << paras.size() << "]\n"; | |||
| str_stream << "InputsInfo --" << InputsInfo(paras, inputs); | |||
| *error_msg = str_stream.str(); | |||
| } | |||
| return false; | |||
| } | |||
| auto input = inputs[no_weight_input++]; | |||
| if (!CompareInput(input, paras[i])) { | |||
| MS_LOG(ERROR) << "Please check the input information."; | |||
| MS_LOG(ERROR) << "InputsInfo --" << InputsInfo(paras, inputs); | |||
| if (error_msg != nullptr) { | |||
| std::stringstream str_stream; | |||
| str_stream << "Please check the input information.\n"; | |||
| str_stream << "InputsInfo --" << InputsInfo(paras, inputs); | |||
| *error_msg = str_stream.str(); | |||
| } | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool GpuInferenceSession::CompareInput(const tensor::TensorPtr &input, const ParameterPtr ¶meter) const { | |||
| MS_EXCEPTION_IF_NULL(input); | |||
| MS_EXCEPTION_IF_NULL(parameter); | |||
| // compare dims | |||
| auto parameter_shape = AnfAlgo::GetOutputDeviceShape(parameter, 0); | |||
| // compare shape | |||
| auto input_shape = input->shape(); | |||
| vector<size_t> trans_input; | |||
| (void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(trans_input), | |||
| [](const int64_t dim) { return static_cast<size_t>(dim); }); | |||
| auto is_scalar_shape = [](const vector<size_t> &shape) { | |||
| return shape.empty() || (shape.size() == 1 && shape[0] == 1); | |||
| }; | |||
| if ((!is_scalar_shape(trans_input) || !is_scalar_shape(parameter_shape)) && (trans_input != parameter_shape)) { | |||
| MS_LOG(ERROR) << "Input shape is inconsistent. The actual shape is " << PrintInputShape(trans_input) | |||
| << ", but the parameter shape is " << PrintInputShape(parameter_shape) | |||
| << ". parameter : " << parameter->DebugString(); | |||
| return false; | |||
| } | |||
| // compare data type | |||
| auto kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(parameter); | |||
| if (input->data_type() != kernel_build_info->GetOutputDeviceType(0)) { | |||
| MS_LOG(ERROR) << "Input data type is inconsistent. The actual data type is " << input->data_type() | |||
| << ", but the parameter data type is " << kernel_build_info->GetOutputDeviceType(0) | |||
| << ". parameter : " << parameter->DebugString(); | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| template <typename T> | |||
| std::string GpuInferenceSession::PrintInputShape(std::vector<T> shape) const { | |||
| string res = "["; | |||
| for (auto dim : shape) { | |||
| res += " " + std::to_string(dim); | |||
| } | |||
| return res + " ]"; | |||
| } | |||
| std::string GpuInferenceSession::InputsInfo(const std::vector<ParameterPtr> ¶s, | |||
| const std::vector<tensor::TensorPtr> &inputs) const { | |||
| const std::map<TypeId, std::string> dtype_name_map{ | |||
| {TypeId::kNumberTypeBegin, "Unknown"}, {TypeId::kNumberTypeBool, "Bool"}, | |||
| {TypeId::kNumberTypeFloat64, "Float64"}, {TypeId::kNumberTypeInt8, "Int8"}, | |||
| {TypeId::kNumberTypeUInt8, "Uint8"}, {TypeId::kNumberTypeInt16, "Int16"}, | |||
| {TypeId::kNumberTypeUInt16, "Uint16"}, {TypeId::kNumberTypeInt32, "Int32"}, | |||
| {TypeId::kNumberTypeUInt32, "Uint32"}, {TypeId::kNumberTypeInt64, "Int64"}, | |||
| {TypeId::kNumberTypeUInt64, "Uint64"}, {TypeId::kNumberTypeFloat16, "Float16"}, | |||
| {TypeId::kNumberTypeFloat32, "Float32"}, | |||
| }; | |||
| auto data_type_to_string = [&dtype_name_map](TypeId type_id) { | |||
| auto it = dtype_name_map.find(type_id); | |||
| if (it == dtype_name_map.end()) { | |||
| return std::string("Unknown"); | |||
| } | |||
| return it->second; | |||
| }; | |||
| std::string graph = "graph inputs:{ "; | |||
| for (size_t i = 0; i < paras.size(); ++i) { | |||
| auto ¶ = paras[i]; | |||
| graph += std::to_string(i) + ": dims " + std::to_string(AnfAlgo::GetOutputDeviceShape(para, 0).size()) + | |||
| ", shape " + PrintInputShape(AnfAlgo::GetOutputDeviceShape(para, 0)) + ", data type " + | |||
| data_type_to_string(AnfAlgo::GetSelectKernelBuildInfo(para)->GetOutputDeviceType(0)) + " }"; | |||
| } | |||
| std::string actual = "given inputs:{ "; | |||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||
| actual += std::to_string(i) + ": dims " + std::to_string(inputs[i]->shape().size()) + ", shape " + | |||
| PrintInputShape(inputs[i]->shape()) + ", data type " + data_type_to_string(inputs[i]->data_type()) + " }"; | |||
| } | |||
| return graph + " " + actual; | |||
| } | |||
| } // namespace session | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,53 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_SESSION_GPU_INFERENCE_SESSION_H | |||
| #define MINDSPORE_CCSRC_BACKEND_SESSION_GPU_INFERENCE_SESSION_H | |||
| #include <unordered_map> | |||
| #include <string> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <utility> | |||
| #include <stack> | |||
| #include <map> | |||
| #include <tuple> | |||
| #include <set> | |||
| #include "backend/session/gpu_session.h" | |||
| #include "backend/session/kernel_graph.h" | |||
| #include "backend/kernel_compiler/kernel.h" | |||
| #include "backend/session/session_factory.h" | |||
| namespace mindspore { | |||
| namespace session { | |||
| class GpuInferenceSession : public gpu::GPUSession { | |||
| public: | |||
| GpuInferenceSession() = default; | |||
| ~GpuInferenceSession() = default; | |||
| void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, | |||
| const std::vector<tensor::TensorPtr> &inputs_const) const; | |||
| bool CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs, | |||
| std::string *error_msg) const override; | |||
| bool CompareInput(const tensor::TensorPtr &input, const ParameterPtr ¶meter) const; | |||
| template <typename T> | |||
| std::string PrintInputShape(std::vector<T> shape) const; | |||
| std::string InputsInfo(const std::vector<ParameterPtr> ¶s, const std::vector<tensor::TensorPtr> &inputs) const; | |||
| protected: | |||
| GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) override; | |||
| }; | |||
| MS_REG_SESSION(kGpuInferenceDevice, GpuInferenceSession); | |||
| } // namespace session | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_SESSION_GPU_INFERENCE_SESSION_H | |||
| @@ -15,9 +15,11 @@ | |||
| */ | |||
| #include "backend/session/gpu_session.h" | |||
| #include <string> | |||
| #include "backend/optimizer/common/helper.h" | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| #include "backend/optimizer/common/pass_manager.h" | |||
| #include "backend/optimizer/common/common_backend_optimization.h" | |||
| #include "backend/optimizer/gpu/adam_weight_decay_fusion.h" | |||
| #include "backend/optimizer/gpu/adam_fusion.h" | |||
| #include "backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h" | |||
| @@ -298,16 +300,31 @@ void GPUSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const | |||
| GraphId GPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { | |||
| // Construct graph, if successfully, graph_sum_ + 1 | |||
| auto graph_id = graph_sum_; | |||
| auto graph = ConstructKernelGraph(lst, outputs); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| return CompileGraphImpl(graph); | |||
| } | |||
| GraphId GPUSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) { | |||
| std::vector<KernelGraphPtr> all_graphs; | |||
| auto root_graph = ConstructKernelGraph(func_graph, &all_graphs); | |||
| MS_EXCEPTION_IF_NULL(root_graph); | |||
| if (all_graphs.size() != 1) { | |||
| MS_LOG(EXCEPTION) << "Gpu backend does not support multi-graph schedule. graph num" << all_graphs.size(); | |||
| } | |||
| opt::BackendCommonOptimization(root_graph); | |||
| return CompileGraphImpl(root_graph); | |||
| } | |||
| GraphId GPUSession::CompileGraphImpl(KernelGraphPtr graph) { | |||
| // Prepare ms context info for dump .pb graph | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG); | |||
| // Dump .pb graph before graph optimization | |||
| if (save_graphs) { | |||
| DumpIRProto(graph, "before_opt_" + std::to_string(graph_id)); | |||
| DumpIRProto(graph, "before_opt_" + std::to_string(graph->graph_id())); | |||
| } | |||
| // Graph optimization irrelevant to device data format | |||
| Optimize(graph); | |||
| @@ -326,7 +343,7 @@ GraphId GPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr | |||
| AssignStream(graph); | |||
| // Dump .pb graph before remove nop nodes | |||
| if (save_graphs) { | |||
| DumpIRProto(graph, "before_removeNop_" + std::to_string(graph_id)); | |||
| DumpIRProto(graph, "before_removeNop_" + std::to_string(graph->graph_id())); | |||
| } | |||
| // Update Graph Dynamic Shape Attr. | |||
| UpdateGraphDynamicShapeAttr(NOT_NULL(graph)); | |||
| @@ -343,7 +360,7 @@ GraphId GPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr | |||
| SetSummaryNodes(graph.get()); | |||
| // Dump .pb graph after graph optimization | |||
| if (save_graphs) { | |||
| DumpIRProto(graph, "after_opt_" + std::to_string(graph_id)); | |||
| DumpIRProto(graph, "after_opt_" + std::to_string(graph->graph_id())); | |||
| } | |||
| // Set graph manager. | |||
| MS_EXCEPTION_IF_NULL(context_); | |||
| @@ -361,9 +378,8 @@ GraphId GPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr | |||
| debugger_->LoadGraphs(graph); | |||
| } | |||
| #endif | |||
| MS_LOG(INFO) << "CompileGraph graph_id: " << graph_id; | |||
| return graph_id; | |||
| MS_LOG(INFO) << "CompileGraph graph_id: " << graph->graph_id(); | |||
| return graph->graph_id(); | |||
| } | |||
| void GPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | |||
| @@ -37,6 +37,7 @@ class GPUSession : public SessionBasic { | |||
| protected: | |||
| void UnifyMindIR(const KernelGraphPtr &graph) override { return; } | |||
| GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; | |||
| GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) override; | |||
| void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override; | |||
| void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, | |||
| @@ -81,6 +82,8 @@ class GPUSession : public SessionBasic { | |||
| void SyncValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph) const; | |||
| void CleanValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph) const; | |||
| GraphId CompileGraphImpl(KernelGraphPtr kernel_graph); | |||
| }; | |||
| using GPUSessionPtr = std::shared_ptr<GPUSession>; | |||
| MS_REG_SESSION(kGPUDevice, GPUSession); | |||
| @@ -15,10 +15,15 @@ if(ENABLE_ACL) | |||
| "model/model_converter_utils/*.cc" | |||
| "graph/acl/*.cc" | |||
| ) | |||
| endif() | |||
| if(ENABLE_D) | |||
| file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR} "python_utils.cc" "model/ms/*.cc" "graph/ms/*.cc") | |||
| file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR} | |||
| "python_utils.cc" "model/ms/*.cc" "graph/ascend/*.cc") | |||
| endif() | |||
| if(ENABLE_GPU) | |||
| file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR} "python_utils.cc" "model/ms/*.cc" "graph/gpu/*.cc") | |||
| endif() | |||
| set(MSLIB_SRC ${CMAKE_CURRENT_SOURCE_DIR}/types.cc | |||
| @@ -98,6 +103,15 @@ if(ENABLE_D) | |||
| target_link_libraries(mindspore_shared_lib PRIVATE ${adump_server}) | |||
| endif() | |||
| if(ENABLE_GPU) | |||
| target_link_libraries(mindspore_shared_lib PRIVATE gpu_cuda_lib gpu_queue cublas | |||
| ${CUDA_PATH}/lib64/libcurand.so | |||
| ${CUDNN_LIBRARY_PATH} | |||
| ${CUDA_PATH}/lib64/libcudart.so | |||
| ${CUDA_PATH}/lib64/stubs/libcuda.so | |||
| ${CUDA_PATH}/lib64/libcusolver.so) | |||
| endif() | |||
| if(CMAKE_SYSTEM_NAME MATCHES "Linux") | |||
| set(MINDSPORE_RPATH $ORIGIN) | |||
| if(ENABLE_D) | |||
| @@ -110,7 +124,8 @@ if(CMAKE_SYSTEM_NAME MATCHES "Linux") | |||
| set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/add-ons) | |||
| set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/opp/op_impl/built-in/ai_core/tbe/op_tiling) | |||
| set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/nnae/latest/opp/op_impl/built-in/ai_core/tbe/op_tiling) | |||
| set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe/op_tiling) | |||
| set(MINDSPORE_RPATH | |||
| ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe/op_tiling) | |||
| elseif(ENABLE_ACL) | |||
| set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/nnae/latest/atc/lib64) | |||
| set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/atc/lib64) | |||
| @@ -121,7 +136,8 @@ if(CMAKE_SYSTEM_NAME MATCHES "Linux") | |||
| set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/add-ons) | |||
| set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/opp/op_impl/built-in/ai_core/tbe/op_tiling) | |||
| set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/nnae/latest/opp/op_impl/built-in/ai_core/tbe/op_tiling) | |||
| set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe/op_tiling) | |||
| set(MINDSPORE_RPATH | |||
| ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe/op_tiling) | |||
| endif() | |||
| set_target_properties(mindspore_shared_lib PROPERTIES INSTALL_RPATH ${MINDSPORE_RPATH}) | |||
| @@ -13,7 +13,7 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "cxx_api/graph/ms/ms_graph_impl.h" | |||
| #include "cxx_api/graph/ascend/ascend_graph_impl.h" | |||
| #include <algorithm> | |||
| #include "include/api/context.h" | |||
| #include "cxx_api/factory.h" | |||
| @@ -26,43 +26,9 @@ | |||
| #include "runtime/device/kernel_runtime_manager.h" | |||
| namespace mindspore::api { | |||
| API_FACTORY_REG(GraphCell::GraphImpl, Ascend910, MsGraphImpl); | |||
| static DataType TransTypeId2InferDataType(TypeId type_id) { | |||
| const std::map<TypeId, api::DataType> id2type_map{ | |||
| {TypeId::kNumberTypeBegin, api::kMsUnknown}, {TypeId::kNumberTypeBool, api::kMsBool}, | |||
| {TypeId::kNumberTypeFloat64, api::kMsFloat64}, {TypeId::kNumberTypeInt8, api::kMsInt8}, | |||
| {TypeId::kNumberTypeUInt8, api::kMsUint8}, {TypeId::kNumberTypeInt16, api::kMsInt16}, | |||
| {TypeId::kNumberTypeUInt16, api::kMsUint16}, {TypeId::kNumberTypeInt32, api::kMsInt32}, | |||
| {TypeId::kNumberTypeUInt32, api::kMsUint32}, {TypeId::kNumberTypeInt64, api::kMsInt64}, | |||
| {TypeId::kNumberTypeUInt64, api::kMsUint64}, {TypeId::kNumberTypeFloat16, api::kMsFloat16}, | |||
| {TypeId::kNumberTypeFloat32, api::kMsFloat32}, | |||
| }; | |||
| // cppcheck-suppress stlIfFind | |||
| if (auto it = id2type_map.find(type_id); it != id2type_map.end()) { | |||
| return it->second; | |||
| } | |||
| MS_LOG(WARNING) << "Unsupported data id " << type_id; | |||
| return api::kMsUnknown; | |||
| } | |||
| template <class T> | |||
| inline static void ClearIfNotNull(T *vec) { | |||
| if (vec != nullptr) { | |||
| vec->clear(); | |||
| } | |||
| } | |||
| template <class T, class U = std::vector<T>> | |||
| inline static void PushbackIfNotNull(U *vec, T &&item) { | |||
| if (vec != nullptr) { | |||
| vec->emplace_back(item); | |||
| } | |||
| } | |||
| API_FACTORY_REG(GraphCell::GraphImpl, Ascend910, AscendGraphImpl); | |||
| MsGraphImpl::MsGraphImpl() | |||
| AscendGraphImpl::AscendGraphImpl() | |||
| : session_impl_(nullptr), | |||
| graph_id_(0), | |||
| device_type_("Ascend"), | |||
| @@ -75,9 +41,9 @@ MsGraphImpl::MsGraphImpl() | |||
| init_flag_(false), | |||
| load_flag_(false) {} | |||
| MsGraphImpl::~MsGraphImpl() { (void)FinalizeEnv(); } | |||
| AscendGraphImpl::~AscendGraphImpl() { (void)FinalizeEnv(); } | |||
| Status MsGraphImpl::InitEnv() { | |||
| Status AscendGraphImpl::InitEnv() { | |||
| if (init_flag_) { | |||
| return SUCCESS; | |||
| } | |||
| @@ -108,7 +74,7 @@ Status MsGraphImpl::InitEnv() { | |||
| return SUCCESS; | |||
| } | |||
| Status MsGraphImpl::FinalizeEnv() { | |||
| Status AscendGraphImpl::FinalizeEnv() { | |||
| if (!init_flag_) { | |||
| return SUCCESS; | |||
| } | |||
| @@ -136,7 +102,7 @@ Status MsGraphImpl::FinalizeEnv() { | |||
| return SUCCESS; | |||
| } | |||
| Status MsGraphImpl::CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr) { | |||
| Status AscendGraphImpl::CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr) { | |||
| MS_ASSERT(session_impl_ != nullptr); | |||
| try { | |||
| graph_id_ = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr)); | |||
| @@ -147,7 +113,7 @@ Status MsGraphImpl::CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr) | |||
| } | |||
| } | |||
| std::vector<tensor::TensorPtr> MsGraphImpl::RunGraph(const std::vector<tensor::TensorPtr> &inputs) { | |||
| std::vector<tensor::TensorPtr> AscendGraphImpl::RunGraph(const std::vector<tensor::TensorPtr> &inputs) { | |||
| try { | |||
| VectorRef outputs; | |||
| session_impl_->RunGraph(graph_id_, inputs, &outputs); | |||
| @@ -158,7 +124,7 @@ std::vector<tensor::TensorPtr> MsGraphImpl::RunGraph(const std::vector<tensor::T | |||
| } | |||
| } | |||
| Status MsGraphImpl::CheckModelInputs(const std::vector<tensor::TensorPtr> &inputs) const { | |||
| Status AscendGraphImpl::CheckModelInputs(const std::vector<tensor::TensorPtr> &inputs) const { | |||
| MS_ASSERT(session_impl_ != nullptr); | |||
| std::string error_msg; | |||
| if (!session_impl_->CheckModelInputs(graph_id_, inputs, &error_msg)) { | |||
| @@ -167,7 +133,7 @@ Status MsGraphImpl::CheckModelInputs(const std::vector<tensor::TensorPtr> &input | |||
| return SUCCESS; | |||
| } | |||
| Status MsGraphImpl::ExecuteModel(const std::vector<Buffer> &request, std::vector<Buffer> *reply) { | |||
| Status AscendGraphImpl::ExecuteModel(const std::vector<Buffer> &request, std::vector<Buffer> *reply) { | |||
| MS_EXCEPTION_IF_NULL(reply); | |||
| if (context_ == nullptr) { | |||
| MS_LOG(ERROR) << "rtCtx is nullptr"; | |||
| @@ -206,8 +172,8 @@ Status MsGraphImpl::ExecuteModel(const std::vector<Buffer> &request, std::vector | |||
| return SUCCESS; | |||
| } | |||
| Status MsGraphImpl::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) { | |||
| Status AscendGraphImpl::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) { | |||
| if (!load_flag_) { | |||
| Status ret = Load(); | |||
| if (ret != SUCCESS) { | |||
| @@ -216,22 +182,22 @@ Status MsGraphImpl::GetInputsInfo(std::vector<std::string> *names, std::vector<s | |||
| } | |||
| } | |||
| ClearIfNotNull(names); | |||
| ClearIfNotNull(shapes); | |||
| ClearIfNotNull(data_types); | |||
| ClearIfNotNull(mem_sizes); | |||
| GraphUtils::ClearIfNotNull(names); | |||
| GraphUtils::ClearIfNotNull(shapes); | |||
| GraphUtils::ClearIfNotNull(data_types); | |||
| GraphUtils::ClearIfNotNull(mem_sizes); | |||
| for (size_t i = 0; i < inputs_.size(); i++) { | |||
| auto &tensor = inputs_[i]; | |||
| PushbackIfNotNull(names, input_names_[i]); | |||
| PushbackIfNotNull(shapes, tensor->shape()); | |||
| PushbackIfNotNull(data_types, TransTypeId2InferDataType(tensor->data_type())); | |||
| PushbackIfNotNull(mem_sizes, tensor->Size()); | |||
| GraphUtils::PushbackIfNotNull(names, input_names_[i]); | |||
| GraphUtils::PushbackIfNotNull(shapes, tensor->shape()); | |||
| GraphUtils::PushbackIfNotNull(data_types, GraphUtils::TransTypeId2InferDataType(tensor->data_type())); | |||
| GraphUtils::PushbackIfNotNull(mem_sizes, tensor->Size()); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status MsGraphImpl::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) { | |||
| Status AscendGraphImpl::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) { | |||
| if (!load_flag_) { | |||
| Status ret = Load(); | |||
| if (ret != SUCCESS) { | |||
| @@ -240,22 +206,22 @@ Status MsGraphImpl::GetOutputsInfo(std::vector<std::string> *names, std::vector< | |||
| } | |||
| } | |||
| ClearIfNotNull(names); | |||
| ClearIfNotNull(shapes); | |||
| ClearIfNotNull(data_types); | |||
| ClearIfNotNull(mem_sizes); | |||
| GraphUtils::ClearIfNotNull(names); | |||
| GraphUtils::ClearIfNotNull(shapes); | |||
| GraphUtils::ClearIfNotNull(data_types); | |||
| GraphUtils::ClearIfNotNull(mem_sizes); | |||
| for (size_t i = 0; i < outputs_.size(); i++) { | |||
| auto &tensor = outputs_[i]; | |||
| PushbackIfNotNull(names, output_names_[i]); | |||
| PushbackIfNotNull(shapes, tensor->shape()); | |||
| PushbackIfNotNull(data_types, TransTypeId2InferDataType(tensor->data_type())); | |||
| PushbackIfNotNull(mem_sizes, tensor->Size()); | |||
| GraphUtils::PushbackIfNotNull(names, output_names_[i]); | |||
| GraphUtils::PushbackIfNotNull(shapes, tensor->shape()); | |||
| GraphUtils::PushbackIfNotNull(data_types, GraphUtils::TransTypeId2InferDataType(tensor->data_type())); | |||
| GraphUtils::PushbackIfNotNull(mem_sizes, tensor->Size()); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status MsGraphImpl::Load() { | |||
| Status AscendGraphImpl::Load() { | |||
| // check graph type | |||
| if (graph_->ModelType() != ModelType::kMindIR) { | |||
| MS_LOG(ERROR) << "Unsupported model type " << graph_->ModelType(); | |||
| @@ -311,7 +277,7 @@ Status MsGraphImpl::Load() { | |||
| return SUCCESS; | |||
| } | |||
| Status MsGraphImpl::Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) { | |||
| Status AscendGraphImpl::Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) { | |||
| MS_EXCEPTION_IF_NULL(outputs); | |||
| if (!load_flag_) { | |||
| Status ret = Load(); | |||
| @@ -13,8 +13,8 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_MS_MS_GRAPH_IMPL_H | |||
| #define MINDSPORE_CCSRC_CXX_API_GRAPH_MS_MS_GRAPH_IMPL_H | |||
| #ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_MS_ASCEND_GRAPH_IMPL_H | |||
| #define MINDSPORE_CCSRC_CXX_API_GRAPH_MS_ASCEND_GRAPH_IMPL_H | |||
| #include <functional> | |||
| #include <map> | |||
| #include <string> | |||
| @@ -28,12 +28,13 @@ | |||
| #include "ir/anf.h" | |||
| #include "cxx_api/model/model_impl.h" | |||
| #include "runtime/context.h" | |||
| #include "cxx_api/graph/graph_utils.h" | |||
| namespace mindspore::api { | |||
| class MsGraphImpl : public GraphCell::GraphImpl { | |||
| class AscendGraphImpl : public GraphCell::GraphImpl { | |||
| public: | |||
| MsGraphImpl(); | |||
| ~MsGraphImpl() override; | |||
| AscendGraphImpl(); | |||
| ~AscendGraphImpl() override; | |||
| Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override; | |||
| Status Load() override; | |||
| @@ -63,4 +64,4 @@ class MsGraphImpl : public GraphCell::GraphImpl { | |||
| bool load_flag_; | |||
| }; | |||
| } // namespace mindspore::api | |||
| #endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_MS_GRAPH_IMPL_H | |||
| #endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_ASCEND_GRAPH_IMPL_H | |||
| @@ -20,10 +20,10 @@ VMCallbackRegister &VMCallbackRegister::GetInstance() { | |||
| return instance; | |||
| } | |||
| bool VMCallbackRegister::Registe(Status (*pRegProfCtrlCallback)(MsprofCtrlCallback), | |||
| Status (*pRegProfSetDeviceCallback)(MsprofSetDeviceCallback), | |||
| Status (*pRegProfReporterCallback)(MsprofReporterCallback), | |||
| Status (*pProfCommandHandle)(ProfCommandHandleType, void *, uint32_t)) { | |||
| bool VMCallbackRegister::Register(Status (*pRegProfCtrlCallback)(MsprofCtrlCallback), | |||
| Status (*pRegProfSetDeviceCallback)(MsprofSetDeviceCallback), | |||
| Status (*pRegProfReporterCallback)(MsprofReporterCallback), | |||
| Status (*pProfCommandHandle)(ProfCommandHandleType, void *, uint32_t)) { | |||
| return false; | |||
| } | |||
| @@ -0,0 +1,256 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "cxx_api/graph/gpu/gpu_graph_impl.h" | |||
| #include <algorithm> | |||
| #include "include/api/context.h" | |||
| #include "cxx_api/factory.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "mindspore/core/base/base_ref_utils.h" | |||
| #include "backend/session/session_factory.h" | |||
| #include "backend/session/executor_manager.h" | |||
| #include "runtime/device/kernel_runtime_manager.h" | |||
| namespace mindspore::api { | |||
| API_FACTORY_REG(GraphCell::GraphImpl, GPU, GPUGraphImpl); | |||
| GPUGraphImpl::GPUGraphImpl() | |||
| : session_impl_(nullptr), | |||
| graph_id_(0), | |||
| device_id_(Context::Instance().GetDeviceID()), | |||
| inputs_(), | |||
| outputs_(), | |||
| input_names_(), | |||
| output_names_(), | |||
| init_flag_(false), | |||
| load_flag_(false) {} | |||
| Status GPUGraphImpl::InitEnv() { | |||
| if (init_flag_) { | |||
| MS_LOG(WARNING) << "Initialized again, return success."; | |||
| return SUCCESS; | |||
| } | |||
| auto ms_context = MsContext::GetInstance(); | |||
| if (ms_context == nullptr) { | |||
| MS_LOG(ERROR) << "Get Context failed!"; | |||
| return FAILED; | |||
| } | |||
| ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode); | |||
| ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id_); | |||
| ms_context->set_param<std::string>(MS_CTX_DEVICE_TARGET, kGPUDevice); | |||
| ms_context->set_param<bool>(MS_CTX_ENABLE_INFER_OPT, true); | |||
| session_impl_ = session::SessionFactory::Get().Create(kGpuInferenceDevice); | |||
| if (session_impl_ == nullptr) { | |||
| MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << kGpuInferenceDevice | |||
| << " is available."; | |||
| return FAILED; | |||
| } | |||
| session_impl_->Init(device_id_); | |||
| init_flag_ = true; | |||
| return SUCCESS; | |||
| } | |||
| Status GPUGraphImpl::FinalizeEnv() { | |||
| if (!init_flag_) { | |||
| MS_LOG(WARNING) << "Never initialize before, return success"; | |||
| return SUCCESS; | |||
| } | |||
| MS_LOG_INFO << "Start finalize env"; | |||
| session::ExecutorManager::Instance().Clear(); | |||
| device::KernelRuntimeManager::Instance().ClearRuntimeResource(); | |||
| init_flag_ = false; | |||
| MS_LOG(INFO) << "End finalize env"; | |||
| return SUCCESS; | |||
| } | |||
| Status GPUGraphImpl::Load() { | |||
| // check graph type | |||
| if (graph_->ModelType() != ModelType::kMindIR) { | |||
| MS_LOG(ERROR) << "Unsupported model type " << graph_->ModelType(); | |||
| return INVALID_INPUTS; | |||
| } | |||
| const auto &graph_data = GraphImpl::MutableGraphData(); | |||
| MS_EXCEPTION_IF_NULL(graph_data); | |||
| auto func_graph = graph_data->GetFuncGraph(); | |||
| // init | |||
| Status ret = InitEnv(); | |||
| if (ret != SUCCESS) { | |||
| MS_LOG(ERROR) << "InitEnv failed."; | |||
| return FAILED; | |||
| } | |||
| ret = CompileGraph(func_graph); | |||
| if (ret != SUCCESS) { | |||
| MS_LOG(ERROR) << "Compile graph model failed"; | |||
| return FAILED; | |||
| } | |||
| session_impl_->GetModelInputsInfo(graph_id_, &inputs_, &input_names_); | |||
| session_impl_->GetModelOutputsInfo(graph_id_, &outputs_, &output_names_); | |||
| if (inputs_.empty() || inputs_.size() != input_names_.size()) { | |||
| MS_LOG_ERROR << "Get model inputs info failed"; | |||
| return FAILED; | |||
| } | |||
| if (outputs_.empty() || outputs_.size() != output_names_.size()) { | |||
| MS_LOG_ERROR << "Get model outputs info failed"; | |||
| return FAILED; | |||
| } | |||
| load_flag_ = true; | |||
| return SUCCESS; | |||
| } | |||
| Status GPUGraphImpl::CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr) { | |||
| MS_ASSERT(session_impl_ != nullptr); | |||
| try { | |||
| graph_id_ = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr)); | |||
| return SUCCESS; | |||
| } catch (std::exception &e) { | |||
| MS_LOG(ERROR) << "CompileGraph failed: " << e.what(); | |||
| return FAILED; | |||
| } | |||
| } | |||
| std::vector<tensor::TensorPtr> GPUGraphImpl::RunGraph(const std::vector<tensor::TensorPtr> &inputs) { | |||
| try { | |||
| VectorRef outputs; | |||
| session_impl_->RunGraph(graph_id_, inputs, &outputs); | |||
| return TransformVectorRefToMultiTensor(outputs); | |||
| } catch (std::exception &e) { | |||
| MS_LOG(ERROR) << "RunGraph failed: " << e.what(); | |||
| return std::vector<tensor::TensorPtr>(); | |||
| } | |||
| } | |||
| Status GPUGraphImpl::ExecuteModel(const std::vector<Buffer> &request, std::vector<Buffer> *reply) { | |||
| MS_EXCEPTION_IF_NULL(reply); | |||
| vector<tensor::TensorPtr> inputs; | |||
| for (size_t i = 0; i < request.size(); i++) { | |||
| auto &item = request[i]; | |||
| auto input = inputs_[i]; | |||
| if (input->Size() != item.DataSize()) { | |||
| MS_LOG(ERROR) << "Input " << i << " data size " << item.DataSize() << " not match model input data size " | |||
| << input->Size(); | |||
| return FAILED; | |||
| } | |||
| auto ret = memcpy_s(input->data_c(), input->Size(), item.Data(), item.DataSize()); | |||
| if (ret != SUCCESS) { | |||
| MS_LOG(ERROR) << "Tensor copy failed"; | |||
| return FAILED; | |||
| } | |||
| inputs.push_back(input); | |||
| } | |||
| vector<tensor::TensorPtr> outputs = RunGraph(inputs); | |||
| if (outputs.empty()) { | |||
| MS_LOG(ERROR) << "Execute Model Failed"; | |||
| return FAILED; | |||
| } | |||
| reply->clear(); | |||
| std::transform(outputs.begin(), outputs.end(), std::back_inserter(*reply), | |||
| [](const tensor::TensorPtr &tensor) { return Buffer(tensor->data_c(), tensor->Size()); }); | |||
| return SUCCESS; | |||
| } | |||
| Status GPUGraphImpl::Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) { | |||
| MS_EXCEPTION_IF_NULL(outputs); | |||
| if (!load_flag_) { | |||
| Status ret = Load(); | |||
| if (ret != SUCCESS) { | |||
| MS_LOG(ERROR) << "PrepareModel failed."; | |||
| return ret; | |||
| } | |||
| } | |||
| if (inputs.size() != inputs_.size()) { | |||
| MS_LOG(ERROR) << "inputs count not match, required count " << inputs_.size() << ", given count " << inputs.size(); | |||
| return INVALID_INPUTS; | |||
| } | |||
| for (size_t i = 0; i < inputs_.size(); ++i) { | |||
| if (inputs[i].DataSize() != inputs_[i]->Size()) { | |||
| MS_LOG(ERROR) << "input " << i << " data size not match, required size " << inputs_[i]->Size() << ", given count " | |||
| << inputs[i].DataSize(); | |||
| return INVALID_INPUTS; | |||
| } | |||
| } | |||
| if (ExecuteModel(inputs, outputs) != SUCCESS) { | |||
| MS_LOG(ERROR) << "Execute Model Failed"; | |||
| return FAILED; | |||
| } | |||
| if (outputs_.size() != outputs->size()) { | |||
| MS_LOG(ERROR) << "Predict output size " << outputs->size() << " not match output size got from model info " | |||
| << outputs_.size(); | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status GPUGraphImpl::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) { | |||
| if (!load_flag_) { | |||
| Status ret = Load(); | |||
| if (ret != SUCCESS) { | |||
| MS_LOG(ERROR) << "PrepareModel failed."; | |||
| return ret; | |||
| } | |||
| } | |||
| GraphUtils::ClearIfNotNull(names); | |||
| GraphUtils::ClearIfNotNull(shapes); | |||
| GraphUtils::ClearIfNotNull(data_types); | |||
| GraphUtils::ClearIfNotNull(mem_sizes); | |||
| for (size_t i = 0; i < inputs_.size(); i++) { | |||
| auto &tensor = inputs_[i]; | |||
| GraphUtils::PushbackIfNotNull(names, input_names_[i]); | |||
| GraphUtils::PushbackIfNotNull(shapes, tensor->shape()); | |||
| GraphUtils::PushbackIfNotNull(data_types, GraphUtils::TransTypeId2InferDataType(tensor->data_type())); | |||
| GraphUtils::PushbackIfNotNull(mem_sizes, tensor->Size()); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status GPUGraphImpl::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) { | |||
| if (!load_flag_) { | |||
| Status ret = Load(); | |||
| if (ret != SUCCESS) { | |||
| MS_LOG(ERROR) << "PrepareModel failed."; | |||
| return ret; | |||
| } | |||
| } | |||
| GraphUtils::ClearIfNotNull(names); | |||
| GraphUtils::ClearIfNotNull(shapes); | |||
| GraphUtils::ClearIfNotNull(data_types); | |||
| GraphUtils::ClearIfNotNull(mem_sizes); | |||
| for (size_t i = 0; i < outputs_.size(); i++) { | |||
| auto &tensor = outputs_[i]; | |||
| GraphUtils::PushbackIfNotNull(names, output_names_[i]); | |||
| GraphUtils::PushbackIfNotNull(shapes, tensor->shape()); | |||
| GraphUtils::PushbackIfNotNull(data_types, GraphUtils::TransTypeId2InferDataType(tensor->data_type())); | |||
| GraphUtils::PushbackIfNotNull(mem_sizes, tensor->Size()); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| } // namespace mindspore::api | |||
| @@ -0,0 +1,67 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_MS_GPU_GRAPH_IMPL_H | |||
| #define MINDSPORE_CCSRC_CXX_API_GRAPH_MS_GPU_GRAPH_IMPL_H | |||
| #include <string> | |||
| #include <vector> | |||
| #include <utility> | |||
| #include <memory> | |||
| #include "include/api/status.h" | |||
| #include "include/api/graph.h" | |||
| #include "cxx_api/graph/graph_impl.h" | |||
| #include "backend/session/session_basic.h" | |||
| #include "ir/anf.h" | |||
| #include "cxx_api/model/model_impl.h" | |||
| #include "cxx_api/graph/graph_utils.h" | |||
| namespace mindspore::api { | |||
| class GPUGraphImpl : public GraphCell::GraphImpl { | |||
| public: | |||
| GPUGraphImpl(); | |||
| ~GPUGraphImpl() override = default; | |||
| Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override; | |||
| Status Load() override; | |||
| Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override; | |||
| Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes, | |||
| std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override; | |||
| private: | |||
| Status InitEnv(); | |||
| Status FinalizeEnv(); | |||
| Status CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr); | |||
| Status CheckModelInputs(const std::vector<tensor::TensorPtr> &inputs) const; | |||
| std::vector<tensor::TensorPtr> RunGraph(const std::vector<tensor::TensorPtr> &inputs); | |||
| Status ExecuteModel(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs); | |||
| std::shared_ptr<session::SessionBasic> session_impl_; | |||
| uint32_t graph_id_; | |||
| std::string device_type_; | |||
| uint32_t device_id_; | |||
| std::vector<tensor::TensorPtr> inputs_; | |||
| std::vector<tensor::TensorPtr> outputs_; | |||
| std::vector<std::string> input_names_; | |||
| std::vector<std::string> output_names_; | |||
| bool init_flag_; | |||
| bool load_flag_; | |||
| // tensor-rt | |||
| uint32_t batch_size_; | |||
| uint32_t workspace_size_; | |||
| }; | |||
| } // namespace mindspore::api | |||
| #endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_GPU_GRAPH_IMPL_H | |||
| @@ -0,0 +1,63 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_UTILS_H | |||
| #define MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_UTILS_H | |||
| #include <map> | |||
| #include <vector> | |||
| #include "include/api/types.h" | |||
| #include "ir/dtype/type_id.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace mindspore::api { | |||
| class GraphUtils { | |||
| public: | |||
| static DataType TransTypeId2InferDataType(TypeId type_id) { | |||
| const std::map<TypeId, api::DataType> id2type_map{ | |||
| {TypeId::kNumberTypeBegin, api::kMsUnknown}, {TypeId::kNumberTypeBool, api::kMsBool}, | |||
| {TypeId::kNumberTypeFloat64, api::kMsFloat64}, {TypeId::kNumberTypeInt8, api::kMsInt8}, | |||
| {TypeId::kNumberTypeUInt8, api::kMsUint8}, {TypeId::kNumberTypeInt16, api::kMsInt16}, | |||
| {TypeId::kNumberTypeUInt16, api::kMsUint16}, {TypeId::kNumberTypeInt32, api::kMsInt32}, | |||
| {TypeId::kNumberTypeUInt32, api::kMsUint32}, {TypeId::kNumberTypeInt64, api::kMsInt64}, | |||
| {TypeId::kNumberTypeUInt64, api::kMsUint64}, {TypeId::kNumberTypeFloat16, api::kMsFloat16}, | |||
| {TypeId::kNumberTypeFloat32, api::kMsFloat32}, | |||
| }; | |||
| auto it = id2type_map.find(type_id); | |||
| if (it != id2type_map.end()) { | |||
| return it->second; | |||
| } | |||
| MS_LOG(WARNING) << "Unsupported data id " << type_id; | |||
| return api::kMsUnknown; | |||
| } | |||
| template <class T> | |||
| inline static void ClearIfNotNull(T *vec) { | |||
| if (vec != nullptr) { | |||
| vec->clear(); | |||
| } | |||
| } | |||
| template <class T, class U> | |||
| inline static void PushbackIfNotNull(U *vec, T &&item) { | |||
| if (vec != nullptr) { | |||
| vec->emplace_back(item); | |||
| } | |||
| } | |||
| }; | |||
| } // namespace mindspore::api | |||
| #endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_UTILS_H | |||
| @@ -22,6 +22,7 @@ | |||
| namespace mindspore { | |||
| namespace api { | |||
| API_FACTORY_REG(ModelImpl, Ascend910, MsModel); | |||
| API_FACTORY_REG(ModelImpl, GPU, MsModel); | |||
| Status MsModel::Build(const std::map<std::string, std::string> &) { | |||
| MS_LOG(INFO) << "Start build model."; | |||
| @@ -21,6 +21,7 @@ | |||
| namespace mindspore::api { | |||
| const char *kDeviceTypeAscend310 = "Ascend310"; | |||
| const char *kDeviceTypeAscend910 = "Ascend910"; | |||
| const char *kDeviceTypeGpu = "GPU"; | |||
| class DataImpl { | |||
| public: | |||
| @@ -31,7 +31,7 @@ constexpr Status PROF_FAILED = 0xFFFFFFFF; | |||
| } // namespace | |||
| Status RegProfCtrlCallback(MsprofCtrlCallback func) { | |||
| if (VMCallbackRegister::GetInstance().registed()) { | |||
| if (VMCallbackRegister::GetInstance().registered()) { | |||
| return VMCallbackRegister::GetInstance().DoRegProfCtrlCallback(func); | |||
| } else { | |||
| return PROF_SUCCESS; | |||
| @@ -39,7 +39,7 @@ Status RegProfCtrlCallback(MsprofCtrlCallback func) { | |||
| } | |||
| Status RegProfSetDeviceCallback(MsprofSetDeviceCallback func) { | |||
| if (VMCallbackRegister::GetInstance().registed()) { | |||
| if (VMCallbackRegister::GetInstance().registered()) { | |||
| return VMCallbackRegister::GetInstance().DoRegProfSetDeviceCallback(func); | |||
| } else { | |||
| return PROF_SUCCESS; | |||
| @@ -47,7 +47,7 @@ Status RegProfSetDeviceCallback(MsprofSetDeviceCallback func) { | |||
| } | |||
| Status RegProfReporterCallback(MsprofReporterCallback func) { | |||
| if (VMCallbackRegister::GetInstance().registed()) { | |||
| if (VMCallbackRegister::GetInstance().registered()) { | |||
| return VMCallbackRegister::GetInstance().DoRegProfReporterCallback(func); | |||
| } else { | |||
| return PROF_SUCCESS; | |||
| @@ -55,7 +55,7 @@ Status RegProfReporterCallback(MsprofReporterCallback func) { | |||
| } | |||
| Status ProfCommandHandle(ProfCommandHandleType type, void *data, uint32_t len) { | |||
| if (VMCallbackRegister::GetInstance().registed()) { | |||
| if (VMCallbackRegister::GetInstance().registered()) { | |||
| return VMCallbackRegister::GetInstance().DoProfCommandHandle(type, data, len); | |||
| } else { | |||
| return PROF_SUCCESS; | |||
| @@ -69,16 +69,16 @@ VMCallbackRegister &VMCallbackRegister::GetInstance() { | |||
| return instance; | |||
| } | |||
| bool VMCallbackRegister::Registe(Status (*pRegProfCtrlCallback)(MsprofCtrlCallback), | |||
| Status (*pRegProfSetDeviceCallback)(MsprofSetDeviceCallback), | |||
| Status (*pRegProfReporterCallback)(MsprofReporterCallback), | |||
| Status (*pProfCommandHandle)(ProfCommandHandleType, void *, uint32_t)) { | |||
| if (!registed_) { | |||
| bool VMCallbackRegister::Register(Status (*pRegProfCtrlCallback)(MsprofCtrlCallback), | |||
| Status (*pRegProfSetDeviceCallback)(MsprofSetDeviceCallback), | |||
| Status (*pRegProfReporterCallback)(MsprofReporterCallback), | |||
| Status (*pProfCommandHandle)(ProfCommandHandleType, void *, uint32_t)) { | |||
| if (!registered_) { | |||
| pRegProfCtrlCallback_ = pRegProfCtrlCallback; | |||
| pRegProfSetDeviceCallback_ = pRegProfSetDeviceCallback; | |||
| pRegProfReporterCallback_ = pRegProfReporterCallback; | |||
| pProfCommandHandle_ = pProfCommandHandle; | |||
| registed_ = true; | |||
| registered_ = true; | |||
| ForceMsprofilerInit(); | |||
| return true; | |||
| } | |||
| @@ -49,12 +49,12 @@ class __attribute__((visibility("default"))) VMCallbackRegister { | |||
| static VMCallbackRegister &GetInstance(); | |||
| VMCallbackRegister(const VMCallbackRegister &) = delete; | |||
| VMCallbackRegister &operator=(const VMCallbackRegister &) = delete; | |||
| bool Registe(Status (*pRegProfCtrlCallback)(MsprofCtrlCallback), | |||
| Status (*pRegProfSetDeviceCallback)(MsprofSetDeviceCallback), | |||
| Status (*pRegProfReporterCallback)(MsprofReporterCallback), | |||
| Status (*pProfCommandHandle)(ProfCommandHandleType, void *, uint32_t)); | |||
| bool Register(Status (*pRegProfCtrlCallback)(MsprofCtrlCallback), | |||
| Status (*pRegProfSetDeviceCallback)(MsprofSetDeviceCallback), | |||
| Status (*pRegProfReporterCallback)(MsprofReporterCallback), | |||
| Status (*pProfCommandHandle)(ProfCommandHandleType, void *, uint32_t)); | |||
| void ForceMsprofilerInit(); | |||
| bool registed() { return registed_; } | |||
| bool registered() { return registered_; } | |||
| Status DoRegProfCtrlCallback(MsprofCtrlCallback func) { return pRegProfCtrlCallback_(func); } | |||
| Status DoRegProfSetDeviceCallback(MsprofSetDeviceCallback func) { return pRegProfSetDeviceCallback_(func); } | |||
| Status DoRegProfReporterCallback(MsprofReporterCallback func) { return pRegProfReporterCallback_(func); } | |||
| @@ -64,7 +64,7 @@ class __attribute__((visibility("default"))) VMCallbackRegister { | |||
| private: | |||
| VMCallbackRegister() | |||
| : registed_(false), | |||
| : registered_(false), | |||
| ms_profile_inited_(false), | |||
| pRegProfCtrlCallback_(nullptr), | |||
| pRegProfSetDeviceCallback_(nullptr), | |||
| @@ -72,7 +72,7 @@ class __attribute__((visibility("default"))) VMCallbackRegister { | |||
| pProfCommandHandle_(nullptr) {} | |||
| ~VMCallbackRegister() = default; | |||
| bool registed_; | |||
| bool registered_; | |||
| bool ms_profile_inited_; | |||
| Status (*pRegProfCtrlCallback_)(MsprofCtrlCallback); | |||
| Status (*pRegProfSetDeviceCallback_)(MsprofSetDeviceCallback); | |||
| @@ -299,8 +299,8 @@ Status ProfCommandHandle(ProfCommandHandleType type, void *data, uint32_t len) { | |||
| bool DoRegiste() { | |||
| MS_LOG(INFO) << "VM profiling register start"; | |||
| return VMCallbackRegister::GetInstance().Registe(RegProfCtrlCallback, RegProfSetDeviceCallback, | |||
| RegProfReporterCallback, ProfCommandHandle); | |||
| return VMCallbackRegister::GetInstance().Register(RegProfCtrlCallback, RegProfSetDeviceCallback, | |||
| RegProfReporterCallback, ProfCommandHandle); | |||
| } | |||
| static bool doRegiste = DoRegiste(); | |||
| } // namespace ascend | |||
| @@ -41,6 +41,7 @@ const char kCPUDevice[] = "CPU"; | |||
| const char kGPUDevice[] = "GPU"; | |||
| const char kAscendDevice[] = "Ascend"; | |||
| const char kDavinciInferenceDevice[] = "AscendInference"; | |||
| const char kGpuInferenceDevice[] = "GpuInference"; | |||
| const char kDavinciDevice[] = "Davinci"; | |||
| const char KNpuLog[] = "_npu_log"; | |||
| const unsigned int MAX_CALL_DEPTH_DEFAULT = 1000; | |||
| @@ -51,7 +52,7 @@ const float kDefaultMaxDeviceMemory = 1024; | |||
| // enum definition for MindSpore Context Parameter | |||
| enum MsCtxParam : unsigned { | |||
| // paramater of type bool | |||
| // parameter of type bool | |||
| MS_CTX_TYPE_BOOL_BEGIN, | |||
| MS_CTX_ENABLE_AUTO_MIXED_PRECISION = MS_CTX_TYPE_BOOL_BEGIN, | |||
| MS_CTX_CHECK_BPROP_FLAG, | |||
| @@ -74,14 +75,15 @@ enum MsCtxParam : unsigned { | |||
| MS_CTX_ENABLE_PROFILING, | |||
| MS_CTX_SAVE_GRAPHS_FLAG, | |||
| MS_CTX_ENABLE_PARALLEL_SPLIT, | |||
| MS_CTX_ENABLE_INFER_OPT, | |||
| MS_CTX_TYPE_BOOL_END, | |||
| // paramater of type int | |||
| // parameter of type int | |||
| MS_CTX_TYPE_INT_BEGIN = MS_CTX_TYPE_BOOL_END, | |||
| MS_CTX_EXECUTION_MODE = MS_CTX_TYPE_INT_BEGIN, | |||
| MS_CTX_TYPE_INT_END, | |||
| // paramater of type uint32 | |||
| // parameter of type uint32 | |||
| MS_CTX_TYPE_UINT32_BEGIN = MS_CTX_TYPE_INT_END, | |||
| MS_CTX_DEVICE_ID = MS_CTX_TYPE_UINT32_BEGIN, | |||
| MS_CTX_GE_REF, | |||
| @@ -89,12 +91,12 @@ enum MsCtxParam : unsigned { | |||
| MS_CTX_TSD_REF, | |||
| MS_CTX_TYPE_UINT32_END, | |||
| // paramater of type float | |||
| // parameter of type float | |||
| MS_CTX_TYPE_FLOAT_BEGIN = MS_CTX_TYPE_UINT32_END, | |||
| MS_CTX_MAX_DEVICE_MEMORY = MS_CTX_TYPE_FLOAT_BEGIN, | |||
| MS_CTX_TYPE_FLOAT_END, | |||
| // paramater of type string | |||
| // parameter of type string | |||
| MS_CTX_TYPE_STRING_BEGIN = MS_CTX_TYPE_FLOAT_END, | |||
| MS_CTX_DEVICE_TARGET = MS_CTX_TYPE_STRING_BEGIN, | |||
| MS_CTX_GRAPH_MEMORY_MAX_SIZE, | |||