Browse Source

repeat grad

tags/v1.1.0
jonwe 5 years ago
parent
commit
e896d38c34
7 changed files with 574 additions and 0 deletions
  1. +29
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/repeat_elements_grad_gpu_kernel.cc
  2. +119
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/repeat_elements_grad_gpu_kernel.h
  3. +48
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/repeat_elements_grad_impl.cu
  4. +26
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/repeat_elements_grad_impl.cuh
  5. +10
    -0
      mindspore/ops/_grad/grad_array_ops.py
  6. +21
    -0
      mindspore/ops/operations/_grad_ops.py
  7. +321
    -0
      tests/st/ops/gpu/test_repeat_elements_grad_op.py

+ 29
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/repeat_elements_grad_gpu_kernel.cc View File

@@ -0,0 +1,29 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cstdint>

#include "backend/kernel_compiler/gpu/arrays/repeat_elements_grad_gpu_kernel.h"

namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(RepeatElementsGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
RepeatElementsGradGpuKernel, half)

MS_REG_GPU_KERNEL_ONE(RepeatElementsGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
RepeatElementsGradGpuKernel, int32_t)
} // namespace kernel
} // namespace mindspore

+ 119
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/repeat_elements_grad_gpu_kernel.h View File

@@ -0,0 +1,119 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_REPEAT_ELEMENTS_GRAD_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_REPEAT_ELEMENTS_GRAD_GPU_KERNEL_H_

#include "backend/kernel_compiler/gpu/cuda_impl/repeat_elements_grad_impl.cuh"

#include <cuda_runtime.h>

#include <algorithm>
#include <vector>

#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"

namespace mindspore {
namespace kernel {
template <typename T>
class RepeatElementsGradGpuKernel : public GpuKernel {
public:
RepeatElementsGradGpuKernel()
: rep_(1), axis_(0), input_size_(1), output_size_(0), outer_size_(1), repeat_dim_size_(1), inner_size_(1) {}
~RepeatElementsGradGpuKernel() = default;

const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }

bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *dy = GetDeviceAddress<T>(inputs, 0);
T *dx = GetDeviceAddress<T>(outputs, 0);

CalRepeatElementsGrad(dy, rep_, dx, outer_size_, repeat_dim_size_, inner_size_,
reinterpret_cast<cudaStream_t>(stream_ptr));

return true;
}

bool Init(const CNodePtr &kernel_node) override {
size_t input_count = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_count != 1) {
MS_LOG(EXCEPTION) << input_count << " arguments were provided, but RepeatElementGradGpuKernel expects 1.";
}

std::vector<size_t> dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
int dy_dim = dy_shape.size();

axis_ = GetAttr<int>(kernel_node, "axis");
if (axis_ < 0) {
axis_ += dy_dim;
}
rep_ = GetAttr<int>(kernel_node, "rep");
if (axis_ >= dy_dim) {
axis_ = dy_dim - 1;
rep_ = 1;
}

for (int i = 0; i < dy_dim; i++) {
auto e = dy_shape[i];
input_size_ *= e;
input_shape_.push_back(e);
if (i < axis_) {
outer_size_ *= e;
} else if (i > axis_) {
inner_size_ *= e;
} else {
repeat_dim_size_ = e / rep_;
}
}

output_size_ = input_size_ / rep_;
output_shape_ = input_shape_;
output_shape_[axis_] /= rep_;

InitSizeLists();

return true;
}

protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_ * sizeof(T));
output_size_list_.push_back(output_size_ * sizeof(T));
}

private:
int rep_;
int axis_;
size_t input_size_;
size_t output_size_;
int outer_size_;
int repeat_dim_size_;
int inner_size_;
std::vector<int> input_shape_;
std::vector<int> output_shape_;

std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore

#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_REPEAT_ELEMENTS_GRAD_GPU_KERNEL_H_

+ 48
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/repeat_elements_grad_impl.cu View File

@@ -0,0 +1,48 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_runtime.h>

#include "repeat_elements_grad_impl.cuh"
#include "runtime/device/gpu/cuda_common.h"

