diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt index 6eeb1f0cc7..95ccf58cb8 100644 --- a/mindspore/ccsrc/CMakeLists.txt +++ b/mindspore/ccsrc/CMakeLists.txt @@ -74,6 +74,13 @@ if(ENABLE_GPU) endif() endif() + if(DEFINED ENV{TENSORRT_HOME} AND NOT $ENV{TENSORRT_HOME} STREQUAL "") + message("Enable GPU inference. Tensor-RT dir: $ENV{TENSORRT_HOME}") + set(ENABLE_GPU_INFER TRUE) + add_compile_definitions(ENABLE_GPU_INFER) + include_directories($ENV{TENSORRT_HOME}/include) + endif() + if(NOT CUPTI_INCLUDE_DIRS OR CUPTI_INCLUDE_DIRS STREQUAL "") set(CUPTI_INCLUDE_DIRS ${CUDA_PATH}/extras/CUPTI/include) endif() @@ -292,6 +299,14 @@ else() target_link_libraries(mindspore -Wl,--start-group proto_input mindspore::protobuf -Wl,--end-group) endif() +if(ENABLE_GPU_INFER) + find_library(trt_plugin libnvinfer_plugin.so $ENV{TENSORRT_HOME}/lib) + find_library(trt_nvinfo libnvinfer.so $ENV{TENSORRT_HOME}/lib) + find_library(trt_parser libnvparsers.so $ENV{TENSORRT_HOME}/lib) + target_link_libraries(mindspore ${trt_plugin} ${trt_nvinfo} ${trt_parser}) + set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:$ENV{TENSORRT_HOME}/lib) +endif() + # set c_expression building set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE) set_property(SOURCE "pipeline/jit/init.cc" PROPERTY diff --git a/mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt b/mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt index 3fc2c6bf8a..11173c16b7 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt +++ b/mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt @@ -84,6 +84,7 @@ if(ENABLE_GPU) list(REMOVE_ITEM GPU_SRC_LIST "gpu/nccl/nccl_collective_gpu_kernel.cc") list(REMOVE_ITEM GPU_SRC_LIST "gpu/nccl/nccl_send_gpu_kernel.cc") list(REMOVE_ITEM GPU_SRC_LIST "gpu/nccl/nccl_recv_gpu_kernel.cc") + list(REMOVE_ITEM GPU_SRC_LIST "gpu/trt/trt_kernel.cc") if(ENABLE_MPI) include(ExternalProject) @@ -91,6 +92,11 @@ if(ENABLE_GPU) list(APPEND GPU_SRC_LIST ${GPU_NCCL_LIST}) endif() + if(ENABLE_GPU_INFER) + file(GLOB_RECURSE GPU_TRT_KERNEL_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "gpu/trt/*.cc") + list(APPEND GPU_SRC_LIST ${GPU_TRT_KERNEL_LIST}) + endif() + # add_library(_mindspore_kernel_cuda_obj OBJECT ${CUDA_SRC_LIST}) endif() diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/trt/trt_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/trt/trt_kernel.cc new file mode 100644 index 0000000000..7415641442 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/trt/trt_kernel.cc @@ -0,0 +1,78 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/gpu/trt/trt_kernel.h" + +#include +#include +#include +#include +#include +#include "backend/kernel_compiler/gpu/data/dataset_utils.h" +#include "backend/kernel_compiler/gpu/trt/trt_utils.h" + +namespace mindspore { +namespace kernel { +const std::vector &TrtKernel::GetInputSizeList() const { return input_size_list_; } +const std::vector &TrtKernel::GetOutputSizeList() const { return output_size_list_; } +const std::vector &TrtKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } + +bool TrtKernel::Init(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + for (size_t i = 0; i < input_num; i++) { + auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, i); + auto type_id = AnfAlgo::GetInputDeviceDataType(kernel_node, i); + size_t unit_size = UnitSizeInBytes(type_id); + auto size_in_byte = std::accumulate(input_shape.begin(), input_shape.end(), unit_size, std::multiplies()); + input_size_list_.push_back(size_in_byte); + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + for (size_t j = 0; j < output_num; j++) { + auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, j); + auto type_id = AnfAlgo::GetOutputDeviceDataType(kernel_node, j); + size_t unit_size = UnitSizeInBytes(type_id); + auto size_in_byte = std::accumulate(output_shape.begin(), output_shape.end(), unit_size, std::multiplies()); + output_size_list_.push_back(size_in_byte); + } + + runtime_ = TrtPtr(nvinfer1::createInferRuntime(Singleton::Instance())); + MS_EXCEPTION_IF_NULL(runtime_); + serialize_ = GetAttr(kernel_node, "serialize_model"); + engine_ = TrtPtr(runtime_->deserializeCudaEngine(serialize_.c_str(), serialize_.size(), nullptr)); + MS_EXCEPTION_IF_NULL(engine_); + if (SizeToInt(input_num + output_num) != engine_->getNbBindings()) { + MS_LOG(EXCEPTION) << "Inputs and outputs num not match. Got: " << input_num + output_num + << ", expect: " << engine_->getNbBindings(); + } + + context_ = TrtPtr(engine_->createExecutionContext()); + MS_EXCEPTION_IF_NULL(context_); + return true; +} + +bool TrtKernel::Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream) { + MS_EXCEPTION_IF_NULL(context_); + std::vector device_buffer; + std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(device_buffer), + [](const AddressPtr &input) -> void * { return input->addr; }); + std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(device_buffer), + [](const AddressPtr &output) -> void * { return output->addr; }); + context_->enqueue(1, device_buffer.data(), reinterpret_cast(stream), nullptr); + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/trt/trt_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/trt/trt_kernel.h new file mode 100644 index 0000000000..f1e55935eb --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/trt/trt_kernel.h @@ -0,0 +1,57 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TRT_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TRT_KERNEL_H_ + +#include +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class TrtKernel : public GpuKernel { + public: + TrtKernel() : serialize_(""), runtime_(nullptr), engine_(nullptr), context_(nullptr) {} + ~TrtKernel() = default; + + bool Init(const CNodePtr &kernel_node) override; + const std::vector &GetInputSizeList() const override; + const std::vector &GetOutputSizeList() const override; + const std::vector &GetWorkspaceSizeList() const override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + void InitSizeLists() override{}; + + private: + std::string serialize_; + std::shared_ptr runtime_; + std::shared_ptr engine_; + std::shared_ptr context_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; + +MS_REG_GPU_KERNEL(TrtNode, TrtKernel) +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TRT_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/trt/trt_utils.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/trt/trt_utils.h new file mode 100644 index 0000000000..2b0f585568 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/trt/trt_utils.h @@ -0,0 +1,145 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TRT_UTILS_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TRT_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "utils/log_adapter.h" +#include "utils/singleton.h" +#include "utils/convert_utils_base.h" + +namespace mindspore { +class TrtUtils { + public: + static TypeId TrtDtypeToMsDtype(const nvinfer1::DataType &trt_dtype) { + static std::map type_list = {{nvinfer1::DataType::kFLOAT, TypeId::kNumberTypeFloat32}, + {nvinfer1::DataType::kHALF, TypeId::kNumberTypeFloat16}, + {nvinfer1::DataType::kINT8, TypeId::kNumberTypeInt8}, + {nvinfer1::DataType::kINT32, TypeId::kNumberTypeInt}}; + + auto iter = type_list.find(trt_dtype); + if (iter == type_list.end()) { + MS_LOG(EXCEPTION) << "Invalid Tensor-RT dtype: " << trt_dtype; + } + return iter->second; + } + + static nvinfer1::DataType MsDtypeToTrtDtype(const TypeId &ms_dtype) { + static std::map type_list = {{TypeId::kNumberTypeFloat32, nvinfer1::DataType::kFLOAT}, + {TypeId::kNumberTypeFloat16, nvinfer1::DataType::kHALF}, + {TypeId::kNumberTypeInt8, nvinfer1::DataType::kINT8}, + {TypeId::kNumberTypeInt, nvinfer1::DataType::kINT32}}; + auto iter = type_list.find(ms_dtype); + if (iter == type_list.end()) { + MS_LOG(EXCEPTION) << "data type not support: " << ms_dtype; + } + return iter->second; + } + + static nvinfer1::Dims MsDimsToTrtDims(const std::vector &ms_shape, bool ignore_batch_dim = false) { + nvinfer1::Dims trt_dims; + size_t offset = ignore_batch_dim ? 1 : 0; + for (size_t i = offset; i < ms_shape.size(); ++i) { + trt_dims.d[i - offset] = SizeToInt(ms_shape[i]); + } + trt_dims.nbDims = ms_shape.size() - offset; + return trt_dims; + } + + static nvinfer1::Dims TrtDimsToMsDims(const ShapeVector &ms_shape, bool ignore_batch_dim = false) { + nvinfer1::Dims trt_dims; + size_t offset = ignore_batch_dim ? 1 : 0; + for (size_t i = offset; i < ms_shape.size(); ++i) { + trt_dims.d[i - offset] = LongToInt(ms_shape[i]); + } + trt_dims.nbDims = ms_shape.size() - offset; + return trt_dims; + } + + static ShapeVector TrtDimsToMsDims(const nvinfer1::Dims &trt_dims) { + ShapeVector shape; + std::transform(trt_dims.d, trt_dims.d + trt_dims.nbDims, std::back_inserter(shape), + [](const uint32_t &value) { return static_cast(value); }); + return shape; + } +}; + +class TrtLogger : public nvinfer1::ILogger { + public: + TrtLogger() { + log_level_ = MsLogLevel::WARNING; // set default log level to WARNING + const char *glog_config = std::getenv("GLOG_v"); + if (glog_config == nullptr) { + return; + } + + std::string str_level{glog_config}; + if (str_level.size() == 1) { + int ch = str_level.c_str()[0]; + ch = ch - '0'; // subtract ASCII code of '0', which is 48 + if (ch >= mindspore::DEBUG && ch <= mindspore::ERROR) { + log_level_ = static_cast(ch); + } + } + } + // Redirect Tensor-RT inner log to GLOG + void log(Severity severity, const char *msg) override { +#ifdef USE_GLOG + static std::map> logger_map = { + {Severity::kVERBOSE, {MsLogLevel::DEBUG, google::GLOG_INFO, "VERBOSE"}}, + {Severity::kINFO, {MsLogLevel::INFO, google::GLOG_INFO, "INFO"}}, + {Severity::kWARNING, {MsLogLevel::WARNING, google::GLOG_WARNING, "WARNING"}}, + {Severity::kERROR, {MsLogLevel::ERROR, google::GLOG_ERROR, "ERROR"}}, + {Severity::kINTERNAL_ERROR, {MsLogLevel::ERROR, google::GLOG_ERROR, "INTERNAL ERROR"}}}; + + auto iter = logger_map.find(severity); + if (iter == logger_map.end()) { + google::LogMessage("", 0, google::GLOG_WARNING).stream() << "Unrecognized severity type: " << msg << std::endl; + return; + } + + auto level = iter->second; + // discard log + if (std::get<0>(level) < log_level_) { + return; + } + + google::LogMessage("", 0, std::get<1>(level)).stream() + << "[TensorRT " << std::get<2>(level) << "] " << msg << std::endl; +#endif // USE_GLOG + } + + private: + MsLogLevel log_level_; +}; + +// Using RAII to avoid tensor-rt object leakage +template +inline std::shared_ptr TrtPtr(T *obj) { + return std::shared_ptr(obj, [](T *obj) { + if (obj) obj->destroy(); + }); +} +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TRT_UTILS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/CMakeLists.txt b/mindspore/ccsrc/backend/optimizer/CMakeLists.txt index 22fd7a423b..d0067569c8 100644 --- a/mindspore/ccsrc/backend/optimizer/CMakeLists.txt +++ b/mindspore/ccsrc/backend/optimizer/CMakeLists.txt @@ -21,9 +21,16 @@ if(ENABLE_GPU) list(APPEND _PREACTIVATE_SRC_LIST ${_GPU_SRC_LIST}) endif() +if(ENABLE_GPU_INFER) + file(GLOB_RECURSE GPU_SRC_TRT_PASS_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "trt_pass/*.cc") + list(APPEND _PREACTIVATE_SRC_LIST ${GPU_SRC_TRT_PASS_LIST}) +endif() + if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") - set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -Wno-user-defined-warnings -Wno-inconsistent-missing-override -Wno-overloaded-virtual -Wno-unused-const-variable -Wno-pessimizing-move") + set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -Wno-user-defined-warnings -Wno-inconsistent-missing-override + -Wno-overloaded-virtual -Wno-unused-const-variable -Wno-pessimizing-move") endif() -set_property(SOURCE ${_PREACTIVATE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PRE_ACT) +set_property(SOURCE ${_PREACTIVATE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS + SUBMODULE_ID=mindspore::SubModuleId::SM_PRE_ACT) add_library(_mindspore_backend_optimizer_obj OBJECT ${_PREACTIVATE_SRC_LIST}) diff --git a/mindspore/ccsrc/cxx_api/CMakeLists.txt b/mindspore/ccsrc/cxx_api/CMakeLists.txt index d5e4707593..5041dea655 100644 --- a/mindspore/ccsrc/cxx_api/CMakeLists.txt +++ b/mindspore/ccsrc/cxx_api/CMakeLists.txt @@ -45,7 +45,7 @@ if(CMAKE_SYSTEM_NAME MATCHES "Darwin") target_link_libraries(mindspore_shared_lib PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY} -Wl,-force_load mindspore -Wl,-noall_load proto_input mindspore_gvar mindspore::protobuf) else() - if(ENABLE_D OR ENABLE_ACL) + if(ENABLE_D OR ENABLE_ACL OR ENABLE_GPU) target_link_libraries(mindspore_shared_lib PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY} -Wl,--whole-archive mindspore -Wl,--no-whole-archive proto_input mindspore_gvar mindspore::protobuf) else() diff --git a/mindspore/ccsrc/cxx_api/graph/gpu/gpu_graph_impl.cc b/mindspore/ccsrc/cxx_api/graph/gpu/gpu_graph_impl.cc index ff7719a715..749a4bc08b 100644 --- a/mindspore/ccsrc/cxx_api/graph/gpu/gpu_graph_impl.cc +++ b/mindspore/ccsrc/cxx_api/graph/gpu/gpu_graph_impl.cc @@ -245,6 +245,9 @@ std::vector GPUGraphImpl::GetOutputs() { void *data = nullptr; size_t data_size = tensor->Size(); if (i < last_outputs_.size()) { + if (last_outputs_[i]->NeedSyncDeviceToHost()) { + last_outputs_[i]->data_sync(false); + } data = last_outputs_[i]->data_c(); data_size = last_outputs_[i]->Size(); } diff --git a/mindspore/core/utils/singleton.h b/mindspore/core/utils/singleton.h new file mode 100644 index 0000000000..79b9680aac --- /dev/null +++ b/mindspore/core/utils/singleton.h @@ -0,0 +1,39 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_UTILS_SINGLETON_H_ +#define MINDSPORE_CORE_UTILS_SINGLETON_H_ + +namespace mindspore { +template +class Singleton { + public: + explicit Singleton(T &&) = delete; + explicit Singleton(const T &) = delete; + void operator=(const T &) = delete; + // thread safety implement + template + static T &Instance(_Args... args) { + static T instance(args...); + return instance; + } + + protected: + Singleton() = default; + virtual ~Singleton() = default; +}; +} // namespace mindspore +#endif // MINDSPORE_CORE_UTILS_SINGLETON_H_