| @@ -0,0 +1,33 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/kernel_compiler/gpu/arrays/gathernd_gpu_kernel.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| MS_REG_GPU_KERNEL_TWO( | |||||
| GatherNd, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), | |||||
| GatherNdGpuFwdKernel, float, int) | |||||
| MS_REG_GPU_KERNEL_TWO( | |||||
| GatherNd, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), | |||||
| GatherNdGpuFwdKernel, half, int) | |||||
| MS_REG_GPU_KERNEL_TWO( | |||||
| GatherNd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||||
| GatherNdGpuFwdKernel, int, int) | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,162 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_GATHERND_GPU_KERNEL_H | |||||
| #define MINDSPORE_GATHERND_GPU_KERNEL_H | |||||
| #include <vector> | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/gathernd.cuh" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| template <typename T, typename S> | |||||
| class GatherNdGpuFwdKernel : public GpuKernel { | |||||
| public: | |||||
| GatherNdGpuFwdKernel() : dev_batch_strides_(nullptr), dev_batch_indices_(nullptr) {} | |||||
| ~GatherNdGpuFwdKernel() { | |||||
| if (dev_batch_strides_ != nullptr) { | |||||
| device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(static_cast<void *>(dev_batch_strides_)); | |||||
| } | |||||
| if (dev_batch_indices_ != nullptr) { | |||||
| device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(static_cast<void *>(dev_batch_indices_)); | |||||
| } | |||||
| } | |||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||||
| VARIABLE_NOT_USED(workspace); | |||||
| T *input_addr = GetDeviceAddress<T>(inputs, 0); | |||||
| S *indices_addr = GetDeviceAddress<S>(inputs, 1); | |||||
| T *output_addr = GetDeviceAddress<T>(outputs, 0); | |||||
| GatherNd(input_addr, indices_addr, output_addr, dims_[0], dims_[1], dims_[2], dev_batch_strides_, | |||||
| dev_batch_indices_, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| return true; | |||||
| } | |||||
| bool Init(const CNodePtr &kernel_node) override { | |||||
| InitResource(); | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||||
| if (input_num != 2) { | |||||
| MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherNdGpuFwdKernel needs 2."; | |||||
| } | |||||
| input_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||||
| indices_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||||
| output_shapes_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||||
| Reshape(); | |||||
| size_t dim_indices_last = dims_[dims_.size() - 1]; | |||||
| batch_strides_.resize(dim_indices_last, 0); | |||||
| batch_indices_.resize(dim_indices_last, 0); | |||||
| if (dim_indices_last > 0) { | |||||
| batch_strides_[dim_indices_last - 1] = input_shapes_[dim_indices_last - 1]; | |||||
| batch_indices_[dim_indices_last - 1] = dims_[1]; | |||||
| } | |||||
| for (size_t i = dim_indices_last - 1; i > 0; --i) { | |||||
| batch_strides_[i - 1] = input_shapes_[i - 1]; | |||||
| batch_indices_[i - 1] = batch_indices_[i] * input_shapes_[i]; | |||||
| } | |||||
| size_t strides_len = sizeof(S) * batch_strides_.size(); | |||||
| void *dev_batch_strides_work = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(strides_len); | |||||
| if (dev_batch_strides_work == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "Failed to alloc dev_batch_strides_work, size: " << strides_len; | |||||
| } | |||||
| dev_batch_strides_ = static_cast<S *>(dev_batch_strides_work); | |||||
| size_t indices_len = sizeof(S) * batch_indices_.size(); | |||||
| void *dev_batch_indices_work = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(indices_len); | |||||
| if (dev_batch_indices_work == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "Failed to alloc dev_batch_indices_work, size: " << indices_len; | |||||
| } | |||||
| dev_batch_indices_ = static_cast<S *>(dev_batch_indices_work); | |||||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpy(dev_batch_strides_, &batch_strides_[0], strides_len, cudaMemcpyHostToDevice), | |||||
| "cudaMemcpy failed in GatherNdGpuFwdKernel::Init."); | |||||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpy(dev_batch_indices_, &batch_indices_[0], indices_len, cudaMemcpyHostToDevice), | |||||
| "cudaMemcpy failed in GatherNdGpuFwdKernel::Init."); | |||||
| InitSizeLists(); | |||||
| return true; | |||||
| } | |||||
| protected: | |||||
| void InitSizeLists() override { | |||||
| size_t size = GetSize(input_shapes_); | |||||
| input_size_list_.push_back(size); | |||||
| size = GetSize(indices_shapes_); | |||||
| input_size_list_.push_back(size); | |||||
| size = GetSize(output_shapes_); | |||||
| output_size_list_.push_back(size); | |||||
| } | |||||
| private: | |||||
| void Reshape() { | |||||
| size_t dim_of_indices = 1; | |||||
| for (size_t i = 0; i < indices_shapes_.size() - IntToSize(1); i++) { | |||||
| dim_of_indices *= indices_shapes_[i]; | |||||
| } | |||||
| size_t dim_after_indices = 1; | |||||
| size_t dim_indices_last = indices_shapes_[indices_shapes_.size() - IntToSize(1)]; | |||||
| for (size_t i = dim_indices_last; i < input_shapes_.size(); i++) { | |||||
| dim_after_indices *= input_shapes_[i]; | |||||
| } | |||||
| dims_.emplace_back(dim_of_indices); | |||||
| dims_.emplace_back(dim_after_indices); | |||||
| dims_.emplace_back(dim_indices_last); | |||||
| return; | |||||
| } | |||||
| size_t GetSize(const std::vector<size_t> &shape) const { | |||||
| if (shape.size() == 0) { | |||||
| return 0; | |||||
| } | |||||
| size_t result = sizeof(T); | |||||
| for (size_t i = 0; i < shape.size(); i++) { | |||||
| result *= shape[i]; | |||||
| } | |||||
| return result; | |||||
| } | |||||
| std::vector<size_t> input_shapes_; | |||||
| std::vector<size_t> indices_shapes_; | |||||
| std::vector<size_t> output_shapes_; | |||||
| std::vector<size_t> dims_; | |||||
| std::vector<size_t> input_size_list_; | |||||
| std::vector<size_t> output_size_list_; | |||||
| std::vector<size_t> workspace_size_list_; | |||||
| std::vector<S> batch_strides_; | |||||
| std::vector<S> batch_indices_; | |||||
| S *dev_batch_strides_; | |||||
| S *dev_batch_indices_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_GATHERND_GPU_KERNEL_H | |||||
| @@ -0,0 +1,33 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| MS_REG_GPU_KERNEL_TWO( | |||||
| ScatterNd, | |||||
| KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| ScatterNdGpuFwdKernel, float, int) | |||||
| MS_REG_GPU_KERNEL_TWO( | |||||
| ScatterNd, | |||||
| KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||||
| ScatterNdGpuFwdKernel, half, int) | |||||
| MS_REG_GPU_KERNEL_TWO( | |||||
| ScatterNd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||||
| ScatterNdGpuFwdKernel, int, int) | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,175 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_SCATTER_ND_GPU_KERNEL_H | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_SCATTER_ND_GPU_KERNEL_H | |||||
| #include <vector> | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cuh" | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| template <typename T, typename S> | |||||
| class ScatterNdGpuFwdKernel : public GpuKernel { | |||||
| public: | |||||
| ScatterNdGpuFwdKernel() | |||||
| : input_size_(1), | |||||
| indices_size_(1), | |||||
| output_size_(1), | |||||
| block_size_(1), | |||||
| indices_stride_(nullptr), | |||||
| work_shape_(nullptr), | |||||
| indices_dim_0_(0), | |||||
| indices_dim_1_(0) {} | |||||
| ~ScatterNdGpuFwdKernel() { | |||||
| if (indices_stride_ != nullptr) { | |||||
| device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(static_cast<void *>(indices_stride_)); | |||||
| } | |||||
| if (work_shape_ != nullptr) { | |||||
| device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(static_cast<void *>(work_shape_)); | |||||
| } | |||||
| } | |||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||||
| VARIABLE_NOT_USED(workspace); | |||||
| S *indices = GetDeviceAddress<S>(inputs, 0); | |||||
| T *update = GetDeviceAddress<T>(inputs, 1); | |||||
| T *output = GetDeviceAddress<T>(outputs, 0); | |||||
| ScatterNd(indices, update, output, block_size_, input_size_, output_size_, indices_dim_0_, indices_dim_1_, | |||||
| indices_stride_, work_shape_, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| return true; | |||||
| } | |||||
| bool Init(const CNodePtr &kernel_node) override { | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||||
| if (input_num != 2) { | |||||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but transpose needs 2 input."; | |||||
| return false; | |||||
| } | |||||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||||
| if (output_num != 1) { | |||||
| MS_LOG(ERROR) << "Output number is " << output_num << ", but transpose needs 1 output."; | |||||
| return false; | |||||
| } | |||||
| input_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||||
| indices_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||||
| output_shapes_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||||
| vec_work_shape_ = GetAttr<std::vector<S>>(kernel_node, "shape"); | |||||
| GetSize(); | |||||
| size_t indices_len = sizeof(S) * vec_indices_stride_.size(); | |||||
| void *indices_stride_work = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(indices_len); | |||||
| if (indices_stride_work == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "Failed to alloc indices_stride_work, size: " << indices_len; | |||||
| } | |||||
| indices_stride_ = static_cast<S *>(indices_stride_work); | |||||
| size_t vec_work_len = sizeof(S) * vec_work_shape_.size(); | |||||
| void *work_shape_work = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(vec_work_len); | |||||
| if (work_shape_work == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "Failed to alloc work_shape_work, size: " << vec_work_len; | |||||
| } | |||||
| work_shape_ = static_cast<S *>(work_shape_work); | |||||
| CHECK_CUDA_RET_WITH_EXCEPT( | |||||
| cudaMemcpy(indices_stride_, &vec_indices_stride_[0], indices_len, cudaMemcpyHostToDevice), | |||||
| "cudaMemcpy failed in ScatterNdGpuFwdKernel::Init."); | |||||
| CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpy(work_shape_, &vec_work_shape_[0], vec_work_len, cudaMemcpyHostToDevice), | |||||
| "cudaMemcpy failed in ScatterNdGpuFwdKernel::Init."); | |||||
| InitSizeLists(); | |||||
| return true; | |||||
| } | |||||
| protected: | |||||
| void InitSizeLists() override { | |||||
| input_size_list_.push_back(indices_size_); | |||||
| input_size_list_.push_back(input_size_); | |||||
| output_size_list_.push_back(output_size_); | |||||
| return; | |||||
| } | |||||
| void GetSize() { | |||||
| indices_size_ = sizeof(S); | |||||
| for (size_t i = 0; i < indices_shapes_.size(); i++) { | |||||
| indices_size_ *= indices_shapes_[i]; | |||||
| } | |||||
| input_size_ = sizeof(T); | |||||
| for (size_t i = 0; i < input_shapes_.size(); i++) { | |||||
| input_size_ *= input_shapes_[i]; | |||||
| } | |||||
| output_size_ = sizeof(T); | |||||
| for (size_t i = 0; i < output_shapes_.size(); i++) { | |||||
| output_size_ *= output_shapes_[i]; | |||||
| } | |||||
| // calculate indices dim 0/1 | |||||
| indices_dim_0_ = indices_shapes_[0]; | |||||
| indices_dim_1_ = indices_shapes_[1]; | |||||
| // calculate block_size | |||||
| for (size_t i = indices_dim_1_; i < output_shapes_.size(); i++) { | |||||
| block_size_ *= output_shapes_[i]; | |||||
| } | |||||
| // calculate indices_stride | |||||
| for (size_t i = 0; i < indices_dim_1_; i++) { | |||||
| vec_indices_stride_.push_back(0); | |||||
| } | |||||
| vec_indices_stride_[indices_dim_1_ - 1] = block_size_; | |||||
| for (size_t i = indices_dim_1_ - 1; i > 0; --i) { | |||||
| vec_indices_stride_[i - 1] = vec_indices_stride_[i] * output_shapes_[i]; | |||||
| } | |||||
| } | |||||
| private: | |||||
| std::vector<size_t> input_shapes_; | |||||
| std::vector<size_t> indices_shapes_; | |||||
| std::vector<size_t> output_shapes_; | |||||
| std::vector<S> vec_indices_stride_; | |||||
| std::vector<S> vec_work_shape_; | |||||
| std::vector<size_t> input_size_list_; | |||||
| std::vector<size_t> output_size_list_; | |||||
| std::vector<size_t> workspace_size_list_; | |||||
| size_t input_size_; | |||||
| size_t indices_size_; | |||||
| size_t output_size_; | |||||
| size_t block_size_; | |||||
| S *indices_stride_; | |||||
| S *work_shape_; | |||||
| size_t indices_dim_0_; | |||||
| size_t indices_dim_1_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_SCATTER_ND_GPU_KERNEL_H | |||||
| @@ -0,0 +1,81 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/boundingbox_decode_impl.cuh" | |||||
| template <typename T> | |||||
| __global__ void BoundingBoxDecodeKernel(const size_t size, const T *rois, const T *deltas, T *bboxes, const float m1, | |||||
| const float m2, const float m3, const float m4, const float s1, const float s2, | |||||
| const float s3, const float s4, const int max_height, const int max_width, | |||||
| const float ratio_clip) { | |||||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { | |||||
| const size_t left_x = i * 4; | |||||
| const size_t left_y = i * 4 + 1; | |||||
| const size_t right_x = i * 4 + 2; | |||||
| const size_t right_y = i * 4 + 3; | |||||
| T dx = deltas[left_x] * s1 + m1; | |||||
| T dy = deltas[left_y] * s2 + m2; | |||||
| T dw = deltas[right_x] * s3 + m3; | |||||
| T dh = deltas[right_y] * s4 + m4; | |||||
| T max_ratio = abs(log(ratio_clip)); | |||||
| dw = dw > max_ratio ? max_ratio : (dw < (-max_ratio) ? (-max_ratio) : dw); | |||||
| dh = dh > max_ratio ? max_ratio : (dh < (-max_ratio) ? (-max_ratio) : dh); | |||||
| T px = (rois[left_x] + rois[right_x]) * 0.5f; | |||||
| T py = (rois[left_y] + rois[right_y]) * 0.5f; | |||||
| T pw = rois[right_x] - rois[left_x] + 1.0f; | |||||
| T ph = rois[right_y] - rois[left_y] + 1.0f; | |||||
| T gx = px + pw * dx; | |||||
| T gy = py + ph * dy; | |||||
| T gw = pw * exp(dw); | |||||
| T gh = ph * exp(dh); | |||||
| T x1 = gx - gw * 0.5f + 0.5f; | |||||
| T y1 = gy - gh * 0.5f + 0.5f; | |||||
| T x2 = gx + gw * 0.5f - 0.5f; | |||||
| T y2 = gy + gh * 0.5f - 0.5f; | |||||
| x1 = x1 > max_width ? max_width : (x1 < 0 ? 0 : x1); | |||||
| y1 = y1 > max_height ? max_height : (y1 < 0 ? 0 : y1); | |||||
| x2 = x2 > max_width ? max_width : (x2 < 0 ? 0 : x2); | |||||
| y2 = y2 > max_height ? max_height : (y2 < 0 ? 0 : y2); | |||||
| bboxes[left_x] = x1; | |||||
| bboxes[left_y] = y1; | |||||
| bboxes[right_x] = x2; | |||||
| bboxes[right_y] = y2; | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| void BoundingBoxDecode(const size_t size, const T *rois, const T *deltas, T *bboxes, const float &m1, const float &m2, | |||||
| const float &m3, const float &m4, const float &s1, const float &s2, const float &s3, | |||||
| const float &s4, const int &max_height, const int &max_width, const float &ratio_clip, | |||||
| cudaStream_t cuda_stream) { | |||||
| BoundingBoxDecodeKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, rois, deltas, bboxes, m1, m2, m3, m4, | |||||
| s1, s2, s3, s4, max_height, max_width, | |||||
| ratio_clip); | |||||
| } | |||||
| template void BoundingBoxDecode<float>(const size_t size, const float *rois, const float *deltas, float *bboxes, | |||||
| const float &m1, const float &m2, const float &m3, const float &m4, | |||||
| const float &s1, const float &s2, const float &s3, const float &s4, | |||||
| const int &max_height, const int &max_width, const float &ratio_clip, | |||||
| cudaStream_t cuda_stream); | |||||
| @@ -0,0 +1,27 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_BOUNDINGBOX_DECODE_IMPL_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_BOUNDINGBOX_DECODE_IMPL_H_ | |||||
| #include "runtime/device/gpu/cuda_common.h" | |||||
| template <typename T> | |||||
| void BoundingBoxDecode(const size_t size, const T *rois, const T *deltas, T *bboxes, const float &m1, const float &m2, | |||||
| const float &m3, const float &m4, const float &s1, const float &s2, const float &s3, | |||||
| const float &s4, const int &max_height, const int &max_width, const float &ratio_clip, | |||||
| cudaStream_t cuda_stream); | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_BOUNDINGBOX_DECODE_IMPL_H_ | |||||
| @@ -0,0 +1,62 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/boundingbox_encode_impl.cuh" | |||||
| template <typename T> | |||||
| __global__ void BoundingBoxEncodeKernel(const size_t size, const T *anchor_box, const T *groundtruth_box, T *deltas, | |||||
| const float m1, const float m2, const float m3, const float m4, const float s1, | |||||
| const float s2, const float s3, const float s4) { | |||||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { | |||||
| const size_t left_x = i * 4; | |||||
| const size_t left_y = i * 4 + 1; | |||||
| const size_t right_x = i * 4 + 2; | |||||
| const size_t right_y = i * 4 + 3; | |||||
| T px = (anchor_box[left_x] + anchor_box[right_x]) * 0.5f; | |||||
| T py = (anchor_box[left_y] + anchor_box[right_y]) * 0.5f; | |||||
| T pw = anchor_box[right_x] - anchor_box[left_x] + 1.0f; | |||||
| T ph = anchor_box[right_y] - anchor_box[left_y] + 1.0f; | |||||
| T gx = (groundtruth_box[left_x] + groundtruth_box[right_x]) * 0.5f; | |||||
| T gy = (groundtruth_box[left_y] + groundtruth_box[right_y]) * 0.5f; | |||||
| T gw = groundtruth_box[right_x] - groundtruth_box[left_x] + 1.0f; | |||||
| T gh = groundtruth_box[right_y] - groundtruth_box[left_y] + 1.0f; | |||||
| T dx = (gx - px) / pw; | |||||
| T dy = (gy - py) / ph; | |||||
| T dw = log(gw / pw); | |||||
| T dh = log(gh / ph); | |||||
| deltas[left_x] = (dx - m1) / s1; | |||||
| deltas[left_y] = (dy - m2) / s2; | |||||
| deltas[right_x] = (dw - m3) / s3; | |||||
| deltas[right_y] = (dh - m4) / s4; | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| void BoundingBoxEncode(const size_t size, const T *anchor_box, const T *groundtruth_box, T *deltas, const float &m1, | |||||
| const float &m2, const float &m3, const float &m4, const float &s1, const float &s2, | |||||
| const float &s3, const float &s4, cudaStream_t cuda_stream) { | |||||
| BoundingBoxEncodeKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, anchor_box, groundtruth_box, deltas, | |||||
| m1, m2, m3, m4, s1, s2, s3, s4); | |||||
| } | |||||
| template void BoundingBoxEncode<float>(const size_t size, const float *anchor_box, const float *groundtruth_box, | |||||
| float *deltas, const float &m1, const float &m2, const float &m3, | |||||
| const float &m4, const float &s1, const float &s2, const float &s3, | |||||
| const float &s4, cudaStream_t cuda_stream); | |||||
| @@ -0,0 +1,26 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_BOUNDINGBOX_ENCODE_IMPL_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_BOUNDINGBOX_ENCODE_IMPL_H_ | |||||
| #include "runtime/device/gpu/cuda_common.h" | |||||
| template <typename T> | |||||
| void BoundingBoxEncode(const size_t size, const T *anchor_box, const T *groundtruth_box, T *deltas, const float &m1, | |||||
| const float &m2, const float &m3, const float &m4, const float &s1, const float &s2, | |||||
| const float &s3, const float &s4, cudaStream_t cuda_stream); | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_BOUNDINGBOX_ENCODE_IMPL_H_ | |||||
| @@ -69,6 +69,25 @@ struct AddFunc { | |||||
| __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return (lhs + rhs); } | __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return (lhs + rhs); } | ||||
| }; | }; | ||||
| template <typename T, typename S> | |||||
| struct FloorDivFunc { | |||||
| __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return floor(static_cast<float>(lhs / rhs)); } | |||||
| }; | |||||
| template <> | |||||
| struct FloorDivFunc<half, half> { | |||||
| __device__ __forceinline__ half operator()(const half &lhs, const half &rhs) { | |||||
| return __float2half(floor(__half2float(lhs)/ __half2float(rhs))); | |||||
| } | |||||
| }; | |||||
| template <> | |||||
| struct FloorDivFunc<half, bool> { | |||||
| // invalid branch | |||||
| __device__ __forceinline__ half operator()(const half &lhs, const half &rhs) { return false; } | |||||
| }; | |||||
| template <> | template <> | ||||
| struct PowerFunc<half, bool> { | struct PowerFunc<half, bool> { | ||||
| // invalid branch | // invalid branch | ||||
| @@ -77,6 +96,7 @@ struct PowerFunc<half, bool> { | |||||
| __device__ __forceinline__ int Index(const int &index, const int &dim) { return dim == 1 ? 0 : index; } | __device__ __forceinline__ int Index(const int &index, const int &dim) { return dim == 1 ? 0 : index; } | ||||
| template <typename T, typename S, typename Func> | template <typename T, typename S, typename Func> | ||||
| __device__ __forceinline__ void BroadcastOperator(const int &l0, const int &l1, const int &l2, const int &l3, | __device__ __forceinline__ void BroadcastOperator(const int &l0, const int &l1, const int &l2, const int &l3, | ||||
| const int &r0, const int &r1, const int &r2, const int &r3, | const int &r0, const int &r1, const int &r2, const int &r3, | ||||
| @@ -116,16 +136,19 @@ __global__ void BroadcastKernel(const int l0, const int l1, const int l2, const | |||||
| output); | output); | ||||
| case BROADCAST_TYPE_REALDIV: | case BROADCAST_TYPE_REALDIV: | ||||
| return BroadcastOperator<T, S, RealDivFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, | return BroadcastOperator<T, S, RealDivFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, | ||||
| output); | |||||
| output); | |||||
| case BROADCAST_TYPE_MUL: | case BROADCAST_TYPE_MUL: | ||||
| return BroadcastOperator<T, S, MulFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, | return BroadcastOperator<T, S, MulFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, | ||||
| output); | |||||
| output); | |||||
| case BROADCAST_TYPE_SUB: | case BROADCAST_TYPE_SUB: | ||||
| return BroadcastOperator<T, S, SubFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, | return BroadcastOperator<T, S, SubFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, | ||||
| output); | |||||
| output); | |||||
| case BROADCAST_TYPE_ADD: | case BROADCAST_TYPE_ADD: | ||||
| return BroadcastOperator<T, S, AddFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, | return BroadcastOperator<T, S, AddFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, | ||||
| output); | |||||
| output); | |||||
| case BROADCAST_TYPE_FLOORDIV: | |||||
| return BroadcastOperator<T, S, FloorDivFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, | |||||
| output); | |||||
| } | } | ||||
| } | } | ||||
| @@ -167,6 +190,8 @@ __global__ void NoBroadcastKernel(const int nums, enum BroadcastOpType op, const | |||||
| return NoBroadcastOperator<T, S, SubFunc<T, S>>(nums, input0, input1, output); | return NoBroadcastOperator<T, S, SubFunc<T, S>>(nums, input0, input1, output); | ||||
| case BROADCAST_TYPE_ADD: | case BROADCAST_TYPE_ADD: | ||||
| return NoBroadcastOperator<T, S, AddFunc<T, S>>(nums, input0, input1, output); | return NoBroadcastOperator<T, S, AddFunc<T, S>>(nums, input0, input1, output); | ||||
| case BROADCAST_TYPE_FLOORDIV: | |||||
| return NoBroadcastOperator<T, S, FloorDivFunc<T, S>>(nums, input0, input1, output); | |||||
| } | } | ||||
| } | } | ||||
| @@ -195,7 +220,7 @@ void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, con | |||||
| const int &o2, const int &o3, const T *input_addr, T *output_addr, cudaStream_t stream) { | const int &o2, const int &o3, const T *input_addr, T *output_addr, cudaStream_t stream) { | ||||
| int nums = o0 * o1 * o2 * o3; | int nums = o0 * o1 * o2 * o3; | ||||
| BroadcastToKernel<<<GET_BLOCKS(nums), GET_THREADS, 0, stream>>>(i0, i1, i2, i3, o0, o1, o2, o3, input_addr, | BroadcastToKernel<<<GET_BLOCKS(nums), GET_THREADS, 0, stream>>>(i0, i1, i2, i3, o0, o1, o2, o3, input_addr, | ||||
| output_addr); | |||||
| output_addr); | |||||
| } | } | ||||
| template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, | template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, | ||||
| @@ -226,9 +251,8 @@ template void NoBroadcast(const int &nums, enum BroadcastOpType op, const half * | |||||
| bool *output, cudaStream_t stream); | bool *output, cudaStream_t stream); | ||||
| template void NoBroadcast(const int &nums, enum BroadcastOpType op, const half *input0, const half *input1, | template void NoBroadcast(const int &nums, enum BroadcastOpType op, const half *input0, const half *input1, | ||||
| half *output, cudaStream_t stream); | half *output, cudaStream_t stream); | ||||
| template void NoBroadcast(const int &nums, enum BroadcastOpType op, const int *input0, const int *input1, int *output, | |||||
| cudaStream_t stream); | |||||
| template void NoBroadcast(const int &nums, enum BroadcastOpType op, const int *input0, const int *input1, | |||||
| int *output, cudaStream_t stream); | |||||
| template void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, const int &o0, const int &o1, | template void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, const int &o0, const int &o1, | ||||
| const int &o2, const int &o3, const float *input_addr, float *output_addr, | const int &o2, const int &o3, const float *input_addr, float *output_addr, | ||||
| cudaStream_t stream); | cudaStream_t stream); | ||||
| @@ -29,6 +29,7 @@ enum BroadcastOpType { | |||||
| BROADCAST_TYPE_MUL = 6, | BROADCAST_TYPE_MUL = 6, | ||||
| BROADCAST_TYPE_SUB = 7, | BROADCAST_TYPE_SUB = 7, | ||||
| BROADCAST_TYPE_ADD = 8, | BROADCAST_TYPE_ADD = 8, | ||||
| BROADCAST_TYPE_FLOORDIV = 9, | |||||
| BROADCAST_TYPE_INVALID = 0xffffffff, | BROADCAST_TYPE_INVALID = 0xffffffff, | ||||
| }; | }; | ||||
| @@ -0,0 +1,65 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/gathernd.cuh" | |||||
| #include "runtime/device/gpu/cuda_common.h" | |||||
| template <typename T, typename S> | |||||
| __global__ void GatherNdKernel(T *input, S *indices, T *output, const size_t output_dim0, const size_t output_dim1, | |||||
| const size_t indices_dim1, S *batch_indices, S *batch_strides) { | |||||
| int num = output_dim0 * output_dim1; | |||||
| int i, j; | |||||
| for (int write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; | |||||
| write_index += blockDim.x * gridDim.x) { | |||||
| i = write_index / output_dim1 % output_dim0; | |||||
| j = write_index % output_dim1; | |||||
| bool out_of_bound = false; | |||||
| int read_index = 0; | |||||
| int indices_i = 0; | |||||
| for (size_t k = 0; k < indices_dim1; k++) { | |||||
| size_t ind = indices_dim1 * i + k; | |||||
| indices_i = indices[ind]; | |||||
| out_of_bound |= !(indices_i < batch_indices[k]); | |||||
| read_index += indices_i * batch_strides[k]; | |||||
| } | |||||
| read_index += j; | |||||
| if (!out_of_bound) { | |||||
| output[write_index] = input[read_index]; | |||||
| } else { | |||||
| output[write_index] = 0; | |||||
| } | |||||
| } | |||||
| return; | |||||
| } | |||||
| template <typename T, typename S> | |||||
| void GatherNd(T *input, S *indices, T *output, const size_t &output_dim0, const size_t &output_dim1, | |||||
| const size_t &indices_dim1, S *batch_indices, S *batch_strides, cudaStream_t stream) { | |||||
| int size = output_dim0 * output_dim1; | |||||
| GatherNdKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, indices, output, output_dim0, output_dim1, | |||||
| indices_dim1, batch_indices, batch_strides); | |||||
| return; | |||||
| } | |||||
| template void GatherNd<float, int>(float *input, int *indices, float *output, const size_t &output_dim0, | |||||
| const size_t &output_dim1, const size_t &indices_dim1, int *batch_indices, | |||||
| int *batch_strides, cudaStream_t stream); | |||||
| template void GatherNd<half, int>(half *input, int *indices, half *output, const size_t &output_dim0, | |||||
| const size_t &output_dim1, const size_t &indices_dim1, int *batch_indices, | |||||
| int *batch_strides, cudaStream_t stream); | |||||
| template void GatherNd<int, int>(int *input, int *indices, int *output, const size_t &output_dim0, | |||||
| const size_t &output_dim1, const size_t &indices_dim1, int *batch_indices, | |||||
| int *batch_strides, cudaStream_t stream); | |||||
| @@ -0,0 +1,26 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_GATHERND_GPU_CU_H | |||||
| #define MINDSPORE_GATHERND_GPU_CU_H | |||||
| #include "runtime/device/gpu/cuda_common.h" | |||||
| template <typename T, typename S> | |||||
| void GatherNd(T *input, S *indices, T *output, const size_t &output_dim0, const size_t &output_dim1, | |||||
| const size_t &indices_dim1, S *batch_indices, S *batch_strides, cudaStream_t stream); | |||||
| #endif // MINDSPORE_GATHERND_GPU_CU_H | |||||
| @@ -0,0 +1,68 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cuh" | |||||
| #include "runtime/device/gpu/cuda_common.h" | |||||
| template <typename T, typename S> | |||||
| __global__ void ScatterNdKernel(S *indices, T *update, T *output, const size_t block_size, const size_t input_size, | |||||
| const size_t output_size, const size_t indices_dim_0, const size_t indices_dim_1, | |||||
| S *indices_stride, S *work_shape) { | |||||
| int i, j; | |||||
| for (int read_index = blockIdx.x * blockDim.x + threadIdx.x; read_index < input_size; | |||||
| read_index += blockDim.x * gridDim.x) { | |||||
| int write_index = 0; | |||||
| bool out_bound = false; | |||||
| i = read_index / block_size; | |||||
| j = read_index % block_size; | |||||
| for (size_t k = 0; k < indices_dim_1; k++) { | |||||
| S indices_i = indices[i * indices_dim_1 + k]; | |||||
| out_bound |= indices_i >= work_shape[k]; | |||||
| write_index += indices_i * indices_stride[k]; | |||||
| } | |||||
| write_index += j; | |||||
| out_bound |= write_index >= output_size; | |||||
| if (!out_bound) { | |||||
| output[write_index] = update[read_index]; | |||||
| } | |||||
| } | |||||
| } | |||||
| template <typename T, typename S> | |||||
| void ScatterNd(S *indices, T *update, T *output, const size_t &block_size, const size_t &input_size, | |||||
| const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, S *indices_stride, | |||||
| S *work_shape, cudaStream_t stream) { | |||||
| ScatterNdKernel<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(indices, update, output, block_size, input_size, | |||||
| output_size, indices_dim_0, indices_dim_1, | |||||
| indices_stride, work_shape); | |||||
| return; | |||||
| } | |||||
| template void ScatterNd<float, int>(int *indices, float *update, float *output, const size_t &block_size, | |||||
| const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, | |||||
| const size_t &indices_dim_1, int *indices_stride, int *work_shape, | |||||
| cudaStream_t stream); | |||||
| template void ScatterNd<half, int>(int *indices, half *update, half *output, const size_t &block_size, | |||||
| const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, | |||||
| const size_t &indices_dim_1, int *indices_stride, int *work_shape, | |||||
| cudaStream_t stream); | |||||
| template void ScatterNd<int, int>(int *indices, int *update, int *output, const size_t &block_size, | |||||
| const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, | |||||
| const size_t &indices_dim_1, int *indices_stride, int *work_shape, | |||||
| cudaStream_t stream); | |||||
| @@ -0,0 +1,26 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_SCATTER_ND_GPU_CU_H | |||||
| #define MINDSPORE_SCATTER_ND_GPU_CU_H | |||||
| #include "runtime/device/gpu/cuda_common.h" | |||||
| template <typename T, typename S> | |||||
| void ScatterNd(S *indices, T *update, T *output, const size_t &block_size, const size_t &input_size, | |||||
| const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, S *indices_stride, | |||||
| S *work_shape, cudaStream_t stream); | |||||
| #endif // MINDSPORE_SCATTER_ND_GPU_CU_H | |||||
| @@ -0,0 +1,57 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <iostream> | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/sgd_impl.cuh" | |||||
| template <typename T> | |||||
| __global__ void SGDKernel(const int size, const T dampening, const T weight_decay, const bool nesterov, const T *grad, | |||||
| const T *momentum, const T *lr, T *param, T *accum, T *stat) { | |||||
| for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { | |||||
| T grad_new = grad[i]; | |||||
| if (weight_decay != static_cast<T>(0)) { | |||||
| grad_new += param[i] * weight_decay; | |||||
| } | |||||
| if (momentum[0] != static_cast<T>(0)) { | |||||
| if (stat[i] == static_cast<T>(0)) { | |||||
| accum[i] = grad_new; | |||||
| stat[i] = 0; | |||||
| } else { | |||||
| accum[i] = accum[i] * momentum[0] + (1.0 - dampening) * grad_new; | |||||
| } | |||||
| if (nesterov) { | |||||
| grad_new += accum[i] * momentum[0]; | |||||
| } else { | |||||
| grad_new = accum[i]; | |||||
| } | |||||
| } | |||||
| param[i] -= lr[0] * grad_new; | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| void SGD(const int size, const T dampening, const T weight_decay, const bool nesterov, const T *lr, const T *momentum, | |||||
| const T *grad, T *param, T *accum, T *stat, cudaStream_t cuda_stream) { | |||||
| SGDKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, dampening, weight_decay, nesterov, grad, momentum, | |||||
| lr, param, accum, stat); | |||||
| } | |||||
| template void SGD(const int size, const float dampening, const float weight_decay, const bool nesterov, const float *lr, | |||||
| const float *momentum, const float *grad, float *param, float *accum, float *stat, | |||||
| cudaStream_t cuda_stream); | |||||
| @@ -0,0 +1,25 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SGD_IMPL_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SGD_IMPL_H_ | |||||
| #include "runtime/device/gpu/cuda_common.h" | |||||
| template <typename T> | |||||
| void SGD(const int size, const T dampening, const T weight_decay, const bool nesterov, const T *lr, const T *momentum, | |||||
| const T *grad, T *param, T *accum, T *stat, cudaStream_t cuda_stream); | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SGD_IMPL_H_ | |||||
| @@ -51,6 +51,10 @@ MS_REG_GPU_KERNEL_TWO( | |||||
| TensorAdd, | TensorAdd, | ||||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | ||||
| BroadcastOpGpuKernel, float, float) | BroadcastOpGpuKernel, float, float) | ||||
| MS_REG_GPU_KERNEL_TWO( | |||||
| FloorDiv, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| BroadcastOpGpuKernel, float, float) | |||||
| // fp16 | // fp16 | ||||
| MS_REG_GPU_KERNEL_TWO( | MS_REG_GPU_KERNEL_TWO( | ||||
| @@ -85,6 +89,10 @@ MS_REG_GPU_KERNEL_TWO( | |||||
| TensorAdd, | TensorAdd, | ||||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | ||||
| BroadcastOpGpuKernel, half, half) | BroadcastOpGpuKernel, half, half) | ||||
| MS_REG_GPU_KERNEL_TWO( | |||||
| FloorDiv, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||||
| BroadcastOpGpuKernel, half, half) | |||||
| // int32 | // int32 | ||||
| MS_REG_GPU_KERNEL_TWO( | MS_REG_GPU_KERNEL_TWO( | ||||
| @@ -96,10 +96,10 @@ class BroadcastOpGpuKernel : public GpuKernel { | |||||
| std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); | std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); | ||||
| static std::map<std::string, BroadcastOpType> kBroadcastTypeMap = { | static std::map<std::string, BroadcastOpType> kBroadcastTypeMap = { | ||||
| {"Greater", BROADCAST_TYPE_GREATER}, {"Less", BROADCAST_TYPE_LESS}, {"Maximum", BROADCAST_TYPE_MAXIMUM}, | |||||
| {"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER}, {"RealDiv", BROADCAST_TYPE_REALDIV}, | |||||
| {"FloorDiv", BROADCAST_TYPE_REALDIV}, {"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB}, | |||||
| {"TensorAdd", BROADCAST_TYPE_ADD}, | |||||
| {"Greater", BROADCAST_TYPE_GREATER}, {"Less", BROADCAST_TYPE_LESS}, {"Maximum", BROADCAST_TYPE_MAXIMUM}, | |||||
| {"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER}, {"RealDiv", BROADCAST_TYPE_REALDIV}, | |||||
| {"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB}, {"TensorAdd", BROADCAST_TYPE_ADD}, | |||||
| {"FloorDiv", BROADCAST_TYPE_FLOORDIV}, | |||||
| }; | }; | ||||
| auto iter = kBroadcastTypeMap.find(kernel_name); | auto iter = kBroadcastTypeMap.find(kernel_name); | ||||
| @@ -0,0 +1,32 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/kernel_compiler/gpu/nn/sgd_gpu_kernel.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| MS_REG_GPU_KERNEL_ONE(SGD, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddInputAttr(kNumberTypeFloat32) | |||||
| .AddOutputAttr(kNumberTypeFloat32), | |||||
| SGDGpuKernel, float) | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,88 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SGD_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_NN_SGD_KERNEL_H_ | |||||
| #include <vector> | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/sgd_impl.cuh" | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| template <typename T> | |||||
| class SGDGpuKernel : public GpuKernel { | |||||
| public: | |||||
| SGDGpuKernel() : size_(1), dampening_(0.0), weight_decay_(0.0), nesterov_(false) {} | |||||
| ~SGDGpuKernel() override = default; | |||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||||
| const std::vector<AddressPtr> &outputs, void *stream) override { | |||||
| T *param = GetDeviceAddress<T>(inputs, 0); | |||||
| T *grad = GetDeviceAddress<T>(inputs, 1); | |||||
| T *lr = GetDeviceAddress<T>(inputs, 2); | |||||
| T *accum = GetDeviceAddress<T>(inputs, 3); | |||||
| T *momentum = GetDeviceAddress<T>(inputs, 4); | |||||
| T *stat = GetDeviceAddress<T>(inputs, 5); | |||||
| SGD(size_, dampening_, weight_decay_, nesterov_, lr, momentum, grad, param, accum, stat, | |||||
| reinterpret_cast<cudaStream_t>(stream)); | |||||
| return true; | |||||
| } | |||||
| bool Init(const CNodePtr &kernel_node) override { | |||||
| dampening_ = GetAttr<float>(kernel_node, "dampening"); | |||||
| weight_decay_ = GetAttr<float>(kernel_node, "weight_decay"); | |||||
| nesterov_ = GetAttr<bool>(kernel_node, "nesterov"); | |||||
| auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||||
| for (auto &dim : input_shape) { | |||||
| size_ *= dim; | |||||
| } | |||||
| InitSizeLists(); | |||||
| return true; | |||||
| } | |||||
| protected: | |||||
| void InitSizeLists() override { | |||||
| size_t input_size = size_ * sizeof(T); | |||||
| input_size_list_.push_back(input_size); // parameter | |||||
| input_size_list_.push_back(input_size); // gradient | |||||
| input_size_list_.push_back(sizeof(T)); // lr | |||||
| input_size_list_.push_back(input_size); // accum | |||||
| input_size_list_.push_back(sizeof(T)); // momentum | |||||
| input_size_list_.push_back(input_size); // stat | |||||
| output_size_list_.push_back(input_size); | |||||
| } | |||||
| private: | |||||
| size_t size_; | |||||
| float dampening_; | |||||
| float weight_decay_; | |||||
| bool nesterov_; | |||||
| std::vector<size_t> input_size_list_; | |||||
| std::vector<size_t> output_size_list_; | |||||
| std::vector<size_t> workspace_size_list_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SGD_KERNEL_H_ | |||||
| @@ -0,0 +1,26 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/kernel_compiler/gpu/other/boundingbox_decode_gpu_kernel.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| MS_REG_GPU_KERNEL_ONE( | |||||
| BoundingBoxDecode, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| BoundingBoxDecodeGpuKernel, float) | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,152 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_OTHER_BOUNDINGBOX_DECODE_GPU_KERNEL_H | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_OTHER_BOUNDINGBOX_DECODE_GPU_KERNEL_H | |||||
| #include <vector> | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/boundingbox_decode_impl.cuh" | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| template <typename T> | |||||
| class BoundingBoxDecodeGpuKernel : public GpuKernel { | |||||
| public: | |||||
| BoundingBoxDecodeGpuKernel() : rois_size_(0), deltas_size_(0), bboxes_size_(0), wh_ratio_clip_(0.016) {} | |||||
| ~BoundingBoxDecodeGpuKernel() override = default; | |||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||||
| T *rois_addr = GetDeviceAddress<T>(inputs, 0); | |||||
| T *deltas_addr = GetDeviceAddress<T>(inputs, 1); | |||||
| T *bboxes_addr = GetDeviceAddress<T>(outputs, 0); | |||||
| if (inputs[0]->size != inputs[1]->size) { | |||||
| MS_LOG(ERROR) << "Rois box size must equal with deltas box size -" << inputs[1]->size << ", but got" | |||||
| << inputs[0]->size; | |||||
| return false; | |||||
| } | |||||
| const size_t coordinate = 4; | |||||
| const size_t block_size = inputs[0]->size / sizeof(T); | |||||
| if ((block_size % coordinate) != 0) { | |||||
| MS_LOG(ERROR) << "The size of the box must be a multiple of 4."; | |||||
| return false; | |||||
| } | |||||
| BoundingBoxDecode(block_size / coordinate, rois_addr, deltas_addr, bboxes_addr, means_[0], means_[1], means_[2], | |||||
| means_[3], stds_[0], stds_[1], stds_[2], stds_[3], max_shape_[0], max_shape_[1], wh_ratio_clip_, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| return true; | |||||
| } | |||||
| bool Init(const CNodePtr &kernel_node) override { | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||||
| if (input_num != 2) { | |||||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but BoundingBoxDecode needs 2 inputs."; | |||||
| return false; | |||||
| } | |||||
| rois_size_ = sizeof(T); | |||||
| deltas_size_ = sizeof(T); | |||||
| bboxes_size_ = sizeof(T); | |||||
| auto logits_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||||
| for (size_t i = 0; i < logits_shape.size(); i++) { | |||||
| rois_size_ *= logits_shape[i]; | |||||
| } | |||||
| auto labels_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||||
| for (size_t i = 0; i < labels_shape.size(); i++) { | |||||
| deltas_size_ *= labels_shape[i]; | |||||
| } | |||||
| auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||||
| for (size_t i = 0; i < output_shape.size(); i++) { | |||||
| bboxes_size_ *= output_shape[i]; | |||||
| } | |||||
| InitSizeLists(); | |||||
| const size_t coordinate_size = 4; | |||||
| if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<ValueTuple>() || | |||||
| AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<ValueList>()) { | |||||
| means_ = GetAttr<std::vector<float>>(kernel_node, "means"); | |||||
| } else if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<FloatImm>()) { | |||||
| float mean = GetAttr<int>(kernel_node, "means"); | |||||
| for (size_t i = 0; i < coordinate_size; i++) { | |||||
| means_.emplace_back(mean); | |||||
| } | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Attribute means type is invalid."; | |||||
| } | |||||
| if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<ValueTuple>() || | |||||
| AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<ValueList>()) { | |||||
| stds_ = GetAttr<std::vector<float>>(kernel_node, "stds"); | |||||
| } else if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<FloatImm>()) { | |||||
| float std = GetAttr<int>(kernel_node, "stds"); | |||||
| for (size_t i = 0; i < coordinate_size; i++) { | |||||
| stds_.emplace_back(std); | |||||
| } | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Attribute stds type is invalid."; | |||||
| } | |||||
| max_shape_ = GetAttr<std::vector<int>>(kernel_node, "max_shape"); | |||||
| wh_ratio_clip_ = GetAttr<float>(kernel_node, "wh_ratio_clip"); | |||||
| if (means_.size() < coordinate_size || stds_.size() < coordinate_size) { | |||||
| MS_LOG(EXCEPTION) << "The size of means or stds is less than 4."; | |||||
| } | |||||
| if (max_shape_.size() < 2) { | |||||
| MS_LOG(EXCEPTION) << "The size of max_shape is less than 2."; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| protected: | |||||
| void InitSizeLists() override { | |||||
| input_size_list_.push_back(rois_size_); | |||||
| input_size_list_.push_back(deltas_size_); | |||||
| output_size_list_.push_back(bboxes_size_); | |||||
| } | |||||
| private: | |||||
| size_t rois_size_; | |||||
| size_t deltas_size_; | |||||
| size_t bboxes_size_; | |||||
| std::vector<float> means_; | |||||
| std::vector<float> stds_; | |||||
| std::vector<int> max_shape_; | |||||
| float wh_ratio_clip_; | |||||
| std::vector<size_t> input_size_list_; | |||||
| std::vector<size_t> output_size_list_; | |||||
| std::vector<size_t> workspace_size_list_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_OTHER_BOUNDINGBOX_DECODE_GPU_KERNEL_H | |||||
| @@ -0,0 +1,26 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/kernel_compiler/gpu/other/boundingbox_encode_gpu_kernel.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| MS_REG_GPU_KERNEL_ONE( | |||||
| BoundingBoxEncode, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| BoundingBoxEncodeGpuKernel, float) | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,143 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_OTHER_BOUNDINGBOX_ENCODE_GPU_KERNEL_H | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_OTHER_BOUNDINGBOX_ENCODE_GPU_KERNEL_H | |||||
| #include <vector> | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/boundingbox_encode_impl.cuh" | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| template <typename T> | |||||
| class BoundingBoxEncodeGpuKernel : public GpuKernel { | |||||
| public: | |||||
| BoundingBoxEncodeGpuKernel() : anchor_size_(0), groundtruth_size_(0), deltas_size_(0) {} | |||||
| ~BoundingBoxEncodeGpuKernel() override = default; | |||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||||
| T *anchor_addr = GetDeviceAddress<T>(inputs, 0); | |||||
| T *groundtruth_addr = GetDeviceAddress<T>(inputs, 1); | |||||
| T *deltas_addr = GetDeviceAddress<T>(outputs, 0); | |||||
| if (inputs[0]->size != inputs[1]->size) { | |||||
| MS_LOG(ERROR) << "Anchor box size must equal with groundtruth box size -" << inputs[1]->size << ", but got" | |||||
| << inputs[0]->size; | |||||
| return false; | |||||
| } | |||||
| const size_t coordinate = 4; | |||||
| const size_t block_size = inputs[0]->size / sizeof(T); | |||||
| if ((block_size % coordinate) != 0) { | |||||
| MS_LOG(ERROR) << "The size of the box must be a multiple of 4."; | |||||
| return false; | |||||
| } | |||||
| BoundingBoxEncode(block_size / coordinate, anchor_addr, groundtruth_addr, deltas_addr, means_[0], means_[1], | |||||
| means_[2], means_[3], stds_[0], stds_[1], stds_[2], stds_[3], | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| return true; | |||||
| } | |||||
| bool Init(const CNodePtr &kernel_node) override { | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||||
| if (input_num != 2) { | |||||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but BoundingBoxEncode needs 2 inputs."; | |||||
| return false; | |||||
| } | |||||
| anchor_size_ = sizeof(T); | |||||
| groundtruth_size_ = sizeof(T); | |||||
| deltas_size_ = sizeof(T); | |||||
| auto logits_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||||
| for (size_t i = 0; i < logits_shape.size(); i++) { | |||||
| anchor_size_ *= logits_shape[i]; | |||||
| } | |||||
| auto labels_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||||
| for (size_t i = 0; i < labels_shape.size(); i++) { | |||||
| groundtruth_size_ *= labels_shape[i]; | |||||
| } | |||||
| auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||||
| for (size_t i = 0; i < output_shape.size(); i++) { | |||||
| deltas_size_ *= output_shape[i]; | |||||
| } | |||||
| InitSizeLists(); | |||||
| const size_t coordinate_size = 4; | |||||
| if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<ValueTuple>() || | |||||
| AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<ValueList>()) { | |||||
| means_ = GetAttr<std::vector<float>>(kernel_node, "means"); | |||||
| } else if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<FloatImm>()) { | |||||
| float mean = GetAttr<int>(kernel_node, "means"); | |||||
| for (size_t i = 0; i < coordinate_size; i++) { | |||||
| means_.emplace_back(mean); | |||||
| } | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Attribute means type is invalid."; | |||||
| } | |||||
| if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<ValueTuple>() || | |||||
| AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<ValueList>()) { | |||||
| stds_ = GetAttr<std::vector<float>>(kernel_node, "stds"); | |||||
| } else if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<FloatImm>()) { | |||||
| float std = GetAttr<int>(kernel_node, "stds"); | |||||
| for (size_t i = 0; i < coordinate_size; i++) { | |||||
| stds_.emplace_back(std); | |||||
| } | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Attribute stds type is invalid."; | |||||
| } | |||||
| if (means_.size() < coordinate_size || stds_.size() < coordinate_size) { | |||||
| MS_LOG(EXCEPTION) << "The size of means or stds is less than 4."; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| protected: | |||||
| void InitSizeLists() override { | |||||
| input_size_list_.push_back(anchor_size_); | |||||
| input_size_list_.push_back(groundtruth_size_); | |||||
| output_size_list_.push_back(deltas_size_); | |||||
| } | |||||
| private: | |||||
| size_t anchor_size_; | |||||
| size_t groundtruth_size_; | |||||
| size_t deltas_size_; | |||||
| std::vector<float> means_; | |||||
| std::vector<float> stds_; | |||||
| std::vector<size_t> input_size_list_; | |||||
| std::vector<size_t> output_size_list_; | |||||
| std::vector<size_t> workspace_size_list_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_OTHER_BOUNDINGBOX_ENCODE_GPU_KERNEL_H | |||||
| @@ -0,0 +1,60 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| class NetBoundingBoxDecode(nn.Cell): | |||||
| def __init__(self, means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0)): | |||||
| super(NetBoundingBoxDecode, self).__init__() | |||||
| self.decode = P.BoundingBoxDecode(max_shape=(768, 1280), means=means, stds=stds, | |||||
| wh_ratio_clip=0.016) | |||||
| def construct(self, anchor, groundtruth): | |||||
| return self.decode(anchor, groundtruth) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_boundingbox_decode(): | |||||
| anchor = np.array([[4, 1, 2, 1], [2, 2, 2, 3]], np.float32) | |||||
| deltas = np.array([[3, 1, 2, 2], [1, 2, 1, 4]], np.float32) | |||||
| means = (0.1, 0.1, 0.2, 0.2) | |||||
| stds = (2.0, 2.0, 3.0, 3.0) | |||||
| anchor_box = Tensor(anchor, mindspore.float32) | |||||
| deltas_box = Tensor(deltas, mindspore.float32) | |||||
| expect_deltas = np.array([[28.6500, 0.0000, 0.0000, 33.8500], | |||||
| [0.0000, 0.0000, 15.8663, 72.7000]], np.float32) | |||||
| error = np.ones(shape=[2, 4]) * 1.0e-4 | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||||
| boundingbox_decode = NetBoundingBoxDecode(means, stds) | |||||
| output = boundingbox_decode(anchor_box, deltas_box) | |||||
| diff = output.asnumpy() - expect_deltas | |||||
| assert np.all(abs(diff) < error) | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') | |||||
| boundingbox_decode = NetBoundingBoxDecode(means, stds) | |||||
| output = boundingbox_decode(anchor_box, deltas_box) | |||||
| diff = output.asnumpy() - expect_deltas | |||||
| assert np.all(abs(diff) < error) | |||||
| @@ -0,0 +1,80 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| class NetBoundingBoxEncode(nn.Cell): | |||||
| def __init__(self, means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0)): | |||||
| super(NetBoundingBoxEncode, self).__init__() | |||||
| self.encode = P.BoundingBoxEncode(means=means, stds=stds) | |||||
| def construct(self, anchor, groundtruth): | |||||
| return self.encode(anchor, groundtruth) | |||||
| def bbox2delta(proposals, gt, means, stds): | |||||
| px = (proposals[..., 0] + proposals[..., 2]) * 0.5 | |||||
| py = (proposals[..., 1] + proposals[..., 3]) * 0.5 | |||||
| pw = proposals[..., 2] - proposals[..., 0] + 1.0 | |||||
| ph = proposals[..., 3] - proposals[..., 1] + 1.0 | |||||
| gx = (gt[..., 0] + gt[..., 2]) * 0.5 | |||||
| gy = (gt[..., 1] + gt[..., 3]) * 0.5 | |||||
| gw = gt[..., 2] - gt[..., 0] + 1.0 | |||||
| gh = gt[..., 3] - gt[..., 1] + 1.0 | |||||
| dx = (gx - px) / pw | |||||
| dy = (gy - py) / ph | |||||
| dw = np.log(gw / pw) | |||||
| dh = np.log(gh / ph) | |||||
| means = np.array(means, np.float32) | |||||
| stds = np.array(stds, np.float32) | |||||
| deltas = np.stack([(dx - means[0]) / stds[0], (dy - means[1]) / stds[1], | |||||
| (dw - means[2]) / stds[2], (dh - means[3]) / stds[3]], axis=-1) | |||||
| return deltas | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_boundingbox_encode(): | |||||
| anchor = np.array([[4, 1, 6, 9], [2, 5, 5, 9]]).astype(np.float32) | |||||
| gt = np.array([[3, 2, 7, 7], [1, 5, 5, 8]]).astype(np.float32) | |||||
| means = (0.1, 0.1, 0.2, 0.2) | |||||
| stds = (2.0, 2.0, 3.0, 3.0) | |||||
| anchor_box = Tensor(anchor, mindspore.float32) | |||||
| groundtruth_box = Tensor(gt, mindspore.float32) | |||||
| expect_deltas = bbox2delta(anchor, gt, means, stds) | |||||
| error = np.ones(shape=[2, 4]) * 1.0e-6 | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||||
| boundingbox_encode = NetBoundingBoxEncode(means, stds) | |||||
| output = boundingbox_encode(anchor_box, groundtruth_box) | |||||
| diff = output.asnumpy() - expect_deltas | |||||
| assert np.all(abs(diff) < error) | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') | |||||
| boundingbox_encode = NetBoundingBoxEncode(means, stds) | |||||
| output = boundingbox_encode(anchor_box, groundtruth_box) | |||||
| diff = output.asnumpy() - expect_deltas | |||||
| assert np.all(abs(diff) < error) | |||||
| @@ -0,0 +1,116 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| class NetFloorDiv(nn.Cell): | |||||
| def __init__(self): | |||||
| super(NetFloorDiv, self).__init__() | |||||
| self.floordiv = P.FloorDiv() | |||||
| def construct(self, x, y): | |||||
| return self.floordiv(x, y) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_floor_div(): | |||||
| x0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32) | |||||
| y0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32) | |||||
| x1_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32) | |||||
| y1_np = np.random.randint(1, 5, (2, 1, 4, 4)).astype(np.float32) | |||||
| x2_np = np.random.randint(1, 5, (2, 1, 1, 4)).astype(np.float32) | |||||
| y2_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32) | |||||
| x3_np = np.random.randint(1, 5, 1).astype(np.float32) | |||||
| y3_np = np.random.randint(1, 5, 1).astype(np.float32) | |||||
| x4_np = np.array(768).astype(np.float32) | |||||
| y4_np = np.array(3072.5).astype(np.float32) | |||||
| x5_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float16) | |||||
| y5_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float16) | |||||
| x6_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.int32) | |||||
| y6_np = np.random.randint(1, 5, (2, 1, 4, 4)).astype(np.int32) | |||||
| x0 = Tensor(x0_np) | |||||
| y0 = Tensor(y0_np) | |||||
| x1 = Tensor(x1_np) | |||||
| y1 = Tensor(y1_np) | |||||
| x2 = Tensor(x2_np) | |||||
| y2 = Tensor(y2_np) | |||||
| x3 = Tensor(x3_np) | |||||
| y3 = Tensor(y3_np) | |||||
| x4 = Tensor(x4_np) | |||||
| y4 = Tensor(y4_np) | |||||
| x5 = Tensor(x5_np) | |||||
| y5 = Tensor(y5_np) | |||||
| x6 = Tensor(x6_np) | |||||
| y6 = Tensor(y6_np) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||||
| floor_div = NetFloorDiv() | |||||
| output0 = floor_div(x0, y0) | |||||
| expect0 = np.floor_divide(x0_np, y0_np) | |||||
| diff0 = output0.asnumpy() - expect0 | |||||
| error0 = np.ones(shape=expect0.shape) * 1.0e-5 | |||||
| assert np.all(diff0 < error0) | |||||
| assert output0.shape == expect0.shape | |||||
| output1 = floor_div(x1, y1) | |||||
| expect1 = np.floor_divide(x1_np, y1_np) | |||||
| diff1 = output1.asnumpy() - expect1 | |||||
| error1 = np.ones(shape=expect1.shape) * 1.0e-5 | |||||
| assert np.all(diff1 < error1) | |||||
| assert output1.shape == expect1.shape | |||||
| output2 = floor_div(x2, y2) | |||||
| expect2 = np.floor_divide(x2_np, y2_np) | |||||
| diff2 = output2.asnumpy() - expect2 | |||||
| error2 = np.ones(shape=expect2.shape) * 1.0e-5 | |||||
| assert np.all(diff2 < error2) | |||||
| assert output2.shape == expect2.shape | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') | |||||
| output3 = floor_div(x3, y3) | |||||
| expect3 = np.floor_divide(x3_np, y3_np) | |||||
| diff3 = output3.asnumpy() - expect3 | |||||
| error3 = np.ones(shape=expect3.shape) * 1.0e-5 | |||||
| assert np.all(diff3 < error3) | |||||
| assert output3.shape == expect3.shape | |||||
| output4 = floor_div(x4, y4) | |||||
| expect4 = np.floor_divide(x4_np, y4_np) | |||||
| diff4 = output4.asnumpy() - expect4 | |||||
| error4 = np.ones(shape=expect4.shape) * 1.0e-5 | |||||
| assert np.all(diff4 < error4) | |||||
| assert output4.shape == expect4.shape | |||||
| output5 = floor_div(x5, y5) | |||||
| expect5 = np.floor_divide(x5_np, y5_np) | |||||
| diff5 = output5.asnumpy() - expect5 | |||||
| error5 = np.ones(shape=expect5.shape) * 1.0e-5 | |||||
| assert np.all(diff5 < error5) | |||||
| assert output5.shape == expect5.shape | |||||
| output6 = floor_div(x6, y6) | |||||
| expect6 = np.floor_divide(x6_np, y6_np) | |||||
| diff6 = output6.asnumpy() - expect6 | |||||
| error6 = np.ones(shape=expect6.shape) * 1.0e-5 | |||||
| assert np.all(diff6 < error6) | |||||
| assert output6.shape == expect6.shape | |||||
| @@ -0,0 +1,151 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| import mindspore.nn as nn | |||||
| import mindspore.context as context | |||||
| class GatherNdNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super(GatherNdNet, self).__init__() | |||||
| self.gathernd = P.GatherNd() | |||||
| def construct(self, x, indices): | |||||
| return self.gathernd(x, indices) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_gathernd0(): | |||||
| x = Tensor(np.arange(3 * 2, dtype=np.float32).reshape(3, 2)) | |||||
| indices = Tensor(np.array([[1, 1], [0, 1]]).astype(np.int32)) | |||||
| expect = np.array([3., 1.]) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| gathernd = GatherNdNet() | |||||
| output = gathernd(x, indices) | |||||
| error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 | |||||
| diff = output.asnumpy() - expect | |||||
| assert np.all(diff < error) | |||||
| assert np.all(-diff < error) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_traning | |||||
| @pytest.mark.env_onecard | |||||
| def test_gathernd1(): | |||||
| x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.float32).reshape(2, 3, 4, 5)) | |||||
| indices = Tensor(np.array([[[[[l, k, j, i] for i in [1, 3, 4]] for j in range(4)] | |||||
| for k in range(3)] for l in range(2)], dtype='i4')) | |||||
| expect = np.array([[[[1., 3., 4.], | |||||
| [6., 8., 9.], | |||||
| [11., 13., 14.], | |||||
| [16., 18., 19.]], | |||||
| [[21., 23., 24.], | |||||
| [26., 28., 29.], | |||||
| [31., 33., 34.], | |||||
| [36., 38., 39.]], | |||||
| [[41., 43., 44.], | |||||
| [46., 48., 49.], | |||||
| [51., 53., 54.], | |||||
| [56., 58., 59.]]], | |||||
| [[[61., 63., 64.], | |||||
| [66., 68., 69.], | |||||
| [71., 73., 74.], | |||||
| [76., 78., 79.]], | |||||
| [[81., 83., 84.], | |||||
| [86., 88., 89.], | |||||
| [91., 93., 94.], | |||||
| [96., 98., 99.]], | |||||
| [[101., 103., 104.], | |||||
| [106., 108., 109.], | |||||
| [111., 113., 114.], | |||||
| [116., 118., 119.]]]]) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| gather = GatherNdNet() | |||||
| output = gather(x, indices) | |||||
| error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 | |||||
| diff = output.asnumpy() - expect | |||||
| assert np.all(diff < error) | |||||
| assert np.all(-diff < error) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_traning | |||||
| @pytest.mark.env_onecard | |||||
| def test_gathernd2(): | |||||
| x = Tensor(np.array([[4., 5., 4., 1., 5.], | |||||
| [4., 9., 5., 6., 4.], | |||||
| [9., 8., 4., 3., 6.], | |||||
| [0., 4., 2., 2., 8.], | |||||
| [1., 8., 6., 2., 8.], | |||||
| [8., 1., 9., 7., 3.], | |||||
| [7., 9., 2., 5., 7.], | |||||
| [9., 8., 6., 8., 5.], | |||||
| [3., 7., 2., 7., 4.], | |||||
| [4., 2., 8., 2., 9.]]).astype(np.float16)) | |||||
| indices = Tensor(np.array([[4000], [1], [300000]]).astype(np.int32)) | |||||
| expect = np.array([[0., 0., 0., 0., 0.], | |||||
| [4., 9., 5., 6., 4.], | |||||
| [0., 0., 0., 0., 0.]]) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| gathernd = GatherNdNet() | |||||
| output = gathernd(x, indices) | |||||
| error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 | |||||
| diff = output.asnumpy() - expect | |||||
| assert np.all(diff < error) | |||||
| assert np.all(-diff < error) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_traning | |||||
| @pytest.mark.env_onecard | |||||
| def test_gathernd3(): | |||||
| x = Tensor(np.array([[4, 5, 4, 1, 5], | |||||
| [4, 9, 5, 6, 4], | |||||
| [9, 8, 4, 3, 6], | |||||
| [0, 4, 2, 2, 8], | |||||
| [1, 8, 6, 2, 8], | |||||
| [8, 1, 9, 7, 3], | |||||
| [7, 9, 2, 5, 7], | |||||
| [9, 8, 6, 8, 5], | |||||
| [3, 7, 2, 7, 4], | |||||
| [4, 2, 8, 2, 9]] | |||||
| ).astype(np.int32)) | |||||
| indices = Tensor(np.array([[4000], [1], [300000]]).astype(np.int32)) | |||||
| expect = np.array([[0, 0, 0, 0, 0], | |||||
| [4, 9, 5, 6, 4], | |||||
| [0, 0, 0, 0, 0]]) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| gathernd = GatherNdNet() | |||||
| output = gathernd(x, indices) | |||||
| error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 | |||||
| diff = output.asnumpy() - expect | |||||
| assert np.all(diff < error) | |||||
| assert np.all(-diff < error) | |||||
| @@ -0,0 +1,50 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| class Net(nn.Cell): | |||||
| def __init__(self, _shape): | |||||
| super(Net, self).__init__() | |||||
| self.shape = _shape | |||||
| self.scatternd = P.ScatterNd() | |||||
| def construct(self, indices, update): | |||||
| return self.scatternd(indices, update, self.shape) | |||||
| def scatternd_net(indices, update, _shape, expect): | |||||
| scatternd = Net(_shape) | |||||
| output = scatternd(Tensor(indices), Tensor(update)) | |||||
| error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 | |||||
| diff = output.asnumpy() - expect | |||||
| assert np.all(diff < error) | |||||
| assert np.all(-diff < error) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_traning | |||||
| @pytest.mark.env_onecard | |||||
| def test_scatternd(): | |||||
| arr_indices = np.array([[0, 1], [1, 1]]).astype(np.int32) | |||||
| arr_update = np.array([3.2, 1.1]).astype(np.float32) | |||||
| shape = (2, 2) | |||||
| expect = np.array([[0., 3.2], | |||||
| [0., 1.1]]) | |||||
| scatternd_net(arr_indices, arr_update, shape, expect) | |||||
| @@ -0,0 +1,73 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.nn import Dense | |||||
| from mindspore.nn import TrainOneStepCell, WithLossCell | |||||
| from mindspore.nn.optim import SGD | |||||
| from mindspore.ops import operations as P | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| class NetSGD(nn.Cell): | |||||
| def __init__(self): | |||||
| super(NetSGD, self).__init__() | |||||
| self.batch_size = 1 | |||||
| self.reshape = P.Reshape() | |||||
| weight = Tensor(np.ones([10, 16]).astype(np.float32) * 0.01) | |||||
| self.fc1 = Dense(16, 10, weight_init=weight) | |||||
| def construct(self, input_x): | |||||
| output = self.reshape(input_x, (self.batch_size, -1)) | |||||
| output = self.fc1(output) | |||||
| return output | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_SGD(): | |||||
| epoch = 3 | |||||
| net = NetSGD() | |||||
| learning_rate = 0.1 | |||||
| momentum = 0.9 | |||||
| dampening = 0.0 | |||||
| weight_decay = 0.0 | |||||
| nesterov = True | |||||
| loss_scale = 1.0 | |||||
| optimizer = SGD(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum, dampening, | |||||
| weight_decay, nesterov, loss_scale) | |||||
| criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||||
| net_with_criterion = WithLossCell(net, criterion) | |||||
| train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer | |||||
| train_network.set_train() | |||||
| losses = [] | |||||
| for _ in range(epoch): | |||||
| data = Tensor(np.arange(0, 16).reshape(1, 1, 4, 4).astype(np.float32) * 0.01) | |||||
| label = Tensor(np.array([0]).astype(np.int32)) | |||||
| loss = train_network(data, label) | |||||
| losses.append(loss.asnumpy()) | |||||
| last_loss = 100.0 | |||||
| for loss in losses: | |||||
| assert last_loss > loss | |||||
| last_loss = loss | |||||
| return losses | |||||