From: @chenlei_autodiff Reviewed-by: Signed-off-by:pull/14085/MERGE
| @@ -21,6 +21,8 @@ from .clip_by_norm_no_div_sum import ClipByNormNoDivSum | |||||
| from .dropout_grad import DropoutGrad | from .dropout_grad import DropoutGrad | ||||
| from .fused_adam import FusedAdam | from .fused_adam import FusedAdam | ||||
| from .fused_adam_weight_decay import FusedAdamWeightDecay | from .fused_adam_weight_decay import FusedAdamWeightDecay | ||||
| from .batchnorm import BatchNorm | |||||
| from .batchnorm_grad import BatchNormGrad | |||||
| from .gelu import GeLU | from .gelu import GeLU | ||||
| from .gelu_grad import GeLUGrad | from .gelu_grad import GeLUGrad | ||||
| from .gkdropout import GkDropout | from .gkdropout import GkDropout | ||||
| @@ -31,6 +33,8 @@ from .logsoftmax_grad import LogSoftmaxGrad | |||||
| from .maximum_grad import MaximumGrad | from .maximum_grad import MaximumGrad | ||||
| from .minimum_grad import MinimumGrad | from .minimum_grad import MinimumGrad | ||||
| from .reduce_mean import ReduceMean | from .reduce_mean import ReduceMean | ||||
| from .relu import ReLU | |||||
| from .relu_grad import ReluGrad | |||||
| from .softmax import Softmax | from .softmax import Softmax | ||||
| from .sigmoid import Sigmoid | from .sigmoid import Sigmoid | ||||
| from .sigmoid_grad import SigmoidGrad | from .sigmoid_grad import SigmoidGrad | ||||
| @@ -0,0 +1,132 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # =========================================================================== | |||||
| """generate json desc for BatchNorm""" | |||||
| from mindspore._extends.graph_kernel.model.model import DataFormat as DF | |||||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||||
| @VLD.add_format(DF.NHWC, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT) | |||||
| @VLD.add_format(DF.NCHW, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT) | |||||
| @VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT) | |||||
| @VLD.check_attrs('is_training', 'momentum', 'epsilon') | |||||
| class BatchNorm(Expander): | |||||
| """BatchNorm expander""" | |||||
| def _expand(self, graph_builder): | |||||
| # get op info | |||||
| input_x = self.inputs[0] | |||||
| input_scale = self.inputs[1] | |||||
| input_offset = self.inputs[2] | |||||
| input_mean = self.inputs[3] | |||||
| input_variance = self.inputs[4] | |||||
| epsilon_v = graph_builder.value(input_scale.dtype, self.attrs['epsilon'], input_scale.data_format) | |||||
| if self.attrs['is_training']: | |||||
| reduce_axis = () | |||||
| shape_x = input_x.shape | |||||
| if input_x.data_format == "NHWC": | |||||
| reduce_axis = (0, 1, 2) | |||||
| num = shape_x[0] * shape_x[1] * shape_x[2] | |||||
| else: | |||||
| reduce_axis = (0, 2, 3) | |||||
| num = shape_x[0] * shape_x[2] * shape_x[3] | |||||
| num_rec = 1.0 / num | |||||
| num_rec_v = graph_builder.value(input_scale.dtype, num_rec, input_scale.data_format) | |||||
| # compute mean value of input_x | |||||
| mean_sum = graph_builder.emit( | |||||
| 'ReduceSum', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) | |||||
| mean_muls = graph_builder.emit('Mul', [mean_sum, num_rec_v]) | |||||
| # compute variance of input_x | |||||
| if not input_x.data_format == "NHWC": | |||||
| mean_muls_expand = graph_builder.emit('ExpandDims', [mean_muls], attrs={'axis': 1}) | |||||
| mean_muls_expand = graph_builder.emit('ExpandDims', [mean_muls_expand], attrs={'axis': 2}) | |||||
| else: | |||||
| mean_muls_expand = mean_muls | |||||
| var_sub = graph_builder.emit('Sub', [input_x, mean_muls_expand]) | |||||
| var_mul = graph_builder.emit('Mul', [var_sub, var_sub]) | |||||
| var_sum = graph_builder.emit('ReduceSum', [var_mul], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) | |||||
| var_mul = graph_builder.emit('Mul', [var_sum, num_rec_v]) | |||||
| # y_sqrt_rec means 1 / sqrt(variance + epsilon), which is calculated in backward pass | |||||
| scalar_one = 1.0 | |||||
| scalar_one_v = graph_builder.value(input_scale.dtype, scalar_one, input_scale.data_format) | |||||
| y_add = graph_builder.emit('Add', [var_mul, epsilon_v]) | |||||
| y_sqrt = graph_builder.emit('Sqrt', [y_add]) | |||||
| y_sqrt_rec = graph_builder.emit('RealDiv', [scalar_one_v, y_sqrt]) | |||||
| # compute res_y | |||||
| tmp_sub = graph_builder.emit('Sub', [input_x, mean_muls_expand]) | |||||
| if not input_x.data_format == "NHWC": | |||||
| y_sqrt_rec_expand = graph_builder.emit('ExpandDims', [y_sqrt_rec], attrs={'axis': 1}) | |||||
| y_sqrt_rec_expand = graph_builder.emit('ExpandDims', [y_sqrt_rec_expand], attrs={'axis': 2}) | |||||
| else: | |||||
| y_sqrt_rec_expand = y_sqrt_rec | |||||
| y_norm = graph_builder.emit('Mul', [tmp_sub, y_sqrt_rec_expand]) | |||||
| if not input_x.data_format == "NHWC": | |||||
| input_scale_expand = graph_builder.emit('ExpandDims', [input_scale], attrs={'axis': 1}) | |||||
| input_scale_expand = graph_builder.emit('ExpandDims', [input_scale_expand], attrs={'axis': 2}) | |||||
| else: | |||||
| input_scale_expand = input_scale | |||||
| res_y_mul = graph_builder.emit('Mul', [input_scale_expand, y_norm]) | |||||
| if not input_x.data_format == "NHWC": | |||||
| input_offset_expand = graph_builder.emit('ExpandDims', [input_offset], attrs={'axis': 1}) | |||||
| input_offset_expand = graph_builder.emit('ExpandDims', [input_offset_expand], attrs={'axis': 2}) | |||||
| else: | |||||
| input_offset_expand = input_offset | |||||
| res_y = graph_builder.emit('Add', [res_y_mul, input_offset_expand]) | |||||
| # compute mean_res | |||||
| momentum_sub = scalar_one - self.attrs['momentum'] | |||||
| momentum_v_sub = graph_builder.value(input_scale.dtype, momentum_sub, input_scale.data_format) | |||||
| new_running_mean_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_mean]) | |||||
| momentum_v = graph_builder.value(input_scale.dtype, self.attrs['momentum'], input_scale.data_format) | |||||
| current_mean_tmp = graph_builder.emit('Mul', [momentum_v, mean_muls]) | |||||
| updated_moving_mean = graph_builder.emit('Add', [new_running_mean_tmp, current_mean_tmp]) | |||||
| mean_res = graph_builder.emit( | |||||
| 'InplaceAssign', [input_mean, updated_moving_mean, updated_moving_mean], attrs={'fake_output': True}) | |||||
| # variance_res is calculated by sample variance, and need to multiply by num / (num - 1) | |||||
| var_num = float(num) / (num - 1) | |||||
| var_num_v = graph_builder.value(input_scale.dtype, var_num, input_scale.data_format) | |||||
| var_mul_update = graph_builder.emit('Mul', [var_num_v, var_mul]) | |||||
| new_running_var_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_variance]) | |||||
| current_var_tmp = graph_builder.emit('Mul', [momentum_v, var_mul_update]) | |||||
| updated_moving_variance = graph_builder.emit('Add', [new_running_var_tmp, current_var_tmp]) | |||||
| variance_res = graph_builder.emit( | |||||
| 'InplaceAssign', [input_variance, updated_moving_variance, updated_moving_variance], | |||||
| attrs={'fake_output': True}) | |||||
| # compute reverse, just return a C shape tensor | |||||
| reserve = graph_builder.emit('Add', [input_offset, scalar_one_v]) | |||||
| return res_y, mean_res, variance_res, mean_muls, y_sqrt_rec, reserve | |||||
| # infer mode | |||||
| if not input_x.data_format == "NHWC": | |||||
| input_mean = graph_builder.emit('ExpandDims', [input_mean], attrs={'axis': 1}) | |||||
| input_mean = graph_builder.emit('ExpandDims', [input_mean], attrs={'axis': 2}) | |||||
| input_scale = graph_builder.emit('ExpandDims', [input_scale], attrs={'axis': 1}) | |||||
| input_scale = graph_builder.emit('ExpandDims', [input_scale], attrs={'axis': 2}) | |||||
| input_offset = graph_builder.emit('ExpandDims', [input_offset], attrs={'axis': 1}) | |||||
| input_offset = graph_builder.emit('ExpandDims', [input_offset], attrs={'axis': 2}) | |||||
| x_sub = graph_builder.emit('Sub', [input_x, input_mean]) | |||||
| x_sub_mul = graph_builder.emit('Mul', [input_scale, x_sub]) | |||||
| var_add = graph_builder.emit('Add', [epsilon_v, input_variance]) | |||||
| var_add_sqrt = graph_builder.emit('Sqrt', [var_add]) | |||||
| if not input_x.data_format == "NHWC": | |||||
| var_add_sqrt = graph_builder.emit('ExpandDims', [var_add_sqrt], attrs={'axis': 1}) | |||||
| var_add_sqrt = graph_builder.emit('ExpandDims', [var_add_sqrt], attrs={'axis': 2}) | |||||
| x_div = graph_builder.emit('RealDiv', [x_sub_mul, var_add_sqrt]) | |||||
| res_y = graph_builder.emit('Add', [input_offset, x_div]) | |||||
| return res_y, var_add, var_add, var_add, var_add | |||||
| @@ -0,0 +1,102 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # =========================================================================== | |||||
| """generate json desc for BatchNormGrad""" | |||||
| from mindspore._extends.graph_kernel.model.model import DataFormat as DF | |||||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||||
| @VLD.add_format(DF.NHWC, DF.NHWC, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT) | |||||
| @VLD.add_format(DF.NCHW, DF.NCHW, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT) | |||||
| @VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT) | |||||
| @VLD.check_attrs('is_training', 'epsilon') | |||||
| class BatchNormGrad(Expander): | |||||
| """BatchNormGrad expander""" | |||||
| def _expand(self, graph_builder): | |||||
| # get op info | |||||
| input_dy = self.inputs[0] | |||||
| input_x = self.inputs[1] | |||||
| input_scale = self.inputs[2] | |||||
| input_save_mean = self.inputs[3] | |||||
| input_save_inv_variance = self.inputs[4] | |||||
| reduce_axis = () | |||||
| shape_x = input_x.shape | |||||
| if input_x.data_format == "NHWC": | |||||
| reduce_axis = (0, 1, 2) | |||||
| num = shape_x[0] * shape_x[1] * shape_x[2] | |||||
| else: | |||||
| reduce_axis = (0, 2, 3) | |||||
| num = shape_x[0] * shape_x[2] * shape_x[3] | |||||
| ori_type = input_x.dtype | |||||
| if ori_type == 'float16': | |||||
| input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float32'}) | |||||
| if input_dy.dtype == 'float16': | |||||
| input_dy = graph_builder.emit('Cast', [input_dy], attrs={'dst_type': 'float32'}) | |||||
| num_rec = -1.0 / num | |||||
| num_rec_v = graph_builder.value(input_scale.dtype, num_rec, input_scale.data_format) | |||||
| dbeta = graph_builder.emit('ReduceSum', [input_dy], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) | |||||
| # in training input_save_inv_variance means 1 / sqrt(variance + epsilon), which is calculated in forward pass | |||||
| if self.attrs['is_training']: | |||||
| inv_variance = input_save_inv_variance | |||||
| else: | |||||
| epsilon_v = graph_builder.value(input_scale.dtype, self.attrs['epsilon'], input_scale.data_format) | |||||
| var_add = graph_builder.emit('Add', [input_save_inv_variance, epsilon_v]) | |||||
| sqrt_var_eps = graph_builder.emit('Sqrt', [var_add]) | |||||
| scalar_one = 1.0 | |||||
| scalar_one_v = graph_builder.value(input_scale.dtype, scalar_one, input_scale.data_format) | |||||
| inv_variance = graph_builder.emit('RealDiv', [scalar_one_v, sqrt_var_eps]) | |||||
| # compute dgamma | |||||
| if not input_x.data_format == "NHWC": | |||||
| input_save_mean = graph_builder.emit('ExpandDims', [input_save_mean], attrs={'axis': 1}) | |||||
| input_save_mean = graph_builder.emit('ExpandDims', [input_save_mean], attrs={'axis': 2}) | |||||
| inv_variance = graph_builder.emit('ExpandDims', [inv_variance], attrs={'axis': 1}) | |||||
| inv_variance = graph_builder.emit('ExpandDims', [inv_variance], attrs={'axis': 2}) | |||||
| input_scale = graph_builder.emit('ExpandDims', [input_scale], attrs={'axis': 1}) | |||||
| input_scale = graph_builder.emit('ExpandDims', [input_scale], attrs={'axis': 2}) | |||||
| x_sub_mean = graph_builder.emit('Sub', [input_x, input_save_mean]) | |||||
| x_div = graph_builder.emit('Mul', [x_sub_mean, inv_variance]) | |||||
| dgamma_param = graph_builder.emit('Mul', [input_dy, x_div]) | |||||
| dgamma = graph_builder.emit( | |||||
| 'ReduceSum', [dgamma_param], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) | |||||
| # compute dx | |||||
| if self.attrs['is_training']: | |||||
| tmp_b = graph_builder.emit('Mul', [num_rec_v, dbeta]) | |||||
| if not input_x.data_format == "NHWC": | |||||
| dgamma_expand = graph_builder.emit('ExpandDims', [dgamma], attrs={'axis': 1}) | |||||
| dgamma_expand = graph_builder.emit('ExpandDims', [dgamma_expand], attrs={'axis': 2}) | |||||
| tmp_b = graph_builder.emit('ExpandDims', [tmp_b], attrs={'axis': 1}) | |||||
| tmp_b = graph_builder.emit('ExpandDims', [tmp_b], attrs={'axis': 2}) | |||||
| else: | |||||
| dgamma_expand = dgamma | |||||
| x_sub_mean_dgamma_mul = graph_builder.emit('Mul', [x_div, dgamma_expand]) | |||||
| tmp_c = graph_builder.emit('Mul', [num_rec_v, x_sub_mean_dgamma_mul]) | |||||
| tmp_ab_add = graph_builder.emit('Add', [input_dy, tmp_b]) | |||||
| tmp_abc_add = graph_builder.emit('Add', [tmp_ab_add, tmp_c]) | |||||
| gamma_mul = graph_builder.emit('Mul', [input_scale, tmp_abc_add]) | |||||
| dx = graph_builder.emit('Mul', [inv_variance, gamma_mul]) | |||||
| else: | |||||
| y_scale = graph_builder.emit('Mul', [input_scale, input_dy]) | |||||
| dx = graph_builder.emit('Mul', [inv_variance, y_scale]) | |||||
| if ori_type == 'float16': | |||||
| dx = graph_builder.emit('Cast', [dx], attrs={'dst_type': 'float16'}) | |||||
| # set output tensors' data_format | |||||
| dx.data_format = self.outputs[0]['format'] | |||||
| dgamma.data_format = self.outputs[1]['format'] | |||||
| dbeta.data_format = self.outputs[2]['format'] | |||||
| return dx, dgamma, dbeta | |||||
| @@ -0,0 +1,30 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # =========================================================================== | |||||
| """generate json desc for relu""" | |||||
| from ._utils import Expander | |||||
| class ReLU(Expander): | |||||
| """ReLU expander""" | |||||
| def _expand(self, graph_builder): | |||||
| input_x = self.inputs[0] | |||||
| const_zero = graph_builder.value(input_x.dtype, 0) | |||||
| ge_result = graph_builder.emit('Greater', [input_x, const_zero]) | |||||
| ge_result = graph_builder.emit('Cast', [ge_result], attrs={'dst_type': input_x.dtype}) | |||||
| result = graph_builder.emit('Mul', [ge_result, input_x]) | |||||
| return result | |||||
| @@ -0,0 +1,32 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # =========================================================================== | |||||
| """generate json desc for relu_grad""" | |||||
| from ._utils import Expander, ExpanderInfoValidator as VLD | |||||
| @VLD.check_all_formats_same | |||||
| class ReluGrad(Expander): | |||||
| """ReLU expander""" | |||||
| def _expand(self, graph_builder): | |||||
| input_x = self.inputs[0] | |||||
| input_y = self.inputs[1] | |||||
| const_zero = graph_builder.value(input_y.dtype, 0) | |||||
| ge_result = graph_builder.emit('Greater', [input_y, const_zero]) | |||||
| ge_result = graph_builder.emit('Cast', [ge_result], attrs={'dst_type': input_x.dtype}) | |||||
| result = graph_builder.emit('Mul', [ge_result, input_x]) | |||||
| return result | |||||
| @@ -176,7 +176,9 @@ class Reshape(_Reshape): | |||||
| class ExpandDims(_Reshape): | class ExpandDims(_Reshape): | ||||
| def _infer_shape(self): | def _infer_shape(self): | ||||
| return list(self.inputs[0].shape).insert(self.attrs["axis"], 1) | |||||
| shape = list(self.inputs[0].shape) | |||||
| shape.insert(self.attrs["axis"], 1) | |||||
| return shape | |||||
| class Cast(_Elemwise): | class Cast(_Elemwise): | ||||
| @@ -52,6 +52,8 @@ std::vector<PrimitivePtr> GetExpandOps() { | |||||
| prim::kPrimGeLU, | prim::kPrimGeLU, | ||||
| prim::kPrimFusedAdam, | prim::kPrimFusedAdam, | ||||
| prim::kPrimFusedAdamWeightDecay, | prim::kPrimFusedAdamWeightDecay, | ||||
| prim::kPrimBatchNorm, | |||||
| prim::kPrimBatchNormGrad, | |||||
| prim::kPrimReduceMean, | prim::kPrimReduceMean, | ||||
| prim::kPrimMaximumGrad, | prim::kPrimMaximumGrad, | ||||
| prim::kPrimMinimumGrad, | prim::kPrimMinimumGrad, | ||||
| @@ -60,6 +62,8 @@ std::vector<PrimitivePtr> GetExpandOps() { | |||||
| prim::kPrimSoftmax, | prim::kPrimSoftmax, | ||||
| prim::kPrimLayerNorm, | prim::kPrimLayerNorm, | ||||
| prim::kPrimLayerNormGrad, | prim::kPrimLayerNormGrad, | ||||
| prim::kPrimRelu, | |||||
| prim::kPrimReluGrad, | |||||
| prim::kPrimSigmoid, | prim::kPrimSigmoid, | ||||
| prim::kPrimSigmoidGrad, | prim::kPrimSigmoidGrad, | ||||
| prim::kPrimSigmoidCrossEntropyWithLogits, | prim::kPrimSigmoidCrossEntropyWithLogits, | ||||
| @@ -0,0 +1,84 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore.common.tensor import Tensor | |||||
| from mindspore.common.parameter import Parameter | |||||
| from mindspore.ops import operations as P | |||||
| class Net(nn.Cell): | |||||
| def __init__(self, input_scale, input_bias, input_mean, input_variance, is_training): | |||||
| super(Net, self).__init__() | |||||
| self.fused_bn_ex = P.BatchNorm(is_training=is_training, epsilon=1e-5, momentum=0.9) | |||||
| self.scale = Parameter(input_scale, name='scale') | |||||
| self.bias = Parameter(input_bias, name='b') | |||||
| self.mean = Parameter(input_mean, name='mean') | |||||
| self.variance = Parameter(input_variance, name='variance') | |||||
| def construct(self, input_x): | |||||
| return self.fused_bn_ex(input_x, self.scale, self.bias, self.mean, self.variance) | |||||
| def get_output(x, weight, bias, moving_mean, moving_var, is_training, enable_graph_kernel=False): | |||||
| if enable_graph_kernel: | |||||
| context.set_context(enable_graph_kernel=True) | |||||
| net = Net(Tensor(weight), Tensor(bias), Tensor(moving_mean), Tensor(moving_var), is_training) | |||||
| output = net(Tensor(x)) | |||||
| return output, net.mean, net.variance | |||||
| def test_bn_train(): | |||||
| x = np.random.normal(0, 1, [1, 2, 4, 4]).astype(np.float32) | |||||
| weight = np.random.normal(0, 1, [2,]).astype(np.float32) | |||||
| bias = np.random.normal(0, 1, [2,]).astype(np.float32) | |||||
| moving_mean = np.random.normal(0, 1, [2,]).astype(np.float32) | |||||
| moving_var = np.random.normal(0, 1, [2,]).astype(np.float32) | |||||
| train_expect = get_output(x, weight, bias, moving_mean, moving_var, True, False) | |||||
| train_output = get_output(x, weight, bias, moving_mean, moving_var, True, True) | |||||
| assert np.allclose(train_expect[0][0].asnumpy(), train_output[0][0].asnumpy(), 0.0001, 0.0001) | |||||
| assert np.allclose(train_expect[0][3].asnumpy(), train_output[0][3].asnumpy(), 0.0001, 0.0001) | |||||
| assert np.allclose(train_expect[0][4].asnumpy(), train_output[0][4].asnumpy(), 0.0001, 0.0001) | |||||
| assert np.allclose(train_expect[1].data.asnumpy(), train_output[1].data.asnumpy(), 0.0001, 0.0001) | |||||
| assert np.allclose(train_expect[2].data.asnumpy(), train_output[2].data.asnumpy(), 0.0001, 0.0001) | |||||
| def test_bn_infer(): | |||||
| x = np.random.normal(5, 1, [1, 2, 4, 4]).astype(np.float32) | |||||
| weight = np.random.normal(5, 1, [2,]).astype(np.float32) | |||||
| bias = np.random.normal(5, 1, [2,]).astype(np.float32) | |||||
| moving_mean = np.random.normal(5, 1, [2,]).astype(np.float32) | |||||
| moving_var = np.random.normal(5, 1, [2,]).astype(np.float32) | |||||
| infer_expect = get_output(x, weight, bias, moving_mean, moving_var, False, False) | |||||
| infer_output = get_output(x, weight, bias, moving_mean, moving_var, False, True) | |||||
| assert np.allclose(infer_expect[0][0].asnumpy(), infer_output[0][0].asnumpy(), 0.0001, 0.0001) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_bn_train_gpu(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| test_bn_train() | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_bn_infer_gpu(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| test_bn_infer() | |||||
| @@ -0,0 +1,87 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore.common.tensor import Tensor | |||||
| from mindspore.ops.operations import _grad_ops as G | |||||
| class Net(nn.Cell): | |||||
| def __init__(self, is_training): | |||||
| super(Net, self).__init__() | |||||
| self.fused_bn_grad_ex = G.BatchNormGrad(is_training=is_training, epsilon=1e-5) | |||||
| def construct(self, input_dy, input_x, input_scale, input_save_mean, input_save_inv_variance, input_reverse): | |||||
| return self.fused_bn_grad_ex( | |||||
| input_dy, input_x, input_scale, input_save_mean, input_save_inv_variance, input_reverse) | |||||
| def get_output(input_dy, input_x, input_scale, input_save_mean, input_save_inv_variance, input_reverse, | |||||
| is_training, enable_graph_kernel=False): | |||||
| if enable_graph_kernel: | |||||
| context.set_context(enable_graph_kernel=True) | |||||
| net = Net(is_training) | |||||
| output = net(input_dy, input_x, input_scale, input_save_mean, input_save_inv_variance, input_reverse) | |||||
| return output | |||||
| def test_bn_grad_train(): | |||||
| input_dy = Tensor(np.random.normal(5, 1, [1, 2, 4, 4]).astype(np.float32)) | |||||
| input_x = Tensor(np.random.normal(5, 1, [1, 2, 4, 4]).astype(np.float32)) | |||||
| input_scale = Tensor(np.random.normal(5, 1, [2,]).astype(np.float32)) | |||||
| input_save_mean = Tensor(np.random.normal(5, 1, [2,]).astype(np.float32)) | |||||
| input_save_inv_variance = Tensor(np.random.normal(5, 1, [2,]).astype(np.float32)) | |||||
| input_reverse = Tensor(np.random.normal(5, 1, [2,]).astype(np.float32)) | |||||
| expect = get_output( | |||||
| input_dy, input_x, input_scale, input_save_mean, input_save_inv_variance, input_reverse, True, False) | |||||
| output = get_output( | |||||
| input_dy, input_x, input_scale, input_save_mean, input_save_inv_variance, input_reverse, True, True) | |||||
| assert np.allclose(expect[0].asnumpy(), output[0].asnumpy(), 0.0001, 0.0001) | |||||
| assert np.allclose(expect[1].asnumpy(), output[1].asnumpy(), 0.0001, 0.0001) | |||||
| assert np.allclose(expect[2].asnumpy(), output[2].asnumpy(), 0.0001, 0.0001) | |||||
| def test_bn_grad_infer(): | |||||
| input_dy = Tensor(np.random.normal(5, 1, [1, 2, 4, 4]).astype(np.float32)) | |||||
| input_x = Tensor(np.random.normal(5, 1, [1, 2, 4, 4]).astype(np.float32)) | |||||
| input_scale = Tensor(np.random.normal(5, 1, [2,]).astype(np.float32)) | |||||
| input_save_mean = Tensor(np.random.normal(5, 1, [2,]).astype(np.float32)) | |||||
| input_save_inv_variance = Tensor(np.random.normal(5, 1, [2,]).astype(np.float32)) | |||||
| input_reverse = Tensor(np.random.normal(5, 1, [2,]).astype(np.float32)) | |||||
| expect = get_output( | |||||
| input_dy, input_x, input_scale, input_save_mean, input_save_inv_variance, input_reverse, False, False) | |||||
| output = get_output( | |||||
| input_dy, input_x, input_scale, input_save_mean, input_save_inv_variance, input_reverse, False, True) | |||||
| assert np.allclose(expect[0].asnumpy(), output[0].asnumpy(), 0.0001, 0.0001) | |||||
| assert np.allclose(expect[1].asnumpy(), output[1].asnumpy(), 0.0001, 0.0001) | |||||
| assert np.allclose(expect[2].asnumpy(), output[2].asnumpy(), 0.0001, 0.0001) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_bn_grad_train_gpu(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| test_bn_grad_train() | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_bn_grad_infer_gpu(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| test_bn_grad_train() | |||||
| @@ -0,0 +1,61 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.relu = P.ReLU() | |||||
| def construct(self, x): | |||||
| return self.relu(x) | |||||
| def get_output(x, enable_graph_kernel=False): | |||||
| if enable_graph_kernel: | |||||
| context.set_context(enable_graph_kernel=True) | |||||
| net = Net() | |||||
| output = net(x) | |||||
| return output | |||||
| def test_relu(shape, dtype): | |||||
| x = Tensor(np.random.normal(0, 10, shape).astype(dtype)) | |||||
| expect = get_output(x, False) | |||||
| output = get_output(x, True) | |||||
| expect_np = expect.asnumpy().copy() | |||||
| output_np = output.asnumpy().copy() | |||||
| assert np.allclose(expect_np, output_np, 0.0001, 0.0001) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_relu_gpu(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| test_relu((4, 3), np.int32) | |||||
| test_relu((12, 1), np.float16) | |||||
| def test_relu_ascend(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| test_relu((4, 3), np.int32) | |||||
| test_relu((12, 1), np.float16) | |||||
| @@ -0,0 +1,62 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops.operations import _grad_ops as G | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.relu_grad = G.ReluGrad() | |||||
| def construct(self, y_backprop, x): | |||||
| return self.relu_grad(y_backprop, x) | |||||
| def get_output(y_backprop, x, enable_graph_kernel=False): | |||||
| if enable_graph_kernel: | |||||
| context.set_context(enable_graph_kernel=True) | |||||
| net = Net() | |||||
| output = net(y_backprop, x) | |||||
| return output | |||||
| def test_relu_grad(shape1, shape2, dtype): | |||||
| x = Tensor(np.random.normal(0, 10, shape1).astype(dtype)) | |||||
| y_backprop = Tensor(np.random.normal(0, 10, shape2).astype(dtype)) | |||||
| expect = get_output(y_backprop, x, False) | |||||
| output = get_output(y_backprop, x, True) | |||||
| expect_np = expect.asnumpy().copy() | |||||
| output_np = output.asnumpy().copy() | |||||
| assert np.allclose(expect_np, output_np, 0.0001, 0.0001) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_relu_grad_gpu(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| test_relu_grad((4, 3), (4, 3), np.int32) | |||||
| test_relu_grad((12, 1), (12, 1), np.float16) | |||||
| def test_relu_grad_ascend(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| test_relu_grad((4, 3), (4, 3), np.int32) | |||||
| test_relu_grad((12, 1), (12, 1), np.float16) | |||||