From: @zhao_ting_v Reviewed-by: @wuxuejian,@liangchenghui Signed-off-by: @liangchenghuitags/v1.2.0-rc1
| @@ -76,6 +76,16 @@ void Reciprocal(const T *in, T *out, size_t start, size_t end) { | |||
| out[i] = static_cast<T>(1.0 / in[i]); | |||
| } | |||
| } | |||
| template <typename T> | |||
| void Gelu(const T *in, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| T x = in[i]; | |||
| auto double_x = static_cast<T>(x); | |||
| T tanh_res = (T)std::tanh(0.7978845608 * (double_x + 0.044715 * double_x * double_x * double_x)); | |||
| out[i] = x * ((T)1.0 + tanh_res) / (T)2.0; | |||
| } | |||
| } | |||
| } // namespace | |||
| void ArithmeticSelfCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| @@ -95,6 +105,8 @@ void ArithmeticSelfCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| operate_type_ = FLOOR; | |||
| } else if (kernel_name == prim::kPrimReciprocal->name()) { | |||
| operate_type_ = RECIPROCAL; | |||
| } else if (kernel_name == prim::kPrimGelu->name()) { | |||
| operate_type_ = GELU; | |||
| } | |||
| dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | |||
| } | |||
| @@ -150,6 +162,8 @@ void ArithmeticSelfCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs | |||
| threads.emplace_back(std::thread(Floor<T>, input, output, start, end)); | |||
| } else if (operate_type_ == RECIPROCAL) { | |||
| threads.emplace_back(std::thread(Reciprocal<T>, input, output, start, end)); | |||
| } else if (operate_type_ == GELU) { | |||
| threads.emplace_back(std::thread(Gelu<T>, input, output, start, end)); | |||
| } | |||
| start += once_compute_size; | |||
| } | |||
| @@ -62,6 +62,8 @@ MS_REG_CPU_KERNEL(Floor, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutput | |||
| ArithmeticSelfCPUKernel); | |||
| MS_REG_CPU_KERNEL(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ArithmeticSelfCPUKernel); | |||
| MS_REG_CPU_KERNEL(Gelu, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ArithmeticSelfCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -89,6 +89,8 @@ enum OperateType { | |||
| GREATER, | |||
| GREATEREQUAL, | |||
| RECIPROCAL, | |||
| GELU, | |||
| GELUGRAD, | |||
| }; | |||
| class CPUKernel : public kernel::KernelMod { | |||
| @@ -78,6 +78,18 @@ void EltWiseGradCPUKernel::TanhGrad(const T *input1, const T *input2, T *out, si | |||
| } | |||
| } | |||
| template <typename T> | |||
| void EltWiseGradCPUKernel::GeluGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| T x = input2[i]; | |||
| auto double_x = static_cast<T>(x); | |||
| T tanh_res = (T)std::tanh(0.7978845608 * (double_x + 0.044715 * double_x * double_x * double_x)); | |||
| T mul_right = (T)(0.7978845608 + 0.1070322244 * double_x * double_x); | |||
| T y_res = (((T)1.0 + tanh_res) + x * ((T)1.0 - tanh_res * tanh_res) * mul_right) / (T)2.0; | |||
| out[i] = input1[i] * y_res; | |||
| } | |||
| } | |||
| void EltWiseGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| @@ -93,6 +105,8 @@ void EltWiseGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| operate_type_ = TANHGRAD; | |||
| } else if (kernel_name == "SqrtGrad") { | |||
| operate_type_ = SQRTGRAD; | |||
| } else if (kernel_name == "GeluGrad") { | |||
| operate_type_ = GELUGRAD; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Not support " << kernel_name; | |||
| } | |||
| @@ -172,6 +186,8 @@ void EltWiseGradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, c | |||
| threads.emplace_back(std::thread(&EltWiseGradCPUKernel::TanhGrad<T>, this, input1, input2, output, start, end)); | |||
| } else if (operate_type_ == SQRTGRAD) { | |||
| threads.emplace_back(std::thread(&EltWiseGradCPUKernel::SqrtGrad<T>, this, input1, input2, output, start, end)); | |||
| } else if (operate_type_ == GELUGRAD) { | |||
| threads.emplace_back(std::thread(&EltWiseGradCPUKernel::GeluGrad<T>, this, input1, input2, output, start, end)); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Not support " << operate_type_; | |||
| } | |||
| @@ -47,6 +47,8 @@ class EltWiseGradCPUKernel : public CPUKernel { | |||
| void SqrtGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); | |||
| template <typename T> | |||
| void TanhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); | |||
| template <typename T> | |||
| void GeluGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); | |||
| std::vector<size_t> input_shape0_; | |||
| std::vector<size_t> input_shape1_; | |||
| std::vector<size_t> input_element_num0_; | |||
| @@ -81,6 +83,13 @@ MS_REG_CPU_KERNEL( | |||
| TanhGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| EltWiseGradCPUKernel); | |||
| MS_REG_CPU_KERNEL(GeluGrad, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| EltWiseGradCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,105 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/cpu/layer_norm_cpu_kernel.h" | |||
| #include "runtime/device/cpu/cpu_device_address.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| void LayerNormCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| CheckParam(kernel_node); | |||
| dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | |||
| std::vector<size_t> x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| auto begin_norm_axis = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "begin_norm_axis"); | |||
| auto begin_params_axis = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "begin_params_axis"); | |||
| if (begin_norm_axis < 0) { | |||
| begin_norm_axis += x_shape.size(); | |||
| } | |||
| if (begin_params_axis < 0) { | |||
| begin_params_axis += x_shape.size(); | |||
| } | |||
| for (size_t i = 0; i < IntToSize(begin_norm_axis); i++) { | |||
| block_num_ *= x_shape[i]; | |||
| } | |||
| for (size_t i = IntToSize(begin_norm_axis); i < x_shape.size(); i++) { | |||
| block_size_ *= x_shape[i]; | |||
| } | |||
| for (size_t i = IntToSize(begin_params_axis); i < x_shape.size(); i++) { | |||
| param_num_ *= x_shape[i]; | |||
| } | |||
| if (block_num_ <= 0 || block_size_ <= 0) { | |||
| MS_LOG(EXCEPTION) << "LayerNormCPUKernel input shape error, input shape: " << x_shape; | |||
| } | |||
| } | |||
| bool LayerNormCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| if (dtype_ == kNumberTypeFloat16) { | |||
| LaunchKernel<float16>(inputs, outputs); | |||
| } else if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat64) { | |||
| LaunchKernel<float>(inputs, outputs); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "input dtype only support float16, float32, float64"; | |||
| } | |||
| return true; | |||
| } | |||
| template <typename T> | |||
| void LayerNormCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) { | |||
| size_t f_size = sizeof(T); | |||
| if (inputs[1]->size != f_size * param_num_ || inputs[2]->size != f_size * param_num_) { | |||
| MS_LOG(EXCEPTION) << "The product of gamma and beta's shape must be " << param_num_; | |||
| } | |||
| if (outputs[1]->size != f_size * block_num_ || outputs[2]->size != f_size * block_num_) { | |||
| MS_LOG(EXCEPTION) << "The product of mean and var's shape must be " << block_num_; | |||
| } | |||
| auto x = reinterpret_cast<T *>(inputs[0]->addr); | |||
| auto gamma = reinterpret_cast<T *>(inputs[1]->addr); | |||
| auto beta = reinterpret_cast<T *>(inputs[2]->addr); | |||
| auto y = reinterpret_cast<T *>(outputs[0]->addr); | |||
| auto mean = reinterpret_cast<T *>(outputs[1]->addr); | |||
| auto var = reinterpret_cast<T *>(outputs[2]->addr); | |||
| for (size_t i = 0; i < block_num_; ++i) { | |||
| T sum = (T)0.0; | |||
| T square_sum = (T)0.0; | |||
| for (size_t j = i * block_size_; j < (i + 1) * block_size_; ++j) { | |||
| sum += x[j]; | |||
| square_sum += x[j] * x[j]; | |||
| } | |||
| T block_mean = sum / block_size_; | |||
| T block_var = square_sum / block_size_ - block_mean * block_mean; | |||
| for (size_t j = i * block_size_; j < (i + 1) * block_size_; ++j) { | |||
| auto param_shift = j % param_num_; | |||
| y[j] = (x[j] - block_mean) / (T)std::sqrt(static_cast<double>(block_var) + eps_) * gamma[param_shift] + | |||
| beta[param_shift]; | |||
| } | |||
| mean[i] = block_mean; | |||
| var[i] = block_var; | |||
| } | |||
| } | |||
| void LayerNormCPUKernel::CheckParam(const CNodePtr &kernel_node) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 3) { | |||
| MS_LOG(EXCEPTION) << "LayerNormCPUKernel needs 3 inputs, but gets " << input_num; | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| if (output_num != 3) { | |||
| MS_LOG(EXCEPTION) << "LayerNormCPUKernel expects 3 output, but gets" << output_num; | |||
| } | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,70 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LAYER_NORM_CPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LAYER_NORM_CPU_KERNEL_H_ | |||
| #include <memory> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class LayerNormCPUKernel : public CPUKernel { | |||
| public: | |||
| LayerNormCPUKernel() = default; | |||
| ~LayerNormCPUKernel() override = default; | |||
| void InitKernel(const CNodePtr &kernel_node) override; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| template <typename T> | |||
| void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs); | |||
| private: | |||
| void CheckParam(const CNodePtr &kernel_node); | |||
| TypeId dtype_{kTypeUnknown}; | |||
| float eps_{1e-12}; | |||
| size_t block_num_{1}; | |||
| size_t block_size_{1}; | |||
| size_t param_num_{1}; | |||
| }; | |||
| MS_REG_CPU_KERNEL(LayerNorm, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeFloat16), | |||
| LayerNormCPUKernel); | |||
| MS_REG_CPU_KERNEL(LayerNorm, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| LayerNormCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LAYER_NORM_CPU_KERNEL_H_ | |||
| @@ -0,0 +1,124 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/cpu/layer_norm_grad_cpu_kernel.h" | |||
| #include "runtime/device/cpu/cpu_device_address.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| void LayerNormGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| CheckParam(kernel_node); | |||
| dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | |||
| std::vector<size_t> x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| auto begin_norm_axis = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "begin_norm_axis"); | |||
| auto begin_params_axis = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "begin_params_axis"); | |||
| if (begin_norm_axis < 0) { | |||
| begin_norm_axis += x_shape.size(); | |||
| } | |||
| if (begin_params_axis < 0) { | |||
| begin_params_axis += x_shape.size(); | |||
| } | |||
| for (size_t i = 0; i < IntToSize(begin_norm_axis); i++) { | |||
| block_num_ *= x_shape[i]; | |||
| } | |||
| for (size_t i = IntToSize(begin_norm_axis); i < x_shape.size(); i++) { | |||
| block_size_ *= x_shape[i]; | |||
| } | |||
| for (size_t i = 0; i < IntToSize(begin_params_axis); i++) { | |||
| param_size_ *= x_shape[i]; | |||
| } | |||
| for (size_t i = begin_params_axis; i < x_shape.size(); i++) { | |||
| param_num_ *= x_shape[i]; | |||
| } | |||
| if (block_num_ <= 0 || block_size_ <= 0) { | |||
| MS_LOG(EXCEPTION) << "LayerNormGradCPUKernel input shape error, input shape: " << x_shape; | |||
| } | |||
| } | |||
| bool LayerNormGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> &workspace, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| if (dtype_ == kNumberTypeFloat16) { | |||
| LaunchKernel<float16>(inputs, workspace, outputs); | |||
| } else if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat64) { | |||
| LaunchKernel<float>(inputs, workspace, outputs); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "input dtype only support float16, float32, float64"; | |||
| } | |||
| return true; | |||
| } | |||
| template <typename T> | |||
| void LayerNormGradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | |||
| const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) { | |||
| auto x = reinterpret_cast<T *>(inputs[0]->addr); | |||
| auto dy = reinterpret_cast<T *>(inputs[1]->addr); | |||
| auto var = reinterpret_cast<T *>(inputs[2]->addr); | |||
| auto mean = reinterpret_cast<T *>(inputs[3]->addr); | |||
| auto gamma = reinterpret_cast<T *>(inputs[4]->addr); | |||
| auto dx = reinterpret_cast<T *>(outputs[0]->addr); | |||
| auto dg = reinterpret_cast<T *>(outputs[1]->addr); | |||
| auto db = reinterpret_cast<T *>(outputs[2]->addr); | |||
| for (size_t i = 0; i < param_num_; ++i) { | |||
| T dgamma = (T)0.0; | |||
| T dbeta = (T)0.0; | |||
| for (size_t j = i; j < param_size_ * param_num_; j += param_num_) { | |||
| auto norm_shift = static_cast<int>(j / block_size_); | |||
| dgamma += dy[j] * (T)std::pow(static_cast<double>(var[norm_shift]) + eps_, -0.5) * (x[j] - mean[norm_shift]); | |||
| dbeta += dy[j]; | |||
| } | |||
| dg[i] = dgamma; | |||
| db[i] = dbeta; | |||
| } | |||
| for (size_t i = 0; i < block_num_; ++i) { | |||
| T sum1 = (T)0.0; | |||
| T sum2 = (T)0.0; | |||
| T sum3 = (T)0.0; | |||
| for (size_t j = i * block_size_; j < (i + 1) * block_size_; ++j) { | |||
| auto param_shift = j % param_num_; | |||
| auto norm_shift = static_cast<int>(j / block_size_); | |||
| auto dxm = x[j] - mean[norm_shift]; | |||
| auto dyg = dy[j] * gamma[param_shift]; | |||
| sum1 += (T)(-0.5) * dyg * dxm * (T)std::pow(static_cast<double>(var[norm_shift]) + eps_, -1.5); | |||
| sum2 += dyg; | |||
| sum3 += (T)(-2.0) * dxm; | |||
| } | |||
| for (size_t j = i * block_size_; j < (i + 1) * block_size_; ++j) { | |||
| auto param_shift = j % param_num_; | |||
| auto norm_shift = static_cast<int>(j / block_size_); | |||
| auto var_sqrt = (T)std::pow(static_cast<double>(var[norm_shift]) + eps_, -0.5); | |||
| auto dx1 = dy[j] * gamma[param_shift] * var_sqrt; | |||
| auto dx2 = sum1 * (T)2.0 / block_size_ * (x[j] - mean[norm_shift]); | |||
| auto dx3 = ((T)(-1.0) * var_sqrt * sum2 + ((T)1.0 / block_size_) * sum1 * sum3) * ((T)1.0 / block_size_); | |||
| dx[j] = dx1 + dx2 + dx3; | |||
| } | |||
| } | |||
| } | |||
| void LayerNormGradCPUKernel::CheckParam(const CNodePtr &kernel_node) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 5) { | |||
| MS_LOG(EXCEPTION) << "LayerNormGradCPUKernel needs 5 inputs, but gets " << input_num; | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| if (output_num != 3) { | |||
| MS_LOG(EXCEPTION) << "LayerNormGradCPUKernel expects 3 output, but gets" << output_num; | |||
| } | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,76 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LAYER_NORM_GRAD_CPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LAYER_NORM_GRAD_CPU_KERNEL_H_ | |||
| #include <memory> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class LayerNormGradCPUKernel : public CPUKernel { | |||
| public: | |||
| LayerNormGradCPUKernel() = default; | |||
| ~LayerNormGradCPUKernel() override = default; | |||
| void InitKernel(const CNodePtr &kernel_node) override; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| template <typename T> | |||
| void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs); | |||
| private: | |||
| void CheckParam(const CNodePtr &kernel_node); | |||
| TypeId dtype_{kTypeUnknown}; | |||
| float eps_{1e-12}; | |||
| size_t block_num_{1}; | |||
| size_t block_size_{1}; | |||
| size_t param_num_{1}; | |||
| size_t param_size_{1}; | |||
| }; | |||
| MS_REG_CPU_KERNEL(LayerNormGrad, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeFloat16), | |||
| LayerNormGradCPUKernel); | |||
| MS_REG_CPU_KERNEL(LayerNormGrad, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| LayerNormGradCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LAYER_NORM_GRAD_CPU_KERNEL_H_ | |||
| @@ -0,0 +1,63 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
| class GeluNet(nn.Cell): | |||
| def __init__(self): | |||
| super(GeluNet, self).__init__() | |||
| self.gelu = P.Gelu() | |||
| def construct(self, x): | |||
| return self.gelu(x) | |||
| class Grad(nn.Cell): | |||
| def __init__(self, network): | |||
| super(Grad, self).__init__() | |||
| self.grad = C.GradOperation(get_all=True, sens_param=True) | |||
| self.network = network | |||
| def construct(self, input_data, sens): | |||
| gout = self.grad(self.network)(input_data, sens) | |||
| return gout | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_gelugrad(): | |||
| x_ms = Tensor(np.array([0.58401114, 0.68800163, 0.9760397, 0.14702141, 0.46563736, 0.9607501, | |||
| 0.14567593, 0.12261796, 0.37054458, 0.46421242]).astype(np.float32)) | |||
| dy_ms = Tensor(np.array([0.5559598, 0.96994054, 0.24770357, 0.34646875, 0.2984393, 0.03287048, | |||
| 0.55681044, 0.966908, 0.06015943, 0.6099489]).astype(np.float32)) | |||
| net = GeluNet() | |||
| grad = Grad(net) | |||
| output = grad(x_ms, dy_ms) | |||
| expect = [0.50963277, 0.9414753, 0.2667653, 0.21358444, 0.25243032, 0.0352667, | |||
| 0.34266686, 0.57757664, 0.04707306, 0.51536125] | |||
| assert np.allclose(output[0].asnumpy(), expect) | |||
| @@ -0,0 +1,93 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
| class GeluNet(nn.Cell): | |||
| def __init__(self): | |||
| super(GeluNet, self).__init__() | |||
| self.gelu = P.Gelu() | |||
| def construct(self, x): | |||
| return self.gelu(x) | |||
| def GeluCompute(x): | |||
| return 0.5 * x * (1.0 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * x * x * x))) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_gelu_1d(): | |||
| x_np = np.random.random((50,)).astype(np.float32) | |||
| y_np = GeluCompute(x_np) | |||
| x_ms = Tensor(x_np) | |||
| net = GeluNet() | |||
| y_ms = net(x_ms) | |||
| assert np.allclose(y_np, y_ms.asnumpy()) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_gelu_2d(): | |||
| x_np = np.random.random((50, 40)).astype(np.float32) | |||
| y_np = GeluCompute(x_np) | |||
| x_ms = Tensor(x_np) | |||
| net = GeluNet() | |||
| y_ms = net(x_ms) | |||
| assert np.allclose(y_np, y_ms.asnumpy()) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_gelu_4d(): | |||
| x_np = np.random.random((32, 3, 224, 224)).astype(np.float32) | |||
| y_np = GeluCompute(x_np) | |||
| x_ms = Tensor(x_np) | |||
| net = GeluNet() | |||
| y_ms = net(x_ms) | |||
| assert np.allclose(y_np, y_ms.asnumpy()) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_gelu_neg(): | |||
| x_np = np.random.random((32, 3, 224, 224)).astype(np.float32) * -1 | |||
| y_np = GeluCompute(x_np) | |||
| x_ms = Tensor(x_np) | |||
| net = GeluNet() | |||
| y_ms = net(x_ms) | |||
| assert np.allclose(y_np, y_ms.asnumpy()) | |||
| @@ -0,0 +1,221 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops.operations import _grad_ops as G | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
| class LayerNormGradNet(nn.Cell): | |||
| def __init__(self, begin_norm_axis, begin_params_axis): | |||
| super(LayerNormGradNet, self).__init__() | |||
| self.norm = G.LayerNormGrad(begin_norm_axis, begin_params_axis) | |||
| def construct(self, dy, x, var, mean, gamma): | |||
| return self.norm(dy, x, var, mean, gamma) | |||
| def LayerNormGradReference(x, dy, gamma, epsilon, begin_norm_axis, begin_params_axis): | |||
| begin_norm_axis = begin_norm_axis if begin_norm_axis >= 0 else begin_norm_axis + len(x.shape) | |||
| begin_params_axis = begin_params_axis if begin_params_axis >= 0 else begin_params_axis + len(x.shape) | |||
| norm_axis = [i for i in range(begin_norm_axis, len(x.shape))] | |||
| param_axis = [i for i in range(0, begin_params_axis)] | |||
| num = 1 | |||
| for i in range(begin_norm_axis, len(x.shape)): | |||
| num *= x.shape[i] | |||
| mean = np.mean(x, axis=tuple(norm_axis), keepdims=True) | |||
| var = np.var(x, axis=tuple(norm_axis), keepdims=True) | |||
| gamma = gamma.reshape((*((1,) * begin_params_axis), *x.shape[begin_params_axis:])) | |||
| dg = np.sum(dy * np.power(var + epsilon, -0.5) * (x - mean), axis=tuple(param_axis), keepdims=True) | |||
| db = np.sum(dy, axis=tuple(param_axis), keepdims=True) | |||
| sum1 = np.sum((-0.5) * dy * gamma * (x - mean) * np.power(var + epsilon, -1.5), axis=tuple(norm_axis), | |||
| keepdims=True) | |||
| sum2 = np.sum(dy * gamma, axis=tuple(norm_axis), keepdims=True) | |||
| sum3 = np.sum(-2.0 * (x - mean), axis=tuple(norm_axis), keepdims=True) | |||
| dx1 = dy * gamma * np.power(var + epsilon, -0.5) | |||
| dx2 = sum1 * 2.0 / num * (x - mean) | |||
| dx3 = ((-1.0) * np.power(var + epsilon, -0.5) * sum2 + (1.0 / num) * sum1 * sum3) * (1.0 / num) | |||
| dx = dx1 + dx2 + dx3 | |||
| return dx, dg, db, mean, var | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_layernormgrad0(): | |||
| begin_norm_axis = 1 | |||
| begin_params_axis = 1 | |||
| x_np = np.random.randn(4096, 3072).astype(np.float32) | |||
| dy_np = np.random.randn(4096, 3072).astype(np.float32) | |||
| gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) | |||
| epsilon = 10e-12 | |||
| dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis, | |||
| begin_params_axis) | |||
| dy_ms = Tensor(dy_np) | |||
| x_ms = Tensor(x_np) | |||
| var_ms = Tensor(var_np) | |||
| mean_ms = Tensor(mean_np) | |||
| gamma_ms = Tensor(gamma_np) | |||
| net = LayerNormGradNet(begin_norm_axis, begin_params_axis) | |||
| dx_ms, dg_ms, db_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms) | |||
| assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-4, atol=1e-4) | |||
| assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-4, atol=1e-3) | |||
| assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-4, atol=1e-3) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_layernormgrad1(): | |||
| begin_norm_axis = 1 | |||
| begin_params_axis = 1 | |||
| x_np = np.random.randn(640, 768).astype(np.float32) | |||
| dy_np = np.random.randn(640, 768).astype(np.float32) | |||
| gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) | |||
| epsilon = 10e-12 | |||
| dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis, | |||
| begin_params_axis) | |||
| dy_ms = Tensor(dy_np) | |||
| x_ms = Tensor(x_np) | |||
| var_ms = Tensor(var_np) | |||
| mean_ms = Tensor(mean_np) | |||
| gamma_ms = Tensor(gamma_np) | |||
| net = LayerNormGradNet(begin_norm_axis, begin_params_axis) | |||
| dx_ms, dg_ms, db_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms) | |||
| assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-4, atol=1e-4) | |||
| assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-4, atol=1e-3) | |||
| assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-4, atol=1e-3) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_layernormgrad2(): | |||
| begin_norm_axis = -1 | |||
| begin_params_axis = -1 | |||
| x_np = np.random.randn(32, 128, 768).astype(np.float32) | |||
| dy_np = np.random.randn(32, 128, 768).astype(np.float32) | |||
| gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) | |||
| epsilon = 10e-12 | |||
| dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis, | |||
| begin_params_axis) | |||
| dy_ms = Tensor(dy_np) | |||
| x_ms = Tensor(x_np) | |||
| var_ms = Tensor(var_np) | |||
| mean_ms = Tensor(mean_np) | |||
| gamma_ms = Tensor(gamma_np) | |||
| net = LayerNormGradNet(begin_norm_axis, begin_params_axis) | |||
| dx_ms, dg_ms, db_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms) | |||
| assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-4, atol=1e-4) | |||
| assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-4, atol=1e-3) | |||
| assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-4, atol=1e-3) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_layernormgrad3(): | |||
| begin_norm_axis = -1 | |||
| begin_params_axis = -1 | |||
| x_np = np.random.randn(32, 64).astype(np.float32) | |||
| dy_np = np.random.randn(32, 64).astype(np.float32) | |||
| gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) | |||
| epsilon = 10e-12 | |||
| dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis, | |||
| begin_params_axis) | |||
| dy_ms = Tensor(dy_np) | |||
| x_ms = Tensor(x_np) | |||
| var_ms = Tensor(var_np) | |||
| mean_ms = Tensor(mean_np) | |||
| gamma_ms = Tensor(gamma_np) | |||
| net = LayerNormGradNet(begin_norm_axis, begin_params_axis) | |||
| dx_ms, dg_ms, db_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms) | |||
| assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-4, atol=1e-4) | |||
| assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-4, atol=1e-3) | |||
| assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-4, atol=1e-3) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_layernormgrad4(): | |||
| begin_norm_axis = -1 | |||
| begin_params_axis = -1 | |||
| x_np = np.random.randn(32, 64).astype(np.float32) | |||
| dy_np = np.random.randn(32, 64).astype(np.float32) | |||
| gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) | |||
| epsilon = 10e-12 | |||
| dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis, | |||
| begin_params_axis) | |||
| dy_ms = Tensor(dy_np) | |||
| x_ms = Tensor(x_np) | |||
| var_ms = Tensor(var_np) | |||
| mean_ms = Tensor(mean_np) | |||
| gamma_ms = Tensor(gamma_np) | |||
| net = LayerNormGradNet(begin_norm_axis, begin_params_axis) | |||
| dx_ms, dg_ms, db_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms) | |||
| assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-4, atol=1e-4) | |||
| assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-4, atol=1e-3) | |||
| assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-4, atol=1e-3) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_layernormgrad5(): | |||
| begin_norm_axis = 2 | |||
| begin_params_axis = 1 | |||
| x_np = np.random.randn(128, 2, 16, 32).astype(np.float32) | |||
| dy_np = np.random.randn(128, 2, 16, 32).astype(np.float32) | |||
| gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) | |||
| epsilon = 10e-12 | |||
| dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis, | |||
| begin_params_axis) | |||
| dy_ms = Tensor(dy_np) | |||
| x_ms = Tensor(x_np) | |||
| var_ms = Tensor(var_np) | |||
| mean_ms = Tensor(mean_np) | |||
| gamma_ms = Tensor(gamma_np) | |||
| net = LayerNormGradNet(begin_norm_axis, begin_params_axis) | |||
| dx_ms, dg_ms, db_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms) | |||
| assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-4, atol=1e-4) | |||
| assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-4, atol=1e-3) | |||
| assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-4, atol=1e-3) | |||
| @@ -0,0 +1,199 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
| class LayerNormNet(nn.Cell): | |||
| def __init__(self, begin_norm_axis, begin_params_axis): | |||
| super(LayerNormNet, self).__init__() | |||
| self.norm = P.LayerNorm(begin_norm_axis, begin_params_axis) | |||
| def construct(self, x, gamma, beta): | |||
| return self.norm(x, gamma, beta) | |||
| def LayerNormReference(begin_norm_axis, begin_params_axis, x, gamma, beta): | |||
| begin_norm_axis = begin_norm_axis if begin_norm_axis >= 0 else begin_norm_axis + len(x.shape) | |||
| begin_params_axis = begin_params_axis if begin_params_axis >= 0 else begin_params_axis + len(x.shape) | |||
| axis = [i for i in range(begin_norm_axis, len(x.shape))] | |||
| mean = np.mean(x, axis=tuple(axis), keepdims=True) | |||
| var = np.var(x, axis=tuple(axis), keepdims=True) | |||
| gamma = gamma.reshape((*((1,) * begin_params_axis), *x.shape[begin_params_axis:])) | |||
| beta = beta.reshape((*((1,) * begin_params_axis), *x.shape[begin_params_axis:])) | |||
| y = np.subtract(x, mean) / np.sqrt(var + 1e-12) * gamma + beta | |||
| return y, mean, var | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_layernorm0(): | |||
| begin_norm_axis = 1 | |||
| begin_params_axis = 1 | |||
| x_np = np.random.randn(4096, 3072).astype(np.float32) | |||
| gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) | |||
| beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) | |||
| y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np) | |||
| x_ms = Tensor(x_np) | |||
| gamma_ms = Tensor(gamma_np) | |||
| beta_ms = Tensor(beta_np) | |||
| net = LayerNormNet(begin_norm_axis, begin_params_axis) | |||
| y_ms, mean_ms, var_ms = net(x_ms, gamma_ms, beta_ms) | |||
| assert np.allclose(y_ms.asnumpy(), y_np, atol=1e-4) | |||
| assert np.allclose(mean_ms.asnumpy(), mean_np, atol=1e-4) | |||
| assert np.allclose(var_ms.asnumpy(), var_np, atol=1e-4) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_layernorm1(): | |||
| begin_norm_axis = 1 | |||
| begin_params_axis = 1 | |||
| x_np = np.random.randn(640, 768).astype(np.float32) | |||
| gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) | |||
| beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) | |||
| y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np) | |||
| x_ms = Tensor(x_np) | |||
| gamma_ms = Tensor(gamma_np) | |||
| beta_ms = Tensor(beta_np) | |||
| net = LayerNormNet(begin_norm_axis, begin_params_axis) | |||
| y_ms, mean_ms, var_ms = net(x_ms, gamma_ms, beta_ms) | |||
| assert np.allclose(y_ms.asnumpy(), y_np, rtol=1e-6, atol=1e-4) | |||
| assert np.allclose(mean_ms.asnumpy(), mean_np, rtol=1e-6, atol=1e-4) | |||
| assert np.allclose(var_ms.asnumpy(), var_np, rtol=1e-6, atol=1e-4) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_layernorm3d_1(): | |||
| begin_norm_axis = -1 | |||
| begin_params_axis = -1 | |||
| x_np = np.random.randn(32, 128, 768).astype(np.float32) | |||
| gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) | |||
| beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) | |||
| y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np) | |||
| x_ms = Tensor(x_np) | |||
| gamma_ms = Tensor(gamma_np) | |||
| beta_ms = Tensor(beta_np) | |||
| net = LayerNormNet(begin_norm_axis, begin_params_axis) | |||
| y_ms, mean_ms, var_ms = net(x_ms, gamma_ms, beta_ms) | |||
| assert np.allclose(y_ms.asnumpy(), y_np, rtol=1e-6, atol=1e-4) | |||
| assert np.allclose(mean_ms.asnumpy(), mean_np, rtol=1e-6, atol=1e-4) | |||
| assert np.allclose(var_ms.asnumpy(), var_np, rtol=1e-6, atol=1e-4) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_layernorm3d_2(): | |||
| begin_norm_axis = -1 | |||
| begin_params_axis = 1 | |||
| x_np = np.random.randn(32, 128, 768).astype(np.float32) | |||
| gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) | |||
| beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) | |||
| y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np) | |||
| x_ms = Tensor(x_np) | |||
| gamma_ms = Tensor(gamma_np) | |||
| beta_ms = Tensor(beta_np) | |||
| net = LayerNormNet(begin_norm_axis, begin_params_axis) | |||
| y_ms, mean_ms, var_ms = net(x_ms, gamma_ms, beta_ms) | |||
| assert np.allclose(y_ms.asnumpy(), y_np, rtol=1e-6, atol=1e-4) | |||
| assert np.allclose(mean_ms.asnumpy(), mean_np, rtol=1e-6, atol=1e-4) | |||
| assert np.allclose(var_ms.asnumpy(), var_np, rtol=1e-6, atol=1e-4) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_layernorm2d_2(): | |||
| begin_norm_axis = -1 | |||
| begin_params_axis = 1 | |||
| x_np = np.random.randn(64, 32).astype(np.float32) | |||
| gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) | |||
| beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) | |||
| y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np) | |||
| x_ms = Tensor(x_np) | |||
| gamma_ms = Tensor(gamma_np) | |||
| beta_ms = Tensor(beta_np) | |||
| net = LayerNormNet(begin_norm_axis, begin_params_axis) | |||
| y_ms, mean_ms, var_ms = net(x_ms, gamma_ms, beta_ms) | |||
| assert np.allclose(y_ms.asnumpy(), y_np, rtol=1e-6, atol=1e-4) | |||
| assert np.allclose(mean_ms.asnumpy(), mean_np, rtol=1e-6, atol=1e-4) | |||
| assert np.allclose(var_ms.asnumpy(), var_np, rtol=1e-6, atol=1e-4) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_layernorm2d_3(): | |||
| begin_norm_axis = -1 | |||
| begin_params_axis = 1 | |||
| x_np = np.random.randn(128, 128).astype(np.float32) | |||
| gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) | |||
| beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) | |||
| y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np) | |||
| x_ms = Tensor(x_np) | |||
| gamma_ms = Tensor(gamma_np) | |||
| beta_ms = Tensor(beta_np) | |||
| net = LayerNormNet(begin_norm_axis, begin_params_axis) | |||
| y_ms, mean_ms, var_ms = net(x_ms, gamma_ms, beta_ms) | |||
| assert np.allclose(y_ms.asnumpy(), y_np, rtol=1e-6, atol=1e-4) | |||
| assert np.allclose(mean_ms.asnumpy(), mean_np, rtol=1e-6, atol=1e-4) | |||
| assert np.allclose(var_ms.asnumpy(), var_np, rtol=1e-6, atol=1e-4) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_layernorm2d_4(): | |||
| begin_norm_axis = 2 | |||
| begin_params_axis = 1 | |||
| np.random.seed(42) | |||
| x_np = np.random.randn(128, 2, 16, 32).astype(np.float32) | |||
| gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) | |||
| beta_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) | |||
| y_np, mean_np, var_np = LayerNormReference(begin_norm_axis, begin_params_axis, x_np, gamma_np, beta_np) | |||
| x_ms = Tensor(x_np) | |||
| gamma_ms = Tensor(gamma_np) | |||
| beta_ms = Tensor(beta_np) | |||
| net = LayerNormNet(begin_norm_axis, begin_params_axis) | |||
| y_ms, mean_ms, var_ms = net(x_ms, gamma_ms, beta_ms) | |||
| assert np.allclose(y_ms.asnumpy(), y_np, rtol=1e-6, atol=1e-4) | |||
| assert np.allclose(mean_ms.asnumpy(), mean_np, rtol=1e-6, atol=1e-4) | |||
| assert np.allclose(var_ms.asnumpy(), var_np, rtol=1e-6, atol=1e-4) | |||