template <typename T>
__global__ void RepeatElementsGrad(const int dx_size, const T *dy, const int rep, T *dx, const int outer_size,
const int repeat_dim_size, const int inner_size) {
for (size_t t_id = blockIdx.x * blockDim.x + threadIdx.x; t_id < dx_size; t_id += gridDim.x * blockDim.x) {
int inner_id = t_id % inner_size;
int repeat_dim_id = t_id / inner_size % repeat_dim_size;
int outer_id = t_id / inner_size / repeat_dim_size;
T dx_i = static_cast<T>(0);
for (int i = 0; i < rep; i++) {
dx_i += dy[(outer_id * rep * repeat_dim_size * inner_size) + (repeat_dim_id * rep * inner_size) +
(i * inner_size) + inner_id];
}
dx[t_id] = dx_i;
}
}

template <typename T>
void CalRepeatElementsGrad(const T *dy, const int rep, T *dx, const int outer_size, const int repeat_dim_size,
const int inner_size, cudaStream_t cuda_stream) {
const int dx_size = outer_size * repeat_dim_size * inner_size;
RepeatElementsGrad<<<GET_BLOCKS(dx_size), GET_THREADS, 0, cuda_stream>>>(dx_size, dy, rep, dx, outer_size,
repeat_dim_size, inner_size);
}

template void CalRepeatElementsGrad<int>(const int *dy, const int rep, int *dx, const int outer_size,
const int repeat_dim_size, const int inner_size, cudaStream_t cuda_stream);
template void CalRepeatElementsGrad<half>(const half *dy, const int rep, half *dx, const int outer_size,
const int repeat_dim_size, const int inner_size, cudaStream_t cuda_stream);

+ 26
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/repeat_elements_grad_impl.cuh View File

@@ -0,0 +1,26 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_REPEAT_ELEMENTS_GRAD_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_REPEAT_ELEMENTS_GRAD_H_

#include <cuda_runtime.h>

template <typename T>
void CalRepeatElementsGrad(const T *dy, const int rep, T *dx, const int outer_size, const int repeat_dim_size,
const int inner_size, cudaStream_t cuda_stream);

#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_REPEAT_ELEMENTS_GRAD_H_

+ 10
- 0
mindspore/ops/_grad/grad_array_ops.py View File

@@ -848,3 +848,13 @@ def get_bprop_unique(self):
dx = op(dout, out) dx = op(dout, out)
return (dx,) return (dx,)
return bprop return bprop


@bprop_getters.register(P.RepeatElements)
def get_bprop_repeat_elements(self):
"""Generate bprop for RepeatElements"""
op = G.RepeatElementsGrad(self.rep, self.axis)
def bprop(x, y, dy):
dx = op(dy)
return (dx,)
return bprop

+ 21
- 0
mindspore/ops/operations/_grad_ops.py View File

@@ -1731,3 +1731,24 @@ class LRNGrad(PrimitiveWithInfer):


def infer_shape(self, grads, x, y): def infer_shape(self, grads, x, y):
return x return x


class RepeatElementsGrad(PrimitiveWithInfer):
"""Gradients of RepeatElements operation."""

@prim_attr_register
def __init__(self, rep, axis=0):
self.init_prim_io_names(inputs=['dy'], outputs=['dx'])
validator.check_value_type("rep", rep, [int], self.name)
validator.check_value_type("axis", axis, [int], self.name)
self.rep = rep
self.axis = axis

def infer_dtype(self, dy_type):
validator.check_type_name("dy_type", dy_type, [mstype.float16, mstype.float32, mstype.int32], self.name)
return dy_type

def infer_shape(self, dy_shape):
dx_shape = dy_shape
dx_shape[self.axis] = dy_shape[self.axis] // self.rep
return dx_shape

+ 321
- 0
tests/st/ops/gpu/test_repeat_elements_grad_op.py View File

@@ -0,0 +1,321 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

import numpy as np
import pytest

from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops as G
import mindspore.nn as nn
import mindspore.context as context


class RepeatElementsNet(nn.Cell):
def __init__(self, rep, axis):
super(RepeatElementsNet, self).__init__()
self.repeat_elements = P.RepeatElements(rep, axis)

def construct(self, x):
return self.repeat_elements(x)


class RepeatElementsGradNet(nn.Cell):
def __init__(self, rep, axis):
super(RepeatElementsGradNet, self).__init__()
self.repeat_elements_grad = G.RepeatElementsGrad(rep, axis)

