| @@ -24,11 +24,5 @@ MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutp | |||
| SliceGpuFwdKernel, int) | |||
| MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| SliceGpuFwdKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| SliceGpuFwdKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| SliceGpuFwdKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| SliceGpuFwdKernel, int) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -41,14 +41,9 @@ class SliceGpuFwdKernel : public GpuKernel { | |||
| } | |||
| T *input = GetDeviceAddress<T>(inputs, 0); | |||
| T *output = GetDeviceAddress<T>(outputs, 0); | |||
| if (is_strided_slice_) { | |||
| CalStridedSlice(output_size_ / sizeof(T), input, input_shape_, begin_, size_, strides_, output, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } else { | |||
| Slice4DKernel(begin_[0], begin_[1], begin_[2], begin_[3], size_[0], size_[1], size_[2], size_[3], input_shape_[0], | |||
| input_shape_[1], input_shape_[2], input_shape_[3], input, output, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } | |||
| Slice4DKernel(begin_[0], begin_[1], begin_[2], begin_[3], size_[0], size_[1], size_[2], size_[3], input_shape_[0], | |||
| input_shape_[1], input_shape_[2], input_shape_[3], input, output, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| @@ -29,11 +29,5 @@ MS_REG_GPU_KERNEL_ONE( | |||
| SliceGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| SliceGradGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| SliceGradGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| SliceGradGpuKernel, int) | |||
| MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| SliceGradGpuKernel, half) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -38,13 +38,8 @@ class SliceGradGpuKernel : public GpuKernel { | |||
| T *dy = GetDeviceAddress<T>(inputs, 0); | |||
| T *dx = GetDeviceAddress<T>(outputs, 0); | |||
| FillDeviceArray(outputs[0]->size / sizeof(T), dx, 0.f, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| if (is_strided_slice_) { | |||
| CalStridedSliceGrad(output_size_ / sizeof(T), dy, input_shape_, begin_, size_, strides_, dx, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } else { | |||
| CalSliceGrad(output_size_ / sizeof(T), dy, input_shape_, begin_, size_, dx, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } | |||
| CalSliceGrad(output_size_ / sizeof(T), dy, input_shape_, begin_, size_, dx, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| @@ -140,7 +135,7 @@ class SliceGradGpuKernel : public GpuKernel { | |||
| size_t input_size_; | |||
| size_t output_size_; | |||
| size_t workspace_size_; | |||
| }; | |||
| }; // namespace kernel | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,28 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| StridedSliceGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| StridedSliceGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| StridedSliceGpuKernel, int) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,190 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_STRIDED_SLICE_GPU_KERNEL_H | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_STRIDED_SLICE_GPU_KERNEL_H | |||
| #include <vector> | |||
| #include <bitset> | |||
| #include <algorithm> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| constexpr int MAX_DIMS = 4; | |||
| template <typename T> | |||
| class StridedSliceGpuKernel : public GpuKernel { | |||
| public: | |||
| StridedSliceGpuKernel() : null_output_(false) {} | |||
| ~StridedSliceGpuKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| if (null_output_) { | |||
| return true; | |||
| } | |||
| T *input = GetDeviceAddress<T>(inputs, 0); | |||
| T *output = GetDeviceAddress<T>(outputs, 0); | |||
| StridedSlice(input_shape_, begin_, strides_, output_shape_, input, output, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| if (input_shape_.size() > MAX_DIMS) { | |||
| MS_LOG(ERROR) << "StridedSlice support support dims less than " << input_shape_.size(); | |||
| return false; | |||
| } | |||
| FillEmptyDims(kernel_node); | |||
| ParseMasks(kernel_node); | |||
| FillOutputDim(); | |||
| null_output_ = IsNullOutput(); | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitSizeLists() override { | |||
| input_size_list_.push_back(input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3] * sizeof(T)); | |||
| output_size_list_.push_back(output_shape_[0] * output_shape_[1] * output_shape_[2] * output_shape_[3] * sizeof(T)); | |||
| } | |||
| private: | |||
| void FillEmptyDims(const CNodePtr &kernel_node) { | |||
| begin_ = GetAttr<std::vector<int>>(kernel_node, "begin"); | |||
| end_ = GetAttr<std::vector<int>>(kernel_node, "end"); | |||
| strides_ = GetAttr<std::vector<int>>(kernel_node, "strides"); | |||
| for (size_t i = 0; i < MAX_DIMS; i++) { | |||
| if (i < begin_.size()) { | |||
| begin_[i] = | |||
| std::min(begin_[i] < 0 ? SizeToInt(begin_[i] + input_shape_[i]) : begin_[i], SizeToInt(input_shape_[i] - 1)); | |||
| } else { | |||
| begin_.push_back(0); | |||
| } | |||
| if (i < end_.size()) { | |||
| end_[i] = std::max(end_[i] < 0 ? end_[i] + SizeToInt(input_shape_[i]) : end_[i], -1); | |||
| } else { | |||
| end_.push_back(i < input_shape_.size() ? input_shape_[i] : 1); | |||
| } | |||
| if (i >= strides_.size()) { | |||
| strides_.push_back(1); | |||
| } | |||
| if (i >= input_shape_.size()) { | |||
| input_shape_.push_back(1); | |||
| } | |||
| } | |||
| } | |||
| void ParseMasks(const CNodePtr &kernel_node) { | |||
| auto begin_mask_int = GetAttr<int>(kernel_node, "begin_mask"); | |||
| auto begin_mask = Dec2Bin(begin_mask_int); | |||
| for (size_t i = 0; i < begin_mask.size(); i++) { | |||
| if (begin_mask[i]) { | |||
| begin_[i] = 0; | |||
| } | |||
| } | |||
| auto end_mask_int = GetAttr<int>(kernel_node, "end_mask"); | |||
| auto end_mask = Dec2Bin(end_mask_int); | |||
| for (size_t j = 0; j < end_mask.size(); j++) { | |||
| if (end_mask[j]) { | |||
| end_[j] = input_shape_[j]; | |||
| } | |||
| } | |||
| auto ellipsis_mask_int = GetAttr<int>(kernel_node, "ellipsis_mask"); | |||
| auto ellipsis_mask = Dec2Bin(ellipsis_mask_int); | |||
| for (size_t k = 0; k < ellipsis_mask.size(); k++) { | |||
| if (ellipsis_mask[k]) { | |||
| begin_[k] = 0; | |||
| end_[k] = input_shape_[k]; | |||
| strides_[k] = 1; | |||
| } | |||
| } | |||
| auto shrink_axis_mask_str = GetAttr<int>(kernel_node, "shrink_axis_mask"); | |||
| auto shrink_axis_mask = Dec2Bin(shrink_axis_mask_str); | |||
| for (size_t l = 0; l < shrink_axis_mask.size(); l++) { | |||
| if (shrink_axis_mask[l]) { | |||
| end_[l] = end_[l] > begin_[l] ? begin_[l] + 1 : begin_[l] - 1; | |||
| strides_[l] = end_[l] > begin_[l] ? 1 : -1; | |||
| } | |||
| } | |||
| } | |||
| std::vector<bool> Dec2Bin(const int &mask) { | |||
| auto mask_str = std::bitset<MAX_DIMS>(mask).to_string(); | |||
| int dim_idx = 0; | |||
| std::vector<bool> result = {false, false, false, false}; | |||
| for (int i = mask_str.size() - 1; i >= 0; i--) { | |||
| if (mask_str[i] == '1') { | |||
| result[dim_idx] = true; | |||
| } | |||
| dim_idx++; | |||
| } | |||
| return result; | |||
| } | |||
| void FillOutputDim() { | |||
| for (int i = 0; i < MAX_DIMS; i++) { | |||
| if (begin_[i] <= end_[i] && strides_[i] > 0) { | |||
| output_shape_.push_back((end_[i] - 1 - begin_[i]) / strides_[i] + 1); | |||
| } else if (begin_[i] > end_[i] && strides_[i] < 0) { | |||
| output_shape_.push_back((end_[i] - begin_[i] + 1) / strides_[i] + 1); | |||
| } else { | |||
| output_shape_.push_back(0); | |||
| } | |||
| } | |||
| } | |||
| bool IsNullOutput() { | |||
| for (int i = 0; i < MAX_DIMS; i++) { | |||
| if (begin_[i] >= end_[i] && strides_[i] > 0) { | |||
| return true; | |||
| } | |||
| if (begin_[i] < end_[i] && strides_[i] < 0) { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| std::vector<int> begin_; | |||
| std::vector<int> end_; | |||
| std::vector<int> strides_; | |||
| std::vector<size_t> input_shape_; | |||
| std::vector<int> output_shape_; | |||
| int null_output_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_STRIDED_SLICE_GPU_KERNEL_H | |||
| @@ -0,0 +1,28 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| StridedSliceGradGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| StridedSliceGradGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| StridedSliceGradGpuKernel, int) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,191 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_STRIDED_SLICE_GRAD_GPU_KERNEL_H | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_STRIDED_SLICE_GRAD_GPU_KERNEL_H | |||
| #include <vector> | |||
| #include <bitset> | |||
| #include <algorithm> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| constexpr int MAX_DIMS = 4; | |||
| template <typename T> | |||
| class StridedSliceGradGpuKernel : public GpuKernel { | |||
| public: | |||
| StridedSliceGradGpuKernel() : null_output_(false) {} | |||
| ~StridedSliceGradGpuKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| T *dy = GetDeviceAddress<T>(inputs, 0); | |||
| T *dx = GetDeviceAddress<T>(outputs, 0); | |||
| FillDeviceArray(outputs[0]->size / sizeof(T), dx, 0.f, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| if (null_output_) { | |||
| return true; | |||
| } | |||
| StridedSliceGrad(output_shape_, begin_, strides_, input_shape_, dy, dx, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| input_shape_ = GetAttr<std::vector<int>>(kernel_node, "shapex"); | |||
| if (input_shape_.size() > MAX_DIMS) { | |||
| MS_LOG(ERROR) << "StridedSliceGrad support support dims less than " << input_shape_.size(); | |||
| return false; | |||
| } | |||
| FillEmptyDims(kernel_node); | |||
| ParseMasks(kernel_node); | |||
| FillOutputDim(); | |||
| null_output_ = IsNullOutput(); | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitSizeLists() override { | |||
| input_size_list_.push_back(output_shape_[0] * output_shape_[1] * output_shape_[2] * output_shape_[3] * sizeof(T)); | |||
| output_size_list_.push_back(input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3] * sizeof(T)); | |||
| } | |||
| private: | |||
| void FillEmptyDims(const CNodePtr &kernel_node) { | |||
| begin_ = GetAttr<std::vector<int>>(kernel_node, "begin"); | |||
| end_ = GetAttr<std::vector<int>>(kernel_node, "end"); | |||
| strides_ = GetAttr<std::vector<int>>(kernel_node, "strides"); | |||
| for (size_t i = 0; i < MAX_DIMS; i++) { | |||
| if (i < begin_.size()) { | |||
| begin_[i] = | |||
| std::min(begin_[i] < 0 ? SizeToInt(begin_[i] + input_shape_[i]) : begin_[i], SizeToInt(input_shape_[i] - 1)); | |||
| } else { | |||
| begin_.push_back(0); | |||
| } | |||
| if (i < end_.size()) { | |||
| end_[i] = std::max(end_[i] < 0 ? end_[i] + SizeToInt(input_shape_[i]) : end_[i], -1); | |||
| } else { | |||
| end_.push_back(i < input_shape_.size() ? input_shape_[i] : 1); | |||
| } | |||
| if (i >= strides_.size()) { | |||
| strides_.push_back(1); | |||
| } | |||
| if (i >= input_shape_.size()) { | |||
| input_shape_.push_back(1); | |||
| } | |||
| } | |||
| } | |||
| void ParseMasks(const CNodePtr &kernel_node) { | |||
| auto begin_mask_int = GetAttr<int>(kernel_node, "begin_mask"); | |||
| auto begin_mask = Dec2Bin(begin_mask_int); | |||
| for (size_t i = 0; i < begin_mask.size(); i++) { | |||
| if (begin_mask[i]) { | |||
| begin_[i] = 0; | |||
| } | |||
| } | |||
| auto end_mask_int = GetAttr<int>(kernel_node, "end_mask"); | |||
| auto end_mask = Dec2Bin(end_mask_int); | |||
| for (size_t j = 0; j < end_mask.size(); j++) { | |||
| if (end_mask[j]) { | |||
| end_[j] = input_shape_[j]; | |||
| } | |||
| } | |||
| auto ellipsis_mask_int = GetAttr<int>(kernel_node, "ellipsis_mask"); | |||
| auto ellipsis_mask = Dec2Bin(ellipsis_mask_int); | |||
| for (size_t k = 0; k < ellipsis_mask.size(); k++) { | |||
| if (ellipsis_mask[k]) { | |||
| begin_[k] = 0; | |||
| end_[k] = input_shape_[k]; | |||
| strides_[k] = 1; | |||
| } | |||
| } | |||
| auto shrink_axis_mask_str = GetAttr<int>(kernel_node, "shrink_axis_mask"); | |||
| auto shrink_axis_mask = Dec2Bin(shrink_axis_mask_str); | |||
| for (size_t l = 0; l < shrink_axis_mask.size(); l++) { | |||
| if (shrink_axis_mask[l]) { | |||
| end_[l] = end_[l] > begin_[l] ? begin_[l] + 1 : begin_[l] - 1; | |||
| strides_[l] = end_[l] > begin_[l] ? 1 : -1; | |||
| } | |||
| } | |||
| } | |||
| std::vector<bool> Dec2Bin(const int &mask) { | |||
| auto mask_str = std::bitset<MAX_DIMS>(mask).to_string(); | |||
| int dim_idx = 0; | |||
| std::vector<bool> result = {false, false, false, false}; | |||
| for (int i = mask_str.size() - 1; i >= 0; i--) { | |||
| if (mask_str[i] == '1') { | |||
| result[dim_idx] = true; | |||
| } | |||
| dim_idx++; | |||
| } | |||
| return result; | |||
| } | |||
| void FillOutputDim() { | |||
| for (int i = 0; i < MAX_DIMS; i++) { | |||
| if (begin_[i] <= end_[i] && strides_[i] > 0) { | |||
| output_shape_.push_back((end_[i] - 1 - begin_[i]) / strides_[i] + 1); | |||
| } else if (begin_[i] > end_[i] && strides_[i] < 0) { | |||
| output_shape_.push_back((end_[i] - begin_[i] + 1) / strides_[i] + 1); | |||
| } else { | |||
| output_shape_.push_back(0); | |||
| } | |||
| } | |||
| } | |||
| bool IsNullOutput() { | |||
| for (int i = 0; i < MAX_DIMS; i++) { | |||
| if (begin_[i] >= end_[i] && strides_[i] > 0) { | |||
| return true; | |||
| } | |||
| if (begin_[i] < end_[i] && strides_[i] < 0) { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| std::vector<int> begin_; | |||
| std::vector<int> end_; | |||
| std::vector<int> strides_; | |||
| std::vector<int> input_shape_; | |||
| std::vector<int> output_shape_; | |||
| int null_output_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_STRIDED_SLICE_GRAD_GPU_KERNEL_H | |||
| @@ -21,48 +21,29 @@ | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh" | |||
| template <typename T> | |||
| __global__ void Slice4D(const int s1, const int s2, const int s3, const int s4, | |||
| const int l1, const int l2, const int l3, const int l4, | |||
| const int d1, const int d2, const int d3, const int d4, | |||
| const T *input, T *output) { | |||
| __global__ void Slice4D(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2, | |||
| const int l3, const int l4, const int d1, const int d2, const int d3, const int d4, | |||
| const T *input, T *output) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (l1 * l2 * l3 * l4); pos += blockDim.x * gridDim.x) { | |||
| int i = pos / (l2 * l3 * l4) % l1; | |||
| int j = pos / (l3 * l4) % l2; | |||
| int k = pos / l4 % l3; | |||
| int o = pos % l4; | |||
| int offset = (i + s1) * (d2 * d3 * d4) + | |||
| (j + s2) * (d3 * d4) + | |||
| (k + s3) * d4 + | |||
| (o + s4); | |||
| int offset = (i + s1) * (d2 * d3 * d4) + (j + s2) * (d3 * d4) + (k + s3) * d4 + (o + s4); | |||
| output[pos] = input[offset]; | |||
| } | |||
| } | |||
| template <typename T> | |||
| __global__ void SliceGrad(const T* dy, int p, int start, int length, T* output) { | |||
| __global__ void SliceGrad(const T *dy, int p, int start, int length, T *output) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (length); pos += blockDim.x * gridDim.x) { | |||
| output[start + pos] = dy[p + pos]; | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void StridedSlice(const T* input, int p, int start, int begin, int stride, int ended, T* output) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < std::ceil(static_cast<float>(ended - begin) / stride); | |||
| pos += blockDim.x * gridDim.x) { | |||
| output[p + pos] = input[start + pos * stride]; | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void StridedSliceGrad(const T* dy, int p, int start, int begin, int stride, int ended, T* dx) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < std::ceil(static_cast<float>(ended - begin) / stride); | |||
| pos += blockDim.x * gridDim.x) { | |||
| dx[start + pos * stride] = dy[p + pos]; | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void FillArray(T* addr, const size_t len, const float value) { | |||
| __global__ void FillArray(T *addr, const size_t len, const float value) { | |||
| T value_ = static_cast<T>(value); | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < len; pos += blockDim.x * gridDim.x) { | |||
| addr[pos] = value_; | |||
| @@ -70,21 +51,20 @@ __global__ void FillArray(T* addr, const size_t len, const float value) { | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void FillDeviceArray(const size_t input_size, T* addr, const float value, cudaStream_t cuda_stream) { | |||
| void FillDeviceArray(const size_t input_size, T *addr, const float value, cudaStream_t cuda_stream) { | |||
| FillArray<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>(addr, input_size, value); | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, | |||
| const int l1, const int l2, const int l3, const int l4, | |||
| const int d1, const int d2, const int d3, const int d4, | |||
| const T *input, T *output, cudaStream_t stream) { | |||
| Slice4D<<<GET_BLOCKS(l1 * l2 * l3 * l4), GET_THREADS, 0, stream>>>(s1, s2, s3, s4, l1, l2, l3, l4, | |||
| d1, d2, d3, d4, input, output); | |||
| void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2, const int l3, | |||
| const int l4, const int d1, const int d2, const int d3, const int d4, const T *input, T *output, | |||
| cudaStream_t stream) { | |||
| Slice4D<<<GET_BLOCKS(l1 * l2 * l3 * l4), GET_THREADS, 0, stream>>>(s1, s2, s3, s4, l1, l2, l3, l4, d1, d2, d3, d4, | |||
| input, output); | |||
| } | |||
| template <typename T> | |||
| void CalSliceGrad(const size_t input_size, const T* dy, const std::vector<int> in_shape, const std::vector<int> begin, | |||
| const std::vector<int> size, T* output, cudaStream_t cuda_stream) { | |||
| void CalSliceGrad(const size_t input_size, const T *dy, const std::vector<int> in_shape, const std::vector<int> begin, | |||
| const std::vector<int> size, T *output, cudaStream_t cuda_stream) { | |||
| int block = in_shape[1] * in_shape[2] * in_shape[3]; | |||
| int map = in_shape[2] * in_shape[3]; | |||
| int w = in_shape[3]; | |||
| @@ -100,92 +80,100 @@ void CalSliceGrad(const size_t input_size, const T* dy, const std::vector<int> i | |||
| } | |||
| } | |||
| } | |||
| template <typename T> | |||
| void CalStridedSlice(const size_t input_size, const T* input, const std::vector<int> in_shape, | |||
| const std::vector<int> begin, const std::vector<int> end, const std::vector<int> strides, | |||
| T* output, cudaStream_t cuda_stream) { | |||
| int block = in_shape[1] * in_shape[2] * in_shape[3]; | |||
| int map = in_shape[2] * in_shape[3]; | |||
| int w = in_shape[3]; | |||
| int ended = end[3]; | |||
| int p = 0; | |||
| int start = 0; | |||
| for (int i = begin[0]; i < ((end[0] > begin[0]) ? end[0] : (2 * begin[0] - end[0])); i += std::abs(strides[0])) { | |||
| for (int j = begin[1]; j < ((end[1] > begin[1]) ? end[1] : (2 * begin[1] - end[1])); j += std::abs(strides[1])) { | |||
| for (int k = begin[2]; k < ((end[2] > begin[2]) ? end[2] : (2 * begin[2] - end[2])); k += std::abs(strides[2])) { | |||
| start = (strides[0] > 0 ? i : 2 * begin[0] - i) * block + (strides[1] > 0 ? j : 2 * begin[1] - j) * map + | |||
| (strides[2] > 0 ? k : 2 * begin[2] - k) * w + begin[3]; | |||
| StridedSlice<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>(input, p, start, begin[3], strides[3], | |||
| ended, output); | |||
| p = p + std::ceil(static_cast<float>(end[3] - begin[3]) / strides[3]); | |||
| } | |||
| } | |||
| __global__ void StridedSliceKernel(const int b0, const int b1, const int b2, const int b3, const int s0, const int s1, | |||
| const int s2, const int s3, const int i0, const int i1, const int i2, const int i3, | |||
| const int o0, const int o1, const int o2, const int o3, const T *input_addr, | |||
| T *output_addr) { | |||
| int output_num = o0 * o1 * o2 * o3; | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < output_num; pos += blockDim.x * gridDim.x) { | |||
| int i = pos / (o1 * o2 * o3) % o0; | |||
| int j = pos / (o2 * o3) % o1; | |||
| int k = pos / o3 % o2; | |||
| int l = pos % o3; | |||
| int input_idx = (i * s0 + b0) * i1 * i2 * i3 + (j * s1 + b1) * i2 * i3 + (k * s2 + b2) * i3 + (l * s3 + b3); | |||
| output_addr[pos] = input_addr[input_idx]; | |||
| } | |||
| } | |||
| template <typename T> | |||
| void CalStridedSliceGrad(const size_t input_size, const T* dy, const std::vector<int> in_shape, | |||
| const std::vector<int> begin, const std::vector<int> end, const std::vector<int> strides, | |||
| T* dx, cudaStream_t cuda_stream) { | |||
| int block = in_shape[1] * in_shape[2] * in_shape[3]; | |||
| int map = in_shape[2] * in_shape[3]; | |||
| int w = in_shape[3]; | |||
| int ended = end[3]; | |||
| int p = 0; | |||
| int start = 0; | |||
| for (int i = begin[0]; i < ((end[0] > begin[0]) ? end[0] : (2 * begin[0] - end[0] + 1)); i += std::abs(strides[0])) { | |||
| for (int j = begin[1]; j < ((end[1] > begin[1]) ? end[1] : (2 * begin[1] - end[1] + 1)); | |||
| j += std::abs(strides[1])) { | |||
| for (int k = begin[2]; k < ((end[2] > begin[2]) ? end[2] : (2 * begin[2] - end[2] + 1)); | |||
| k += std::abs(strides[2])) { | |||
| start = (strides[0] > 0 ? i : 2 * begin[0] - i) * block + (strides[1] > 0 ? j : 2 * begin[1] - j) * map + | |||
| (strides[2] > 0 ? k : 2 * begin[2] - k) * w + begin[3]; | |||
| StridedSliceGrad<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>(dy, p, start, begin[3], strides[3], | |||
| ended, dx); | |||
| p = p + std::ceil(static_cast<float>(end[3] - begin[3]) / strides[3]); | |||
| } | |||
| } | |||
| void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int> &begin, | |||
| const std::vector<int> &strides, const std::vector<int> &output_shape, const T *input, T *output, | |||
| cudaStream_t cuda_stream) { | |||
| int size = output_shape[0] * output_shape[1] * output_shape[2] * output_shape[3]; | |||
| StridedSliceKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>( | |||
| begin[0], begin[1], begin[2], begin[3], strides[0], strides[1], strides[2], strides[3], input_shape[0], | |||
| input_shape[1], input_shape[2], input_shape[3], output_shape[0], output_shape[1], output_shape[2], output_shape[3], | |||
| input, output); | |||
| } | |||
| template <typename T> | |||
| __global__ void StridedSliceGradKernel(const int b0, const int b1, const int b2, const int b3, const int s0, | |||
| const int s1, const int s2, const int s3, const int i0, const int i1, | |||
| const int i2, const int i3, const int o0, const int o1, const int o2, | |||
| const int o3, const T *dy, T *dx) { | |||
| int output_num = o0 * o1 * o2 * o3; | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < output_num; pos += blockDim.x * gridDim.x) { | |||
| int i = pos / (o1 * o2 * o3) % o0; | |||
| int j = pos / (o2 * o3) % o1; | |||
| int k = pos / o3 % o2; | |||
| int l = pos % o3; | |||
| int input_idx = (i * s0 + b0) * i1 * i2 * i3 + (j * s1 + b1) * i2 * i3 + (k * s2 + b2) * i3 + (l * s3 + b3); | |||
| dx[input_idx] = dy[pos]; | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void StridedSliceGrad(const std::vector<int> &dy_shape, const std::vector<int> &begin, const std::vector<int> &strides, | |||
| const std::vector<int> &dx_shape, const T *dy, T *dx, cudaStream_t cuda_stream) { | |||
| int size = dy_shape[0] * dy_shape[1] * dy_shape[2] * dy_shape[3]; | |||
| StridedSliceGradKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>( | |||
| begin[0], begin[1], begin[2], begin[3], strides[0], strides[1], strides[2], strides[3], dx_shape[0], dx_shape[1], | |||
| dx_shape[2], dx_shape[3], dy_shape[0], dy_shape[1], dy_shape[2], dy_shape[3], dy, dx); | |||
| } | |||
| template void FillDeviceArray<float>(const size_t input_size, float* addr, const float value, cudaStream_t cuda_stream); | |||
| template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, | |||
| const int l1, const int l2, const int l3, const int l4, | |||
| const int d1, const int d2, const int d3, const int d4, | |||
| template void FillDeviceArray<float>(const size_t input_size, float *addr, const float value, cudaStream_t cuda_stream); | |||
| template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2, | |||
| const int l3, const int l4, const int d1, const int d2, const int d3, const int d4, | |||
| const float *input, float *output, cudaStream_t stream); | |||
| template void CalSliceGrad<float>(const size_t input_size, const float* dy, const std::vector<int> in_shape, | |||
| const std::vector<int> begin, const std::vector<int> size, float* output, | |||
| template void CalSliceGrad<float>(const size_t input_size, const float *dy, const std::vector<int> in_shape, | |||
| const std::vector<int> begin, const std::vector<int> size, float *output, | |||
| cudaStream_t cuda_stream); | |||
| template void CalStridedSlice<float>(const size_t input_size, const float* input, const std::vector<int> in_shape, | |||
| const std::vector<int> begin, const std::vector<int> end, | |||
| const std::vector<int> strides, float* output, cudaStream_t cuda_stream); | |||
| template void CalStridedSliceGrad<float>(const size_t input_size, const float* dy, const std::vector<int> in_shape, | |||
| const std::vector<int> begin, const std::vector<int> end, | |||
| const std::vector<int> strides, float* dx, cudaStream_t cuda_stream); | |||
| template void FillDeviceArray<half>(const size_t input_size, half* addr, const float value, cudaStream_t cuda_stream); | |||
| template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, | |||
| const int l1, const int l2, const int l3, const int l4, | |||
| const int d1, const int d2, const int d3, const int d4, | |||
| template void FillDeviceArray<half>(const size_t input_size, half *addr, const float value, cudaStream_t cuda_stream); | |||
| template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2, | |||
| const int l3, const int l4, const int d1, const int d2, const int d3, const int d4, | |||
| const half *input, half *output, cudaStream_t stream); | |||
| template void CalSliceGrad<half>(const size_t input_size, const half* dy, const std::vector<int> in_shape, | |||
| const std::vector<int> begin, const std::vector<int> size, half* output, | |||
| template void CalSliceGrad<half>(const size_t input_size, const half *dy, const std::vector<int> in_shape, | |||
| const std::vector<int> begin, const std::vector<int> size, half *output, | |||
| cudaStream_t cuda_stream); | |||
| template void CalStridedSlice<half>(const size_t input_size, const half* input, const std::vector<int> in_shape, | |||
| const std::vector<int> begin, const std::vector<int> end, | |||
| const std::vector<int> strides, half* output, cudaStream_t cuda_stream); | |||
| template void CalStridedSliceGrad<half>(const size_t input_size, const half* dy, const std::vector<int> in_shape, | |||
| const std::vector<int> begin, const std::vector<int> end, | |||
| const std::vector<int> strides, half* dx, cudaStream_t cuda_stream); | |||
| template void FillDeviceArray<int>(const size_t input_size, int* addr, const float value, cudaStream_t cuda_stream); | |||
| template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, | |||
| const int l1, const int l2, const int l3, const int l4, | |||
| const int d1, const int d2, const int d3, const int d4, | |||
| template void FillDeviceArray<int>(const size_t input_size, int *addr, const float value, cudaStream_t cuda_stream); | |||
| template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2, | |||
| const int l3, const int l4, const int d1, const int d2, const int d3, const int d4, | |||
| const int *input, int *output, cudaStream_t stream); | |||
| template void CalSliceGrad<int>(const size_t input_size, const int* dy, const std::vector<int> in_shape, | |||
| const std::vector<int> begin, const std::vector<int> size, int* output, | |||
| template void CalSliceGrad<int>(const size_t input_size, const int *dy, const std::vector<int> in_shape, | |||
| const std::vector<int> begin, const std::vector<int> size, int *output, | |||
| cudaStream_t cuda_stream); | |||
| template void CalStridedSlice<int>(const size_t input_size, const int* input, const std::vector<int> in_shape, | |||
| const std::vector<int> begin, const std::vector<int> end, | |||
| const std::vector<int> strides, int* output, cudaStream_t cuda_stream); | |||
| template void CalStridedSliceGrad<int>(const size_t input_size, const int* dy, const std::vector<int> in_shape, | |||
| const std::vector<int> begin, const std::vector<int> end, | |||
| const std::vector<int> strides, int* dx, cudaStream_t cuda_stream); | |||
| template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int> &begin, | |||
| const std::vector<int> &strides, const std::vector<int> &output_shape, const float *input, | |||
| float *output, cudaStream_t cuda_stream); | |||
| template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int> &begin, | |||
| const std::vector<int> &strides, const std::vector<int> &output_shape, const half *input, | |||
| half *output, cudaStream_t cuda_stream); | |||
| template void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int> &begin, | |||
| const std::vector<int> &strides, const std::vector<int> &output_shape, const int *input, | |||
| int *output, cudaStream_t cuda_stream); | |||
| template void StridedSliceGrad(const std::vector<int> &dy_shape, const std::vector<int> &begin, | |||
| const std::vector<int> &strides, const std::vector<int> &dx_shape, const float *dy, | |||
| float *dx, cudaStream_t cuda_stream); | |||
| template void StridedSliceGrad(const std::vector<int> &dy_shape, const std::vector<int> &begin, | |||
| const std::vector<int> &strides, const std::vector<int> &dx_shape, const half *dy, | |||
| half *dx, cudaStream_t cuda_stream); | |||
| template void StridedSliceGrad(const std::vector<int> &dy_shape, const std::vector<int> &begin, | |||
| const std::vector<int> &strides, const std::vector<int> &dx_shape, const int *dy, | |||
| int *dx, cudaStream_t cuda_stream); | |||
| @@ -21,23 +21,20 @@ | |||
| #include <vector> | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, | |||
| const int l1, const int l2, const int l3, const int l4, | |||
| const int d1, const int d2, const int d3, const int d4, | |||
| const T *input, T *output, cudaStream_t stream); | |||
| void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, const int l1, const int l2, const int l3, | |||
| const int l4, const int d1, const int d2, const int d3, const int d4, const T *input, T *output, | |||
| cudaStream_t stream); | |||
| template <typename T> | |||
| void CalSliceGrad(const size_t input_size, const T* input, const std::vector<int> in_shape, | |||
| const std::vector<int> begin, const std::vector<int> size, T* output, cudaStream_t cuda_stream); | |||
| void CalSliceGrad(const size_t input_size, const T *input, const std::vector<int> in_shape, | |||
| const std::vector<int> begin, const std::vector<int> size, T *output, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void CalStridedSlice(const size_t input_size, const T* input, const std::vector<int> in_shape, | |||
| const std::vector<int> begin, const std::vector<int> end, const std::vector<int> strides, | |||
| T* output, cudaStream_t cuda_stream); | |||
| void StridedSlice(const std::vector<size_t> &input_shape, const std::vector<int> &begin, | |||
| const std::vector<int> &strides, const std::vector<int> &output_shape, const T *input, T *output, | |||
| cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void CalStridedSliceGrad(const size_t input_size, const T* dy, const std::vector<int> in_shape, | |||
| const std::vector<int> begin, const std::vector<int> end, const std::vector<int> strides, | |||
| T* dx, cudaStream_t cuda_stream); | |||
| void StridedSliceGrad(const std::vector<int> &dy_shape, const std::vector<int> &begin, const std::vector<int> &strides, | |||
| const std::vector<int> &dx_shape, const T *dy, T *dx, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void FillDeviceArray(const size_t input_size, T* addr, const float value, cudaStream_t cuda_stream); | |||
| void FillDeviceArray(const size_t input_size, T *addr, const float value, cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SLICEIMPL_H_ | |||
| @@ -19,31 +19,258 @@ import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.common.api import ms_function | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops.operations import _grad_ops as G | |||
| from mindspore.ops import composite as C | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| class StridedSliceGrad(nn.Cell): | |||
| def __init__(self): | |||
| super(StridedSliceGrad, self).__init__() | |||
| self.ssg = G.StridedSliceGrad() | |||
| self.shape = P.Shape() | |||
| class StridedSliceNet(nn.Cell): | |||
| def __init__(self, begin, end, stride, begin_mask=0, end_mask=0, ellipsis_mask=0): | |||
| super(StridedSliceNet, self).__init__() | |||
| self.begin = begin | |||
| self.end = end | |||
| self.strides = stride | |||
| self.slice = P.StridedSlice(begin_mask, end_mask, ellipsis_mask) | |||
| @ms_function | |||
| def construct(self, dy, x): | |||
| return self.ssg(dy, self.shape(x), (2, 0, 0), (3, 2, 3), (1, 1, 1)) | |||
| def construct(self, x): | |||
| return self.slice(x, self.begin, self.end, self.strides) | |||
| class GradData(nn.Cell): | |||
| def __init__(self, network): | |||
| super(GradData, self).__init__() | |||
| self.grad = C.GradOperation(name="get_all", get_all=True, sens_param=False) | |||
| self.network = network | |||
| def construct(self, x): | |||
| return self.grad(self.network)(x) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_slice(): | |||
| x = Tensor(np.array([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 7, 8]]]).astype(np.float32)) | |||
| dy = Tensor(np.array([[[5., 1., 5.], [6., 1., 8.]]]).astype(np.float32)) | |||
| ssg = StridedSliceGrad() | |||
| output = ssg(dy, x) | |||
| expect = [[[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]], [[5, 1, 5], [6, 1, 8]]] | |||
| assert (output.asnumpy() == expect).all() | |||
| def test_strided_slice_grad(): | |||
| x = Tensor(np.arange(0, 2*3*4*5).reshape(2, 3, 4, 5).astype(np.float32)) | |||
| net = StridedSliceNet((1, 0, 0, 2), (2, 2, 2, 4), (1, 1, 1, 1)) | |||
| dx = GradData(net)(x) | |||
| expect = np.array([[[[0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.]], | |||
| [[0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.]], | |||
| [[0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.]]], | |||
| [[[0., 0., 1., 1., 0.], | |||
| [0., 0., 1., 1., 0.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.]], | |||
| [[0., 0., 1., 1., 0.], | |||
| [0., 0., 1., 1., 0.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.]], | |||
| [[0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.]]]]) | |||
| assert np.allclose(dx[0].asnumpy(), expect) | |||
| net = StridedSliceNet((1, 0, 0, 5), (2, 2, 2, 1), (1, 1, 1, -2)) | |||
| dx = GradData(net)(x) | |||
| expect = np.array([[[[0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.]], | |||
| [[0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.]], | |||
| [[0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.]]], | |||
| [[[0., 0., 1., 0., 1.], | |||
| [0., 0., 1., 0., 1.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.]], | |||
| [[0., 0., 1., 0., 1.], | |||
| [0., 0., 1., 0., 1.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.]], | |||
| [[0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.]]]]) | |||
| assert np.allclose(dx[0].asnumpy(), expect) | |||
| net = StridedSliceNet((1, 0, 0, -1), (2, 2, 2, 1), (1, 1, 1, -1)) | |||
| dx = GradData(net)(x) | |||
| expect = np.array([[[[0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.]], | |||
| [[0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.]], | |||
| [[0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.]]], | |||
| [[[0., 0., 1., 1., 1.], | |||
| [0., 0., 1., 1., 1.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.]], | |||
| [[0., 0., 1., 1., 1.], | |||
| [0., 0., 1., 1., 1.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.]], | |||
| [[0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.]]]]) | |||
| assert np.allclose(dx[0].asnumpy(), expect) | |||
| # ME infer fault | |||
| # y = GradData()(x, (1, 0, -1, -2), (2, 2, 0, -5), (1, 1, -1, -2)) | |||
| # expect = np.array([[[[0., 0., 0., 0., 0.], | |||
| # [0., 0., 0., 0., 0.], | |||
| # [0., 0., 0., 0., 0.], | |||
| # [0., 0., 0., 0., 0.]], | |||
| # [[0., 0., 0., 0., 0.], | |||
| # [0., 0., 0., 0., 0.], | |||
| # [0., 0., 0., 0., 0.], | |||
| # [0., 0., 0., 0., 0.]], | |||
| # [[0., 0., 0., 0., 0.], | |||
| # [0., 0., 0., 0., 0.], | |||
| # [0., 0., 0., 0., 0.], | |||
| # [0., 0., 0., 0., 0.]]], | |||
| # [[[0., 0., 0., 0., 0.], | |||
| # [0., 1., 0., 1., 0.], | |||
| # [0., 1., 0., 1., 0.], | |||
| # [0., 1., 0., 1., 0.]], | |||
| # [[0., 0., 0., 0., 0.], | |||
| # [0., 1., 0., 1., 0.], | |||
| # [0., 1., 0., 1., 0.], | |||
| # [0., 1., 0., 1., 0.]],begin_mask=0b1000, end_mask=0b0010, ellipsis_mask=0b0100 | |||
| # [[0., 0., 0., 0., 0.], | |||
| # [0., 0., 0., 0., 0.], | |||
| # [0., 0., 0., 0., 0.], | |||
| # [0., 0., 0., 0., 0.]]]]) | |||
| # assert np.allclose(y.asnumpy(), expect) | |||
| # y = Grad(begin_mask=0b1000, end_mask=0b0010)(x, (1, 0, 0, 2), (2, 2, 2, 4), (1, 1, 1, 1)) | |||
| # expect = np.array([[[[0., 0., 0., 0., 0.], | |||
| # [0., 0., 0., 0., 0.], | |||
| # [0., 0., 0., 0., 0.], | |||
| # [0., 0., 0., 0., 0.]], | |||
| # [[0., 0., 0., 0., 0.], | |||
| # [0., 0., 0., 0., 0.], | |||
| # [0., 0., 0., 0., 0.], | |||
| # [0., 0., 0., 0., 0.]], | |||
| # [[0., 0., 0., 0., 0.], | |||
| # [0., 0., 0., 0., 0.], | |||
| # [0., 0., 0., 0., 0.], | |||
| # [0., 0., 0., 0., 0.]]], | |||
| # [[[0., 0., 1., 1., 0.], | |||
| # [0., 0., 1., 1., 0.], | |||
| # [0., 0., 0., 0., 0.], | |||
| # [0., 0., 0., 0., 0.]], | |||
| # [[0., 0., 1., 1., 0.], | |||
| # [0., 0., 1., 1., 0.], | |||
| # [0., 0., 0., 0., 0.], | |||
| # [0., 0., 0., 0., 0.]], | |||
| # [[0., 0., 0., 0., 0.], | |||
| # [0., 0., 0., 0., 0.], | |||
| # [0., 0., 0., 0., 0.], | |||
| # [0., 0., 0., 0., 0.]]]]) | |||
| # assert np.allclose(y.asnumpy(), expect) | |||
| net = StridedSliceNet((1, 0, 0, 2), (2, 2, 2, 4), (1, 1, 1, 1), | |||
| begin_mask=0b1000, end_mask=0b0010, ellipsis_mask=0b0100) | |||
| dx = GradData(net)(x) | |||
| expect = np.array([[[[0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.]], | |||
| [[0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.]], | |||
| [[0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.]]], | |||
| [[[1., 1., 1., 1., 0.], | |||
| [1., 1., 1., 1., 0.], | |||
| [1., 1., 1., 1., 0.], | |||
| [1., 1., 1., 1., 0.]], | |||
| [[1., 1., 1., 1., 0.], | |||
| [1., 1., 1., 1., 0.], | |||
| [1., 1., 1., 1., 0.], | |||
| [1., 1., 1., 1., 0.]], | |||
| [[1., 1., 1., 1., 0.], | |||
| [1., 1., 1., 1., 0.], | |||
| [1., 1., 1., 1., 0.], | |||
| [1., 1., 1., 1., 0.]]]]) | |||
| assert np.allclose(dx[0].asnumpy(), expect) | |||
| x = Tensor(np.arange(0, 3*4*5).reshape(3, 4, 5).astype(np.float32)) | |||
| net = StridedSliceNet((1, 0, 0), (2, -3, 3), (1, 1, 3)) | |||
| dx = GradData(net)(x) | |||
| expect = np.array([[[0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.]], | |||
| [[1., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.]], | |||
| [[0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.], | |||
| [0., 0., 0., 0., 0.]]]) | |||
| assert np.allclose(dx[0].asnumpy(), expect) | |||
| @@ -17,29 +17,79 @@ import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| class StridedSlice(nn.Cell): | |||
| def __init__(self): | |||
| super(StridedSlice, self).__init__() | |||
| self.stridedslice = P.StridedSlice() | |||
| def construct(self, x): | |||
| return self.stridedslice(x, (2, 0, 0), (3, 2, 3), (1, 1, 1)) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_slice(): | |||
| x = Tensor(np.array([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 7, 8]]]).astype(np.int32)) | |||
| stridedslice = StridedSlice() | |||
| output = stridedslice(x) | |||
| expect = [[[5., 5., 5.], | |||
| [6., 7., 8.]]] | |||
| assert (output.asnumpy() == expect).all() | |||
| def test_stridedslice(): | |||
| x = Tensor(np.arange(0, 2*3*4*5).reshape(2, 3, 4, 5).astype(np.float32)) | |||
| y = P.StridedSlice()(x, (1, 0, 0, 2), (2, 2, 2, 4), (1, 1, 1, 1)) | |||
| expect = np.array([[[[62, 63], | |||
| [67, 68]], | |||
| [[82, 83], | |||
| [87, 88]]]]) | |||
| assert np.allclose(y.asnumpy(), expect) | |||
| y = P.StridedSlice()(x, (1, 0, 0, 5), (2, 2, 2, 1), (1, 1, 1, -2)) | |||
| expect = np.array([[[[64, 62], | |||
| [69, 67]], | |||
| [[84, 82], | |||
| [89, 87]]]]) | |||
| assert np.allclose(y.asnumpy(), expect) | |||
| y = P.StridedSlice()(x, (1, 0, 0, -1), (2, 2, 2, 1), (1, 1, 1, -1)) | |||
| expect = np.array([[[[64, 63, 62], | |||
| [69, 68, 67]], | |||
| [[84, 83, 82], | |||
| [89, 88, 87]]]]) | |||
| assert np.allclose(y.asnumpy(), expect) | |||
| # ME infer fault | |||
| # y = P.StridedSlice()(x, (1, 0, -1, -2), (2, 2, 0, -5), (1, 1, -1, -2)) | |||
| # expect = np.array([[[[78, 76], | |||
| # [73, 71], | |||
| # [68, 66]], | |||
| # [[98, 96], | |||
| # [93, 91], | |||
| # [88, 86]]]]) | |||
| # assert np.allclose(y.asnumpy(), expect) | |||
| # y = P.StridedSlice(begin_mask=0b1000, end_mask=0b0010)(x, (1, 0, 0, 2), (2, 2, 2, 4), (1, 1, 1, 1)) | |||
| # expect = np.array([[[[ 62, 63], | |||
| # [ 67, 68]], | |||
| # [[ 82, 83], | |||
| # [ 87, 88]], | |||
| # [[102, 103], | |||
| # [107, 108]]]]) | |||
| # assert np.allclose(y.asnumpy(), expect) | |||
| op = P.StridedSlice(begin_mask=0b1000, end_mask=0b0010, ellipsis_mask=0b0100) | |||
| y = op(x, (1, 0, 0, 2), (2, 2, 2, 4), (1, 1, 1, 1)) | |||
| expect = np.array([[[[60, 61, 62, 63], | |||
| [65, 66, 67, 68], | |||
| [70, 71, 72, 73], | |||
| [75, 76, 77, 78]], | |||
| [[80, 81, 82, 83], | |||
| [85, 86, 87, 88], | |||
| [90, 91, 92, 93], | |||
| [95, 96, 97, 98]], | |||
| [[100, 101, 102, 103], | |||
| [105, 106, 107, 108], | |||
| [110, 111, 112, 113], | |||
| [115, 116, 117, 118]]]]) | |||
| assert np.allclose(y.asnumpy(), expect) | |||
| x = Tensor(np.arange(0, 3*4*5).reshape(3, 4, 5).astype(np.float32)) | |||
| y = P.StridedSlice()(x, (1, 0, 0), (2, -3, 3), (1, 1, 3)) | |||
| expect = np.array([[[20]]]) | |||
| assert np.allclose(y.asnumpy(), expect) | |||
| x_np = np.arange(0, 4*5).reshape(4, 5).astype(np.float32) | |||
| y = Tensor(x_np)[:, ::-1] | |||
| expect = x_np[:, ::-1] | |||
| assert np.allclose(y.asnumpy(), expect) | |||