Browse Source

add valid check

tags/v1.1.0
Jiaqi 5 years ago
parent
commit
c4ace78d1c
5 changed files with 12 additions and 1 deletions
  1. +2
    -1
      mindspore/common/parameter.py
  2. +3
    -0
      mindspore/nn/optim/adam.py
  3. +2
    -0
      mindspore/nn/optim/ftrl.py
  4. +3
    -0
      mindspore/nn/optim/lazyadam.py
  5. +2
    -0
      mindspore/nn/optim/proximal_ada_grad.py

+ 2
- 1
mindspore/common/parameter.py View File

@@ -316,7 +316,8 @@ class Parameter(MetaTensor_):

Args:
data (Union[Tensor, MetaTensor, int, float]): new data.
slice_shape (bool): If slice the Parameter, will not check if shape is match. Default: False.
slice_shape (bool): If slice the parameter is set to true, the shape is not checked for consistency.
Default: False.

Returns:
Parameter, the parameter after set data.


+ 3
- 0
mindspore/nn/optim/adam.py View File

@@ -336,6 +336,9 @@ class Adam(Optimizer):
def target(self, value):
"""If the input value is set to "CPU", the parameters will be updated on the host using the Fused
optimizer operation."""
if not isinstance(value, str):
raise ValueError("The value must be str type, but got value type is {}".format(type(value)))

if value not in ('CPU', 'Ascend'):
raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(value))



+ 2
- 0
mindspore/nn/optim/ftrl.py View File

@@ -189,6 +189,8 @@ class FTRL(Optimizer):
def target(self, value):
"""If the input value is set to "CPU", the parameters will be updated on the host using the Fused
optimizer operation."""
if not isinstance(value, str):
raise ValueError("The value must be str type, but got value type is {}".format(type(value)))

if value not in ('CPU', 'Ascend'):
raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(value))


+ 3
- 0
mindspore/nn/optim/lazyadam.py View File

@@ -254,6 +254,9 @@ class LazyAdam(Optimizer):
def target(self, value):
"""If the input value is set to "CPU", the parameters will be updated on the host using the Fused
optimizer operation."""
if not isinstance(value, str):
raise ValueError("The value must be str type, but got value type is {}".format(type(value)))

if value not in ('CPU', 'Ascend'):
raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(value))



+ 2
- 0
mindspore/nn/optim/proximal_ada_grad.py View File

@@ -158,6 +158,8 @@ class ProximalAdagrad(Optimizer):
def target(self, value):
"""If the input value is set to "CPU", the parameters will be updated on the host using the Fused
optimizer operation."""
if not isinstance(value, str):
raise ValueError("The value must be str type, but got value type is {}".format(type(value)))
if value not in ('CPU', 'Ascend'):
raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(value))


Loading…
Cancel
Save