Browse Source

add valid check

tags/v1.1.0
Jiaqi 5 years ago
parent
commit
c4ace78d1c
5 changed files with 12 additions and 1 deletions
  1. +2
    -1
      mindspore/common/parameter.py
  2. +3
    -0
      mindspore/nn/optim/adam.py
  3. +2
    -0
      mindspore/nn/optim/ftrl.py
  4. +3
    -0
      mindspore/nn/optim/lazyadam.py
  5. +2
    -0
      mindspore/nn/optim/proximal_ada_grad.py

+ 2
- 1
mindspore/common/parameter.py View File

@@ -316,7 +316,8 @@ class Parameter(MetaTensor_):


Args: Args:
data (Union[Tensor, MetaTensor, int, float]): new data. data (Union[Tensor, MetaTensor, int, float]): new data.
slice_shape (bool): If slice the Parameter, will not check if shape is match. Default: False.
slice_shape (bool): If slice the parameter is set to true, the shape is not checked for consistency.
Default: False.


Returns: Returns:
Parameter, the parameter after set data. Parameter, the parameter after set data.


+ 3
- 0
mindspore/nn/optim/adam.py View File

@@ -336,6 +336,9 @@ class Adam(Optimizer):
def target(self, value): def target(self, value):
"""If the input value is set to "CPU", the parameters will be updated on the host using the Fused """If the input value is set to "CPU", the parameters will be updated on the host using the Fused
optimizer operation.""" optimizer operation."""
if not isinstance(value, str):
raise ValueError("The value must be str type, but got value type is {}".format(type(value)))

if value not in ('CPU', 'Ascend'): if value not in ('CPU', 'Ascend'):
raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(value)) raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(value))




+ 2
- 0
mindspore/nn/optim/ftrl.py View File

@@ -189,6 +189,8 @@ class FTRL(Optimizer):
def target(self, value): def target(self, value):
"""If the input value is set to "CPU", the parameters will be updated on the host using the Fused """If the input value is set to "CPU", the parameters will be updated on the host using the Fused
optimizer operation.""" optimizer operation."""
if not isinstance(value, str):
raise ValueError("The value must be str type, but got value type is {}".format(type(value)))


if value not in ('CPU', 'Ascend'): if value not in ('CPU', 'Ascend'):
raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(value)) raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(value))


+ 3
- 0
mindspore/nn/optim/lazyadam.py View File

@@ -254,6 +254,9 @@ class LazyAdam(Optimizer):
def target(self, value): def target(self, value):
"""If the input value is set to "CPU", the parameters will be updated on the host using the Fused """If the input value is set to "CPU", the parameters will be updated on the host using the Fused
optimizer operation.""" optimizer operation."""
if not isinstance(value, str):
raise ValueError("The value must be str type, but got value type is {}".format(type(value)))

if value not in ('CPU', 'Ascend'): if value not in ('CPU', 'Ascend'):
raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(value)) raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(value))




+ 2
- 0
mindspore/nn/optim/proximal_ada_grad.py View File

@@ -158,6 +158,8 @@ class ProximalAdagrad(Optimizer):
def target(self, value): def target(self, value):
"""If the input value is set to "CPU", the parameters will be updated on the host using the Fused """If the input value is set to "CPU", the parameters will be updated on the host using the Fused
optimizer operation.""" optimizer operation."""
if not isinstance(value, str):
raise ValueError("The value must be str type, but got value type is {}".format(type(value)))
if value not in ('CPU', 'Ascend'): if value not in ('CPU', 'Ascend'):
raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(value)) raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(value))


Loading…
Cancel
Save