From e896d38c34e3cdf92260e67df24b8b2cc85512ff Mon Sep 17 00:00:00 2001 From: jonwe Date: Thu, 29 Oct 2020 16:29:01 -0400 Subject: [PATCH] repeat grad --- .../arrays/repeat_elements_grad_gpu_kernel.cc | 29 ++ .../arrays/repeat_elements_grad_gpu_kernel.h | 119 +++++++ .../cuda_impl/repeat_elements_grad_impl.cu | 48 +++ .../cuda_impl/repeat_elements_grad_impl.cuh | 26 ++ mindspore/ops/_grad/grad_array_ops.py | 10 + mindspore/ops/operations/_grad_ops.py | 21 ++ .../ops/gpu/test_repeat_elements_grad_op.py | 321 ++++++++++++++++++ 7 files changed, 574 insertions(+) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/repeat_elements_grad_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/repeat_elements_grad_gpu_kernel.h create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/repeat_elements_grad_impl.cu create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/repeat_elements_grad_impl.cuh create mode 100644 tests/st/ops/gpu/test_repeat_elements_grad_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/repeat_elements_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/repeat_elements_grad_gpu_kernel.cc new file mode 100644 index 0000000000..c3f364ad6f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/repeat_elements_grad_gpu_kernel.cc @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "backend/kernel_compiler/gpu/arrays/repeat_elements_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(RepeatElementsGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + RepeatElementsGradGpuKernel, half) + +MS_REG_GPU_KERNEL_ONE(RepeatElementsGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + RepeatElementsGradGpuKernel, int32_t) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/repeat_elements_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/repeat_elements_grad_gpu_kernel.h new file mode 100644 index 0000000000..ee51fbf898 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/repeat_elements_grad_gpu_kernel.h @@ -0,0 +1,119 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_REPEAT_ELEMENTS_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_REPEAT_ELEMENTS_GRAD_GPU_KERNEL_H_ + +#include "backend/kernel_compiler/gpu/cuda_impl/repeat_elements_grad_impl.cuh" + +#include + +#include +#include + +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +template +class RepeatElementsGradGpuKernel : public GpuKernel { + public: + RepeatElementsGradGpuKernel() + : rep_(1), axis_(0), input_size_(1), output_size_(0), outer_size_(1), repeat_dim_size_(1), inner_size_(1) {} + ~RepeatElementsGradGpuKernel() = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + T *dy = GetDeviceAddress(inputs, 0); + T *dx = GetDeviceAddress(outputs, 0); + + CalRepeatElementsGrad(dy, rep_, dx, outer_size_, repeat_dim_size_, inner_size_, + reinterpret_cast(stream_ptr)); + + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_count = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_count != 1) { + MS_LOG(EXCEPTION) << input_count << " arguments were provided, but RepeatElementGradGpuKernel expects 1."; + } + + std::vector dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + int dy_dim = dy_shape.size(); + + axis_ = GetAttr(kernel_node, "axis"); + if (axis_ < 0) { + axis_ += dy_dim; + } + rep_ = GetAttr(kernel_node, "rep"); + if (axis_ >= dy_dim) { + axis_ = dy_dim - 1; + rep_ = 1; + } + + for (int i = 0; i < dy_dim; i++) { + auto e = dy_shape[i]; + input_size_ *= e; + input_shape_.push_back(e); + if (i < axis_) { + outer_size_ *= e; + } else if (i > axis_) { + inner_size_ *= e; + } else { + repeat_dim_size_ = e / rep_; + } + } + + output_size_ = input_size_ / rep_; + output_shape_ = input_shape_; + output_shape_[axis_] /= rep_; + + InitSizeLists(); + + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_ * sizeof(T)); + output_size_list_.push_back(output_size_ * sizeof(T)); + } + + private: + int rep_; + int axis_; + size_t input_size_; + size_t output_size_; + int outer_size_; + int repeat_dim_size_; + int inner_size_; + std::vector input_shape_; + std::vector output_shape_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_REPEAT_ELEMENTS_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/repeat_elements_grad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/repeat_elements_grad_impl.cu new file mode 100644 index 0000000000..4c125e6ed7 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/repeat_elements_grad_impl.cu @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "repeat_elements_grad_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" + +template +__global__ void RepeatElementsGrad(const int dx_size, const T *dy, const int rep, T *dx, const int outer_size, + const int repeat_dim_size, const int inner_size) { + for (size_t t_id = blockIdx.x * blockDim.x + threadIdx.x; t_id < dx_size; t_id += gridDim.x * blockDim.x) { + int inner_id = t_id % inner_size; + int repeat_dim_id = t_id / inner_size % repeat_dim_size; + int outer_id = t_id / inner_size / repeat_dim_size; + T dx_i = static_cast(0); + for (int i = 0; i < rep; i++) { + dx_i += dy[(outer_id * rep * repeat_dim_size * inner_size) + (repeat_dim_id * rep * inner_size) + + (i * inner_size) + inner_id]; + } + dx[t_id] = dx_i; + } +} + +template +void CalRepeatElementsGrad(const T *dy, const int rep, T *dx, const int outer_size, const int repeat_dim_size, + const int inner_size, cudaStream_t cuda_stream) { + const int dx_size = outer_size * repeat_dim_size * inner_size; + RepeatElementsGrad<<>>(dx_size, dy, rep, dx, outer_size, + repeat_dim_size, inner_size); +} + +template void CalRepeatElementsGrad(const int *dy, const int rep, int *dx, const int outer_size, + const int repeat_dim_size, const int inner_size, cudaStream_t cuda_stream); +template void CalRepeatElementsGrad(const half *dy, const int rep, half *dx, const int outer_size, + const int repeat_dim_size, const int inner_size, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/repeat_elements_grad_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/repeat_elements_grad_impl.cuh new file mode 100644 index 0000000000..0cb46f1bf6 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/repeat_elements_grad_impl.cuh @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_REPEAT_ELEMENTS_GRAD_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_REPEAT_ELEMENTS_GRAD_H_ + +#include + +template +void CalRepeatElementsGrad(const T *dy, const int rep, T *dx, const int outer_size, const int repeat_dim_size, + const int inner_size, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_REPEAT_ELEMENTS_GRAD_H_ diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index 8255c7088c..43b5131c53 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -848,3 +848,13 @@ def get_bprop_unique(self): dx = op(dout, out) return (dx,) return bprop + + +@bprop_getters.register(P.RepeatElements) +def get_bprop_repeat_elements(self): + """Generate bprop for RepeatElements""" + op = G.RepeatElementsGrad(self.rep, self.axis) + def bprop(x, y, dy): + dx = op(dy) + return (dx,) + return bprop diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index cc61c9c4c7..1bc8624a3f 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -1731,3 +1731,24 @@ class LRNGrad(PrimitiveWithInfer): def infer_shape(self, grads, x, y): return x + + +class RepeatElementsGrad(PrimitiveWithInfer): + """Gradients of RepeatElements operation.""" + + @prim_attr_register + def __init__(self, rep, axis=0): + self.init_prim_io_names(inputs=['dy'], outputs=['dx']) + validator.check_value_type("rep", rep, [int], self.name) + validator.check_value_type("axis", axis, [int], self.name) + self.rep = rep + self.axis = axis + + def infer_dtype(self, dy_type): + validator.check_type_name("dy_type", dy_type, [mstype.float16, mstype.float32, mstype.int32], self.name) + return dy_type + + def infer_shape(self, dy_shape): + dx_shape = dy_shape + dx_shape[self.axis] = dy_shape[self.axis] // self.rep + return dx_shape diff --git a/tests/st/ops/gpu/test_repeat_elements_grad_op.py b/tests/st/ops/gpu/test_repeat_elements_grad_op.py new file mode 100644 index 0000000000..038ee115ec --- /dev/null +++ b/tests/st/ops/gpu/test_repeat_elements_grad_op.py @@ -0,0 +1,321 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.ops.operations import _grad_ops as G +import mindspore.nn as nn +import mindspore.context as context + + +class RepeatElementsNet(nn.Cell): + def __init__(self, rep, axis): + super(RepeatElementsNet, self).__init__() + self.repeat_elements = P.RepeatElements(rep, axis) + + def construct(self, x): + return self.repeat_elements(x) + + +class RepeatElementsGradNet(nn.Cell): + def __init__(self, rep, axis): + super(RepeatElementsGradNet, self).__init__() + self.repeat_elements_grad = G.RepeatElementsGrad(rep, axis) + + def construct(self, dy): + return self.repeat_elements_grad(dy) + + +def repeat_elements(x, rep, axis): + repeat_elements_net = RepeatElementsNet(rep, axis) + return repeat_elements_net(Tensor(x.astype(np.int32))).asnumpy() + + +def repeat_elements_grad(dy, rep, axis): + repeat_elements_grad_net = RepeatElementsGradNet(rep, axis) + return repeat_elements_grad_net(Tensor(dy.astype(np.int32))).asnumpy() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_repeat_elements_grad_1d_one_element_rep_1(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(1) + + ms_out = repeat_elements_grad(a, 1, 0) + np_out = a.repeat(1, 0) + np.testing.assert_array_equal(np_out, ms_out) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_repeat_elements_grad_1d_one_element_rep_many(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(1, 2) + + y = repeat_elements(a, 5, 0) + print(y) + ms_out = repeat_elements_grad(y, 5, 0) + print(ms_out) + np.testing.assert_array_equal(a*5, ms_out) + + y = repeat_elements(a, 513, 0) + ms_out = repeat_elements_grad(y, 513, 0) + np.testing.assert_array_equal(a*513, ms_out) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_repeat_elements_grad_1d_rep_1(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(24) + + ms_out = repeat_elements_grad(a, 1, 0) + np_out = a.repeat(1, 0) + np.testing.assert_array_equal(np_out, ms_out) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_repeat_elements_grad_1d_rep_many(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(4) + + y = repeat_elements(a, 3, 0) + ms_out = repeat_elements_grad(y, 3, 0) + np.testing.assert_array_equal(a*3, ms_out) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_repeat_elements_grad_2d_one_element_rep_1(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(1).reshape(1, 1) + + ms_out = repeat_elements_grad(a, 1, 0) + np_out = a.repeat(1, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements_grad(a, 1, 1) + np_out = a.repeat(1, 1) + np.testing.assert_array_equal(np_out, ms_out) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_repeat_elements_grad_2d_one_element_rep_many(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(1).reshape(1, 1) + + y = repeat_elements(a, 13, 0) + ms_out = repeat_elements_grad(y, 13, 0) + np.testing.assert_array_equal(a*13, ms_out) + + y = repeat_elements(a, 13, 1) + ms_out = repeat_elements_grad(y, 13, 1) + np.testing.assert_array_equal(a*13, ms_out) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_repeat_elements_grad_2d_rep_1(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(24).reshape(12, 2) + + ms_out = repeat_elements_grad(a, 1, 0) + np_out = a.repeat(1, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements_grad(a, 1, 1) + np_out = a.repeat(1, 1) + np.testing.assert_array_equal(np_out, ms_out) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_repeat_elements_grad_2d_rep_many(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(24).reshape(8, 3) + + y = repeat_elements(a, 23, 0) + ms_out = repeat_elements_grad(y, 23, 0) + np.testing.assert_array_equal(a*23, ms_out) + + y = repeat_elements(a, 23, 1) + ms_out = repeat_elements_grad(y, 23, 1) + np.testing.assert_array_equal(a*23, ms_out) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_repeat_elements_grad_5d_one_element_rep_1(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(1).reshape(1, 1, 1, 1, 1) + + ms_out = repeat_elements_grad(a, 1, 0) + np_out = a.repeat(1, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements_grad(a, 1, 1) + np_out = a.repeat(1, 1) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements_grad(a, 1, 2) + np_out = a.repeat(1, 2) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements_grad(a, 1, 3) + np_out = a.repeat(1, 3) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements_grad(a, 1, 4) + np_out = a.repeat(1, 4) + np.testing.assert_array_equal(np_out, ms_out) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_repeat_elements_grad_5d_one_element_rep_many(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(1).reshape(1, 1, 1, 1, 1) + + y = repeat_elements(a, 19, 0) + ms_out = repeat_elements_grad(y, 19, 0) + np.testing.assert_array_equal(a, ms_out) + + y = repeat_elements(a, 19, 1) + ms_out = repeat_elements_grad(y, 19, 1) + np.testing.assert_array_equal(a, ms_out) + + y = repeat_elements(a, 19, 2) + ms_out = repeat_elements_grad(y, 19, 2) + np.testing.assert_array_equal(a, ms_out) + + y = repeat_elements(a, 19, 3) + ms_out = repeat_elements_grad(y, 19, 3) + np.testing.assert_array_equal(a, ms_out) + + y = repeat_elements(a, 19, 4) + ms_out = repeat_elements_grad(y, 19, 4) + np.testing.assert_array_equal(a, ms_out) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_repeat_elements_grad_5d_rep_1(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(224).reshape(8, 2, 1, 7, 2) + + ms_out = repeat_elements_grad(a, 1, 0) + np_out = a.repeat(1, 0) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements_grad(a, 1, 1) + np_out = a.repeat(1, 1) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements_grad(a, 1, 2) + np_out = a.repeat(1, 2) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements_grad(a, 1, 3) + np_out = a.repeat(1, 3) + np.testing.assert_array_equal(np_out, ms_out) + + ms_out = repeat_elements_grad(a, 1, 4) + np_out = a.repeat(1, 4) + np.testing.assert_array_equal(np_out, ms_out) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_repeat_elements_grad_5d_rep_many(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(224).reshape(1, 7, 4, 4, 2) + + y = repeat_elements(a, 7, 0) + ms_out = repeat_elements_grad(y, 7, 0) + np.testing.assert_array_equal(a*7, ms_out) + + y = repeat_elements(a, 7, 1) + ms_out = repeat_elements_grad(y, 7, 1) + np.testing.assert_array_equal(a*7, ms_out) + + y = repeat_elements(a, 7, 2) + ms_out = repeat_elements_grad(y, 7, 2) + np.testing.assert_array_equal(a*7, ms_out) + + y = repeat_elements(a, 7, 3) + ms_out = repeat_elements_grad(y, 7, 3) + np.testing.assert_array_equal(a*7, ms_out) + + y = repeat_elements(a, 7, 4) + ms_out = repeat_elements_grad(y, 7, 4) + np.testing.assert_array_equal(a*7, ms_out) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_repeat_elements_grad_half(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + a = np.arange(1152).astype(np.float16).reshape(4, 3, 4, 2, 1, 1, 4, 3) + + y = repeat_elements(a, 4, 0) + ms_out = repeat_elements_grad(y, 4, 0) + np.testing.assert_array_equal(a*4, ms_out) + + y = repeat_elements(a, 4, 1) + ms_out = repeat_elements_grad(y, 4, 1) + np.testing.assert_array_equal(a*4, ms_out) + + y = repeat_elements(a, 4, 2) + ms_out = repeat_elements_grad(y, 4, 2) + np.testing.assert_array_equal(a*4, ms_out) + + y = repeat_elements(a, 4, 3) + ms_out = repeat_elements_grad(y, 4, 3) + np.testing.assert_array_equal(a*4, ms_out) + + y = repeat_elements(a, 4, 4) + ms_out = repeat_elements_grad(y, 4, 4) + np.testing.assert_array_equal(a*4, ms_out) + + y = repeat_elements(a, 4, 5) + ms_out = repeat_elements_grad(y, 4, 5) + np.testing.assert_array_equal(a*4, ms_out) + + y = repeat_elements(a, 4, 6) + ms_out = repeat_elements_grad(y, 4, 6) + np.testing.assert_array_equal(a*4, ms_out) + + y = repeat_elements(a, 4, 7) + ms_out = repeat_elements_grad(y, 4, 7) + np.testing.assert_array_equal(a*4, ms_out)