From: @tom__chen Reviewed-by: @mikef,@robingrosman Signed-off-by: @robingrosmantags/v1.1.0
| @@ -220,6 +220,8 @@ AbstractBasePtr InferImplAllGather(const AnalysisEnginePtr &, const PrimitivePtr | |||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplReduceScatter(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplReduceScatter(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplSGD(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| @@ -462,5 +462,13 @@ AbstractBasePtr InferImplSparseApplyProximalAdagrad(const AnalysisEnginePtr &, c | |||||
| } | } | ||||
| return std::make_shared<AbstractTuple>(elements); | return std::make_shared<AbstractTuple>(elements); | ||||
| } | } | ||||
| AbstractBasePtr InferImplSGD(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| CheckArgsSize(primitive->name(), args_spec_list, 6); | |||||
| AbstractBasePtrList elements; | |||||
| elements.push_back(args_spec_list[0]->Clone()->Broaden()); | |||||
| return std::make_shared<AbstractTuple>(elements); | |||||
| } | |||||
| } // namespace abstract | } // namespace abstract | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -97,6 +97,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||||
| {prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}}, | {prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}}, | ||||
| {prim::kPrimSparseApplyFtrl, {InferImplSparseApplyFtrl, true}}, | {prim::kPrimSparseApplyFtrl, {InferImplSparseApplyFtrl, true}}, | ||||
| {prim::kPrimSparseApplyProximalAdagrad, {InferImplSparseApplyProximalAdagrad, true}}, | {prim::kPrimSparseApplyProximalAdagrad, {InferImplSparseApplyProximalAdagrad, true}}, | ||||
| {prim::kPrimSGD, {InferImplSGD, true}}, | |||||
| // Others | // Others | ||||
| {prim::kPrimIdentity, {InferImplIdentity, true}}, | {prim::kPrimIdentity, {InferImplIdentity, true}}, | ||||
| // Set impl to null as it will use PartialEvaluator; | // Set impl to null as it will use PartialEvaluator; | ||||
| @@ -176,6 +176,7 @@ inline const PrimitivePtr kPrimSparseApplyFtrl = std::make_shared<Primitive>("Sp | |||||
| inline const PrimitivePtr kPrimSparseApplyProximalAdagrad = std::make_shared<Primitive>("SparseApplyProximalAdagrad"); | inline const PrimitivePtr kPrimSparseApplyProximalAdagrad = std::make_shared<Primitive>("SparseApplyProximalAdagrad"); | ||||
| inline const PrimitivePtr kPrimFusedAdam = std::make_shared<Primitive>("FusedAdam"); | inline const PrimitivePtr kPrimFusedAdam = std::make_shared<Primitive>("FusedAdam"); | ||||
| inline const PrimitivePtr kPrimFusedAdamWeightDecay = std::make_shared<Primitive>("FusedAdamWeightDecay"); | inline const PrimitivePtr kPrimFusedAdamWeightDecay = std::make_shared<Primitive>("FusedAdamWeightDecay"); | ||||
| inline const PrimitivePtr kPrimSGD = std::make_shared<Primitive>("SGD"); | |||||
| // Comm ops | // Comm ops | ||||
| inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); | inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator"); | ||||
| @@ -2111,7 +2111,7 @@ class RNNTLoss(PrimitiveWithInfer): | |||||
| return (acts_type, acts_type) | return (acts_type, acts_type) | ||||
| class SGD(PrimitiveWithInfer): | |||||
| class SGD(PrimitiveWithCheck): | |||||
| """ | """ | ||||
| Computes stochastic gradient descent (optionally with momentum). | Computes stochastic gradient descent (optionally with momentum). | ||||
| @@ -2158,7 +2158,7 @@ class SGD(PrimitiveWithInfer): | |||||
| self.init_prim_io_names(inputs=['parameters', 'gradient', 'learning_rate', 'accum', 'momentum', 'stat'], | self.init_prim_io_names(inputs=['parameters', 'gradient', 'learning_rate', 'accum', 'momentum', 'stat'], | ||||
| outputs=['output']) | outputs=['output']) | ||||
| def infer_shape(self, parameters_shape, gradient_shape, learning_rate_shape, | |||||
| def check_shape(self, parameters_shape, gradient_shape, learning_rate_shape, | |||||
| accum_shape, momentum_shape, stat_shape): | accum_shape, momentum_shape, stat_shape): | ||||
| validator.check_positive_int(len(parameters_shape), "parameters rank", self.name) | validator.check_positive_int(len(parameters_shape), "parameters rank", self.name) | ||||
| validator.check_int(len(gradient_shape), 0, Rel.GE, f'gradient rank', self.name) | validator.check_int(len(gradient_shape), 0, Rel.GE, f'gradient rank', self.name) | ||||
| @@ -2167,15 +2167,13 @@ class SGD(PrimitiveWithInfer): | |||||
| validator.check_int(len(momentum_shape), 0, Rel.GE, f'momentum rank', self.name) | validator.check_int(len(momentum_shape), 0, Rel.GE, f'momentum rank', self.name) | ||||
| validator.check_int(len(stat_shape), 0, Rel.GE, f'stat rank', self.name) | validator.check_int(len(stat_shape), 0, Rel.GE, f'stat rank', self.name) | ||||
| validator.check("gradient shape", gradient_shape, "stat shape", stat_shape, Rel.EQ, self.name) | validator.check("gradient shape", gradient_shape, "stat shape", stat_shape, Rel.EQ, self.name) | ||||
| return parameters_shape | |||||
| def infer_dtype(self, parameters_dtype, gradient_dtype, learning_rate_dtype, | |||||
| def check_dtype(self, parameters_dtype, gradient_dtype, learning_rate_dtype, | |||||
| accum_dtype, momentum_dtype, stat_dtype): | accum_dtype, momentum_dtype, stat_dtype): | ||||
| tuple(map(partial(validator.check_tensor_dtype_valid, | tuple(map(partial(validator.check_tensor_dtype_valid, | ||||
| valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), | valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), | ||||
| ("parameters", "gradient", "learning_rate", "accum", "momentum", "stat"), | ("parameters", "gradient", "learning_rate", "accum", "momentum", "stat"), | ||||
| (parameters_dtype, gradient_dtype, learning_rate_dtype, accum_dtype, momentum_dtype, stat_dtype))) | (parameters_dtype, gradient_dtype, learning_rate_dtype, accum_dtype, momentum_dtype, stat_dtype))) | ||||
| return parameters_dtype | |||||
| class ApplyRMSProp(PrimitiveWithInfer): | class ApplyRMSProp(PrimitiveWithInfer): | ||||