From d22a5976892297ded0a4d8e34bae11cb0f4225ed Mon Sep 17 00:00:00 2001 From: VectorSL Date: Tue, 14 Jul 2020 14:54:10 +0800 Subject: [PATCH] gpu fix addn bug and supported list bug --- .../ccsrc/device/gpu/kernel_info_setter.cc | 3 ++- .../ccsrc/kernel/gpu/math/addn_gpu_kernel.h | 23 +++++++++++++------ 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/mindspore/ccsrc/device/gpu/kernel_info_setter.cc b/mindspore/ccsrc/device/gpu/kernel_info_setter.cc index 42e76e2483..f4367e4714 100644 --- a/mindspore/ccsrc/device/gpu/kernel_info_setter.cc +++ b/mindspore/ccsrc/device/gpu/kernel_info_setter.cc @@ -88,10 +88,11 @@ std::string SupportedTypeList(const CNodePtr &kernel_node) { supported_akg_type_list = supported_akg_type_list + mindspore::kernel::TypeId2String(type); } supported_type_lists = supported_type_lists + supported_akg_type_list + "], out["; + supported_akg_type_list.clear(); for (auto type : supported_akg_type_out) { supported_akg_type_list = supported_akg_type_list + mindspore::kernel::TypeId2String(type); } - supported_type_lists += "]; "; + supported_type_lists = supported_type_lists + supported_akg_type_list + "]; "; } return supported_type_lists; } diff --git a/mindspore/ccsrc/kernel/gpu/math/addn_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/math/addn_gpu_kernel.h index 1498da777f..41930d3d7b 100644 --- a/mindspore/ccsrc/kernel/gpu/math/addn_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/math/addn_gpu_kernel.h @@ -21,6 +21,8 @@ #include #include "kernel/gpu/gpu_kernel.h" #include "kernel/gpu/gpu_kernel_factory.h" +#include "kernel/gpu/math/broadcast_gpu_kernel.h" +#include "kernel/gpu/cuda_impl/slice_impl.cuh" #include "kernel/gpu/kernel_constants.h" namespace mindspore { @@ -43,18 +45,26 @@ class AddNGpuFwdKernel : public GpuKernel { const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *) override { + const std::vector &outputs, void *stream_ptr) override { if (is_null_input_) { return true; } T *output_addr = GetDeviceAddress(outputs, 0); + if (cudnn_data_type_ == CUDNN_DATA_INT32) { + FillDeviceArray(outputs[0]->size / sizeof(T), output_addr, 0.0f, reinterpret_cast(stream_ptr)); + } const float alpha = 1; const float beta = 0; for (size_t i = 0; i < IntToSize(num_input_); i++) { T *input_addr = GetDeviceAddress(inputs, i); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnAddTensor(cudnn_handle_, &alpha, input_descriptor_, input_addr, - &(i > 0 ? alpha : beta), input_descriptor_, output_addr), - "cudnnAddTensor failed"); + if (cudnn_data_type_ == CUDNN_DATA_INT32) { + NoBroadcast(outputs[0]->size / sizeof(T), BROADCAST_TYPE_ADD, input_addr, output_addr, output_addr, + reinterpret_cast(stream_ptr)); + } else { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnAddTensor(cudnn_handle_, &alpha, input_descriptor_, input_addr, + &(i > 0 ? alpha : beta), input_descriptor_, output_addr), + "cudnnAddTensor failed"); + } } return true; } @@ -100,9 +110,8 @@ class AddNGpuFwdKernel : public GpuKernel { } void InitSizeLists() override { if (!is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetTensorSizeInBytes(input_descriptor_, reinterpret_cast(&input_size_)), - "cudnnGetTensorSizeInBytes failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(input_descriptor_, &input_size_), + "cudnnGetTensorSizeInBytes failed"); } for (int i = 0; i < num_input_; i++) { input_size_list_.push_back(input_size_);