Merge pull request !20231 from caifubi/master-tensor-copy-slices-Generalizationtags/v1.4.0
| @@ -19,16 +19,19 @@ | |||
| #include <algorithm> | |||
| #include <vector> | |||
| #include <numeric> | |||
| #include <functional> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/common_utils.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/slice_copy_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class TensorCopySlicesGpuKernel : public GpuKernel { | |||
| public: | |||
| TensorCopySlicesGpuKernel() : input_size_(0), update_size_(0), output_size_(0), offset_(0), copy_size_(0) {} | |||
| TensorCopySlicesGpuKernel() : input_size_(0), update_size_(0), output_size_(0) {} | |||
| ~TensorCopySlicesGpuKernel() {} | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| @@ -37,20 +40,12 @@ class TensorCopySlicesGpuKernel : public GpuKernel { | |||
| T *update_addr = GetDeviceAddress<T>(inputs, 1); | |||
| T *output_addr = GetDeviceAddress<T>(outputs, 0); | |||
| if (inputs[1]->size != copy_size_) { | |||
| MS_LOG(EXCEPTION) << "Invalid update size:" << inputs[1]->size << " copy_size_:" << copy_size_; | |||
| } | |||
| CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, | |||
| cudaMemcpyAsync(output_addr, input_addr, inputs[0]->size, cudaMemcpyDeviceToDevice, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "TensorCopySlices cudaMemcpyAsync outputs failed"); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, | |||
| cudaMemcpyAsync(output_addr + offset_, update_addr, inputs[1]->size, | |||
| cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "TensorCopySlices cudaMemcpyAsync outputs failed"); | |||
| CopySlices(update_shape_, begin_, strides_, output_shape_, update_addr, output_addr, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| @@ -73,42 +68,58 @@ class TensorCopySlicesGpuKernel : public GpuKernel { | |||
| return false; | |||
| } | |||
| auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| auto update_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||
| auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||
| input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| CastShapeSizeToLong(input_shapes, &input_shapes_); | |||
| CastShapeSizeToLong(update_shapes, &update_shapes_); | |||
| CastShapeSizeToLong(output_shapes, &output_shapes_); | |||
| if (input_shape_.size() > kMaxDims) { | |||
| MS_LOG(ERROR) << "StridedSlice support dims no more than " << kMaxDims << ", but the input shape is " | |||
| << input_shape_.size(); | |||
| return false; | |||
| } | |||
| GetSize(); | |||
| InitSizeLists(); | |||
| begin_ = GetAttr<std::vector<int64_t>>(kernel_node, kAttrBegin); | |||
| end_ = GetAttr<std::vector<int64_t>>(kernel_node, kAttrEnd); | |||
| strides_ = GetAttr<std::vector<int64_t>>(kernel_node, kAttrStrides); | |||
| auto begin = GetAttr<std::vector<int64_t>>(kernel_node, kAttrBegin); | |||
| auto end = GetAttr<std::vector<int64_t>>(kernel_node, kAttrEnd); | |||
| auto strides = GetAttr<std::vector<int64_t>>(kernel_node, kAttrStrides); | |||
| FillEmptyDims(kernel_node); | |||
| output_shape_ = input_shape_; | |||
| FillUpdateDim(); | |||
| CheckAtrrAndShapeValid(kernel_node); | |||
| CheckSliceValid(begin, end, strides, input_shapes_); | |||
| auto dim_offset = CalDimOffset(input_shapes_); | |||
| offset_ = CalOffset(begin, end, strides, dim_offset); | |||
| copy_size_ = GetCopySize(dim_offset, begin, end) * sizeof(T); | |||
| GetSize(); | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| protected: | |||
| void CheckAtrrAndShapeValid(const CNodePtr &kernel_node) { | |||
| auto update_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||
| size_t total_update_num = std::accumulate(update_shape.begin(), update_shape.end(), 1, std::multiplies<size_t>()); | |||
| if (begin_.size() != end_.size() || end_.size() != strides_.size()) { | |||
| MS_LOG(EXCEPTION) << "Invalid attr begin:" << begin_ << " end:" << end_ << " strides:" << strides_; | |||
| } | |||
| auto len = begin_.size(); | |||
| size_t total_input_num = 1; | |||
| for (size_t i = 0; i < len; ++i) { | |||
| total_input_num *= ((end_[i] - begin_[i]) / strides_[i]); | |||
| } | |||
| if (total_input_num != total_update_num) { | |||
| MS_LOG(EXCEPTION) << "Invalid update_shape:" << update_shape << ". Maybe you need to broadcast it."; | |||
| } | |||
| } | |||
| void GetSize() { | |||
| input_size_ = sizeof(T); | |||
| for (size_t i = 0; i < input_shapes_.size(); i++) { | |||
| input_size_ *= LongToSize(input_shapes_[i]); | |||
| for (size_t i = 0; i < input_shape_.size(); i++) { | |||
| input_size_ *= input_shape_[i]; | |||
| } | |||
| update_size_ = sizeof(T); | |||
| for (size_t i = 0; i < update_shapes_.size(); i++) { | |||
| update_size_ *= LongToSize(update_shapes_[i]); | |||
| for (size_t i = 0; i < update_shape_.size(); i++) { | |||
| update_size_ *= update_shape_[i]; | |||
| } | |||
| output_size_ = sizeof(T); | |||
| for (size_t i = 0; i < output_shapes_.size(); i++) { | |||
| output_size_ *= LongToSize(output_shapes_[i]); | |||
| for (size_t i = 0; i < output_shape_.size(); i++) { | |||
| output_size_ *= output_shape_[i]; | |||
| } | |||
| } | |||
| @@ -119,21 +130,61 @@ class TensorCopySlicesGpuKernel : public GpuKernel { | |||
| return; | |||
| } | |||
| void FillEmptyDims(const CNodePtr &kernel_node) { | |||
| for (size_t i = 0; i < kMaxDims; i++) { | |||
| if (i < begin_.size()) { | |||
| int64_t dim = input_shape_[i]; | |||
| begin_[i] = std::min(begin_[i] < 0 ? std::max(begin_[i] + dim, static_cast<int64_t>(0)) : begin_[i], dim - 1); | |||
| } else { | |||
| begin_.push_back(0); | |||
| } | |||
| if (i < end_.size()) { | |||
| int64_t dim = input_shape_[i]; | |||
| end_[i] = std::max(end_[i] < 0 ? end_[i] + dim : std::min(end_[i], dim), static_cast<int64_t>(-1)); | |||
| } else { | |||
| end_.push_back(i < input_shape_.size() ? input_shape_[i] : 1); | |||
| } | |||
| if (i >= strides_.size()) { | |||
| strides_.push_back(1); | |||
| } | |||
| if (i >= input_shape_.size()) { | |||
| input_shape_.push_back(1); | |||
| } | |||
| } | |||
| } | |||
| void FillUpdateDim() { | |||
| for (size_t i = 0; i < kMaxDims; i++) { | |||
| if (begin_[i] <= end_[i] && strides_[i] > 0) { | |||
| update_shape_.push_back((end_[i] - 1 - begin_[i]) / strides_[i] + 1); | |||
| } else if (begin_[i] > end_[i] && strides_[i] < 0) { | |||
| update_shape_.push_back((end_[i] - begin_[i] + 1) / strides_[i] + 1); | |||
| } else { | |||
| update_shape_.push_back(0); | |||
| } | |||
| } | |||
| } | |||
| private: | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| std::vector<int64_t> input_shapes_; | |||
| std::vector<int64_t> update_shapes_; | |||
| std::vector<int64_t> output_shapes_; | |||
| std::vector<size_t> input_shape_; | |||
| std::vector<size_t> update_shape_; | |||
| std::vector<size_t> output_shape_; | |||
| std::vector<int64_t> begin_; | |||
| std::vector<int64_t> end_; | |||
| std::vector<int64_t> strides_; | |||
| size_t input_size_; | |||
| size_t update_size_; | |||
| size_t output_size_; | |||
| size_t offset_; | |||
| size_t copy_size_; | |||
| inline static size_t kMaxDims = 8; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,132 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <cuda_runtime.h> | |||
| #include <stdio.h> | |||
| #include <stdint.h> | |||
| #include <algorithm> | |||
| #include <numeric> | |||
| #include <functional> | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/slice_copy_impl.cuh" | |||
| namespace { | |||
| constexpr size_t kMaxDim = 8; | |||
| } | |||
| template <typename T, size_t N> | |||
| class VectorWrapper { | |||
| public: | |||
| explicit VectorWrapper(const std::vector<T> &v) { std::copy(v.begin(), v.end(), data); } | |||
| ~VectorWrapper() {} | |||
| __device__ T& operator[](size_t index) { return data[index]; } | |||
| private: | |||
| T data[N]; | |||
| }; | |||
| template <typename T> | |||
| __global__ void CopySlicesKernel(VectorWrapper<int64_t, kMaxDim> begins, VectorWrapper<int64_t, kMaxDim> stride, | |||
| VectorWrapper<size_t, kMaxDim> u, VectorWrapper<size_t, kMaxDim> u_offset, | |||
| VectorWrapper<size_t, kMaxDim> o_offset, const T *update_addr, T *output_addr) { | |||
| size_t update_num = u[0] * u_offset[0]; | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < update_num; pos += blockDim.x * gridDim.x) { | |||
| size_t i = pos / (u_offset[0]) % u[0]; | |||
| size_t j = pos / (u_offset[1]) % u[1]; | |||
| size_t k = pos / (u_offset[2]) % u[2]; | |||
| size_t l = pos / (u_offset[3]) % u[3]; | |||
| size_t m = pos / (u_offset[4]) % u[4]; | |||
| size_t n = pos / (u_offset[5]) % u[5]; | |||
| size_t o = pos / (u[7]) % u[6]; | |||
| size_t p = pos % u[7]; | |||
| size_t output_idx = (i * stride[0] + begins[0]) * o_offset[0] + (j * stride[1] + begins[1]) * o_offset[1] + | |||
| (k * stride[2] + begins[2]) * o_offset[2] + (l * stride[3] + begins[3]) * o_offset[3] + | |||
| (m * stride[4] + begins[4]) * o_offset[4] + (n * stride[5] + begins[5]) * o_offset[5] + | |||
| (o * stride[6] + begins[6]) * o_offset[6] + (p * stride[7] + begins[7]); | |||
| output_addr[output_idx] = update_addr[pos]; | |||
| } | |||
| } | |||
| std::vector<size_t> CalculateOffset(const std::vector<size_t> &shape) { | |||
| std::vector<size_t> offset(kMaxDim); | |||
| offset[7] = 1; | |||
| offset[6] = offset[7] * shape[7]; | |||
| offset[5] = offset[6] * shape[6]; | |||
| offset[4] = offset[5] * shape[5]; | |||
| offset[3] = offset[4] * shape[4]; | |||
| offset[2] = offset[3] * shape[3]; | |||
| offset[1] = offset[2] * shape[2]; | |||
| offset[0] = offset[1] * shape[1]; | |||
| return offset; | |||
| } | |||
| template <typename T> | |||
| void CopySlices(const std::vector<size_t> &update_shape, const std::vector<int64_t> &begin, | |||
| const std::vector<int64_t> &stride, const std::vector<size_t> &output_shape, const T *update, T *output, | |||
| cudaStream_t cuda_stream) { | |||
| size_t size = std::accumulate(update_shape.begin(), update_shape.end(), 1, std::multiplies<size_t>()); | |||
| VectorWrapper<size_t, kMaxDim> o_offset(CalculateOffset(output_shape)); | |||
| VectorWrapper<size_t, kMaxDim> u_offset(CalculateOffset(update_shape)); | |||
| VectorWrapper<int64_t, kMaxDim> begins(begin); | |||
| VectorWrapper<int64_t, kMaxDim> strides(stride); | |||
| VectorWrapper<size_t, kMaxDim> update_shapes(update_shape); | |||
| CopySlicesKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(begins, strides, update_shapes, u_offset, | |||
| o_offset, update, output); | |||
| } | |||
| template void CopySlices(const std::vector<size_t> &update_shape, const std::vector<int64_t> &begin, | |||
| const std::vector<int64_t> &stride, const std::vector<size_t> &output_shape, | |||
| const bool *update, bool *output, cudaStream_t cuda_stream); | |||
| template void CopySlices(const std::vector<size_t> &update_shape, const std::vector<int64_t> &begin, | |||
| const std::vector<int64_t> &stride, const std::vector<size_t> &output_shape, | |||
| const double *update, double *output, cudaStream_t cuda_stream); | |||
| template void CopySlices(const std::vector<size_t> &update_shape, const std::vector<int64_t> &begin, | |||
| const std::vector<int64_t> &stride, const std::vector<size_t> &output_shape, | |||
| const float *update, float *output, cudaStream_t cuda_stream); | |||
| template void CopySlices(const std::vector<size_t> &update_shape, const std::vector<int64_t> &begin, | |||
| const std::vector<int64_t> &stride, const std::vector<size_t> &output_shape, | |||
| const half *update, half *output, cudaStream_t cuda_stream); | |||
| template void CopySlices(const std::vector<size_t> &update_shape, const std::vector<int64_t> &begin, | |||
| const std::vector<int64_t> &stride, const std::vector<size_t> &output_shape, | |||
| const int64_t *update, int64_t *output, cudaStream_t cuda_stream); | |||
| template void CopySlices(const std::vector<size_t> &update_shape, const std::vector<int64_t> &begin, | |||
| const std::vector<int64_t> &stride, const std::vector<size_t> &output_shape, const int *update, | |||
| int *output, cudaStream_t cuda_stream); | |||
| template void CopySlices(const std::vector<size_t> &update_shape, const std::vector<int64_t> &begin, | |||
| const std::vector<int64_t> &stride, const std::vector<size_t> &output_shape, | |||
| const short *update, short *output, cudaStream_t cuda_stream); // NOLINT | |||
| template void CopySlices(const std::vector<size_t> &update_shape, const std::vector<int64_t> &begin, | |||
| const std::vector<int64_t> &stride, const std::vector<size_t> &output_shape, | |||
| const int8_t *update, int8_t *output, cudaStream_t cuda_stream); | |||
| template void CopySlices(const std::vector<size_t> &update_shape, const std::vector<int64_t> &begin, | |||
| const std::vector<int64_t> &stride, const std::vector<size_t> &output_shape, | |||
| const uint64_t *update, uint64_t *output, cudaStream_t cuda_stream); | |||
| template void CopySlices(const std::vector<size_t> &update_shape, const std::vector<int64_t> &begin, | |||
| const std::vector<int64_t> &stride, const std::vector<size_t> &output_shape, | |||
| const uint32_t *update, uint32_t *output, cudaStream_t cuda_stream); | |||
| template void CopySlices(const std::vector<size_t> &update_shape, const std::vector<int64_t> &begin, | |||
| const std::vector<int64_t> &stride, const std::vector<size_t> &output_shape, | |||
| const uint16_t *update, uint16_t *output, cudaStream_t cuda_stream); | |||
| template void CopySlices(const std::vector<size_t> &update_shape, const std::vector<int64_t> &begin, | |||
| const std::vector<int64_t> &stride, const std::vector<size_t> &output_shape, | |||
| const unsigned char *update, unsigned char *output, cudaStream_t cuda_stream); | |||
| template void CopySlices(const std::vector<size_t> &update_shape, const std::vector<int64_t> &begin, | |||
| const std::vector<int64_t> &stride, const std::vector<size_t> &output_shape, | |||
| const char *update, char *output, cudaStream_t cuda_stream); | |||
| @@ -0,0 +1,28 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SLICE_COPY_IMPL_CUH_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SLICE_COPY_IMPL_CUH_ | |||
| #include <cuda_runtime.h> | |||
| #include <vector> | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| void CopySlices(const std::vector<size_t> &update_shape, const std::vector<int64_t> &begin, | |||
| const std::vector<int64_t> &stride, const std::vector<size_t> &output_shape, const T *update, | |||
| T *output, cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SLICE_COPY_IMPL_CUH_ | |||