|
|
|
@@ -481,7 +481,7 @@ class FusedCastAdamWeightDecay(PrimitiveWithInfer): |
|
|
|
|
|
|
|
:math:`m` represents the 1st moment vector, :math:`v` represents the 2nd moment vector, :math:`g` represents |
|
|
|
`gradient`, :math:`\beta_1, \beta_2` represent `beta1` and `beta2`, |
|
|
|
:math:`\lr` represents `learning_rate`, :math:`w` represents `var`, :math:`decay` represents `weight_decay`, |
|
|
|
:math:`lr` represents `learning_rate`, :math:`w` represents `var`, :math:`decay` represents `weight_decay`, |
|
|
|
:math:`\epsilon` represents `epsilon`. |
|
|
|
|
|
|
|
Args: |
|
|
|
@@ -547,12 +547,13 @@ class FusedCastAdamWeightDecay(PrimitiveWithInfer): |
|
|
|
return var_shape, m_shape, v_shape |
|
|
|
|
|
|
|
def infer_dtype(self, var_dtype, m_dtype, v_dtype, lr_dtype, beta1_dtype, beta2_dtype, |
|
|
|
epsilon_dtype, decay, grad_dtype): |
|
|
|
epsilon_dtype, decay_dtype, grad_dtype): |
|
|
|
args = {"m": m_dtype, "v": v_dtype} |
|
|
|
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) |
|
|
|
validator.check_scalar_or_tensor_types_same({"var": var_dtype}, [mstype.float16, mstype.float32], self.name) |
|
|
|
validator.check_scalar_or_tensor_types_same({"grad": grad_dtype}, [mstype.float16], self.name) |
|
|
|
|
|
|
|
args = {"lr": lr_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype, "decay": decay} |
|
|
|
args = {"lr": lr_dtype, "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype, |
|
|
|
"decay": decay_dtype} |
|
|
|
validator.check_scalar_or_tensor_types_same(args, [mstype.float32], self.name, True) |
|
|
|
return var_dtype, m_dtype, v_dtype |