From: @wenfangpei Reviewed-by: @gaoxiong1,@ckey_dou Signed-off-by: @ckey_doupull/15412/MERGE
| @@ -143,3 +143,61 @@ class ExpanderInfoValidator: | |||
| return cls | |||
| return wrapper | |||
| def to_frac_z_axis(ori_shape, ori_axis): | |||
| """ | |||
| judge the format is fractal NZ | |||
| Parameters | |||
| ---------- | |||
| ori_shape: list or tuple | |||
| original shape of input | |||
| ori_axis: list or tuple | |||
| original axis of original shape to operate | |||
| Returns | |||
| ------- | |||
| output: list | |||
| axis of the fractal Nz shape | |||
| """ | |||
| frac_z_axis = list(ori_axis) | |||
| shape_len = len(ori_shape) | |||
| axis_count = len(frac_z_axis) | |||
| axis_negative_1 = shape_len - 1 | |||
| axis_negative_2 = shape_len - 2 | |||
| for i in range(axis_count): | |||
| axis_index = (frac_z_axis[i] + shape_len) % shape_len | |||
| if axis_index == axis_negative_1: | |||
| if frac_z_axis[i] > shape_len - 2: # akg:[2,3] [1,4] tbe:[2,4] [1,3] | |||
| frac_z_axis[i] = axis_index - 1 | |||
| frac_z_axis.append(axis_index + 2) | |||
| else: # no case cover this branch now | |||
| frac_z_axis[i] = axis_index - 1 | |||
| frac_z_axis.append(axis_index + 2) | |||
| elif axis_index == axis_negative_2: | |||
| frac_z_axis[i] = axis_index + 1 | |||
| frac_z_axis.append(axis_index + 2) | |||
| else: | |||
| frac_z_axis[i] = axis_index | |||
| return frac_z_axis | |||
| def infer_shape_from_fractalNz(fractal): | |||
| "get original shape from fractalNz shape" | |||
| shape = [] | |||
| dims = len(fractal) | |||
| batch = dims - 4 | |||
| for i in range(batch): | |||
| shape.append(fractal[i]) | |||
| m = fractal[dims - 3] * fractal[dims - 2] | |||
| n = fractal[dims - 4] * fractal[dims - 1] | |||
| shape.append(m) | |||
| shape.append(n) | |||
| return shape | |||
| def get_reduced_ori_shape(shape, axis): | |||
| "get shape after reduced which is based on original shape" | |||
| reduced_ori_shape = [] | |||
| for i, value in enumerate(shape): | |||
| if i in axis: | |||
| reduced_ori_shape.append(1) | |||
| else: | |||
| reduced_ori_shape.append(value) | |||
| return reduced_ori_shape | |||
| @@ -15,78 +15,13 @@ | |||
| """generate json desc for LayerNorm""" | |||
| from mindspore._extends.graph_kernel.model.model import DataFormat as DF | |||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||
| from ._utils import infer_shape_from_fractalNz, get_reduced_ori_shape, to_frac_z_axis | |||
| @VLD.add_format(DF.FRAC_NZ, DF.DEFAULT, DF.DEFAULT) | |||
| @VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT) | |||
| @VLD.check_attrs('begin_norm_axis', 'begin_params_axis', 'epsilon') | |||
| class LayerNorm(Expander): | |||
| """LayerNorm expander""" | |||
| def to_frac_z_axis(self, ori_shape, ori_axis): | |||
| """ | |||
| judge the format is fractal NZ | |||
| Parameters | |||
| ---------- | |||
| ori_shape: list or tuple | |||
| original shape of input | |||
| ori_axis: list or tuple | |||
| original axis of original shape to operate | |||
| Returns | |||
| ------- | |||
| output: list | |||
| axis of the fractal Nz shape | |||
| """ | |||
| frac_z_axis = list(ori_axis) | |||
| shape_len = len(ori_shape) | |||
| axis_count = len(frac_z_axis) | |||
| axis_negative_1 = shape_len - 1 | |||
| axis_negative_2 = shape_len - 2 | |||
| for i in range(axis_count): | |||
| axis_index = (frac_z_axis[i] + shape_len) % shape_len | |||
| if axis_index == axis_negative_1: | |||
| if frac_z_axis[i] > shape_len - 2: # akg:[2,3] [1,4] tbe:[2,4] [1,3] | |||
| frac_z_axis[i] = axis_index - 1 | |||
| frac_z_axis.append(axis_index + 2) | |||
| else: # no case cover this branch now | |||
| frac_z_axis[i] = axis_index - 1 | |||
| frac_z_axis.append(axis_index + 2) | |||
| elif axis_index == axis_negative_2: | |||
| frac_z_axis[i] = axis_index + 1 | |||
| frac_z_axis.append(axis_index + 2) | |||
| else: | |||
| frac_z_axis[i] = axis_index | |||
| return frac_z_axis | |||
| def infer_shape_from_fractalNz(self, fractal): | |||
| "get original shape from fractalNz shape" | |||
| shape = [] | |||
| dims = len(fractal) | |||
| batch = dims - 4 | |||
| for i in range(batch): | |||
| shape.append(fractal[i]) | |||
| m = fractal[dims - 3] * fractal[dims - 2] | |||
| n = fractal[dims - 4] * fractal[dims - 1] | |||
| shape.append(m) | |||
| shape.append(n) | |||
| return shape | |||
| def get_reduced_ori_shape(self, shape, axis): | |||
| "get shape after reduced which is based on original shape" | |||
| reduced_ori_shape = [] | |||
| for i, value in enumerate(shape): | |||
| if i in axis: | |||
| reduced_ori_shape.append(1) | |||
| else: | |||
| reduced_ori_shape.append(value) | |||
| return reduced_ori_shape | |||
| def _expand(self, graph_builder): | |||
| input_x, input_gamma, input_beta = self.inputs | |||
| processor = self.processor | |||
| @@ -101,7 +36,7 @@ class LayerNorm(Expander): | |||
| ori_shape_x = input_x.shape | |||
| if input_x.data_format == DF.FRAC_NZ: | |||
| ori_shape_x = self.infer_shape_from_fractalNz(input_x.shape) | |||
| ori_shape_x = infer_shape_from_fractalNz(ori_shape_x) | |||
| # Calculate the scaling ratio of the average | |||
| if begin_norm_axis < 0: | |||
| @@ -116,10 +51,10 @@ class LayerNorm(Expander): | |||
| for i in reduce_axis: | |||
| reduce_elts *= ori_shape_x[i] | |||
| # after reduced | |||
| ori_reduced_shape_x = self.get_reduced_ori_shape(ori_shape_x, reduce_axis) | |||
| ori_reduced_shape_x = get_reduced_ori_shape(ori_shape_x, reduce_axis) | |||
| if input_x.data_format == DF.FRAC_NZ: | |||
| reduce_axis = self.to_frac_z_axis(ori_shape_x, reduce_axis) | |||
| reduce_axis = to_frac_z_axis(ori_shape_x, reduce_axis) | |||
| mean_cof = 1.0 / reduce_elts | |||
| mean_cof_v = graph_builder.value(input_x.dtype, mean_cof) | |||
| @@ -15,8 +15,9 @@ | |||
| """generate json desc for softmax""" | |||
| from mindspore._extends.graph_kernel.model.model import DataFormat as DF | |||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||
| from ._utils import infer_shape_from_fractalNz, get_reduced_ori_shape, to_frac_z_axis | |||
| @VLD.add_format(DF.FRAC_NZ) | |||
| @VLD.add_format(DF.DEFAULT) | |||
| @VLD.check_attrs('axis') | |||
| class Softmax(Expander): | |||
| @@ -24,12 +25,43 @@ class Softmax(Expander): | |||
| def _expand(self, graph_builder): | |||
| input_x = self.inputs[0] | |||
| processor = self.processor | |||
| axis = self.attrs['axis'] | |||
| max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True}) | |||
| ori_shape = input_x.shape | |||
| if input_x.data_format == DF.FRAC_NZ: | |||
| ori_shape = infer_shape_from_fractalNz(input_x.shape) | |||
| for i, _ in enumerate(list(axis)): | |||
| if axis[i] < 0: | |||
| axis[i] += len(ori_shape) | |||
| ori_reduced_shape = get_reduced_ori_shape(ori_shape, axis) | |||
| if input_x.data_format == DF.FRAC_NZ: | |||
| axis = to_frac_z_axis(ori_shape, axis) | |||
| ori_dtype = input_x.dtype | |||
| if ori_dtype != "float16" and processor == "aicore": | |||
| input_x_f16 = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float16'}) | |||
| max_x_f16 = graph_builder.emit('ReduceMax', [input_x_f16], attrs={'reduce_axis': axis, 'keep_dims': True}) | |||
| max_x = graph_builder.emit('Cast', [max_x_f16], attrs={'dst_type': ori_dtype}) | |||
| else: | |||
| max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True}) | |||
| if ori_dtype == "float16" and processor == "aicore": | |||
| max_x = graph_builder.emit('Cast', [max_x], attrs={'dst_type': "float32"}) | |||
| input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': "float32"}) | |||
| if input_x.data_format == DF.FRAC_NZ: | |||
| max_x = graph_builder.emit('Reshape', [max_x], attrs={'shape': ori_reduced_shape}) | |||
| data_sub = graph_builder.emit('Sub', [input_x, max_x]) | |||
| data_exp = graph_builder.emit('Exp', [data_sub]) | |||
| data_expsum = graph_builder.emit('ReduceSum', [data_exp], attrs={'reduce_axis': axis, 'keep_dims': True}) | |||
| if input_x.data_format == DF.FRAC_NZ: | |||
| data_expsum = graph_builder.emit('Reshape', [data_expsum], attrs={'shape': ori_reduced_shape}) | |||
| result = graph_builder.emit('RealDiv', [data_exp, data_expsum]) | |||
| if ori_dtype == "float16" and processor == "aicore": | |||
| result = graph_builder.emit('Cast', [result], attrs={'dst_type': ori_dtype}) | |||
| return result | |||
| @@ -106,6 +106,8 @@ class _Elemwise(OpInfer): | |||
| shape = (1,) | |||
| max_flatten_size = 1 | |||
| for t in self.inputs: | |||
| if t.data_format != DF.DEFAULT: | |||
| return t.shape | |||
| flatten_size = reduce(lambda x, y: x * y, t.shape) | |||
| if flatten_size >= max_flatten_size: | |||
| max_flatten_size = flatten_size | |||
| @@ -50,6 +50,9 @@ std::vector<PrimitivePtr> GetExpandOps() { | |||
| prim::kPrimLayerNorm, | |||
| prim::kPrimLayerNormGrad, | |||
| prim::kPrimExpandDims, | |||
| prim::kPrimBiasAddGrad, | |||
| prim::kPrimGeLU, | |||
| prim::kPrimSoftmax, | |||
| prim::kPrimTile, | |||
| #if ENABLE_D | |||
| prim::kPrimSqrtGrad, | |||
| @@ -58,8 +61,6 @@ std::vector<PrimitivePtr> GetExpandOps() { | |||
| prim::kLambApplyWeightAssign, | |||
| #elif ENABLE_GPU | |||
| prim::kPrimBiasAdd, | |||
| prim::kPrimBiasAddGrad, | |||
| prim::kPrimGeLU, | |||
| prim::kPrimFusedAdam, | |||
| prim::kPrimFusedAdamWeightDecay, | |||
| prim::kPrimBatchNorm, | |||
| @@ -69,7 +70,6 @@ std::vector<PrimitivePtr> GetExpandOps() { | |||
| prim::kPrimMinimumGrad, | |||
| prim::kPrimDropout, | |||
| prim::kPrimDropoutGrad, | |||
| prim::kPrimSoftmax, | |||
| prim::kPrimRelu, | |||
| prim::kPrimReluGrad, | |||
| prim::kPrimSigmoid, | |||
| @@ -0,0 +1,60 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.common.api import ms_function | |||
| from mindspore.ops.operations import _grad_ops as G | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.bias_add_grad = G.BiasAddGrad() | |||
| @ms_function | |||
| def construct(self, dout): | |||
| return self.bias_add_grad(dout) | |||
| def get_output(dout, enable_graph_kernel=False): | |||
| context.set_context(enable_graph_kernel=enable_graph_kernel) | |||
| opt = Net() | |||
| output = opt(Tensor(dout)) | |||
| return output | |||
| def test_bias_add_grad(shape, dtype): | |||
| np.random.seed(0) | |||
| dout = np.random.normal(0, 1, shape).astype(dtype) | |||
| expect = get_output(dout, False) | |||
| output = get_output(dout, True) | |||
| rtol = 1.e-4 | |||
| atol = 1.e-4 | |||
| if dtype == "float16": | |||
| rtol = 1.e-3 | |||
| atol = 1.e-3 | |||
| assert np.allclose(expect.asnumpy(), output.asnumpy(), rtol, atol, equal_nan=True) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.env_onecard | |||
| def test_bias_add_grad_ascend(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| test_bias_add_grad([2, 32, 48, 64], np.float32) | |||
| test_bias_add_grad([2, 32, 48, 64], np.float16) | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -40,40 +40,47 @@ class GeluGradNet(Cell): | |||
| return self.gelu_grad(dy, x, y) | |||
| def CalGelu(x): | |||
| def cal_gelu(x): | |||
| tmp = np.sqrt(2.0 / np.pi) * (x + 0.044715 * x * x * x) | |||
| expect = 0.5 * x * (1.0 + np.tanh(tmp)) | |||
| return expect | |||
| def gelu(x, enable_graph_kernel=False): | |||
| context.set_context(enable_graph_kernel=enable_graph_kernel) | |||
| net = GeluNet() | |||
| result = net(Tensor(x)) | |||
| return result | |||
| def test_gelu(): | |||
| np.random.seed(0) | |||
| input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32) | |||
| net = GeluNet() | |||
| result = net(Tensor(input_x)) | |||
| expect = CalGelu(input_x) | |||
| res = np.allclose(expect, result.asnumpy(), rtol=1.e-4, atol=1.e-7, equal_nan=True) | |||
| expect = gelu(input_x, False) | |||
| result = gelu(input_x, True) | |||
| res = np.allclose(expect.asnumpy(), result.asnumpy(), rtol=1.e-4, atol=1.e-4, equal_nan=True) | |||
| assert res | |||
| def cal_gelu_grad(): | |||
| tanh_res = np.tanh(0.7978845608 * (input_x + 0.044715 * input_x * input_x * input_x)) | |||
| mul_right = 0.7978845608 + 0.1070322244 * input_x * input_x | |||
| dx = 0.5 * (1.0 + tanh_res) + 0.5 * input_x * (1.0 - tanh_res * tanh_res) * mul_right | |||
| expect = input_dy * dx | |||
| return expect | |||
| def gelu_grad(input_dy, input_x, input_y, enable_graph_kernel=False): | |||
| context.set_context(enable_graph_kernel=enable_graph_kernel) | |||
| net = GeluGradNet() | |||
| result = net(Tensor(input_dy), Tensor(input_x), Tensor(input_y)) | |||
| return result | |||
| def test_gelu_grad(): | |||
| np.random.seed(0) | |||
| input_dy = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32) | |||
| input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32) | |||
| input_y = CalGelu(input_x) | |||
| input_y = cal_gelu(input_x) | |||
| net = GeluGradNet() | |||
| result = net(Tensor(input_dy), Tensor(input_x), Tensor(input_y)) | |||
| tanh_res = np.tanh(0.7978845608 * (input_x + 0.044715 * input_x * input_x * input_x)) | |||
| mul_right = 0.7978845608 + 0.1070322244 * input_x * input_x | |||
| dx = 0.5 * (1.0 + tanh_res) + 0.5 * input_x * (1.0 - tanh_res * tanh_res) * mul_right | |||
| expect = input_dy * dx | |||
| res = np.allclose(expect, result.asnumpy(), rtol=1.e-4, atol=1.e-7, equal_nan=True) | |||
| expect = gelu_grad(input_dy, input_x, input_y, False) | |||
| result = gelu_grad(input_dy, input_x, input_y, True) | |||
| res = np.allclose(expect.asnumpy(), result.asnumpy(), rtol=1.e-4, atol=1.e-4, equal_nan=True) | |||
| assert res | |||
| @@ -84,7 +91,10 @@ def test_gelu_gpu(): | |||
| context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU") | |||
| test_gelu() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.env_onecard | |||
| def test_gelu_ascend(): | |||
| context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend") | |||
| test_gelu() | |||
| @@ -97,7 +107,10 @@ def test_gelu_grad_gpu(): | |||
| context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU") | |||
| test_gelu_grad() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.env_onecard | |||
| def test_gelu_grad_ascend(): | |||
| context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend") | |||
| test_gelu_grad() | |||
| @@ -0,0 +1,64 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| class Net(nn.Cell): | |||
| def __init__(self, axis=-1): | |||
| super(Net, self).__init__() | |||
| self.Softmax = P.Softmax(axis) | |||
| def construct(self, x): | |||
| return self.Softmax(x) | |||
| def get_output(x, enable_graph_kernel=False): | |||
| context.set_context(enable_graph_kernel=enable_graph_kernel) | |||
| opt = Net() | |||
| output = opt(Tensor(x)) | |||
| return output | |||
| def test_softmax(shape, dtype): | |||
| np.random.seed(0) | |||
| x = np.random.normal(0, 1, shape).astype(dtype) | |||
| expect = get_output(x, False) | |||
| output = get_output(x, True) | |||
| rtol = 1.e-4 | |||
| atol = 1.e-4 | |||
| if dtype == "float16": | |||
| rtol = 1.e-3 | |||
| atol = 1.e-3 | |||
| assert np.allclose(expect.asnumpy(), output.asnumpy(), rtol, atol, equal_nan=True) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_softmax_gpu(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| test_softmax([4, 32, 48], np.float32) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.env_onecard | |||
| def test_softmax_ascend(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| test_softmax([2, 32, 48, 64], np.float32) | |||