Browse Source

trt op factory

pull/14291/head
wilfChen 4 years ago
parent
commit
dc69dc7ddd
5 changed files with 150 additions and 4 deletions
  1. +1
    -1
      mindspore/ccsrc/CMakeLists.txt
  2. +9
    -3
      mindspore/ccsrc/backend/kernel_compiler/gpu/trt/trt_kernel.cc
  3. +1
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/trt/trt_kernel.h
  4. +61
    -0
      mindspore/ccsrc/backend/optimizer/trt_pass/layer_input.h
  5. +78
    -0
      mindspore/ccsrc/backend/optimizer/trt_pass/trt_op_factory.h

+ 1
- 1
mindspore/ccsrc/CMakeLists.txt View File

@@ -123,7 +123,7 @@ if(ENABLE_GPU)
set(ENABLE_GPU_INFER TRUE)
add_compile_definitions(ENABLE_GPU_INFER)
include_directories($ENV{TENSORRT_HOME}/include)
file(GLOB_RECURSE GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "runtime/device/gpu/trt_loader.cc")
list(APPEND GPU_SRC_LIST ${CMAKE_CURRENT_SOURCE_DIR}/runtime/device/gpu/trt_loader.cc)
endif()

set(NVCC_TMP_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})


+ 9
- 3
mindspore/ccsrc/backend/kernel_compiler/gpu/trt/trt_kernel.cc View File

@@ -51,7 +51,7 @@ bool TrtKernel::Init(const CNodePtr &kernel_node) {

auto trt_loader = Singleton<device::gpu::TrtLoader>::Instance();
if (!trt_loader.nvinfer_loaded()) {
MS_LOG(EXCEPTION) << "Install Tensor-RT and export LD_LIBRARY_PATH=${TENSORRT_HOME}/lib:$LD_LIBRARY_PATH."
MS_LOG(EXCEPTION) << "Install Tensor-RT and export LD_LIBRARY_PATH=${TENSORRT_HOME}/lib:$LD_LIBRARY_PATH.";
}
runtime_ = trt_loader.CreateInferRuntime(&Singleton<TrtLogger>::Instance());
MS_EXCEPTION_IF_NULL(runtime_);
@@ -68,6 +68,13 @@ bool TrtKernel::Init(const CNodePtr &kernel_node) {
return true;
}

TrtKernel::ReleaseResource() {
// Make sure destroy trt object before TrtLoader destruct.
context_.reset();
engine_.reset();
runtime_.reset();
}

bool TrtKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream) {
MS_EXCEPTION_IF_NULL(context_);
@@ -76,8 +83,7 @@ bool TrtKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<
[](const AddressPtr &input) -> void * { return input->addr; });
std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(device_buffer),
[](const AddressPtr &output) -> void * { return output->addr; });
context_->enqueue(1, device_buffer.data(), reinterpret_cast<cudaStream_t>(stream), nullptr);
return true;
return context_->enqueueV2(device_buffer.data(), reinterpret_cast<cudaStream_t>(stream), nullptr);
}
} // namespace kernel
} // namespace mindspore

+ 1
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/trt/trt_kernel.h View File

@@ -38,6 +38,7 @@ class TrtKernel : public GpuKernel {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
void InitSizeLists() override{};
void ReleaseResource() override;

private:
std::string serialize_;


+ 61
- 0
mindspore/ccsrc/backend/optimizer/trt_pass/layer_input.h View File

@@ -0,0 +1,61 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_PASS_LAYER_INPUT_H_
#define MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_PASS_LAYER_INPUT_H_

#include <NvInfer.h>

namespace mindspore::opt {
// Tensor-RT layer inputs include weight or tensor.
// Tensor: Anf-graph inputs or feature map which values change during inference.
// Weight: Anf-graph inputs or value node which remain unchanged during inference.
class LayerInput {
public:
LayerInput() : type_(InputType::kUnknown), weight_(), tensor_(nullptr) {}
explicit LayerInput(nvinfer1::Weights &w) : type_(InputType::kWeight), weight_(w), tensor_(nullptr) {}
explicit LayerInput(nvinfer1::ITensor *t) : type_(InputType::kTensor), weight_(), tensor_(t) {}

bool IsTensor() const { return type_ == InputType::kTensor; }
bool IsWeight() const { return type_ == InputType::kWeight; }

nvinfer1::Weights *weight() {
if (!IsWeight()) {
MS_LOG(WARNING) << "weight not initialized.";
return nullptr;
}
return &weight_;
}

nvinfer1::ITensor *tensor() const {
if (!IsTensor()) {
MS_LOG(WARNING) << "tensor not initialized.";
return nullptr;
}
return tensor_;
}

private:
enum class InputType : char { kUnknown = 0, kTensor, kWeight };
InputType type_;
// Keep the copy rather than point cause Weights created as a local variable.
nvinfer1::Weights weight_;
// Keep the point as ITensor created/held by nvinfer1::INetworkDefinition.
nvinfer1::ITensor *tensor_;
};
} // namespace mindspore::opt

#endif // MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_PASS_LAYER_INPUT_H_

+ 78
- 0
mindspore/ccsrc/backend/optimizer/trt_pass/trt_op_factory.h View File

@@ -0,0 +1,78 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_PASS_OP_FACTORY_H_
#define MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_PASS_OP_FACTORY_H_

#include <functional>
#include <unordered_map>
#include <vector>
#include <utility>
#include <string>
#include <memory>
#include "base/base.h"
#include "ir/anf.h"

namespace mindspore {
namespace opt {
class LayerInput;
class TrtConverterHelper;
using ConvertResult = std::pair<bool, std::vector<LayerInput>>;
using ConvertFunc = std::function<ConvertResult(AnfNodePtr, std::shared_ptr<TrtConverterHelper>)>;

class TrtOpFactory {
public:
static TrtOpFactory &GetInstance() {
static TrtOpFactory instance;
return instance;
}

void Register(const std::string &op_name, const ConvertFunc &func) {
if (op_convert_map_.count(op_name)) {
MS_LOG(EXCEPTION) << "Operator: " << op_name << " re-registered.";
}
op_convert_map_.insert(std::make_pair(op_name, func));
}

ConvertFunc GetConvertFunc(const std::string &op_name) const {
auto iter = op_convert_map_.find(op_name);
if (iter == op_convert_map_.end()) {
MS_LOG(EXCEPTION) << "Operator: " << op_name << " not support.";
}
return iter->second;
}

private:
TrtOpFactory() = default;
~TrtOpFactory() = default;
DISABLE_COPY_AND_ASSIGN(TrtOpFactory)

std::unordered_map<std::string, ConvertFunc> op_convert_map_;
};

class TrtOpRegister {
public:
TrtOpRegister(const std::string &op_name, ConvertFunc func) { TrtOpFactory::GetInstance().Register(op_name, func); }
};

// Register operator converter from AnfNode to trt layer: `OPNAME` should keep the same as primitive definition.
#define MS_TRT_CONVERTER_FUNC_REG(OPNAME) \
ConvertResult Gpu##OPNAME##TrtConverter(AnfNodePtr node, std::shared_ptr<TrtConverterHelper> context); \
static const TrtOpRegister(Gpu##OPNAME##ConverterRegister)(#OPNAME, Gpu##OPNAME##TrtConverter); \
ConvertResult Gpu##OPNAME##TrtConverter(AnfNodePtr node, std::shared_ptr<TrtConverterHelper> context)
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTITIMIZER_TRT_PASS_OP_FACTORY_H_

Loading…
Cancel
Save