def construct(self, dy):
return self.repeat_elements_grad(dy)


def repeat_elements(x, rep, axis):
repeat_elements_net = RepeatElementsNet(rep, axis)
return repeat_elements_net(Tensor(x.astype(np.int32))).asnumpy()


def repeat_elements_grad(dy, rep, axis):
repeat_elements_grad_net = RepeatElementsGradNet(rep, axis)
return repeat_elements_grad_net(Tensor(dy.astype(np.int32))).asnumpy()


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_repeat_elements_grad_1d_one_element_rep_1():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
a = np.arange(1)

ms_out = repeat_elements_grad(a, 1, 0)
np_out = a.repeat(1, 0)
np.testing.assert_array_equal(np_out, ms_out)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_repeat_elements_grad_1d_one_element_rep_many():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
a = np.arange(1, 2)

y = repeat_elements(a, 5, 0)
print(y)
ms_out = repeat_elements_grad(y, 5, 0)
print(ms_out)
np.testing.assert_array_equal(a*5, ms_out)

y = repeat_elements(a, 513, 0)
ms_out = repeat_elements_grad(y, 513, 0)
np.testing.assert_array_equal(a*513, ms_out)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_repeat_elements_grad_1d_rep_1():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
a = np.arange(24)

ms_out = repeat_elements_grad(a, 1, 0)
np_out = a.repeat(1, 0)
np.testing.assert_array_equal(np_out, ms_out)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_repeat_elements_grad_1d_rep_many():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
a = np.arange(4)

y = repeat_elements(a, 3, 0)
ms_out = repeat_elements_grad(y, 3, 0)
np.testing.assert_array_equal(a*3, ms_out)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_repeat_elements_grad_2d_one_element_rep_1():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
a = np.arange(1).reshape(1, 1)

ms_out = repeat_elements_grad(a, 1, 0)
np_out = a.repeat(1, 0)
np.testing.assert_array_equal(np_out, ms_out)

ms_out = repeat_elements_grad(a, 1, 1)
np_out = a.repeat(1, 1)
np.testing.assert_array_equal(np_out, ms_out)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_repeat_elements_grad_2d_one_element_rep_many():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
a = np.arange(1).reshape(1, 1)

y = repeat_elements(a, 13, 0)
ms_out = repeat_elements_grad(y, 13, 0)
np.testing.assert_array_equal(a*13, ms_out)

y = repeat_elements(a, 13, 1)
ms_out = repeat_elements_grad(y, 13, 1)
np.testing.assert_array_equal(a*13, ms_out)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_repeat_elements_grad_2d_rep_1():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
a = np.arange(24).reshape(12, 2)

ms_out = repeat_elements_grad(a, 1, 0)
np_out = a.repeat(1, 0)
np.testing.assert_array_equal(np_out, ms_out)

ms_out = repeat_elements_grad(a, 1, 1)
np_out = a.repeat(1, 1)
np.testing.assert_array_equal(np_out, ms_out)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_repeat_elements_grad_2d_rep_many():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
a = np.arange(24).reshape(8, 3)

y = repeat_elements(a, 23, 0)
ms_out = repeat_elements_grad(y, 23, 0)
np.testing.assert_array_equal(a*23, ms_out)

y = repeat_elements(a, 23, 1)
ms_out = repeat_elements_grad(y, 23, 1)
np.testing.assert_array_equal(a*23, ms_out)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_repeat_elements_grad_5d_one_element_rep_1():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
a = np.arange(1).reshape(1, 1, 1, 1, 1)

ms_out = repeat_elements_grad(a, 1, 0)
np_out = a.repeat(1, 0)
np.testing.assert_array_equal(np_out, ms_out)

ms_out = repeat_elements_grad(a, 1, 1)
np_out = a.repeat(1, 1)
np.testing.assert_array_equal(np_out, ms_out)

ms_out = repeat_elements_grad(a, 1, 2)
np_out = a.repeat(1, 2)
np.testing.assert_array_equal(np_out, ms_out)

ms_out = repeat_elements_grad(a, 1, 3)
np_out = a.repeat(1, 3)
np.testing.assert_array_equal(np_out, ms_out)

ms_out = repeat_elements_grad(a, 1, 4)
np_out = a.repeat(1, 4)
np.testing.assert_array_equal(np_out, ms_out)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_repeat_elements_grad_5d_one_element_rep_many():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
a = np.arange(1).reshape(1, 1, 1, 1, 1)

