From f679568d86599b520ea7eb6bccceeb98760b821d Mon Sep 17 00:00:00 2001 From: linqingke Date: Tue, 21 Jul 2020 19:05:42 +0800 Subject: [PATCH] gpu ops code and test case. --- .../gpu/arrays/gathernd_gpu_kernel.cc | 33 ++++ .../gpu/arrays/gathernd_gpu_kernel.h | 162 ++++++++++++++++ .../gpu/arrays/scatter_nd_gpu_kernel.cc | 33 ++++ .../gpu/arrays/scatter_nd_gpu_kernel.h | 175 ++++++++++++++++++ .../gpu/cuda_impl/boundingbox_decode_impl.cu | 81 ++++++++ .../gpu/cuda_impl/boundingbox_decode_impl.cuh | 27 +++ .../gpu/cuda_impl/boundingbox_encode_impl.cu | 62 +++++++ .../gpu/cuda_impl/boundingbox_encode_impl.cuh | 26 +++ .../gpu/cuda_impl/broadcast_impl.cu | 40 +++- .../gpu/cuda_impl/broadcast_impl.cuh | 1 + .../kernel_compiler/gpu/cuda_impl/gathernd.cu | 65 +++++++ .../gpu/cuda_impl/gathernd.cuh | 26 +++ .../gpu/cuda_impl/scatter_nd.cu | 68 +++++++ .../gpu/cuda_impl/scatter_nd.cuh | 26 +++ .../kernel_compiler/gpu/cuda_impl/sgd_impl.cu | 57 ++++++ .../gpu/cuda_impl/sgd_impl.cuh | 25 +++ .../gpu/math/broadcast_gpu_kernel.cc | 8 + .../gpu/math/broadcast_gpu_kernel.h | 8 +- .../kernel_compiler/gpu/nn/sgd_gpu_kernel.cc | 32 ++++ .../kernel_compiler/gpu/nn/sgd_gpu_kernel.h | 88 +++++++++ .../other/boundingbox_decode_gpu_kernel.cc | 26 +++ .../gpu/other/boundingbox_decode_gpu_kernel.h | 152 +++++++++++++++ .../other/boundingbox_encode_gpu_kernel.cc | 26 +++ .../gpu/other/boundingbox_encode_gpu_kernel.h | 143 ++++++++++++++ .../st/ops/gpu/test_boundingbox_decode_op.py | 60 ++++++ .../st/ops/gpu/test_boundingbox_encode_op.py | 80 ++++++++ tests/st/ops/gpu/test_floordiv_op.py | 116 ++++++++++++ tests/st/ops/gpu/test_gathernd_op.py | 151 +++++++++++++++ tests/st/ops/gpu/test_scatter_nd.py | 50 +++++ tests/st/ops/gpu/test_sgd_op.py | 73 ++++++++ 30 files changed, 1908 insertions(+), 12 deletions(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gathernd_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gathernd_gpu_kernel.h create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.h create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/boundingbox_decode_impl.cu create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/boundingbox_decode_impl.cuh create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/boundingbox_encode_impl.cu create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/boundingbox_encode_impl.cuh create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gathernd.cu create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gathernd.cuh create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cu create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cuh create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sgd_impl.cu create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sgd_impl.cuh create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sgd_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sgd_gpu_kernel.h create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/other/boundingbox_decode_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/other/boundingbox_decode_gpu_kernel.h create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/other/boundingbox_encode_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/other/boundingbox_encode_gpu_kernel.h create mode 100644 tests/st/ops/gpu/test_boundingbox_decode_op.py create mode 100644 tests/st/ops/gpu/test_boundingbox_encode_op.py create mode 100644 tests/st/ops/gpu/test_floordiv_op.py create mode 100644 tests/st/ops/gpu/test_gathernd_op.py create mode 100644 tests/st/ops/gpu/test_scatter_nd.py create mode 100644 tests/st/ops/gpu/test_sgd_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gathernd_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gathernd_gpu_kernel.cc new file mode 100644 index 0000000000..38f168a9b7 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gathernd_gpu_kernel.cc @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/gathernd_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO( + GatherNd, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + GatherNdGpuFwdKernel, float, int) +MS_REG_GPU_KERNEL_TWO( + GatherNd, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), + GatherNdGpuFwdKernel, half, int) +MS_REG_GPU_KERNEL_TWO( + GatherNd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + GatherNdGpuFwdKernel, int, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gathernd_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gathernd_gpu_kernel.h new file mode 100644 index 0000000000..af1efb84f6 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gathernd_gpu_kernel.h @@ -0,0 +1,162 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_GATHERND_GPU_KERNEL_H +#define MINDSPORE_GATHERND_GPU_KERNEL_H + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/gathernd.cuh" + +namespace mindspore { +namespace kernel { +template +class GatherNdGpuFwdKernel : public GpuKernel { + public: + GatherNdGpuFwdKernel() : dev_batch_strides_(nullptr), dev_batch_indices_(nullptr) {} + ~GatherNdGpuFwdKernel() { + if (dev_batch_strides_ != nullptr) { + device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(static_cast(dev_batch_strides_)); + } + if (dev_batch_indices_ != nullptr) { + device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(static_cast(dev_batch_indices_)); + } + } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(workspace); + T *input_addr = GetDeviceAddress(inputs, 0); + S *indices_addr = GetDeviceAddress(inputs, 1); + T *output_addr = GetDeviceAddress(outputs, 0); + + GatherNd(input_addr, indices_addr, output_addr, dims_[0], dims_[1], dims_[2], dev_batch_strides_, + dev_batch_indices_, reinterpret_cast(stream_ptr)); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherNdGpuFwdKernel needs 2."; + } + input_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + indices_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + output_shapes_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); + + Reshape(); + + size_t dim_indices_last = dims_[dims_.size() - 1]; + batch_strides_.resize(dim_indices_last, 0); + batch_indices_.resize(dim_indices_last, 0); + + if (dim_indices_last > 0) { + batch_strides_[dim_indices_last - 1] = input_shapes_[dim_indices_last - 1]; + batch_indices_[dim_indices_last - 1] = dims_[1]; + } + for (size_t i = dim_indices_last - 1; i > 0; --i) { + batch_strides_[i - 1] = input_shapes_[i - 1]; + batch_indices_[i - 1] = batch_indices_[i] * input_shapes_[i]; + } + + size_t strides_len = sizeof(S) * batch_strides_.size(); + void *dev_batch_strides_work = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(strides_len); + if (dev_batch_strides_work == nullptr) { + MS_LOG(EXCEPTION) << "Failed to alloc dev_batch_strides_work, size: " << strides_len; + } + dev_batch_strides_ = static_cast(dev_batch_strides_work); + + size_t indices_len = sizeof(S) * batch_indices_.size(); + void *dev_batch_indices_work = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(indices_len); + if (dev_batch_indices_work == nullptr) { + MS_LOG(EXCEPTION) << "Failed to alloc dev_batch_indices_work, size: " << indices_len; + } + dev_batch_indices_ = static_cast(dev_batch_indices_work); + + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpy(dev_batch_strides_, &batch_strides_[0], strides_len, cudaMemcpyHostToDevice), + "cudaMemcpy failed in GatherNdGpuFwdKernel::Init."); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpy(dev_batch_indices_, &batch_indices_[0], indices_len, cudaMemcpyHostToDevice), + "cudaMemcpy failed in GatherNdGpuFwdKernel::Init."); + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + size_t size = GetSize(input_shapes_); + input_size_list_.push_back(size); + + size = GetSize(indices_shapes_); + input_size_list_.push_back(size); + + size = GetSize(output_shapes_); + output_size_list_.push_back(size); + } + + private: + void Reshape() { + size_t dim_of_indices = 1; + for (size_t i = 0; i < indices_shapes_.size() - IntToSize(1); i++) { + dim_of_indices *= indices_shapes_[i]; + } + + size_t dim_after_indices = 1; + size_t dim_indices_last = indices_shapes_[indices_shapes_.size() - IntToSize(1)]; + for (size_t i = dim_indices_last; i < input_shapes_.size(); i++) { + dim_after_indices *= input_shapes_[i]; + } + dims_.emplace_back(dim_of_indices); + dims_.emplace_back(dim_after_indices); + dims_.emplace_back(dim_indices_last); + return; + } + size_t GetSize(const std::vector &shape) const { + if (shape.size() == 0) { + return 0; + } + size_t result = sizeof(T); + for (size_t i = 0; i < shape.size(); i++) { + result *= shape[i]; + } + return result; + } + + std::vector input_shapes_; + std::vector indices_shapes_; + std::vector output_shapes_; + + std::vector dims_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + std::vector batch_strides_; + std::vector batch_indices_; + + S *dev_batch_strides_; + S *dev_batch_indices_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_GATHERND_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.cc new file mode 100644 index 0000000000..3a9aa6e075 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.cc @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO( + ScatterNd, + KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ScatterNdGpuFwdKernel, float, int) +MS_REG_GPU_KERNEL_TWO( + ScatterNd, + KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ScatterNdGpuFwdKernel, half, int) +MS_REG_GPU_KERNEL_TWO( + ScatterNd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + ScatterNdGpuFwdKernel, int, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.h new file mode 100644 index 0000000000..29c229fbac --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.h @@ -0,0 +1,175 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_SCATTER_ND_GPU_KERNEL_H +#define MINDSPORE_CCSRC_KERNEL_GPU_SCATTER_ND_GPU_KERNEL_H + +#include +#include "backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cuh" +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +template +class ScatterNdGpuFwdKernel : public GpuKernel { + public: + ScatterNdGpuFwdKernel() + : input_size_(1), + indices_size_(1), + output_size_(1), + block_size_(1), + indices_stride_(nullptr), + work_shape_(nullptr), + indices_dim_0_(0), + indices_dim_1_(0) {} + ~ScatterNdGpuFwdKernel() { + if (indices_stride_ != nullptr) { + device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(static_cast(indices_stride_)); + } + if (work_shape_ != nullptr) { + device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(static_cast(work_shape_)); + } + } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(workspace); + S *indices = GetDeviceAddress(inputs, 0); + T *update = GetDeviceAddress(inputs, 1); + T *output = GetDeviceAddress(outputs, 0); + + ScatterNd(indices, update, output, block_size_, input_size_, output_size_, indices_dim_0_, indices_dim_1_, + indices_stride_, work_shape_, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but transpose needs 2 input."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but transpose needs 1 output."; + return false; + } + + input_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + indices_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + output_shapes_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); + + vec_work_shape_ = GetAttr>(kernel_node, "shape"); + + GetSize(); + + size_t indices_len = sizeof(S) * vec_indices_stride_.size(); + void *indices_stride_work = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(indices_len); + if (indices_stride_work == nullptr) { + MS_LOG(EXCEPTION) << "Failed to alloc indices_stride_work, size: " << indices_len; + } + indices_stride_ = static_cast(indices_stride_work); + + size_t vec_work_len = sizeof(S) * vec_work_shape_.size(); + void *work_shape_work = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(vec_work_len); + if (work_shape_work == nullptr) { + MS_LOG(EXCEPTION) << "Failed to alloc work_shape_work, size: " << vec_work_len; + } + work_shape_ = static_cast(work_shape_work); + + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemcpy(indices_stride_, &vec_indices_stride_[0], indices_len, cudaMemcpyHostToDevice), + "cudaMemcpy failed in ScatterNdGpuFwdKernel::Init."); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpy(work_shape_, &vec_work_shape_[0], vec_work_len, cudaMemcpyHostToDevice), + "cudaMemcpy failed in ScatterNdGpuFwdKernel::Init."); + InitSizeLists(); + + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(indices_size_); + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + return; + } + + void GetSize() { + indices_size_ = sizeof(S); + for (size_t i = 0; i < indices_shapes_.size(); i++) { + indices_size_ *= indices_shapes_[i]; + } + input_size_ = sizeof(T); + for (size_t i = 0; i < input_shapes_.size(); i++) { + input_size_ *= input_shapes_[i]; + } + output_size_ = sizeof(T); + for (size_t i = 0; i < output_shapes_.size(); i++) { + output_size_ *= output_shapes_[i]; + } + + // calculate indices dim 0/1 + indices_dim_0_ = indices_shapes_[0]; + indices_dim_1_ = indices_shapes_[1]; + + // calculate block_size + for (size_t i = indices_dim_1_; i < output_shapes_.size(); i++) { + block_size_ *= output_shapes_[i]; + } + + // calculate indices_stride + for (size_t i = 0; i < indices_dim_1_; i++) { + vec_indices_stride_.push_back(0); + } + + vec_indices_stride_[indices_dim_1_ - 1] = block_size_; + + for (size_t i = indices_dim_1_ - 1; i > 0; --i) { + vec_indices_stride_[i - 1] = vec_indices_stride_[i] * output_shapes_[i]; + } + } + + private: + std::vector input_shapes_; + std::vector indices_shapes_; + std::vector output_shapes_; + std::vector vec_indices_stride_; + std::vector vec_work_shape_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + size_t input_size_; + size_t indices_size_; + size_t output_size_; + size_t block_size_; + + S *indices_stride_; + S *work_shape_; + size_t indices_dim_0_; + size_t indices_dim_1_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_SCATTER_ND_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/boundingbox_decode_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/boundingbox_decode_impl.cu new file mode 100644 index 0000000000..38f84dc618 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/boundingbox_decode_impl.cu @@ -0,0 +1,81 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/boundingbox_decode_impl.cuh" + +template +__global__ void BoundingBoxDecodeKernel(const size_t size, const T *rois, const T *deltas, T *bboxes, const float m1, + const float m2, const float m3, const float m4, const float s1, const float s2, + const float s3, const float s4, const int max_height, const int max_width, + const float ratio_clip) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + const size_t left_x = i * 4; + const size_t left_y = i * 4 + 1; + const size_t right_x = i * 4 + 2; + const size_t right_y = i * 4 + 3; + + T dx = deltas[left_x] * s1 + m1; + T dy = deltas[left_y] * s2 + m2; + T dw = deltas[right_x] * s3 + m3; + T dh = deltas[right_y] * s4 + m4; + + T max_ratio = abs(log(ratio_clip)); + + dw = dw > max_ratio ? max_ratio : (dw < (-max_ratio) ? (-max_ratio) : dw); + dh = dh > max_ratio ? max_ratio : (dh < (-max_ratio) ? (-max_ratio) : dh); + + T px = (rois[left_x] + rois[right_x]) * 0.5f; + T py = (rois[left_y] + rois[right_y]) * 0.5f; + T pw = rois[right_x] - rois[left_x] + 1.0f; + T ph = rois[right_y] - rois[left_y] + 1.0f; + + T gx = px + pw * dx; + T gy = py + ph * dy; + T gw = pw * exp(dw); + T gh = ph * exp(dh); + + T x1 = gx - gw * 0.5f + 0.5f; + T y1 = gy - gh * 0.5f + 0.5f; + T x2 = gx + gw * 0.5f - 0.5f; + T y2 = gy + gh * 0.5f - 0.5f; + + x1 = x1 > max_width ? max_width : (x1 < 0 ? 0 : x1); + y1 = y1 > max_height ? max_height : (y1 < 0 ? 0 : y1); + x2 = x2 > max_width ? max_width : (x2 < 0 ? 0 : x2); + y2 = y2 > max_height ? max_height : (y2 < 0 ? 0 : y2); + + bboxes[left_x] = x1; + bboxes[left_y] = y1; + bboxes[right_x] = x2; + bboxes[right_y] = y2; + } +} + +template +void BoundingBoxDecode(const size_t size, const T *rois, const T *deltas, T *bboxes, const float &m1, const float &m2, + const float &m3, const float &m4, const float &s1, const float &s2, const float &s3, + const float &s4, const int &max_height, const int &max_width, const float &ratio_clip, + cudaStream_t cuda_stream) { + BoundingBoxDecodeKernel<<>>(size, rois, deltas, bboxes, m1, m2, m3, m4, + s1, s2, s3, s4, max_height, max_width, + ratio_clip); +} + +template void BoundingBoxDecode(const size_t size, const float *rois, const float *deltas, float *bboxes, + const float &m1, const float &m2, const float &m3, const float &m4, + const float &s1, const float &s2, const float &s3, const float &s4, + const int &max_height, const int &max_width, const float &ratio_clip, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/boundingbox_decode_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/boundingbox_decode_impl.cuh new file mode 100644 index 0000000000..ccd3914a1c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/boundingbox_decode_impl.cuh @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_BOUNDINGBOX_DECODE_IMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_BOUNDINGBOX_DECODE_IMPL_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void BoundingBoxDecode(const size_t size, const T *rois, const T *deltas, T *bboxes, const float &m1, const float &m2, + const float &m3, const float &m4, const float &s1, const float &s2, const float &s3, + const float &s4, const int &max_height, const int &max_width, const float &ratio_clip, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_BOUNDINGBOX_DECODE_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/boundingbox_encode_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/boundingbox_encode_impl.cu new file mode 100644 index 0000000000..cf0ee68ae0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/boundingbox_encode_impl.cu @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/boundingbox_encode_impl.cuh" + +template +__global__ void BoundingBoxEncodeKernel(const size_t size, const T *anchor_box, const T *groundtruth_box, T *deltas, + const float m1, const float m2, const float m3, const float m4, const float s1, + const float s2, const float s3, const float s4) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + const size_t left_x = i * 4; + const size_t left_y = i * 4 + 1; + const size_t right_x = i * 4 + 2; + const size_t right_y = i * 4 + 3; + + T px = (anchor_box[left_x] + anchor_box[right_x]) * 0.5f; + T py = (anchor_box[left_y] + anchor_box[right_y]) * 0.5f; + T pw = anchor_box[right_x] - anchor_box[left_x] + 1.0f; + T ph = anchor_box[right_y] - anchor_box[left_y] + 1.0f; + + T gx = (groundtruth_box[left_x] + groundtruth_box[right_x]) * 0.5f; + T gy = (groundtruth_box[left_y] + groundtruth_box[right_y]) * 0.5f; + T gw = groundtruth_box[right_x] - groundtruth_box[left_x] + 1.0f; + T gh = groundtruth_box[right_y] - groundtruth_box[left_y] + 1.0f; + + T dx = (gx - px) / pw; + T dy = (gy - py) / ph; + T dw = log(gw / pw); + T dh = log(gh / ph); + + deltas[left_x] = (dx - m1) / s1; + deltas[left_y] = (dy - m2) / s2; + deltas[right_x] = (dw - m3) / s3; + deltas[right_y] = (dh - m4) / s4; + } +} + +template +void BoundingBoxEncode(const size_t size, const T *anchor_box, const T *groundtruth_box, T *deltas, const float &m1, + const float &m2, const float &m3, const float &m4, const float &s1, const float &s2, + const float &s3, const float &s4, cudaStream_t cuda_stream) { + BoundingBoxEncodeKernel<<>>(size, anchor_box, groundtruth_box, deltas, + m1, m2, m3, m4, s1, s2, s3, s4); +} + +template void BoundingBoxEncode(const size_t size, const float *anchor_box, const float *groundtruth_box, + float *deltas, const float &m1, const float &m2, const float &m3, + const float &m4, const float &s1, const float &s2, const float &s3, + const float &s4, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/boundingbox_encode_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/boundingbox_encode_impl.cuh new file mode 100644 index 0000000000..8ab810d7b9 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/boundingbox_encode_impl.cuh @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_BOUNDINGBOX_ENCODE_IMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_BOUNDINGBOX_ENCODE_IMPL_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void BoundingBoxEncode(const size_t size, const T *anchor_box, const T *groundtruth_box, T *deltas, const float &m1, + const float &m2, const float &m3, const float &m4, const float &s1, const float &s2, + const float &s3, const float &s4, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_BOUNDINGBOX_ENCODE_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu index f5c88e7ebf..2c0e6f7905 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu @@ -69,6 +69,25 @@ struct AddFunc { __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return (lhs + rhs); } }; +template +struct FloorDivFunc { + __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return floor(static_cast(lhs / rhs)); } +}; + +template <> +struct FloorDivFunc { + __device__ __forceinline__ half operator()(const half &lhs, const half &rhs) { + return __float2half(floor(__half2float(lhs)/ __half2float(rhs))); + } +}; + +template <> +struct FloorDivFunc { + // invalid branch + __device__ __forceinline__ half operator()(const half &lhs, const half &rhs) { return false; } +}; + + template <> struct PowerFunc { // invalid branch @@ -77,6 +96,7 @@ struct PowerFunc { __device__ __forceinline__ int Index(const int &index, const int &dim) { return dim == 1 ? 0 : index; } + template __device__ __forceinline__ void BroadcastOperator(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, const int &r2, const int &r3, @@ -116,16 +136,19 @@ __global__ void BroadcastKernel(const int l0, const int l1, const int l2, const output); case BROADCAST_TYPE_REALDIV: return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, - output); + output); case BROADCAST_TYPE_MUL: return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, - output); + output); case BROADCAST_TYPE_SUB: return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, - output); + output); case BROADCAST_TYPE_ADD: return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, - output); + output); + case BROADCAST_TYPE_FLOORDIV: + return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, + output); } } @@ -167,6 +190,8 @@ __global__ void NoBroadcastKernel(const int nums, enum BroadcastOpType op, const return NoBroadcastOperator>(nums, input0, input1, output); case BROADCAST_TYPE_ADD: return NoBroadcastOperator>(nums, input0, input1, output); + case BROADCAST_TYPE_FLOORDIV: + return NoBroadcastOperator>(nums, input0, input1, output); } } @@ -195,7 +220,7 @@ void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, con const int &o2, const int &o3, const T *input_addr, T *output_addr, cudaStream_t stream) { int nums = o0 * o1 * o2 * o3; BroadcastToKernel<<>>(i0, i1, i2, i3, o0, o1, o2, o3, input_addr, - output_addr); + output_addr); } template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, @@ -226,9 +251,8 @@ template void NoBroadcast(const int &nums, enum BroadcastOpType op, const half * bool *output, cudaStream_t stream); template void NoBroadcast(const int &nums, enum BroadcastOpType op, const half *input0, const half *input1, half *output, cudaStream_t stream); -template void NoBroadcast(const int &nums, enum BroadcastOpType op, const int *input0, const int *input1, int *output, - cudaStream_t stream); - +template void NoBroadcast(const int &nums, enum BroadcastOpType op, const int *input0, const int *input1, + int *output, cudaStream_t stream); template void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, const int &o0, const int &o1, const int &o2, const int &o3, const float *input_addr, float *output_addr, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh index 62a3baad0e..e81cc16e33 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh @@ -29,6 +29,7 @@ enum BroadcastOpType { BROADCAST_TYPE_MUL = 6, BROADCAST_TYPE_SUB = 7, BROADCAST_TYPE_ADD = 8, + BROADCAST_TYPE_FLOORDIV = 9, BROADCAST_TYPE_INVALID = 0xffffffff, }; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gathernd.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gathernd.cu new file mode 100644 index 0000000000..3d02723218 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gathernd.cu @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/gathernd.cuh" +#include "runtime/device/gpu/cuda_common.h" +template +__global__ void GatherNdKernel(T *input, S *indices, T *output, const size_t output_dim0, const size_t output_dim1, + const size_t indices_dim1, S *batch_indices, S *batch_strides) { + int num = output_dim0 * output_dim1; + int i, j; + for (int write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; + write_index += blockDim.x * gridDim.x) { + i = write_index / output_dim1 % output_dim0; + j = write_index % output_dim1; + + bool out_of_bound = false; + int read_index = 0; + int indices_i = 0; + for (size_t k = 0; k < indices_dim1; k++) { + size_t ind = indices_dim1 * i + k; + indices_i = indices[ind]; + out_of_bound |= !(indices_i < batch_indices[k]); + read_index += indices_i * batch_strides[k]; + } + read_index += j; + + if (!out_of_bound) { + output[write_index] = input[read_index]; + } else { + output[write_index] = 0; + } + } + return; +} +template +void GatherNd(T *input, S *indices, T *output, const size_t &output_dim0, const size_t &output_dim1, + const size_t &indices_dim1, S *batch_indices, S *batch_strides, cudaStream_t stream) { + int size = output_dim0 * output_dim1; + GatherNdKernel<<>>(input, indices, output, output_dim0, output_dim1, + indices_dim1, batch_indices, batch_strides); + return; +} + +template void GatherNd(float *input, int *indices, float *output, const size_t &output_dim0, + const size_t &output_dim1, const size_t &indices_dim1, int *batch_indices, + int *batch_strides, cudaStream_t stream); +template void GatherNd(half *input, int *indices, half *output, const size_t &output_dim0, + const size_t &output_dim1, const size_t &indices_dim1, int *batch_indices, + int *batch_strides, cudaStream_t stream); +template void GatherNd(int *input, int *indices, int *output, const size_t &output_dim0, + const size_t &output_dim1, const size_t &indices_dim1, int *batch_indices, + int *batch_strides, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gathernd.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gathernd.cuh new file mode 100644 index 0000000000..c6cbbf7603 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gathernd.cuh @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_GATHERND_GPU_CU_H +#define MINDSPORE_GATHERND_GPU_CU_H + +#include "runtime/device/gpu/cuda_common.h" + +template +void GatherNd(T *input, S *indices, T *output, const size_t &output_dim0, const size_t &output_dim1, + const size_t &indices_dim1, S *batch_indices, S *batch_strides, cudaStream_t stream); + +#endif // MINDSPORE_GATHERND_GPU_CU_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cu new file mode 100644 index 0000000000..5f9672c41f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cu @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cuh" +#include "runtime/device/gpu/cuda_common.h" +template +__global__ void ScatterNdKernel(S *indices, T *update, T *output, const size_t block_size, const size_t input_size, + const size_t output_size, const size_t indices_dim_0, const size_t indices_dim_1, + S *indices_stride, S *work_shape) { + int i, j; + for (int read_index = blockIdx.x * blockDim.x + threadIdx.x; read_index < input_size; + read_index += blockDim.x * gridDim.x) { + int write_index = 0; + bool out_bound = false; + + i = read_index / block_size; + j = read_index % block_size; + + for (size_t k = 0; k < indices_dim_1; k++) { + S indices_i = indices[i * indices_dim_1 + k]; + out_bound |= indices_i >= work_shape[k]; + write_index += indices_i * indices_stride[k]; + } + + write_index += j; + out_bound |= write_index >= output_size; + + if (!out_bound) { + output[write_index] = update[read_index]; + } + } +} + +template +void ScatterNd(S *indices, T *update, T *output, const size_t &block_size, const size_t &input_size, + const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, S *indices_stride, + S *work_shape, cudaStream_t stream) { + ScatterNdKernel<<>>(indices, update, output, block_size, input_size, + output_size, indices_dim_0, indices_dim_1, + indices_stride, work_shape); + return; +} + +template void ScatterNd(int *indices, float *update, float *output, const size_t &block_size, + const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, + const size_t &indices_dim_1, int *indices_stride, int *work_shape, + cudaStream_t stream); +template void ScatterNd(int *indices, half *update, half *output, const size_t &block_size, + const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, + const size_t &indices_dim_1, int *indices_stride, int *work_shape, + cudaStream_t stream); +template void ScatterNd(int *indices, int *update, int *output, const size_t &block_size, + const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, + const size_t &indices_dim_1, int *indices_stride, int *work_shape, + cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cuh new file mode 100644 index 0000000000..7573239743 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cuh @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_SCATTER_ND_GPU_CU_H +#define MINDSPORE_SCATTER_ND_GPU_CU_H + +#include "runtime/device/gpu/cuda_common.h" + +template +void ScatterNd(S *indices, T *update, T *output, const size_t &block_size, const size_t &input_size, + const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, S *indices_stride, + S *work_shape, cudaStream_t stream); +#endif // MINDSPORE_SCATTER_ND_GPU_CU_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sgd_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sgd_impl.cu new file mode 100644 index 0000000000..4c452b116d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sgd_impl.cu @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "backend/kernel_compiler/gpu/cuda_impl/sgd_impl.cuh" + +template +__global__ void SGDKernel(const int size, const T dampening, const T weight_decay, const bool nesterov, const T *grad, + const T *momentum, const T *lr, T *param, T *accum, T *stat) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { + T grad_new = grad[i]; + if (weight_decay != static_cast(0)) { + grad_new += param[i] * weight_decay; + } + + if (momentum[0] != static_cast(0)) { + if (stat[i] == static_cast(0)) { + accum[i] = grad_new; + stat[i] = 0; + } else { + accum[i] = accum[i] * momentum[0] + (1.0 - dampening) * grad_new; + } + + if (nesterov) { + grad_new += accum[i] * momentum[0]; + } else { + grad_new = accum[i]; + } + } + + param[i] -= lr[0] * grad_new; + } +} + +template +void SGD(const int size, const T dampening, const T weight_decay, const bool nesterov, const T *lr, const T *momentum, + const T *grad, T *param, T *accum, T *stat, cudaStream_t cuda_stream) { + SGDKernel<<>>(size, dampening, weight_decay, nesterov, grad, momentum, + lr, param, accum, stat); +} + +template void SGD(const int size, const float dampening, const float weight_decay, const bool nesterov, const float *lr, + const float *momentum, const float *grad, float *param, float *accum, float *stat, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sgd_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sgd_impl.cuh new file mode 100644 index 0000000000..bc2fa3304d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sgd_impl.cuh @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SGD_IMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SGD_IMPL_H_ + +#include "runtime/device/gpu/cuda_common.h" + +template +void SGD(const int size, const T dampening, const T weight_decay, const bool nesterov, const T *lr, const T *momentum, + const T *grad, T *param, T *accum, T *stat, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SGD_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc index 2881cb1251..f5fffc0a4b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc @@ -51,6 +51,10 @@ MS_REG_GPU_KERNEL_TWO( TensorAdd, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), BroadcastOpGpuKernel, float, float) +MS_REG_GPU_KERNEL_TWO( + FloorDiv, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + BroadcastOpGpuKernel, float, float) // fp16 MS_REG_GPU_KERNEL_TWO( @@ -85,6 +89,10 @@ MS_REG_GPU_KERNEL_TWO( TensorAdd, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), BroadcastOpGpuKernel, half, half) +MS_REG_GPU_KERNEL_TWO( + FloorDiv, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + BroadcastOpGpuKernel, half, half) // int32 MS_REG_GPU_KERNEL_TWO( diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h index b131aef58d..7cbc2f692e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h @@ -96,10 +96,10 @@ class BroadcastOpGpuKernel : public GpuKernel { std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); static std::map kBroadcastTypeMap = { - {"Greater", BROADCAST_TYPE_GREATER}, {"Less", BROADCAST_TYPE_LESS}, {"Maximum", BROADCAST_TYPE_MAXIMUM}, - {"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER}, {"RealDiv", BROADCAST_TYPE_REALDIV}, - {"FloorDiv", BROADCAST_TYPE_REALDIV}, {"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB}, - {"TensorAdd", BROADCAST_TYPE_ADD}, + {"Greater", BROADCAST_TYPE_GREATER}, {"Less", BROADCAST_TYPE_LESS}, {"Maximum", BROADCAST_TYPE_MAXIMUM}, + {"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER}, {"RealDiv", BROADCAST_TYPE_REALDIV}, + {"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB}, {"TensorAdd", BROADCAST_TYPE_ADD}, + {"FloorDiv", BROADCAST_TYPE_FLOORDIV}, }; auto iter = kBroadcastTypeMap.find(kernel_name); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sgd_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sgd_gpu_kernel.cc new file mode 100644 index 0000000000..7b022699f7 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sgd_gpu_kernel.cc @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/sgd_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(SGD, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + SGDGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sgd_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sgd_gpu_kernel.h new file mode 100644 index 0000000000..70a57cded0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sgd_gpu_kernel.h @@ -0,0 +1,88 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SGD_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SGD_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/cuda_impl/sgd_impl.cuh" +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +template +class SGDGpuKernel : public GpuKernel { + public: + SGDGpuKernel() : size_(1), dampening_(0.0), weight_decay_(0.0), nesterov_(false) {} + ~SGDGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream) override { + T *param = GetDeviceAddress(inputs, 0); + T *grad = GetDeviceAddress(inputs, 1); + T *lr = GetDeviceAddress(inputs, 2); + T *accum = GetDeviceAddress(inputs, 3); + T *momentum = GetDeviceAddress(inputs, 4); + T *stat = GetDeviceAddress(inputs, 5); + + SGD(size_, dampening_, weight_decay_, nesterov_, lr, momentum, grad, param, accum, stat, + reinterpret_cast(stream)); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + dampening_ = GetAttr(kernel_node, "dampening"); + weight_decay_ = GetAttr(kernel_node, "weight_decay"); + nesterov_ = GetAttr(kernel_node, "nesterov"); + + auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + for (auto &dim : input_shape) { + size_ *= dim; + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + size_t input_size = size_ * sizeof(T); + input_size_list_.push_back(input_size); // parameter + input_size_list_.push_back(input_size); // gradient + input_size_list_.push_back(sizeof(T)); // lr + input_size_list_.push_back(input_size); // accum + input_size_list_.push_back(sizeof(T)); // momentum + input_size_list_.push_back(input_size); // stat + output_size_list_.push_back(input_size); + } + + private: + size_t size_; + float dampening_; + float weight_decay_; + bool nesterov_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SGD_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/other/boundingbox_decode_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/boundingbox_decode_gpu_kernel.cc new file mode 100644 index 0000000000..d08b671241 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/boundingbox_decode_gpu_kernel.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/other/boundingbox_decode_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + BoundingBoxDecode, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + BoundingBoxDecodeGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/other/boundingbox_decode_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/boundingbox_decode_gpu_kernel.h new file mode 100644 index 0000000000..0f1d9ac917 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/boundingbox_decode_gpu_kernel.h @@ -0,0 +1,152 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_OTHER_BOUNDINGBOX_DECODE_GPU_KERNEL_H +#define MINDSPORE_CCSRC_KERNEL_GPU_OTHER_BOUNDINGBOX_DECODE_GPU_KERNEL_H + +#include +#include "backend/kernel_compiler/gpu/cuda_impl/boundingbox_decode_impl.cuh" +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +template +class BoundingBoxDecodeGpuKernel : public GpuKernel { + public: + BoundingBoxDecodeGpuKernel() : rois_size_(0), deltas_size_(0), bboxes_size_(0), wh_ratio_clip_(0.016) {} + + ~BoundingBoxDecodeGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + T *rois_addr = GetDeviceAddress(inputs, 0); + T *deltas_addr = GetDeviceAddress(inputs, 1); + T *bboxes_addr = GetDeviceAddress(outputs, 0); + + if (inputs[0]->size != inputs[1]->size) { + MS_LOG(ERROR) << "Rois box size must equal with deltas box size -" << inputs[1]->size << ", but got" + << inputs[0]->size; + return false; + } + + const size_t coordinate = 4; + const size_t block_size = inputs[0]->size / sizeof(T); + if ((block_size % coordinate) != 0) { + MS_LOG(ERROR) << "The size of the box must be a multiple of 4."; + return false; + } + + BoundingBoxDecode(block_size / coordinate, rois_addr, deltas_addr, bboxes_addr, means_[0], means_[1], means_[2], + means_[3], stds_[0], stds_[1], stds_[2], stds_[3], max_shape_[0], max_shape_[1], wh_ratio_clip_, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but BoundingBoxDecode needs 2 inputs."; + return false; + } + rois_size_ = sizeof(T); + deltas_size_ = sizeof(T); + bboxes_size_ = sizeof(T); + + auto logits_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < logits_shape.size(); i++) { + rois_size_ *= logits_shape[i]; + } + + auto labels_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + for (size_t i = 0; i < labels_shape.size(); i++) { + deltas_size_ *= labels_shape[i]; + } + + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < output_shape.size(); i++) { + bboxes_size_ *= output_shape[i]; + } + + InitSizeLists(); + + const size_t coordinate_size = 4; + if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa() || + AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa()) { + means_ = GetAttr>(kernel_node, "means"); + } else if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa()) { + float mean = GetAttr(kernel_node, "means"); + for (size_t i = 0; i < coordinate_size; i++) { + means_.emplace_back(mean); + } + } else { + MS_LOG(EXCEPTION) << "Attribute means type is invalid."; + } + + if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa() || + AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa()) { + stds_ = GetAttr>(kernel_node, "stds"); + } else if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa()) { + float std = GetAttr(kernel_node, "stds"); + for (size_t i = 0; i < coordinate_size; i++) { + stds_.emplace_back(std); + } + } else { + MS_LOG(EXCEPTION) << "Attribute stds type is invalid."; + } + + max_shape_ = GetAttr>(kernel_node, "max_shape"); + wh_ratio_clip_ = GetAttr(kernel_node, "wh_ratio_clip"); + + if (means_.size() < coordinate_size || stds_.size() < coordinate_size) { + MS_LOG(EXCEPTION) << "The size of means or stds is less than 4."; + } + + if (max_shape_.size() < 2) { + MS_LOG(EXCEPTION) << "The size of max_shape is less than 2."; + } + + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(rois_size_); + input_size_list_.push_back(deltas_size_); + output_size_list_.push_back(bboxes_size_); + } + + private: + size_t rois_size_; + size_t deltas_size_; + size_t bboxes_size_; + std::vector means_; + std::vector stds_; + std::vector max_shape_; + float wh_ratio_clip_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_OTHER_BOUNDINGBOX_DECODE_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/other/boundingbox_encode_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/boundingbox_encode_gpu_kernel.cc new file mode 100644 index 0000000000..98ee8104e0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/boundingbox_encode_gpu_kernel.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/other/boundingbox_encode_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + BoundingBoxEncode, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + BoundingBoxEncodeGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/other/boundingbox_encode_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/boundingbox_encode_gpu_kernel.h new file mode 100644 index 0000000000..564751cda4 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/boundingbox_encode_gpu_kernel.h @@ -0,0 +1,143 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_OTHER_BOUNDINGBOX_ENCODE_GPU_KERNEL_H +#define MINDSPORE_CCSRC_KERNEL_GPU_OTHER_BOUNDINGBOX_ENCODE_GPU_KERNEL_H + +#include +#include "backend/kernel_compiler/gpu/cuda_impl/boundingbox_encode_impl.cuh" +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +template +class BoundingBoxEncodeGpuKernel : public GpuKernel { + public: + BoundingBoxEncodeGpuKernel() : anchor_size_(0), groundtruth_size_(0), deltas_size_(0) {} + + ~BoundingBoxEncodeGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + T *anchor_addr = GetDeviceAddress(inputs, 0); + T *groundtruth_addr = GetDeviceAddress(inputs, 1); + T *deltas_addr = GetDeviceAddress(outputs, 0); + + if (inputs[0]->size != inputs[1]->size) { + MS_LOG(ERROR) << "Anchor box size must equal with groundtruth box size -" << inputs[1]->size << ", but got" + << inputs[0]->size; + return false; + } + + const size_t coordinate = 4; + const size_t block_size = inputs[0]->size / sizeof(T); + if ((block_size % coordinate) != 0) { + MS_LOG(ERROR) << "The size of the box must be a multiple of 4."; + return false; + } + + BoundingBoxEncode(block_size / coordinate, anchor_addr, groundtruth_addr, deltas_addr, means_[0], means_[1], + means_[2], means_[3], stds_[0], stds_[1], stds_[2], stds_[3], + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but BoundingBoxEncode needs 2 inputs."; + return false; + } + anchor_size_ = sizeof(T); + groundtruth_size_ = sizeof(T); + deltas_size_ = sizeof(T); + + auto logits_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < logits_shape.size(); i++) { + anchor_size_ *= logits_shape[i]; + } + + auto labels_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + for (size_t i = 0; i < labels_shape.size(); i++) { + groundtruth_size_ *= labels_shape[i]; + } + + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < output_shape.size(); i++) { + deltas_size_ *= output_shape[i]; + } + + InitSizeLists(); + + const size_t coordinate_size = 4; + if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa() || + AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa()) { + means_ = GetAttr>(kernel_node, "means"); + } else if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa()) { + float mean = GetAttr(kernel_node, "means"); + for (size_t i = 0; i < coordinate_size; i++) { + means_.emplace_back(mean); + } + } else { + MS_LOG(EXCEPTION) << "Attribute means type is invalid."; + } + + if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa() || + AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa()) { + stds_ = GetAttr>(kernel_node, "stds"); + } else if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa()) { + float std = GetAttr(kernel_node, "stds"); + for (size_t i = 0; i < coordinate_size; i++) { + stds_.emplace_back(std); + } + } else { + MS_LOG(EXCEPTION) << "Attribute stds type is invalid."; + } + + if (means_.size() < coordinate_size || stds_.size() < coordinate_size) { + MS_LOG(EXCEPTION) << "The size of means or stds is less than 4."; + } + + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(anchor_size_); + input_size_list_.push_back(groundtruth_size_); + output_size_list_.push_back(deltas_size_); + } + + private: + size_t anchor_size_; + size_t groundtruth_size_; + size_t deltas_size_; + std::vector means_; + std::vector stds_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_OTHER_BOUNDINGBOX_ENCODE_GPU_KERNEL_H diff --git a/tests/st/ops/gpu/test_boundingbox_decode_op.py b/tests/st/ops/gpu/test_boundingbox_decode_op.py new file mode 100644 index 0000000000..8400ee02b9 --- /dev/null +++ b/tests/st/ops/gpu/test_boundingbox_decode_op.py @@ -0,0 +1,60 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + + +class NetBoundingBoxDecode(nn.Cell): + def __init__(self, means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0)): + super(NetBoundingBoxDecode, self).__init__() + self.decode = P.BoundingBoxDecode(max_shape=(768, 1280), means=means, stds=stds, + wh_ratio_clip=0.016) + + def construct(self, anchor, groundtruth): + return self.decode(anchor, groundtruth) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_boundingbox_decode(): + anchor = np.array([[4, 1, 2, 1], [2, 2, 2, 3]], np.float32) + deltas = np.array([[3, 1, 2, 2], [1, 2, 1, 4]], np.float32) + means = (0.1, 0.1, 0.2, 0.2) + stds = (2.0, 2.0, 3.0, 3.0) + anchor_box = Tensor(anchor, mindspore.float32) + deltas_box = Tensor(deltas, mindspore.float32) + expect_deltas = np.array([[28.6500, 0.0000, 0.0000, 33.8500], + [0.0000, 0.0000, 15.8663, 72.7000]], np.float32) + + error = np.ones(shape=[2, 4]) * 1.0e-4 + + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + boundingbox_decode = NetBoundingBoxDecode(means, stds) + output = boundingbox_decode(anchor_box, deltas_box) + diff = output.asnumpy() - expect_deltas + assert np.all(abs(diff) < error) + + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + boundingbox_decode = NetBoundingBoxDecode(means, stds) + output = boundingbox_decode(anchor_box, deltas_box) + diff = output.asnumpy() - expect_deltas + assert np.all(abs(diff) < error) diff --git a/tests/st/ops/gpu/test_boundingbox_encode_op.py b/tests/st/ops/gpu/test_boundingbox_encode_op.py new file mode 100644 index 0000000000..c34e0e0e8e --- /dev/null +++ b/tests/st/ops/gpu/test_boundingbox_encode_op.py @@ -0,0 +1,80 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + + +class NetBoundingBoxEncode(nn.Cell): + def __init__(self, means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0)): + super(NetBoundingBoxEncode, self).__init__() + self.encode = P.BoundingBoxEncode(means=means, stds=stds) + + def construct(self, anchor, groundtruth): + return self.encode(anchor, groundtruth) + +def bbox2delta(proposals, gt, means, stds): + px = (proposals[..., 0] + proposals[..., 2]) * 0.5 + py = (proposals[..., 1] + proposals[..., 3]) * 0.5 + pw = proposals[..., 2] - proposals[..., 0] + 1.0 + ph = proposals[..., 3] - proposals[..., 1] + 1.0 + + gx = (gt[..., 0] + gt[..., 2]) * 0.5 + gy = (gt[..., 1] + gt[..., 3]) * 0.5 + gw = gt[..., 2] - gt[..., 0] + 1.0 + gh = gt[..., 3] - gt[..., 1] + 1.0 + + dx = (gx - px) / pw + dy = (gy - py) / ph + dw = np.log(gw / pw) + dh = np.log(gh / ph) + means = np.array(means, np.float32) + stds = np.array(stds, np.float32) + deltas = np.stack([(dx - means[0]) / stds[0], (dy - means[1]) / stds[1], + (dw - means[2]) / stds[2], (dh - means[3]) / stds[3]], axis=-1) + + return deltas + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_boundingbox_encode(): + anchor = np.array([[4, 1, 6, 9], [2, 5, 5, 9]]).astype(np.float32) + gt = np.array([[3, 2, 7, 7], [1, 5, 5, 8]]).astype(np.float32) + means = (0.1, 0.1, 0.2, 0.2) + stds = (2.0, 2.0, 3.0, 3.0) + anchor_box = Tensor(anchor, mindspore.float32) + groundtruth_box = Tensor(gt, mindspore.float32) + expect_deltas = bbox2delta(anchor, gt, means, stds) + + error = np.ones(shape=[2, 4]) * 1.0e-6 + + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + boundingbox_encode = NetBoundingBoxEncode(means, stds) + output = boundingbox_encode(anchor_box, groundtruth_box) + diff = output.asnumpy() - expect_deltas + assert np.all(abs(diff) < error) + + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + boundingbox_encode = NetBoundingBoxEncode(means, stds) + output = boundingbox_encode(anchor_box, groundtruth_box) + diff = output.asnumpy() - expect_deltas + assert np.all(abs(diff) < error) diff --git a/tests/st/ops/gpu/test_floordiv_op.py b/tests/st/ops/gpu/test_floordiv_op.py new file mode 100644 index 0000000000..dc7d76807f --- /dev/null +++ b/tests/st/ops/gpu/test_floordiv_op.py @@ -0,0 +1,116 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + +class NetFloorDiv(nn.Cell): + def __init__(self): + super(NetFloorDiv, self).__init__() + self.floordiv = P.FloorDiv() + + def construct(self, x, y): + return self.floordiv(x, y) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_floor_div(): + x0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32) + y0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32) + x1_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32) + y1_np = np.random.randint(1, 5, (2, 1, 4, 4)).astype(np.float32) + x2_np = np.random.randint(1, 5, (2, 1, 1, 4)).astype(np.float32) + y2_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32) + x3_np = np.random.randint(1, 5, 1).astype(np.float32) + y3_np = np.random.randint(1, 5, 1).astype(np.float32) + x4_np = np.array(768).astype(np.float32) + y4_np = np.array(3072.5).astype(np.float32) + x5_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float16) + y5_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float16) + x6_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.int32) + y6_np = np.random.randint(1, 5, (2, 1, 4, 4)).astype(np.int32) + + x0 = Tensor(x0_np) + y0 = Tensor(y0_np) + x1 = Tensor(x1_np) + y1 = Tensor(y1_np) + x2 = Tensor(x2_np) + y2 = Tensor(y2_np) + x3 = Tensor(x3_np) + y3 = Tensor(y3_np) + x4 = Tensor(x4_np) + y4 = Tensor(y4_np) + x5 = Tensor(x5_np) + y5 = Tensor(y5_np) + x6 = Tensor(x6_np) + y6 = Tensor(y6_np) + + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + floor_div = NetFloorDiv() + output0 = floor_div(x0, y0) + expect0 = np.floor_divide(x0_np, y0_np) + diff0 = output0.asnumpy() - expect0 + error0 = np.ones(shape=expect0.shape) * 1.0e-5 + assert np.all(diff0 < error0) + assert output0.shape == expect0.shape + + output1 = floor_div(x1, y1) + expect1 = np.floor_divide(x1_np, y1_np) + diff1 = output1.asnumpy() - expect1 + error1 = np.ones(shape=expect1.shape) * 1.0e-5 + assert np.all(diff1 < error1) + assert output1.shape == expect1.shape + + output2 = floor_div(x2, y2) + expect2 = np.floor_divide(x2_np, y2_np) + diff2 = output2.asnumpy() - expect2 + error2 = np.ones(shape=expect2.shape) * 1.0e-5 + assert np.all(diff2 < error2) + assert output2.shape == expect2.shape + + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + output3 = floor_div(x3, y3) + expect3 = np.floor_divide(x3_np, y3_np) + diff3 = output3.asnumpy() - expect3 + error3 = np.ones(shape=expect3.shape) * 1.0e-5 + assert np.all(diff3 < error3) + assert output3.shape == expect3.shape + + output4 = floor_div(x4, y4) + expect4 = np.floor_divide(x4_np, y4_np) + diff4 = output4.asnumpy() - expect4 + error4 = np.ones(shape=expect4.shape) * 1.0e-5 + assert np.all(diff4 < error4) + assert output4.shape == expect4.shape + + output5 = floor_div(x5, y5) + expect5 = np.floor_divide(x5_np, y5_np) + diff5 = output5.asnumpy() - expect5 + error5 = np.ones(shape=expect5.shape) * 1.0e-5 + assert np.all(diff5 < error5) + assert output5.shape == expect5.shape + + output6 = floor_div(x6, y6) + expect6 = np.floor_divide(x6_np, y6_np) + diff6 = output6.asnumpy() - expect6 + error6 = np.ones(shape=expect6.shape) * 1.0e-5 + assert np.all(diff6 < error6) + assert output6.shape == expect6.shape diff --git a/tests/st/ops/gpu/test_gathernd_op.py b/tests/st/ops/gpu/test_gathernd_op.py new file mode 100644 index 0000000000..c901eb08f2 --- /dev/null +++ b/tests/st/ops/gpu/test_gathernd_op.py @@ -0,0 +1,151 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.nn as nn +import mindspore.context as context + +class GatherNdNet(nn.Cell): + def __init__(self): + super(GatherNdNet, self).__init__() + self.gathernd = P.GatherNd() + + def construct(self, x, indices): + return self.gathernd(x, indices) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gathernd0(): + x = Tensor(np.arange(3 * 2, dtype=np.float32).reshape(3, 2)) + indices = Tensor(np.array([[1, 1], [0, 1]]).astype(np.int32)) + expect = np.array([3., 1.]) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + gathernd = GatherNdNet() + output = gathernd(x, indices) + + error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_traning +@pytest.mark.env_onecard +def test_gathernd1(): + x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.float32).reshape(2, 3, 4, 5)) + indices = Tensor(np.array([[[[[l, k, j, i] for i in [1, 3, 4]] for j in range(4)] + for k in range(3)] for l in range(2)], dtype='i4')) + expect = np.array([[[[1., 3., 4.], + [6., 8., 9.], + [11., 13., 14.], + [16., 18., 19.]], + + [[21., 23., 24.], + [26., 28., 29.], + [31., 33., 34.], + [36., 38., 39.]], + + [[41., 43., 44.], + [46., 48., 49.], + [51., 53., 54.], + [56., 58., 59.]]], + + [[[61., 63., 64.], + [66., 68., 69.], + [71., 73., 74.], + [76., 78., 79.]], + + [[81., 83., 84.], + [86., 88., 89.], + [91., 93., 94.], + [96., 98., 99.]], + + [[101., 103., 104.], + [106., 108., 109.], + [111., 113., 114.], + [116., 118., 119.]]]]) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + gather = GatherNdNet() + output = gather(x, indices) + + error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_traning +@pytest.mark.env_onecard +def test_gathernd2(): + x = Tensor(np.array([[4., 5., 4., 1., 5.], + [4., 9., 5., 6., 4.], + [9., 8., 4., 3., 6.], + [0., 4., 2., 2., 8.], + [1., 8., 6., 2., 8.], + [8., 1., 9., 7., 3.], + [7., 9., 2., 5., 7.], + [9., 8., 6., 8., 5.], + [3., 7., 2., 7., 4.], + [4., 2., 8., 2., 9.]]).astype(np.float16)) + + indices = Tensor(np.array([[4000], [1], [300000]]).astype(np.int32)) + expect = np.array([[0., 0., 0., 0., 0.], + [4., 9., 5., 6., 4.], + [0., 0., 0., 0., 0.]]) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + gathernd = GatherNdNet() + output = gathernd(x, indices) + + error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_traning +@pytest.mark.env_onecard +def test_gathernd3(): + x = Tensor(np.array([[4, 5, 4, 1, 5], + [4, 9, 5, 6, 4], + [9, 8, 4, 3, 6], + [0, 4, 2, 2, 8], + [1, 8, 6, 2, 8], + [8, 1, 9, 7, 3], + [7, 9, 2, 5, 7], + [9, 8, 6, 8, 5], + [3, 7, 2, 7, 4], + [4, 2, 8, 2, 9]] + ).astype(np.int32)) + + indices = Tensor(np.array([[4000], [1], [300000]]).astype(np.int32)) + expect = np.array([[0, 0, 0, 0, 0], + [4, 9, 5, 6, 4], + [0, 0, 0, 0, 0]]) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + gathernd = GatherNdNet() + output = gathernd(x, indices) + + error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error) diff --git a/tests/st/ops/gpu/test_scatter_nd.py b/tests/st/ops/gpu/test_scatter_nd.py new file mode 100644 index 0000000000..b201c7be2c --- /dev/null +++ b/tests/st/ops/gpu/test_scatter_nd.py @@ -0,0 +1,50 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + +class Net(nn.Cell): + def __init__(self, _shape): + super(Net, self).__init__() + self.shape = _shape + self.scatternd = P.ScatterNd() + + def construct(self, indices, update): + return self.scatternd(indices, update, self.shape) + +def scatternd_net(indices, update, _shape, expect): + scatternd = Net(_shape) + output = scatternd(Tensor(indices), Tensor(update)) + error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_traning +@pytest.mark.env_onecard +def test_scatternd(): + arr_indices = np.array([[0, 1], [1, 1]]).astype(np.int32) + arr_update = np.array([3.2, 1.1]).astype(np.float32) + shape = (2, 2) + expect = np.array([[0., 3.2], + [0., 1.1]]) + scatternd_net(arr_indices, arr_update, shape, expect) diff --git a/tests/st/ops/gpu/test_sgd_op.py b/tests/st/ops/gpu/test_sgd_op.py new file mode 100644 index 0000000000..85d470f50d --- /dev/null +++ b/tests/st/ops/gpu/test_sgd_op.py @@ -0,0 +1,73 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.nn import Dense +from mindspore.nn import TrainOneStepCell, WithLossCell +from mindspore.nn.optim import SGD +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + +class NetSGD(nn.Cell): + def __init__(self): + super(NetSGD, self).__init__() + self.batch_size = 1 + self.reshape = P.Reshape() + weight = Tensor(np.ones([10, 16]).astype(np.float32) * 0.01) + self.fc1 = Dense(16, 10, weight_init=weight) + + def construct(self, input_x): + output = self.reshape(input_x, (self.batch_size, -1)) + output = self.fc1(output) + return output + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_SGD(): + epoch = 3 + net = NetSGD() + learning_rate = 0.1 + momentum = 0.9 + dampening = 0.0 + weight_decay = 0.0 + nesterov = True + loss_scale = 1.0 + + optimizer = SGD(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum, dampening, + weight_decay, nesterov, loss_scale) + criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) + net_with_criterion = WithLossCell(net, criterion) + train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer + train_network.set_train() + losses = [] + for _ in range(epoch): + data = Tensor(np.arange(0, 16).reshape(1, 1, 4, 4).astype(np.float32) * 0.01) + label = Tensor(np.array([0]).astype(np.int32)) + loss = train_network(data, label) + losses.append(loss.asnumpy()) + + last_loss = 100.0 + for loss in losses: + assert last_loss > loss + last_loss = loss + return losses