| @@ -0,0 +1,26 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/arrays/broadcast_to_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE(BroadcastTo, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| BroadcastToGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(BroadcastTo, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| BroadcastToGpuKernel, half) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,83 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_BROADCAST_TO_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_BROADCAST_TO_GPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class BroadcastToGpuKernel : public GpuKernel { | |||
| public: | |||
| BroadcastToGpuKernel() {} | |||
| ~BroadcastToGpuKernel() = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| T *input_addr = GetDeviceAddress<T>(inputs, 0); | |||
| T *output_addr = GetDeviceAddress<T>(outputs, 0); | |||
| BroadcastTo(input_shape_[0], input_shape_[1], input_shape_[2], input_shape_[3], output_shape_[0], output_shape_[1], | |||
| output_shape_[2], output_shape_[3], input_addr, output_addr, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||
| if (input_shapes.size() > 4 || output_shapes.size() > 4) { | |||
| MS_LOG(EXCEPTION) << "BroadcastTo operation not support dim greater than 4"; | |||
| } | |||
| for (int i = input_shapes.size() - 1; i >= 0; i--) { | |||
| input_shape_[i] = input_shapes[i]; | |||
| } | |||
| for (int j = output_shapes.size() - 1; j >= 0; j--) { | |||
| output_shape_[j] = output_shapes[j]; | |||
| } | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitSizeLists() override { | |||
| input_size_list_.push_back(input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3] * sizeof(T)); | |||
| output_size_list_.push_back(output_shape_[0] * output_shape_[1] * output_shape_[2] * output_shape_[3] * sizeof(T)); | |||
| } | |||
| private: | |||
| int input_shape_[4] = {1, 1, 1, 1}; | |||
| int output_shape_[4] = {1, 1, 1, 1}; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_BROADCAST_TO_GPU_KERNEL_H_ | |||
| @@ -116,16 +116,16 @@ __global__ void BroadcastKernel(const int l0, const int l1, const int l2, const | |||
| output); | |||
| case BROADCAST_TYPE_REALDIV: | |||
| return BroadcastOperator<T, S, RealDivFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, | |||
| output); | |||
| output); | |||
| case BROADCAST_TYPE_MUL: | |||
| return BroadcastOperator<T, S, MulFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, | |||
| output); | |||
| output); | |||
| case BROADCAST_TYPE_SUB: | |||
| return BroadcastOperator<T, S, SubFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, | |||
| output); | |||
| output); | |||
| case BROADCAST_TYPE_ADD: | |||
| return BroadcastOperator<T, S, AddFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, | |||
| output); | |||
| output); | |||
| } | |||
| } | |||
| @@ -176,6 +176,28 @@ void NoBroadcast(const int &nums, enum BroadcastOpType op, const T *input0, cons | |||
| NoBroadcastKernel<<<GET_BLOCKS(nums), GET_THREADS, 0, stream>>>(nums, op, input0, input1, output); | |||
| } | |||
| template <typename T> | |||
| __global__ void BroadcastToKernel(const int i0, const int i1, const int i2, const int i3, const int o0, | |||
| const int o1, const int o2, const int o3, const T *input_addr, T *output_addr) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < o0 * o1 * o2 * o3; pos += blockDim.x * gridDim.x) { | |||
| int i = pos / (o1 * o2 * o3) % o0; | |||
| int j = pos / (o2 * o3) % o1; | |||
| int k = pos / o3 % o2; | |||
| int l = pos % o3; | |||
| int input_idx = Index(i, i0) * i1 * i2 * i3 + Index(j, i1) * i2 * i3 + Index(k, i2) * i3 + Index(l, i3); | |||
| output_addr[pos] = input_addr[input_idx]; | |||
| } | |||
| } | |||
| template <typename T> | |||
| void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, const int &o0, const int &o1, | |||
| const int &o2, const int &o3, const T *input_addr, T *output_addr, cudaStream_t stream) { | |||
| int nums = o0 * o1 * o2 * o3; | |||
| BroadcastToKernel<<<GET_BLOCKS(nums), GET_THREADS, 0, stream>>>(i0, i1, i2, i3, o0, o1, o2, o3, input_addr, | |||
| output_addr); | |||
| } | |||
| template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, | |||
| const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, | |||
| enum BroadcastOpType op, const float *input0, const float *input1, bool *output, | |||
| @@ -204,5 +226,11 @@ template void NoBroadcast(const int &nums, enum BroadcastOpType op, const half * | |||
| bool *output, cudaStream_t stream); | |||
| template void NoBroadcast(const int &nums, enum BroadcastOpType op, const half *input0, const half *input1, | |||
| half *output, cudaStream_t stream); | |||
| template void NoBroadcast(const int &nums, enum BroadcastOpType op, const int *input0, const int *input1, | |||
| int *output, cudaStream_t stream); | |||
| template void NoBroadcast(const int &nums, enum BroadcastOpType op, const int *input0, const int *input1, int *output, | |||
| cudaStream_t stream); | |||
| template void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, const int &o0, const int &o1, | |||
| const int &o2, const int &o3, const float *input_addr, float *output_addr, | |||
| cudaStream_t stream); | |||
| template void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, const int &o0, const int &o1, | |||
| const int &o2, const int &o3, const half *input_addr, half *output_addr, cudaStream_t stream); | |||
| @@ -41,4 +41,8 @@ template <typename T, typename S> | |||
| void NoBroadcast(const int &size, enum BroadcastOpType op, const T *input0, const T *input1, S *output, | |||
| cudaStream_t stream); | |||
| template <typename T> | |||
| void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, const int &o0, const int &o1, | |||
| const int &o2, const int &o3, const T *input_addr, T *output_addr, cudaStream_t stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_ | |||
| @@ -0,0 +1,40 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.ops import operations as P | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_broadcast(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| x_np = np.random.rand(3, 1, 5, 1).astype(np.float32) | |||
| shape = (3, 4, 5, 6) | |||
| output = P.BroadcastTo(shape)(Tensor(x_np)) | |||
| expect = np.broadcast_to(x_np, shape) | |||
| assert np.allclose(output.asnumpy(), expect) | |||
| x1_np = np.random.rand(3, 1, 5, 1).astype(np.float16) | |||
| output = P.BroadcastTo(shape)(Tensor(x1_np)) | |||
| expect = np.broadcast_to(x1_np, shape) | |||
| assert np.allclose(output.asnumpy(), expect) | |||