Browse Source

!8233 [MS] SGD op extending PrimitiveWithCheck

From: @tom__chen
Reviewed-by: @mikef,@robingrosman
Signed-off-by: @robingrosman
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
c324fec6d4
5 changed files with 15 additions and 5 deletions
  1. +2
    -0
      mindspore/core/abstract/infer_functions.h
  2. +8
    -0
      mindspore/core/abstract/prim_nn.cc
  3. +1
    -0
      mindspore/core/abstract/primitive_infer_map.cc
  4. +1
    -0
      mindspore/core/base/core_ops.h
  5. +3
    -5
      mindspore/ops/operations/nn_ops.py

+ 2
- 0
mindspore/core/abstract/infer_functions.h View File

@@ -220,6 +220,8 @@ AbstractBasePtr InferImplAllGather(const AnalysisEnginePtr &, const PrimitivePtr
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplReduceScatter(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSGD(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);

AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);


+ 8
- 0
mindspore/core/abstract/prim_nn.cc View File

@@ -462,5 +462,13 @@ AbstractBasePtr InferImplSparseApplyProximalAdagrad(const AnalysisEnginePtr &, c
}
return std::make_shared<AbstractTuple>(elements);
}

AbstractBasePtr InferImplSGD(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
CheckArgsSize(primitive->name(), args_spec_list, 6);
AbstractBasePtrList elements;
elements.push_back(args_spec_list[0]->Clone()->Broaden());
return std::make_shared<AbstractTuple>(elements);
}
} // namespace abstract
} // namespace mindspore

+ 1
- 0
mindspore/core/abstract/primitive_infer_map.cc View File

@@ -97,6 +97,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}},
{prim::kPrimSparseApplyFtrl, {InferImplSparseApplyFtrl, true}},
{prim::kPrimSparseApplyProximalAdagrad, {InferImplSparseApplyProximalAdagrad, true}},
{prim::kPrimSGD, {InferImplSGD, true}},
// Others
{prim::kPrimIdentity, {InferImplIdentity, true}},
// Set impl to null as it will use PartialEvaluator;


+ 1
- 0
mindspore/core/base/core_ops.h View File

@@ -176,6 +176,7 @@ inline const PrimitivePtr kPrimSparseApplyFtrl = std::make_shared<Primitive>("Sp
inline const PrimitivePtr kPrimSparseApplyProximalAdagrad = std::make_shared<Primitive>("SparseApplyProximalAdagrad");
inline const PrimitivePtr kPrimFusedAdam = std::make_shared<Primitive>("FusedAdam");
inline const PrimitivePtr kPrimFusedAdamWeightDecay = std::make_shared<Primitive>("FusedAdamWeightDecay");
inline const PrimitivePtr kPrimSGD = std::make_shared<Primitive>("SGD");

// Comm ops
inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");


+ 3
- 5
mindspore/ops/operations/nn_ops.py View File

@@ -2111,7 +2111,7 @@ class RNNTLoss(PrimitiveWithInfer):
return (acts_type, acts_type)


class SGD(PrimitiveWithInfer):
class SGD(PrimitiveWithCheck):
"""
Computes stochastic gradient descent (optionally with momentum).

@@ -2158,7 +2158,7 @@ class SGD(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['parameters', 'gradient', 'learning_rate', 'accum', 'momentum', 'stat'],
outputs=['output'])

def infer_shape(self, parameters_shape, gradient_shape, learning_rate_shape,
def check_shape(self, parameters_shape, gradient_shape, learning_rate_shape,
accum_shape, momentum_shape, stat_shape):
validator.check_positive_int(len(parameters_shape), "parameters rank", self.name)
validator.check_int(len(gradient_shape), 0, Rel.GE, f'gradient rank', self.name)
@@ -2167,15 +2167,13 @@ class SGD(PrimitiveWithInfer):
validator.check_int(len(momentum_shape), 0, Rel.GE, f'momentum rank', self.name)
validator.check_int(len(stat_shape), 0, Rel.GE, f'stat rank', self.name)
validator.check("gradient shape", gradient_shape, "stat shape", stat_shape, Rel.EQ, self.name)
return parameters_shape

def infer_dtype(self, parameters_dtype, gradient_dtype, learning_rate_dtype,
def check_dtype(self, parameters_dtype, gradient_dtype, learning_rate_dtype,
accum_dtype, momentum_dtype, stat_dtype):
tuple(map(partial(validator.check_tensor_dtype_valid,
valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
("parameters", "gradient", "learning_rate", "accum", "momentum", "stat"),
(parameters_dtype, gradient_dtype, learning_rate_dtype, accum_dtype, momentum_dtype, stat_dtype)))
return parameters_dtype


class ApplyRMSProp(PrimitiveWithInfer):


Loading…
Cancel
Save