y = repeat_elements(a, 19, 0)
ms_out = repeat_elements_grad(y, 19, 0)
np.testing.assert_array_equal(a, ms_out)

y = repeat_elements(a, 19, 1)
ms_out = repeat_elements_grad(y, 19, 1)
np.testing.assert_array_equal(a, ms_out)

y = repeat_elements(a, 19, 2)
ms_out = repeat_elements_grad(y, 19, 2)
np.testing.assert_array_equal(a, ms_out)

y = repeat_elements(a, 19, 3)
ms_out = repeat_elements_grad(y, 19, 3)
np.testing.assert_array_equal(a, ms_out)

y = repeat_elements(a, 19, 4)
ms_out = repeat_elements_grad(y, 19, 4)
np.testing.assert_array_equal(a, ms_out)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_repeat_elements_grad_5d_rep_1():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
a = np.arange(224).reshape(8, 2, 1, 7, 2)

ms_out = repeat_elements_grad(a, 1, 0)
np_out = a.repeat(1, 0)
np.testing.assert_array_equal(np_out, ms_out)

ms_out = repeat_elements_grad(a, 1, 1)
np_out = a.repeat(1, 1)
np.testing.assert_array_equal(np_out, ms_out)

ms_out = repeat_elements_grad(a, 1, 2)
np_out = a.repeat(1, 2)
np.testing.assert_array_equal(np_out, ms_out)

ms_out = repeat_elements_grad(a, 1, 3)
np_out = a.repeat(1, 3)
np.testing.assert_array_equal(np_out, ms_out)

ms_out = repeat_elements_grad(a, 1, 4)
np_out = a.repeat(1, 4)
np.testing.assert_array_equal(np_out, ms_out)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_repeat_elements_grad_5d_rep_many():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
a = np.arange(224).reshape(1, 7, 4, 4, 2)

y = repeat_elements(a, 7, 0)
ms_out = repeat_elements_grad(y, 7, 0)
np.testing.assert_array_equal(a*7, ms_out)

y = repeat_elements(a, 7, 1)
ms_out = repeat_elements_grad(y, 7, 1)
np.testing.assert_array_equal(a*7, ms_out)

y = repeat_elements(a, 7, 2)
ms_out = repeat_elements_grad(y, 7, 2)
np.testing.assert_array_equal(a*7, ms_out)

y = repeat_elements(a, 7, 3)
ms_out = repeat_elements_grad(y, 7, 3)
np.testing.assert_array_equal(a*7, ms_out)

y = repeat_elements(a, 7, 4)
ms_out = repeat_elements_grad(y, 7, 4)
np.testing.assert_array_equal(a*7, ms_out)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_repeat_elements_grad_half():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
a = np.arange(1152).astype(np.float16).reshape(4, 3, 4, 2, 1, 1, 4, 3)

y = repeat_elements(a, 4, 0)
ms_out = repeat_elements_grad(y, 4, 0)
np.testing.assert_array_equal(a*4, ms_out)

y = repeat_elements(a, 4, 1)
ms_out = repeat_elements_grad(y, 4, 1)
np.testing.assert_array_equal(a*4, ms_out)

y = repeat_elements(a, 4, 2)
ms_out = repeat_elements_grad(y, 4, 2)
np.testing.assert_array_equal(a*4, ms_out)

y = repeat_elements(a, 4, 3)
ms_out = repeat_elements_grad(y, 4, 3)
np.testing.assert_array_equal(a*4, ms_out)

y = repeat_elements(a, 4, 4)
ms_out = repeat_elements_grad(y, 4, 4)
np.testing.assert_array_equal(a*4, ms_out)

y = repeat_elements(a, 4, 5)
ms_out = repeat_elements_grad(y, 4, 5)
np.testing.assert_array_equal(a*4, ms_out)

y = repeat_elements(a, 4, 6)
ms_out = repeat_elements_grad(y, 4, 6)
np.testing.assert_array_equal(a*4, ms_out)

y = repeat_elements(a, 4, 7)
ms_out = repeat_elements_grad(y, 4, 7)
np.testing.assert_array_equal(a*4, ms_out)

Loading…
Cancel
Save