| @@ -17,33 +17,118 @@ from mindspore._extends.graph_kernel.model.model import DataFormat as DF | |||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||
| @VLD.add_format(DF.FRAC_NZ, DF.DEFAULT, DF.DEFAULT) | |||
| @VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT) | |||
| @VLD.check_attrs('begin_norm_axis', 'begin_params_axis', 'epsilon') | |||
| class LayerNorm(Expander): | |||
| """LayerNorm expander""" | |||
| def to_frac_z_axis(self, ori_shape, ori_axis): | |||
| """ | |||
| judge the format is fractal NZ | |||
| Parameters | |||
| ---------- | |||
| ori_shape: list or tuple | |||
| original shape of input | |||
| ori_axis: list or tuple | |||
| original axis of original shape to operate | |||
| Returns | |||
| ------- | |||
| output: list | |||
| axis of the fractal Nz shape | |||
| """ | |||
| frac_z_axis = list(ori_axis) | |||
| shape_len = len(ori_shape) | |||
| axis_count = len(frac_z_axis) | |||
| axis_negative_1 = shape_len - 1 | |||
| axis_negative_2 = shape_len - 2 | |||
| for i in range(axis_count): | |||
| axis_index = (frac_z_axis[i] + shape_len) % shape_len | |||
| if axis_index == axis_negative_1: | |||
| if frac_z_axis[i] > shape_len - 2: # akg:[2,3] [1,4] tbe:[2,4] [1,3] | |||
| frac_z_axis[i] = axis_index - 1 | |||
| frac_z_axis.append(axis_index + 2) | |||
| else: # no case cover this branch now | |||
| frac_z_axis[i] = axis_index - 1 | |||
| frac_z_axis.append(axis_index + 2) | |||
| elif axis_index == axis_negative_2: | |||
| frac_z_axis[i] = axis_index + 1 | |||
| frac_z_axis.append(axis_index + 2) | |||
| else: | |||
| frac_z_axis[i] = axis_index | |||
| return frac_z_axis | |||
| def infer_shape_from_fractalNz(self, fractal): | |||
| "get original shape from fractalNz shape" | |||
| shape = [] | |||
| dims = len(fractal) | |||
| batch = dims - 4 | |||
| for i in range(batch): | |||
| shape.append(fractal[i]) | |||
| m = fractal[dims - 3] * fractal[dims - 2] | |||
| n = fractal[dims - 4] * fractal[dims - 1] | |||
| shape.append(m) | |||
| shape.append(n) | |||
| return shape | |||
| def get_reduced_ori_shape(self, shape, axis): | |||
| "get shape after reduced which is based on original shape" | |||
| reduced_ori_shape = [] | |||
| for i, value in enumerate(shape): | |||
| if i in axis: | |||
| reduced_ori_shape.append(1) | |||
| else: | |||
| reduced_ori_shape.append(value) | |||
| return reduced_ori_shape | |||
| def _expand(self, graph_builder): | |||
| input_x, input_gamma, input_beta = self.inputs | |||
| processor = self.processor | |||
| begin_norm_axis = self.attrs['begin_norm_axis'] | |||
| epsilon = self.attrs['epsilon'] | |||
| ori_dtype = input_x.dtype | |||
| if processor == 'aicore' and ori_dtype == 'float16': | |||
| input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float32'}) | |||
| input_gamma = graph_builder.emit('Cast', [input_gamma], attrs={'dst_type': 'float32'}) | |||
| input_beta = graph_builder.emit('Cast', [input_beta], attrs={'dst_type': 'float32'}) | |||
| ori_shape_x = input_x.shape | |||
| if input_x.data_format == DF.FRAC_NZ: | |||
| ori_shape_x = self.infer_shape_from_fractalNz(input_x.shape) | |||
| # Calculate the scaling ratio of the average | |||
| if begin_norm_axis < 0: | |||
| begin_norm_axis += len(input_x.shape) | |||
| begin_norm_axis += len(ori_shape_x) | |||
| reduce_axis = () | |||
| for i, _ in enumerate(input_x.shape): | |||
| for i, _ in enumerate(ori_shape_x): | |||
| if i > begin_norm_axis or i == begin_norm_axis: | |||
| reduce_axis = reduce_axis + (i,) | |||
| reduce_elts = 1.0 | |||
| for i in reduce_axis: | |||
| reduce_elts *= input_x.shape[i] | |||
| reduce_elts *= ori_shape_x[i] | |||
| if input_x.data_format == DF.FRAC_NZ: | |||
| reduce_axis = self.to_frac_z_axis(ori_shape_x, reduce_axis) | |||
| ori_shape_x = self.get_reduced_ori_shape(ori_shape_x, reduce_axis) # after reduced | |||
| mean_cof = 1.0 / reduce_elts | |||
| mean_cof_v = graph_builder.value(input_x.dtype, mean_cof) | |||
| # Calculate mean | |||
| mean_red = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': True}) | |||
| mean_red = graph_builder.emit('ReduceSum', [input_x], | |||
| attrs={'reduce_axis': reduce_axis, 'keep_dims': True}) | |||
| mean = graph_builder.emit('Mul', [mean_red, mean_cof_v]) | |||
| if input_x.data_format == DF.FRAC_NZ: | |||
| mean = graph_builder.emit('Reshape', [mean], attrs={'shape': ori_shape_x}) | |||
| # Calculate variance | |||
| variance_sub = graph_builder.emit('Sub', [input_x, mean]) | |||
| @@ -51,6 +136,8 @@ class LayerNorm(Expander): | |||
| variance_red = graph_builder.emit('ReduceSum', [variance_mul], | |||
| attrs={'reduce_axis': reduce_axis, 'keep_dims': True}) | |||
| variance = graph_builder.emit('Mul', [variance_red, mean_cof_v]) | |||
| if input_x.data_format == DF.FRAC_NZ: | |||
| variance = graph_builder.emit('Reshape', [variance], attrs={'shape': ori_shape_x}) | |||
| # Calculate normalize | |||
| normalize_sub = graph_builder.emit('Sub', [input_x, mean]) | |||
| @@ -60,7 +147,11 @@ class LayerNorm(Expander): | |||
| normalize_mul = graph_builder.emit('Mul', [normalize_sub, normlize_rsqrt]) | |||
| # Calculate scale and translate | |||
| scale_mul = graph_builder.emit('Mul', [input_gamma, normalize_mul]) | |||
| scale_mul = graph_builder.emit('Mul', [normalize_mul, input_gamma]) | |||
| res = graph_builder.emit('Add', [scale_mul, input_beta]) | |||
| if processor == 'aicore' and ori_dtype == 'float16': | |||
| res = graph_builder.emit('Cast', [res], attrs={'dst_type': 'float16'}) | |||
| mean = graph_builder.emit('Cast', [mean], attrs={'dst_type': 'float16'}) | |||
| variance = graph_builder.emit('Cast', [variance], attrs={'dst_type': 'float16'}) | |||
| return res, mean, variance | |||
| @@ -24,23 +24,33 @@ class LayerNormGrad(Expander): | |||
| def _expand(self, graph_builder): | |||
| x, dy, variance, mean, gamma = self.inputs | |||
| processor = self.processor | |||
| begin_norm_axis = self.attrs['begin_norm_axis'] | |||
| begin_params_axis = self.attrs['begin_params_axis'] | |||
| epsilon = self.attrs['epsilon'] if 'epsilon' in self.attrs else 1e-11 | |||
| epsilon = self.attrs['epsilon'] if 'epsilon' in self.attrs else 1e-12 | |||
| ori_dtype = x.dtype | |||
| if processor == 'aicore' and ori_dtype == 'float16': | |||
| x = graph_builder.emit('Cast', [x], attrs={'dst_type': 'float32'}) | |||
| dy = graph_builder.emit('Cast', [dy], attrs={'dst_type': 'float32'}) | |||
| variance = graph_builder.emit('Cast', [variance], attrs={'dst_type': 'float32'}) | |||
| mean = graph_builder.emit('Cast', [mean], attrs={'dst_type': 'float32'}) | |||
| gamma = graph_builder.emit('Cast', [gamma], attrs={'dst_type': 'float32'}) | |||
| if begin_norm_axis < 0: | |||
| begin_norm_axis += len(x.shape) | |||
| if begin_params_axis < 0: | |||
| begin_params_axis += len(x.shape) | |||
| norm_axis = tuple(range(begin_norm_axis, len(x.shape))) | |||
| param_axis = tuple(range(0, begin_params_axis)) | |||
| reduce_size = 1.0 | |||
| for i in norm_axis: | |||
| reduce_size *= x.shape[i] | |||
| # set some constant val. | |||
| eps = graph_builder.value(x.dtype, epsilon) | |||
| const_one = graph_builder.value(x.dtype, 1.0) | |||
| const_neg_half = graph_builder.value(x.dtype, -0.5) | |||
| const_neg_two = graph_builder.value(x.dtype, -2.0) | |||
| const_two = graph_builder.value(x.dtype, 2.0) | |||
| @@ -49,42 +59,55 @@ class LayerNormGrad(Expander): | |||
| # cal dg db | |||
| var_eps = graph_builder.emit('Add', [variance, eps]) | |||
| sqrt_var_eps = graph_builder.emit('Sqrt', [var_eps]) | |||
| rsqrt_var_eps = graph_builder.emit('RealDiv', [const_one, sqrt_var_eps]) | |||
| var_eps_log = graph_builder.emit('Log', [var_eps]) | |||
| var_eps_mul = graph_builder.emit('Mul', [var_eps_log, const_neg_half]) | |||
| rsqrt_var_eps = graph_builder.emit('Exp', [var_eps_mul]) | |||
| x_sub_mean = graph_builder.emit('Sub', [x, mean]) | |||
| x_sub_mean_mul_rsqrt_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, x_sub_mean]) | |||
| dg_mul = graph_builder.emit('Mul', [dy, x_sub_mean_mul_rsqrt_var_eps]) | |||
| dg = graph_builder.emit('ReduceSum', [dg_mul], attrs={'reduce_axis': param_axis, 'keep_dims': False}) | |||
| db = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': param_axis, 'keep_dims': False}) | |||
| # cal sum_1 | |||
| tmp_var_eps = graph_builder.emit('Mul', [sqrt_var_eps, var_eps]) | |||
| r_tmp_var_eps = graph_builder.emit('RealDiv', [const_one, tmp_var_eps]) | |||
| x_sub_mean_mul_r_tmp_var_eps = graph_builder.emit('Mul', [x_sub_mean, r_tmp_var_eps]) | |||
| # pd_var | |||
| tmp_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, rsqrt_var_eps]) | |||
| r_tmp_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, tmp_var_eps]) | |||
| dy_mul_gamma = graph_builder.emit('Mul', [dy, gamma]) | |||
| tmp_mul = graph_builder.emit('Mul', [dy_mul_gamma, x_sub_mean_mul_r_tmp_var_eps]) | |||
| sum_1_mul = graph_builder.emit('Mul', [const_neg_half, tmp_mul]) | |||
| sum_1 = graph_builder.emit('ReduceSum', [sum_1_mul], attrs={'reduce_axis': norm_axis, 'keep_dims': True}) | |||
| # cal sum_2 | |||
| sum_2 = graph_builder.emit('ReduceSum', [dy_mul_gamma], attrs={'reduce_axis': norm_axis, 'keep_dims': True}) | |||
| # cal sum_3 | |||
| sum_3_mul = graph_builder.emit('Mul', [const_neg_two, x_sub_mean]) | |||
| sum_3 = graph_builder.emit('ReduceSum', [sum_3_mul], attrs={'reduce_axis': norm_axis, 'keep_dims': True}) | |||
| # cal dx = dx1 + dx2 + dx3 | |||
| dx_1 = graph_builder.emit('Mul', [dy_mul_gamma, rsqrt_var_eps]) | |||
| sum_1_mul_two = graph_builder.emit('Mul', [sum_1, const_two]) | |||
| sum_1_mul_two_tmp = graph_builder.emit('Mul', [sum_1_mul_two, mean_cof]) | |||
| dx_2 = graph_builder.emit('Mul', [sum_1_mul_two_tmp, x_sub_mean]) | |||
| neg_rsqrt_var_eps = graph_builder.emit('Mul', [const_neg_one, rsqrt_var_eps]) | |||
| neg_rsqrt_var_eps_mul_sum_2 = graph_builder.emit('Mul', [neg_rsqrt_var_eps, sum_2]) | |||
| sum_1_mul_sum_3 = graph_builder.emit('Mul', [sum_1, sum_3]) | |||
| mean_cof_mul_sum_1_mul_sum_3 = graph_builder.emit('Mul', [mean_cof, sum_1_mul_sum_3]) | |||
| add_tmp = graph_builder.emit('Add', [neg_rsqrt_var_eps_mul_sum_2, mean_cof_mul_sum_1_mul_sum_3]) | |||
| dx_3 = graph_builder.emit('Mul', [add_tmp, mean_cof]) | |||
| dx_tmp = graph_builder.emit('Add', [dx_1, dx_2]) | |||
| dx = graph_builder.emit('Add', [dx_tmp, dx_3]) | |||
| tmp_mul = graph_builder.emit('Mul', [dy_mul_gamma, x_sub_mean]) | |||
| padvar_mul1 = graph_builder.emit('ReduceSum', [tmp_mul], attrs={'reduce_axis': norm_axis, 'keep_dims': True}) | |||
| padvar_mul3 = graph_builder.emit('Mul', [padvar_mul1, r_tmp_var_eps]) | |||
| pd_var = graph_builder.emit('Mul', [padvar_mul3, const_neg_half]) | |||
| # pd_mean | |||
| pdmean1_sum = graph_builder.emit('ReduceSum', [dy_mul_gamma], | |||
| attrs={'reduce_axis': norm_axis, 'keep_dims': True}) | |||
| neg_rsqrt_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, const_neg_one]) | |||
| pd_mean_1 = graph_builder.emit('Mul', [neg_rsqrt_var_eps, pdmean1_sum]) | |||
| pdmean2_mul1 = graph_builder.emit('Mul', [const_neg_two, x_sub_mean]) | |||
| pdmean2_sum = graph_builder.emit('ReduceSum', [pdmean2_mul1], | |||
| attrs={'reduce_axis': norm_axis, 'keep_dims': True}) | |||
| pdmean2_mul3 = graph_builder.emit('Mul', [pdmean2_sum, mean_cof]) | |||
| pd_mean_2 = graph_builder.emit('Mul', [pdmean2_mul3, pd_var]) | |||
| pd_mean = graph_builder.emit('Add', [pd_mean_1, pd_mean_2]) | |||
| # cal dx | |||
| pd_x_1 = graph_builder.emit('Mul', [dy_mul_gamma, rsqrt_var_eps]) | |||
| pdx2_mul = graph_builder.emit('Mul', [pd_var, x_sub_mean]) | |||
| pdx2_mul_two = graph_builder.emit('Mul', [pdx2_mul, const_two]) | |||
| pd_x_2 = graph_builder.emit('Mul', [pdx2_mul_two, mean_cof]) | |||
| pd_x_3 = graph_builder.emit('Mul', [pd_mean, mean_cof]) | |||
| dx_tmp = graph_builder.emit('Add', [pd_x_1, pd_x_2]) | |||
| dx = graph_builder.emit('Add', [dx_tmp, pd_x_3]) | |||
| if processor == 'aicore' and ori_dtype == 'float16': | |||
| dx = graph_builder.emit('Cast', [dx], attrs={'dst_type': 'float16'}) | |||
| dg = graph_builder.emit('Cast', [dg], attrs={'dst_type': 'float16'}) | |||
| db = graph_builder.emit('Cast', [db], attrs={'dst_type': 'float16'}) | |||
| return dx, dg, db | |||
| @@ -47,6 +47,8 @@ std::vector<PrimitivePtr> GetExpandOps() { | |||
| prim::kPrimSquare, | |||
| prim::kPrimGeLUGrad, | |||
| prim::kPrimAssignAdd, | |||
| prim::kPrimLayerNorm, | |||
| prim::kPrimLayerNormGrad, | |||
| #if ENABLE_D | |||
| prim::kPrimTile, | |||
| prim::kPrimSqrtGrad, | |||
| @@ -67,8 +69,6 @@ std::vector<PrimitivePtr> GetExpandOps() { | |||
| prim::kPrimDropout, | |||
| prim::kPrimDropoutGrad, | |||
| prim::kPrimSoftmax, | |||
| prim::kPrimLayerNorm, | |||
| prim::kPrimLayerNormGrad, | |||
| prim::kPrimRelu, | |||
| prim::kPrimReluGrad, | |||
| prim::kPrimSigmoid, | |||
| @@ -54,7 +54,6 @@ | |||
| #include "debug/data_dump/dump_json_parser.h" | |||
| #include "debug/tensor_load.h" | |||
| #include "debug/anf_ir_utils.h" | |||
| #include "backend/optimizer/graph_kernel/shape_ops_splitter.h" | |||
| #include "backend/optimizer/graph_kernel/graph_kernel_optimization.h" | |||
| #include "backend/session/ascend_auto_monad.h" | |||
| #include "debug/data_dump/e2e_dump_util.h" | |||
| @@ -223,12 +223,12 @@ def test_bert_precision(enable_graph_kernel=False): | |||
| # assertion occurs while the loss value, overflow state or loss_scale value is wrong | |||
| loss_value = np.array(callback.loss_list) | |||
| assert np.allclose(loss_value[0], 12.2066, 0, 0.0005) | |||
| if enable_graph_kernel: | |||
| expect_loss_value = [12.206627, 11.840489, 11.798470, 11.796345, 11.790964, 12.366766, 11.971539, 12.576565, | |||
| 12.185522, 12.386192] | |||
| else: | |||
| assert np.allclose(loss_value[0], 12.2066, 0, 0.0005) | |||
| expect_loss_value = [12.206587, 11.966410, 11.965916, 11.975922, 11.970262, 12.608881, 12.174048, 12.840656, | |||
| 12.407923, 12.631133] | |||
| print("loss value: {}".format(loss_value)) | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -13,150 +13,145 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import copy | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| from mindspore import Tensor | |||
| from mindspore.nn import Cell | |||
| import mindspore.nn as nn | |||
| from mindspore.ops.operations import _grad_ops as G | |||
| import mindspore.ops.operations as P | |||
| class Net(Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.layernorm = P.LayerNorm(1, 1) | |||
| class LayerNormNet(nn.Cell): | |||
| def __init__(self, begin_norm_axis, begin_params_axis): | |||
| super(LayerNormNet, self).__init__() | |||
| self.layernorm = P.LayerNorm(begin_norm_axis, begin_params_axis) | |||
| def construct(self, x, y, z): | |||
| return self.layernorm(x, y, z) | |||
| def construct(self, x, gamma, beta): | |||
| return self.layernorm(x, gamma, beta) | |||
| class LayerNormGradNet(nn.Cell): | |||
| def __init__(self, begin_norm_axis, begin_params_axis): | |||
| super(LayerNormGradNet, self).__init__() | |||
| self.norm = G.LayerNormGrad(begin_norm_axis, begin_params_axis) | |||
| self.layernorm_grad = G.LayerNormGrad(begin_norm_axis, begin_params_axis) | |||
| def construct(self, dy, x, var, mean, gamma): | |||
| return self.norm(dy, x, var, mean, gamma) | |||
| def LayerNormGradReference(x, dy, gamma, epsilon, begin_norm_axis, begin_params_axis): | |||
| begin_norm_axis = begin_norm_axis if begin_norm_axis >= 0 else begin_norm_axis + len(x.shape) | |||
| begin_params_axis = begin_params_axis if begin_params_axis >= 0 else begin_params_axis + len(x.shape) | |||
| norm_axis = [i for i in range(begin_norm_axis, len(x.shape))] | |||
| param_axis = [i for i in range(0, begin_params_axis)] | |||
| num = 1 | |||
| for i in range(begin_norm_axis, len(x.shape)): | |||
| num *= x.shape[i] | |||
| mean = np.mean(x, axis=tuple(norm_axis), keepdims=True) | |||
| var = np.var(x, axis=tuple(norm_axis), keepdims=True) | |||
| gamma = gamma.reshape((*((1,) * begin_params_axis), *x.shape[begin_params_axis:])) | |||
| dg = np.sum(dy * np.power(var + epsilon, -0.5) * (x - mean), axis=tuple(param_axis), keepdims=True) | |||
| db = np.sum(dy, axis=tuple(param_axis), keepdims=True) | |||
| sum1 = np.sum((-0.5) * dy * gamma * (x - mean) * np.power(var + epsilon, -1.5), axis=tuple(norm_axis), | |||
| keepdims=True) | |||
| sum2 = np.sum(dy * gamma, axis=tuple(norm_axis), keepdims=True) | |||
| sum3 = np.sum(-2.0 * (x - mean), axis=tuple(norm_axis), keepdims=True) | |||
| dx1 = dy * gamma * np.power(var + epsilon, -0.5) | |||
| dx2 = sum1 * 2.0 / num * (x - mean) | |||
| dx3 = ((-1.0) * np.power(var + epsilon, -0.5) * sum2 + (1.0 / num) * sum1 * sum3) * (1.0 / num) | |||
| dx = dx1 + dx2 + dx3 | |||
| return dx, dg, db, mean, var | |||
| def test_basic(): | |||
| input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32) | |||
| gamma = np.random.normal(0, 1, [3, 4, 3]).astype(np.float32) | |||
| beta = np.random.normal(0, 1, [3, 4, 3]).astype(np.float32) | |||
| shape_x = [2, 3, 4, 3] | |||
| begin_norm_axis = 1 | |||
| in_rank = len(shape_x) | |||
| if begin_norm_axis < 0: | |||
| norm_axis = begin_norm_axis + in_rank | |||
| else: | |||
| norm_axis = begin_norm_axis | |||
| norm_axes = tuple(range(norm_axis, in_rank)) | |||
| mean = np.mean(input_x, axis=norm_axes, keepdims=True) | |||
| mean_b = np.broadcast_to(mean, shape_x) | |||
| diff = input_x - mean_b | |||
| square = np.square(diff) | |||
| smean = np.mean(square, axis=norm_axes, keepdims=True) | |||
| smean_b = np.broadcast_to(smean, shape_x) | |||
| meps = smean_b + 1e-5 | |||
| logs = np.log(meps) | |||
| mul = logs * (-0.5) | |||
| rsqrt = np.exp(mul) | |||
| out = diff * rsqrt | |||
| bn = out * gamma + beta | |||
| expect = (bn, mean, smean) | |||
| net = Net() | |||
| net_result = net(Tensor(input_x), Tensor(gamma), Tensor(beta)) | |||
| if isinstance(net_result, tuple) and len(net_result) == 3: | |||
| result = (net_result[0].asnumpy(), net_result[1].asnumpy(), net_result[2].asnumpy()) | |||
| res0 = np.allclose(expect[0], result[0], rtol=1.e-4, atol=1.e-4, equal_nan=True) | |||
| assert res0 | |||
| res1 = np.allclose(expect[1], result[1], rtol=1.e-4, atol=1.e-7, equal_nan=True) | |||
| assert res1 | |||
| res2 = np.allclose(expect[2], result[2], rtol=1.e-4, atol=1.e-7, equal_nan=True) | |||
| assert res2 | |||
| return self.layernorm_grad(dy, x, var, mean, gamma) | |||
| def get_layernorm_output(x, gamma, beta, begin_norm_axis, begin_params_axis, enable_graph_kernel=False): | |||
| context.set_context(enable_graph_kernel=enable_graph_kernel) | |||
| net = LayerNormNet(begin_norm_axis, begin_params_axis) | |||
| output = net(x, gamma, beta) | |||
| return output | |||
| def get_layernorm_grad_output(x, dy, var, mean, gamma, begin_norm_axis, begin_params_axis, enable_graph_kernel=False): | |||
| context.set_context(enable_graph_kernel=enable_graph_kernel) | |||
| net = LayerNormGradNet(begin_norm_axis, begin_params_axis) | |||
| output = net(x, dy, var, mean, gamma) | |||
| return output | |||
| def get_rtol_atol(dtype): | |||
| if dtype == np.float16: | |||
| return 1.e-3, 1.e-3 | |||
| return 1.e-4, 1.e-4 | |||
| def compare_result(expect, output, dtype): | |||
| rtol, atol = get_rtol_atol(dtype) | |||
| if isinstance(expect, (list, tuple)): | |||
| assert isinstance(output, (list, tuple)) and len(expect) == len(output) | |||
| expect_list = list(expect) | |||
| output_list = list(output) | |||
| for e, o in zip(expect_list, output_list): | |||
| assert np.allclose(e.asnumpy(), o.asnumpy(), rtol, atol, equal_nan=True) | |||
| else: | |||
| assert False | |||
| assert np.allclose(expect.asnumpy(), output.asnumpy(), rtol, atol, equal_nan=True) | |||
| def test_layernormgrad(): | |||
| def test_layernorm(shape, dtype, begin_norm_axis=-1, begin_params_axis=-1): | |||
| begin_norm_axis = begin_norm_axis if begin_norm_axis >= 0 else begin_norm_axis + len(shape) | |||
| begin_params_axis = begin_params_axis if begin_params_axis >= 0 else begin_params_axis + len(shape) | |||
| assert 0 <= begin_norm_axis < len(shape) | |||
| assert 0 <= begin_params_axis < len(shape) | |||
| normalized_shape = shape[begin_params_axis:] | |||
| np.random.seed(0) | |||
| begin_norm_axis = 1 | |||
| begin_params_axis = 1 | |||
| x_np = np.random.randn(4096, 3072).astype(np.float32) | |||
| dy_np = np.random.randn(4096, 3072).astype(np.float32) | |||
| gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32) | |||
| epsilon = 1e-11 | |||
| dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis, | |||
| begin_params_axis) | |||
| dy_ms = Tensor(dy_np) | |||
| x_ms = Tensor(x_np) | |||
| var_ms = Tensor(var_np) | |||
| mean_ms = Tensor(mean_np) | |||
| gamma_ms = Tensor(gamma_np) | |||
| # input tensors | |||
| x = Tensor(np.random.normal(0, 1, shape).astype(dtype)) | |||
| gamma = Tensor(np.random.normal(0, 1, normalized_shape).astype(dtype)) | |||
| beta = Tensor(np.random.normal(0, 1, normalized_shape).astype(dtype)) | |||
| expect = get_layernorm_output(x, gamma, beta, begin_norm_axis, begin_params_axis, False) | |||
| output = get_layernorm_output(x, gamma, beta, begin_norm_axis, begin_params_axis, True) | |||
| compare_result(expect, output, dtype) | |||
| net = LayerNormGradNet(begin_norm_axis, begin_params_axis) | |||
| dx_ms, dg_ms, db_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms) | |||
| assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-6, atol=1e-6) | |||
| assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-6, atol=1e-3) | |||
| assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-6, atol=1e-3) | |||
| def test_layernorm_grad(shape, dtype, begin_norm_axis=-1, begin_params_axis=-1): | |||
| begin_norm_axis = begin_norm_axis if begin_norm_axis >= 0 else begin_norm_axis + len(shape) | |||
| begin_params_axis = begin_params_axis if begin_params_axis >= 0 else begin_params_axis + len(shape) | |||
| assert 0 <= begin_norm_axis < len(shape) | |||
| assert 0 <= begin_params_axis < len(shape) | |||
| norm_axis = [i for i in range(begin_norm_axis, len(shape))] | |||
| norm_shape = copy.deepcopy(shape) | |||
| for i, _ in enumerate(norm_shape): | |||
| if i in norm_axis: | |||
| norm_shape[i] = 1 | |||
| params_shape = shape[begin_params_axis:] | |||
| np.random.seed(0) | |||
| # input tensors | |||
| dy = Tensor(np.random.normal(0, 1, shape).astype(dtype)) | |||
| x = Tensor(np.random.normal(0, 1, shape).astype(dtype)) | |||
| var = Tensor(np.random.normal(0, 1, norm_shape).astype(dtype)) | |||
| mean = Tensor(np.random.normal(0, 1, norm_shape).astype(dtype)) | |||
| gamma = Tensor(np.random.normal(0, 1, params_shape).astype(dtype)) | |||
| expect = get_layernorm_grad_output(x, dy, var, mean, gamma, begin_norm_axis, begin_params_axis, False) | |||
| output = get_layernorm_grad_output(x, dy, var, mean, gamma, begin_norm_axis, begin_params_axis, True) | |||
| compare_result(expect, output, dtype) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_basic_gpu(): | |||
| context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU") | |||
| test_basic() | |||
| def test_layernorm_gpu(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| test_layernorm([4, 32, 32], np.float32, -1, -1) | |||
| def test_basic_ascend(): | |||
| context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend") | |||
| test_basic() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.env_onecard | |||
| def test_layernorm_ascend(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| test_layernorm([4, 32, 32], np.float16, -1, -1) | |||
| test_layernorm([4, 32, 32], np.float32, -1, -1) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_layernormgrad_gpu(): | |||
| context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU") | |||
| test_layernormgrad() | |||
| def test_layernorm_grad_gpu(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| test_layernorm_grad([4, 32, 32], np.float32, -1, -1) | |||
| def test_layernormgrad_ascend(): | |||
| context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend") | |||
| test_layernormgrad() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.env_onecard | |||
| def test_layernorm_grad_ascend(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| test_layernorm_grad([2, 16, 32], np.float16, -1, -1) | |||
| test_layernorm_grad([4, 32, 32], np.float32, -1, -1) | |||