|
|
|
@@ -3760,12 +3760,12 @@ class ApplyAdagradV2(PrimitiveWithInfer): |
|
|
|
update_slots (bool): If `True`, `accum` will be updated. Default: True. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **var** (Parameter) - Variable to be updated. With float32 or float16 data type. |
|
|
|
- **var** (Parameter) - Variable to be updated. With float32 data type. |
|
|
|
- **accum** (Parameter) - Accum to be updated. The shape and dtype should be the same as `var`. |
|
|
|
With float32 or float16 data type. |
|
|
|
- **lr** (Union[Number, Tensor]) - The learning rate value, should be scalar. With float32 or float16 data type. |
|
|
|
With float32 data type. |
|
|
|
- **lr** (Union[Number, Tensor]) - The learning rate value, should be scalar. With float32 data type. |
|
|
|
- **grad** (Tensor) - A tensor for gradient. The shape and dtype should be the same as `var`. |
|
|
|
With float32 or float16 data type. |
|
|
|
With float32 data type. |
|
|
|
|
|
|
|
Outputs: |
|
|
|
Tuple of 2 Tensor, the updated parameters. |
|
|
|
@@ -3817,9 +3817,8 @@ class ApplyAdagradV2(PrimitiveWithInfer): |
|
|
|
|
|
|
|
def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, grad_dtype): |
|
|
|
args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype} |
|
|
|
valid_types = [mstype.float16, mstype.float32] |
|
|
|
validator.check_tensor_type_same(args, valid_types, self.name) |
|
|
|
validator.check_scalar_or_tensor_type_same({'lr': lr_dtype}, valid_types, self.name) |
|
|
|
validator.check_tensor_type_same(args, [mstype.float32], self.name) |
|
|
|
validator.check_scalar_or_tensor_type_same({'lr': lr_dtype}, [mstype.float32], self.name) |
|
|
|
return var_dtype, accum_dtype |
|
|
|
|
|
|
|
|
|
|
|
|