diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/l2normalize_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/l2normalize_grad_cpu_kernel.cc new file mode 100644 index 0000000000..6c187b96a0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/l2normalize_grad_cpu_kernel.cc @@ -0,0 +1,169 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/l2normalize_grad_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +template +void L2NormalizeGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + CheckIONumber(kernel_node); + for (size_t i = 0; i < INPUT_SIZE; i++) { + input_shape_list_.emplace_back(AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i)); + } + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + CheckInputShape(output_shape); + + int output_dim_length = output_shape.size(); + dim_elem_num_list_.resize(output_dim_length, 1); + for (int i = output_dim_length - 2; i >= 0; i--) { + dim_elem_num_list_[i] = output_shape[i + 1] * dim_elem_num_list_[i + 1]; + } + + int axis = LongToInt(AnfAlgo::GetNodeAttr(kernel_node, "axis")); + int input_dim_length = SizeToInt(input_shape_list_[0].size()); + axis_ = axis < 0 ? (axis + input_dim_length) : axis; + epsilon_ = static_cast(AnfAlgo::GetNodeAttr(kernel_node, "epsilon")); +} + +template +bool L2NormalizeGradCPUKernel::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + auto input_x = reinterpret_cast(inputs[0]->addr); + auto y = reinterpret_cast(inputs[1]->addr); + auto dout = reinterpret_cast(inputs[2]->addr); + auto output = reinterpret_cast(outputs[0]->addr); + auto output_size = outputs[0]->size / sizeof(T); + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + std::vector high_dim_index; + OneDimIndexToHighDimIndex(i, &high_dim_index); + std::vector input_x_vector; + GetVector(&input_x_vector, high_dim_index, input_x); + std::vector dout_vector; + GetVector(&dout_vector, high_dim_index, dout); + std::vector y_vector; + GetVector(&y_vector, high_dim_index, y); + GetOutput(input_x_vector, y_vector, dout_vector, high_dim_index, &output[i]); + } + }; + CPUKernelUtils::ParallelFor(task, output_size); + return true; +} + +template +void L2NormalizeGradCPUKernel::CheckInputShape(const std::vector &output_shape) { + for (const auto &shape : input_shape_list_) { + if (output_shape != shape) { + MS_LOG(EXCEPTION) << "Input shape and output shape should be same."; + } + } + auto input_x_shape = input_shape_list_[0]; + if (input_x_shape.size() != 0) { + if (std::any_of(input_x_shape.begin(), input_x_shape.end(), [](size_t i) { return i == 0; })) { + MS_LOG(EXCEPTION) << "L2NormalizeCPUKernel input is null."; + } + } +} + +template +void L2NormalizeGradCPUKernel::CheckIONumber(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != INPUT_SIZE) { + MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but L2NormalizeGradCPUKernel needs 3 input."; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != OUTPUT_SIZE) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but L2NormalizeGradCPUKernel needs 1 output."; + } +} + +template +void L2NormalizeGradCPUKernel::OneDimIndexToHighDimIndex(size_t one_dim_index, std::vector *high_dim_index) { + for (const auto &item : dim_elem_num_list_) { + high_dim_index->push_back(one_dim_index / item); + one_dim_index %= item; + } +} + +template +void L2NormalizeGradCPUKernel::HighDimIndexToOneDimIndex(size_t *one_dim_index, + const std::vector &high_dim_index) { + *one_dim_index = 0; + int len = high_dim_index.size(); + for (int i = 0; i < len; i++) { + *one_dim_index += high_dim_index[i] * dim_elem_num_list_[i]; + } +} + +template +void L2NormalizeGradCPUKernel::GetVector(std::vector *x_vector, const std::vector &high_dim_index, + const T *x) { + auto x_shape = input_shape_list_[0]; + for (size_t i = 0; i < x_shape[axis_]; i++) { + size_t oneDimIndex = 0; + std::vector tmp_high_dim_index = high_dim_index; + tmp_high_dim_index[axis_] = i; + HighDimIndexToOneDimIndex(&oneDimIndex, tmp_high_dim_index); + x_vector->push_back(x[oneDimIndex]); + } +} + +template +void L2NormalizeGradCPUKernel::GetSumOfProduct(const std::vector &x_vector, const std::vector &y_vector, + T *ss) { + size_t len = x_vector.size(); + std::vector tmp_vector(len); + for (size_t i = 0; i < len; i++) { + tmp_vector[i] = x_vector[i] * y_vector[i]; + } + if (len % 2 == 1) { + tmp_vector[0] += tmp_vector[len - 1]; + } + for (size_t stride = len / 2; stride > 0; stride >>= 1) { + for (size_t i = 0; i < stride; i++) { + tmp_vector[i] += tmp_vector[i + stride]; + } + if (stride > 2 && stride % 2 == 1) { + tmp_vector[0] += tmp_vector[stride - 1]; + } + } + *ss = tmp_vector[0]; +} + +template +void L2NormalizeGradCPUKernel::GetOutput(const std::vector &input_x_vector, const std::vector &y_vector, + const std::vector &dout_vector, + const std::vector &high_dim_index, T *output) { + size_t axis_index = high_dim_index[axis_]; + T dout = dout_vector[axis_index]; + T y = y_vector[axis_index]; + T tmp_sum1; + GetSumOfProduct(y_vector, dout_vector, &tmp_sum1); + T tmp_sum2; + GetSumOfProduct(input_x_vector, input_x_vector, &tmp_sum2); + tmp_sum2 = sqrt(tmp_sum2); + if (tmp_sum2 >= epsilon_) { + *output = (dout - y * tmp_sum1) / tmp_sum2; + } else { + *output = (dout - y * tmp_sum1) / epsilon_; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/l2normalize_grad_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/l2normalize_grad_cpu_kernel.h new file mode 100644 index 0000000000..982f1e6789 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/l2normalize_grad_cpu_kernel.h @@ -0,0 +1,70 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_L2NORMALIZE_GRAD_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_L2NORMALIZE_GRAD_CPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +constexpr size_t INPUT_SIZE = 3; +constexpr size_t OUTPUT_SIZE = 1; +template +class L2NormalizeGradCPUKernel : public CPUKernel { + public: + L2NormalizeGradCPUKernel() = default; + ~L2NormalizeGradCPUKernel() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + void InitKernel(const CNodePtr &kernel_node) override; + + private: + void CheckInputShape(const std::vector &output_shape); + void CheckIONumber(const CNodePtr &kernel_node); + void OneDimIndexToHighDimIndex(size_t one_dim_index, std::vector *high_dim_index); + void HighDimIndexToOneDimIndex(size_t *one_dim_index, const std::vector &high_dim_index); + void GetVector(std::vector *x_vector, const std::vector &high_dim_index, const T *x); + void GetSumOfProduct(const std::vector &x_vector, const std::vector &y_vector, T *ss); + void GetOutput(const std::vector &input_x_vector, const std::vector &y_vector, + const std::vector &dout_vector, const std::vector &high_dim_index, T *output); + std::vector> input_shape_list_; + std::vector dim_elem_num_list_; + int axis_{0}; + T epsilon_{0}; +}; + +MS_REG_CPU_KERNEL_T(L2NormalizeGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + L2NormalizeGradCPUKernel, float); + +MS_REG_CPU_KERNEL_T(L2NormalizeGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + L2NormalizeGradCPUKernel, float16); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_L2NORMALIZE_GRAD_CPU_KERNEL_H_ diff --git a/tests/st/ops/cpu/test_l2normalize_grad_op.py b/tests/st/ops/cpu/test_l2normalize_grad_op.py new file mode 100644 index 0000000000..70aba3b5c9 --- /dev/null +++ b/tests/st/ops/cpu/test_l2normalize_grad_op.py @@ -0,0 +1,53 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.common.dtype as mstype +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops.operations import _grad_ops as G + + +class Net(nn.Cell): + def __init__(self, axis=0, epsilon=1e-4): + super(Net, self).__init__() + self.ops = G.L2NormalizeGrad(axis, epsilon) + + def construct(self, input_x, output, dout): + return self.ops(input_x, output, dout) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu_training +@pytest.mark.env_onecard +def test_net01(): + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + axis = 1 + net = Net(axis) + input_x = np.arange(24).astype(np.float32).reshape((2, 3, 4)) + dout = np.arange(24, 48).astype(np.float32).reshape((2, 3, 4)) + output = input_x / np.sqrt(np.sum(input_x**2, axis=axis, keepdims=True)) + except_asn = (dout - output * np.sum(output * dout, axis=axis, keepdims=True) + ) / np.sqrt(np.sum(input_x**2, axis=axis, keepdims=True)) + input_x = Tensor(input_x, mstype.float32) + output = Tensor(output, mstype.float32) + dout = Tensor(dout, mstype.float32) + net_output = net(input_x, output, dout).asnumpy() + precision = np.ones(shape=(2, 3, 4), dtype=np.float32) * 1.0e-5 + absolute_error = np.abs(except_asn-net_output) + assert np.all(absolute_error < precision)