|
|
|
@@ -2111,7 +2111,7 @@ class RNNTLoss(PrimitiveWithInfer): |
|
|
|
return (acts_type, acts_type) |
|
|
|
|
|
|
|
|
|
|
|
class SGD(PrimitiveWithInfer): |
|
|
|
class SGD(PrimitiveWithCheck): |
|
|
|
""" |
|
|
|
Computes stochastic gradient descent (optionally with momentum). |
|
|
|
|
|
|
|
@@ -2158,7 +2158,7 @@ class SGD(PrimitiveWithInfer): |
|
|
|
self.init_prim_io_names(inputs=['parameters', 'gradient', 'learning_rate', 'accum', 'momentum', 'stat'], |
|
|
|
outputs=['output']) |
|
|
|
|
|
|
|
def infer_shape(self, parameters_shape, gradient_shape, learning_rate_shape, |
|
|
|
def check_shape(self, parameters_shape, gradient_shape, learning_rate_shape, |
|
|
|
accum_shape, momentum_shape, stat_shape): |
|
|
|
validator.check_positive_int(len(parameters_shape), "parameters rank", self.name) |
|
|
|
validator.check_int(len(gradient_shape), 0, Rel.GE, f'gradient rank', self.name) |
|
|
|
@@ -2167,15 +2167,13 @@ class SGD(PrimitiveWithInfer): |
|
|
|
validator.check_int(len(momentum_shape), 0, Rel.GE, f'momentum rank', self.name) |
|
|
|
validator.check_int(len(stat_shape), 0, Rel.GE, f'stat rank', self.name) |
|
|
|
validator.check("gradient shape", gradient_shape, "stat shape", stat_shape, Rel.EQ, self.name) |
|
|
|
return parameters_shape |
|
|
|
|
|
|
|
def infer_dtype(self, parameters_dtype, gradient_dtype, learning_rate_dtype, |
|
|
|
def check_dtype(self, parameters_dtype, gradient_dtype, learning_rate_dtype, |
|
|
|
accum_dtype, momentum_dtype, stat_dtype): |
|
|
|
tuple(map(partial(validator.check_tensor_dtype_valid, |
|
|
|
valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), |
|
|
|
("parameters", "gradient", "learning_rate", "accum", "momentum", "stat"), |
|
|
|
(parameters_dtype, gradient_dtype, learning_rate_dtype, accum_dtype, momentum_dtype, stat_dtype))) |
|
|
|
return parameters_dtype |
|
|
|
|
|
|
|
|
|
|
|
class ApplyRMSProp(PrimitiveWithInfer): |
|
|
|
|