Browse Source

Fixing problem issues including class slice example cannot run, adding an example for class SigmoidCrossEntropyWithLogits etc.

tags/v0.2.0-alpha
zhangz0911gm 5 years ago
parent
commit
4ba6f7884d
4 changed files with 32 additions and 4 deletions
  1. +14
    -0
      mindspore/nn/optim/optimizer.py
  2. +11
    -1
      mindspore/nn/optim/sgd.py
  3. +4
    -1
      mindspore/ops/operations/array_ops.py
  4. +3
    -2
      mindspore/ops/operations/nn_ops.py

+ 14
- 0
mindspore/nn/optim/optimizer.py View File

@@ -81,8 +81,22 @@ class Optimizer(Cell):
else:
raise TypeError("Learning rate should be float, Tensor or Iterable.")

if isinstance(weight_decay, int):
weight_decay = float(weight_decay)

if not isinstance(weight_decay, float):
raise TypeError("weight_decay should be a float number!")

if isinstance(loss_scale, int):
loss_scale = float(loss_scale)

if not isinstance(loss_scale, float):
raise TypeError("loss_scale should be a float number!")

if loss_scale <= 0.0:
raise ValueError("Loss scale should be greater than 0, but got {}".format(loss_scale))
self.loss_scale = loss_scale

if weight_decay < 0.0:
raise ValueError("Weight decay should be equal or greater than 0, but got {}".format(weight_decay))



+ 11
- 1
mindspore/nn/optim/sgd.py View File

@@ -61,7 +61,8 @@ class SGD(Optimizer):
dampening (float): A floating point value of dampening for momentum. Default: 0.
weight_decay (float): Weight decay (L2 penalty). Default: 0.
nesterov (bool): Enables the Nesterov momentum. Default: False.
loss_scale (float): A floating point value for the loss scale. Default: 1.0.
loss_scale (float): A floating point value for the loss scale, which should be larger
than 0.0. Default: 1.0.

Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
@@ -83,9 +84,18 @@ class SGD(Optimizer):

super(SGD, self).__init__(learning_rate, params, weight_decay, loss_scale)

if not isinstance(momentum, float):
raise TypeError("momentum should be float number!")

if isinstance(momentum, float) and momentum < 0.0:
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))

if not isinstance(dampening, float):
raise TypeError("dampening should be float number")

if isinstance(dampening, int):
dampening = float(dampening)

if dampening < 0.0:
raise ValueError("dampening should be at least 0.0, but got dampening {}".format(dampening))
self.dampening = dampening


+ 4
- 1
mindspore/ops/operations/array_ops.py View File

@@ -1008,6 +1008,7 @@ class Argmax(PrimitiveWithInfer):

def infer_dtype(self, x_dtype):
validator.check_subclass("input_x", x_dtype, mstype.tensor)
validator.check_typename('input_x', x_dtype, [mstype.float32, mstype.float16])
return mstype.tensor_type(self.output_type)


@@ -1500,7 +1501,9 @@ class Slice(PrimitiveWithInfer):
Tensor.

Examples:
>>> data = Tensor(np.array([3,2,3]).astype(np.int32))
>>> data = Tensor(np.array([[[1, 1, 1], [2, 2, 2]],
>>> [[3, 3, 3], [4, 4, 4]],
>>> [[5, 5, 5], [6, 6, 6]]]).astype(np.int32))
>>> type = P.Slice()(data, (1, 0, 0), (1, 1, 3))
"""



+ 3
- 2
mindspore/ops/operations/nn_ops.py View File

@@ -1436,9 +1436,9 @@ class SGD(PrimitiveWithInfer):
nesterov (bool): Enable Nesterov momentum. Default: False.

Inputs:
- **parameters** (Tensor) - Parameters to be updated.
- **parameters** (Tensor) - Parameters to be updated. Their data type can be list or tuple.
- **gradient** (Tensor) - Gradients.
- **learning_rate** (Tensor) - Learning rate. e.g. Tensor(0.1, mindspore.float32).
- **learning_rate** (Tensor) - Learning rate. Must be float value. e.g. Tensor(0.1, mindspore.float32).
- **accum** (Tensor) - Accum(velocity) to be updated.
- **momentum** (Tensor) - Momentum. e.g. Tensor(0.1, mindspore.float32).
- **stat** (Tensor) - States to be updated with the same shape as gradient.
@@ -1449,6 +1449,7 @@ class SGD(PrimitiveWithInfer):

@prim_attr_register
def __init__(self, dampening=0.0, weight_decay=0.0, nesterov=False):
validator.check_type("nesterov", nesterov, [bool])
self.init_prim_io_names(inputs=['parameters', 'gradient', 'learning_rate', 'accum', 'momentum', 'stat'],
outputs=['output'])



Loading…
Cancel
Save