From: @shibeiji Reviewed-by: @c_34,@liangchenghui Signed-off-by: @liangchenghuitags/v1.2.0-rc1
| @@ -202,6 +202,8 @@ constexpr const char kNameCase[] = "Case"; | |||
| constexpr const char kNameAssert[] = "Assert"; | |||
| constexpr const char kNameCTCGreedyDecoder[] = "CTCGreedyDecoder"; | |||
| constexpr const char kNameReverseV2[] = "ReverseV2"; | |||
| constexpr const char kNameLambApplyWeightAssign[] = "LambApplyWeightAssign"; | |||
| constexpr const char kNameLambApplyOptimizerAssign[] = "LambApplyOptimizerAssign"; | |||
| class OpAdapterMap { | |||
| public: | |||
| @@ -362,4 +362,24 @@ INPUT_MAP(Atan2) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; | |||
| ATTR_MAP(Atan2) = EMPTY_ATTR_MAP; | |||
| OUTPUT_MAP(Atan2) = {{0, OUTPUT_DESC(y)}}; | |||
| REG_ADPT_DESC(Atan2, kNameAtan2, ADPT_DESC(Atan2)) | |||
| // LambApplyOptimizerAssign | |||
| INPUT_MAP(LambApplyOptimizerAssign) = { | |||
| {1, INPUT_DESC(grad)}, {2, INPUT_DESC(inputv)}, {3, INPUT_DESC(inputm)}, | |||
| {4, INPUT_DESC(input3)}, {5, INPUT_DESC(mul0_x)}, {6, INPUT_DESC(mul1_x)}, | |||
| {7, INPUT_DESC(mul2_x)}, {8, INPUT_DESC(mul3_x)}, {9, INPUT_DESC(add2_y)}, | |||
| {10, INPUT_DESC(steps)}, {11, INPUT_DESC(do_use_weight)}, {12, INPUT_DESC(weight_decay_rate)}}; | |||
| ATTR_MAP(LambApplyOptimizerAssign) = EMPTY_ATTR_MAP; | |||
| OUTPUT_MAP(LambApplyOptimizerAssign) = {{0, OUTPUT_DESC(output0)}, {1, OUTPUT_DESC(inputv)}, {2, OUTPUT_DESC(inputm)}}; | |||
| REG_ADPT_DESC(LambApplyOptimizerAssign, kNameLambApplyOptimizerAssign, ADPT_DESC(LambApplyOptimizerAssign)) | |||
| // LambApplyWeightAssign | |||
| INPUT_MAP(LambApplyWeightAssign) = {{1, INPUT_DESC(input0)}, | |||
| {2, INPUT_DESC(input1)}, | |||
| {3, INPUT_DESC(input2)}, | |||
| {4, INPUT_DESC(input3)}, | |||
| {5, INPUT_DESC(input_param)}}; | |||
| ATTR_MAP(LambApplyWeightAssign) = EMPTY_ATTR_MAP; | |||
| OUTPUT_MAP(LambApplyWeightAssign) = {{0, OUTPUT_DESC(input_param)}}; | |||
| REG_ADPT_DESC(LambApplyWeightAssign, kNameLambApplyWeightAssign, ADPT_DESC(LambApplyWeightAssign)) | |||
| } // namespace mindspore::transform | |||
| @@ -189,5 +189,11 @@ DECLARE_OP_USE_OUTPUT(Round) | |||
| DECLARE_OP_ADAPTER(Atan2) | |||
| DECLARE_OP_USE_OUTPUT(Atan2) | |||
| DECLARE_OP_ADAPTER(LambApplyOptimizerAssign) | |||
| DECLARE_OP_USE_OUTPUT(LambApplyOptimizerAssign) | |||
| DECLARE_OP_ADAPTER(LambApplyWeightAssign) | |||
| DECLARE_OP_USE_OUTPUT(LambApplyWeightAssign) | |||
| } // namespace mindspore::transform | |||
| #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_ELEWISE_CALCULATION_OPS_DECLARE_H_ | |||
| @@ -111,6 +111,52 @@ def _update_run_op(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v | |||
| return op_cast(next_param, F.dtype(param)) | |||
| return gradient | |||
| _lamb_opt_ascend = C.MultitypeFuncGraph("lamb_opt_ascend") | |||
| @_lamb_opt_ascend.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", | |||
| "Tensor", "Bool", "Bool") | |||
| def _update_run_op_ascend(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v, gradient, decay_flag, | |||
| optim_filter): | |||
| """ | |||
| Update parameters function when device target is ascend. | |||
| Args: | |||
| beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0). | |||
| beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0). | |||
| eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0. | |||
| lr (Tensor): Learning rate. | |||
| weight_decay (Number): Weight decay. Should be equal to or greater than 0. | |||
| global_step (Tensor): Global step. | |||
| param (Tensor): Parameters. | |||
| m (Tensor): m value of parameters. | |||
| v (Tensor): v value of parameters. | |||
| gradient (Tensor): Gradient of parameters. | |||
| decay_flag (bool): Specifies whether param update with weight decay. | |||
| optim_filter(bool): Applies parameter update or not. | |||
| Returns: | |||
| Tensor, the new value of v after updating. | |||
| """ | |||
| if optim_filter: | |||
| op_cast = P.Cast() | |||
| op_norm = layer.Norm() | |||
| op_lamb_apply_optimizer_assign = P.LambApplyOptimizerAssign() | |||
| op_lamb_apply_weight_assign = P.LambApplyWeightAssign() | |||
| param_fp32 = op_cast(param, mstype.float32) | |||
| gradient_fp32 = op_cast(gradient, mstype.float32) | |||
| new_global_step = op_cast(global_step + num_one, mstype.float32) | |||
| weight_decay_flag = op_cast(decay_flag, mstype.float32) | |||
| update, _, _ = op_lamb_apply_optimizer_assign(gradient_fp32, v, m, param_fp32, | |||
| beta1, 1.0 - beta1, beta2, 1.0 - beta2, eps, | |||
| new_global_step, weight_decay_flag, weight_decay) | |||
| w_norm = op_norm(param_fp32) | |||
| g_norm = op_norm(update) | |||
| update = F.depend(update, op_lamb_apply_weight_assign(w_norm, g_norm, lr, update, param)) | |||
| return update | |||
| return gradient | |||
| lamb_opt_graph_kernel = C.MultitypeFuncGraph("lamb_opt_graph_kernel") | |||
| @@ -279,6 +325,7 @@ class Lamb(Optimizer): | |||
| self.hyper_map = C.HyperMap() | |||
| self.enable_graph_kernel = context.get_context("device_target") == "Ascend" and \ | |||
| context.get_context("enable_graph_kernel") | |||
| self.device_ascend = context.get_context("device_target") == "Ascend" | |||
| def construct(self, gradients): | |||
| lr = self.get_lr() | |||
| @@ -299,19 +346,20 @@ class Lamb(Optimizer): | |||
| self.global_step, lr, self.weight_decay), | |||
| self.params, self.moments1, self.moments2, gradients, self.decay_flags) | |||
| else: | |||
| lamb_opt = _lamb_opt_ascend if self.device_ascend else _lamb_opt | |||
| if self.is_group: | |||
| if self.is_group_lr: | |||
| optim_result = self.hyper_map(F.partial(_lamb_opt, self.beta1, self.beta2, self.eps, | |||
| optim_result = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps, | |||
| self.global_step), | |||
| lr, self.weight_decay, self.params, self.moments1, self.moments2, | |||
| gradients, self.decay_flags, self.optim_filter) | |||
| else: | |||
| optim_result = self.hyper_map(F.partial(_lamb_opt, self.beta1, self.beta2, self.eps, | |||
| optim_result = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps, | |||
| self.global_step, lr), | |||
| self.weight_decay, self.params, self.moments1, self.moments2, | |||
| gradients, self.decay_flags, self.optim_filter) | |||
| else: | |||
| optim_result = self.hyper_map(F.partial(_lamb_opt, self.beta1, self.beta2, self.eps, | |||
| optim_result = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps, | |||
| self.global_step, lr, self.weight_decay), | |||
| self.params, self.moments1, self.moments2, gradients, | |||
| self.decay_flags, self.optim_filter) | |||
| @@ -351,3 +351,5 @@ from .conv3d import _conv3d_tbe | |||
| from .conv3d_backprop_input import _conv3d_backprop_input_tbe | |||
| from .conv3d_backprop_filter import _conv3d_backprop_filter_tbe | |||
| from .conv3d_transpose import _conv3d_transpose_tbe | |||
| from .lamb_apply_optimizer_assign import _lamb_apply_optimizer_assign_tbe | |||
| from .lamb_apply_weight_assign import _lamb_apply_weight_assign_tbe | |||
| @@ -0,0 +1,55 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """LambApplyOptimizerAssign op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| lamb_apply_optimizer_assign_op_info = TBERegOp("LambApplyOptimizerAssign") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("lamb_apply_optimizer_assign.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("lamb_apply_optimizer_assign") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "grad", False, "required", "all") \ | |||
| .input(1, "inputv", False, "required", "all") \ | |||
| .input(2, "inputm", False, "required", "all") \ | |||
| .input(3, "input3", False, "required", "all") \ | |||
| .input(4, "mul0_x", False, "required", "all") \ | |||
| .input(5, "mul1_x", False, "required", "all") \ | |||
| .input(6, "mul2_x", False, "required", "all") \ | |||
| .input(7, "mul3_x", False, "required", "all") \ | |||
| .input(8, "add2_y", False, "required", "all") \ | |||
| .input(9, "steps", False, "required", "all") \ | |||
| .input(10, "do_use_weight", False, "required", "all") \ | |||
| .input(11, "weight_decay_rate", False, "required", "all") \ | |||
| .output(0, "output0", False, "required", "all") \ | |||
| .output(0, "inputv", False, "required", "all") \ | |||
| .output(0, "inputm", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(lamb_apply_optimizer_assign_op_info) | |||
| def _lamb_apply_optimizer_assign_tbe(): | |||
| """LambApplyOptimizerAssign TBE register""" | |||
| return | |||
| @@ -0,0 +1,42 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """LambApplyWeightAssign op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| lamb_apply_weight_assign_op_info = TBERegOp("LambApplyWeightAssign") \ | |||
| .fusion_type("ELEMWISE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("lamb_apply_weight_assign.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("lamb_apply_weight_assign") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "input0", False, "required", "all") \ | |||
| .input(1, "input1", False, "required", "all") \ | |||
| .input(2, "input2", False, "required", "all") \ | |||
| .input(3, "input3", False, "required", "all") \ | |||
| .input(4, "input_param", False, "required", "all") \ | |||
| .output(0, "input_param", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(lamb_apply_weight_assign_op_info) | |||
| def _lamb_apply_weight_assign_tbe(): | |||
| """LambApplyWeightAssign TBE register""" | |||
| return | |||
| @@ -112,3 +112,15 @@ class LambUpdateWithLR: | |||
| class LambNextMV: | |||
| def __call__(self, *args): | |||
| pass | |||
| @op_selector | |||
| class LambApplyOptimizerAssign: | |||
| def __call__(self, *args): | |||
| pass | |||
| @op_selector | |||
| class LambApplyWeightAssign: | |||
| def __call__(self, *args): | |||
| pass | |||
| @@ -41,7 +41,7 @@ from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, | |||
| from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, | |||
| TensorSummary, HistogramSummary, Print, Assert) | |||
| from .control_ops import ControlDepend, GeSwitch, Merge | |||
| from .inner_ops import ScalarCast, Randperm, NoRepeatNGram | |||
| from .inner_ops import ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign | |||
| from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul, BitwiseAnd, BitwiseOr, | |||
| BitwiseXor, Inv, Invert, ApproximateEqual, InplaceAdd, InplaceSub, | |||
| @@ -172,3 +172,132 @@ class NoRepeatNGram(PrimitiveWithInfer): | |||
| valid_values = (mstype.float16, mstype.float32, mstype.float64) | |||
| validator.check_type_name("log_type", log_type, valid_values, self.name) | |||
| return log_type | |||
| class LambApplyOptimizerAssign(PrimitiveWithInfer): | |||
| r""" | |||
| Updates gradients by LAMB optimizer algorithm. Get the compute ratio. | |||
| The Lamb optimzier is proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes | |||
| <https://arxiv.org/abs/1904.00962>`_. | |||
| The updating formulas are as follows, | |||
| .. math:: | |||
| \begin{array}{ll} \\ | |||
| m = \beta_1 * m + (1 - \beta_1) * g \\ | |||
| v = \beta_2 * v + (1 - \beta_2) * g * g \\ | |||
| m = \frac{m}{1 - \beta_1^t} \\ | |||
| v = \frac{v}{1 - \beta_2^t} \\ | |||
| r = \frac{m}{\sqrt{v} + \epsilon} \\ | |||
| w = w - l * \frac{\left \| w \right \|}{\left \| r \right \|} * (r + \lambda * w)) | |||
| \end{array} | |||
| :math:`m` represents the 1st moment vector, :math:`v` represents the 2nd moment vector, :math:`g` represents | |||
| `gradient`, :math:`l` represents learning rate `lr`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`, | |||
| :math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent `beta1_power` and | |||
| `beta2_power`, :math:`\lambda` represents `weight_decay`, :math:`w` represents `var`, :math:`\epsilon` represents | |||
| `epsilon`. | |||
| Inputs: | |||
| - **gradient** (Tensor) - Gradient of parameters, float32/float16. | |||
| - **v** (Tensor) - the 2nd moment vector in the updating formula, has the same type as `gradient`. | |||
| - **m** (Tensor) - The 1st moment vector in the updating formula, has the same type as `gradient`. | |||
| - **var** (Tensor) - Weights to be updated, has the same type as `gradient`. | |||
| - **beta1** (Tensor) - :math:`beta_1` in the updating formula, float32/float16. | |||
| - **sub1** (Tensor) - :math:`1-beta_1` in the updating formula, has the same type as `beta1`. | |||
| - **beta2** (Tensor) - :math:`beta_2` in the updating formula, has the same type as `beta1`. | |||
| - **sub2** (Tensor) - :math:`1-beta_2` in the updating formula, has the same type as `beta1`. | |||
| - **epsilon** (Tensor) - Term added to the denominator, has the same type as `beta1`. | |||
| - **steps** (Tensor) - :math:`t` in the updating formula, global step, has the same type as `beta1`. | |||
| - **lr** (Tensor) - :math:`l` in the updating formula, learning rate, has the same type as `beta1`. | |||
| - **decay_flag** (Tensor) -Specify whether param upadte with weight decay, has the same type as `beta1`. | |||
| - **weight_decay** (Tensor) - :math:`\lambda` in the updating formula, has the same type as `beta1`. | |||
| Outputs: | |||
| Tensor, the compute ratio r. | |||
| - **update** (Tensor) - :math:`r + \lambda * w` in the updating formula. The same shape and data type as `m`. | |||
| - **v** (Tensor) - the 2nd moment vector in the updating formula after updated inplace, | |||
| has the same type as `gradient`. | |||
| - **m** (Tensor) - The 1st moment vector in the updating formula after updated inplace, | |||
| has the same type as `gradient`. | |||
| Supported Platforms: | |||
| ``Ascend`` | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """Initialize LambApplyOptimizerAssign""" | |||
| def infer_shape(self, grad_shape, v_shape, m_shape, var_shape, beta1_shape, sub1_shape, | |||
| beta2_shape, sub2_shape, eps_shape, steps_shape, use_weight_shape, weight_decay_shape): | |||
| validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name) | |||
| validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name) | |||
| validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name) | |||
| return m_shape, v_shape, m_shape | |||
| def infer_dtype(self, grad_dtype, v_dtype, m_dtype, var_dtype, beta1_dtype, sub1_dtype, | |||
| beta2_dtype, sub2_dtype, eps_dtype, steps_dtype, use_weight_dtype, weight_decay_dtype): | |||
| args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype} | |||
| validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) | |||
| args = {"beta1": beta1_dtype, "sub1": sub1_dtype, "beta2": beta2_dtype, "sub2": sub2_dtype, | |||
| "eps": eps_dtype, "steps": steps_dtype, "use_weight": use_weight_dtype, | |||
| "weight_decay": weight_decay_dtype} | |||
| validator.check_scalar_or_tensor_types_same(args, [mstype.float16, mstype.float32], self.name, True) | |||
| return m_dtype, v_dtype, v_dtype | |||
| class LambApplyWeightAssign(PrimitiveWithInfer): | |||
| r""" | |||
| Updates gradients by LAMB optimizer algorithm. The weight update part. | |||
| The Lamb optimzier is proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes | |||
| <https://arxiv.org/abs/1904.00962>`_. | |||
| The updating formulas are as follows, | |||
| .. math:: | |||
| \begin{array}{ll} \\ | |||
| m = \beta_1 * m + (1 - \beta_1) * g \\ | |||
| v = \beta_2 * v + (1 - \beta_2) * g * g \\ | |||
| m = \frac{m}{1 - \beta_1^t} \\ | |||
| v = \frac{v}{1 - \beta_2^t} \\ | |||
| r = \frac{m}{\sqrt{v} + \epsilon} \\ | |||
| w = w - l * \frac{\left \| w \right \|}{\left \| r \right \|} * (r + \lambda * w)) | |||
| \end{array} | |||
| :math:`m` represents the 1st moment vector, :math:`v` represents the 2nd moment vector, :math:`g` represents | |||
| `gradient`, :math:`l` represents learning rate `lr`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`, | |||
| :math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent `beta1_power` and | |||
| `beta2_power`, :math:`\lambda` represents `weight_decay`, :math:`w` represents `var`, :math:`\epsilon` represents | |||
| `epsilon`. | |||
| Inputs: | |||
| - **w_norm** (Tensor) - :math:`\left \| w \right \|` in the updating formula, float32/float16. | |||
| - **g_norm** (Tensor) - :math:`\left \| r \right \|` in the updating formula, has the same type as `w_norm`. | |||
| - **lr** (Tensor) - :math:`l` in the updating formula, the learning rate, float32/float16. | |||
| - **update** (Tensor) -:math:`r + \lambda * w`in the updating formula, float32/float16. | |||
| - **var** (Tensor) - Weights to be updated, the same shape and type as `update`. | |||
| Outputs: | |||
| - **var** (Tensor) - Weights to be updated in place, the same shape and type as `var` in inputs. | |||
| Supported Platforms: | |||
| ``Ascend`` | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| """Initialize LambApplyWeightAssign""" | |||
| def infer_shape(self, w_norm_shape, g_norm_shape, lr_shape, update_shape, var_shape): | |||
| validator.check("var_shape", var_shape, "update_shape", update_shape, Rel.EQ, self.name) | |||
| return var_shape | |||
| def infer_dtype(self, w_norm_dtype, g_norm_dtype, lr_dtype, update_dtype, var_dtype): | |||
| args = {"var": var_dtype, "update": update_dtype} | |||
| validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) | |||
| args = {"w_norm": w_norm_dtype, "g_norm": g_norm_dtype, "lr": lr_dtype} | |||
| validator.check_scalar_or_tensor_types_same(args, [mstype.float16, mstype.float32], self.name, True) | |||
| return var_dtype | |||
| @@ -229,7 +229,7 @@ def test_bert_performance(): | |||
| # assertion occurs while the loss value, overflow state or loss_scale value is wrong | |||
| loss_value = np.array(callback.loss_list) | |||
| expect_loss_value = [10.235566, 10.207392, 10.206976] | |||
| expect_loss_value = [11.325791, 11.285011, 11.284766] | |||
| print("loss value: {}".format(loss_value)) | |||
| assert np.allclose(loss_value, expect_loss_value, 0, 0.0005) | |||
| @@ -239,7 +239,7 @@ def test_bert_performance(): | |||
| assert (overflow == expect_overflow).all() | |||
| loss_scale = np.array(callback.lossscale_list) | |||
| expect_loss_scale = [262144.0, 262144.0, 262144.0] | |||
| expect_loss_scale = [65536.0, 65536.0, 65536.0] | |||
| print("loss scale: {}".format(loss_scale)) | |||
| assert np.allclose(loss_scale, expect_loss_scale, 0, 0) | |||
| @@ -225,8 +225,12 @@ def test_bert_percision(enable_graph_kernel=False): | |||
| loss_value = np.array(callback.loss_list) | |||
| assert np.allclose(loss_value[0], 12.2065868, 0, 0.000001) | |||
| expect_loss_value = [12.2065868, 11.8651543, 11.8282356, 11.8266964, 11.8210478, 12.4073524, 12.0055466, | |||
| 12.6212320, 12.2229223, 12.4272099] | |||
| if enable_graph_kernel: | |||
| expect_loss_value = [12.2065868, 11.8651543, 11.8282356, 11.8266964, 11.8210478, 12.4073524, 12.0055466, | |||
| 12.6212320, 12.2229223, 12.4272099] | |||
| else: | |||
| expect_loss_value = [12.2065868, 11.94102, 11.931558, 11.938105, 11.932648, 12.556579, 12.130686, 12.783716, | |||
| 12.360179, 12.578461] | |||
| print("loss value: {}".format(loss_value)) | |||
| assert np.allclose(loss_value, expect_loss_value, 0, 0.0005) | |||