Browse Source

!8261 gpu support heterogeneous network

From: @wilfchen
Reviewed-by: @limingqi107,@cristoval
Signed-off-by: @cristoval
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
71a3086fff
5 changed files with 10 additions and 11 deletions
  1. +2
    -2
      mindspore/nn/optim/adam.py
  2. +2
    -2
      mindspore/nn/optim/ftrl.py
  3. +2
    -2
      mindspore/nn/optim/lazyadam.py
  4. +2
    -3
      mindspore/nn/optim/optimizer.py
  5. +2
    -2
      mindspore/nn/optim/proximal_ada_grad.py

+ 2
- 2
mindspore/nn/optim/adam.py View File

@@ -350,8 +350,8 @@ class Adam(Optimizer):
if not isinstance(value, str): if not isinstance(value, str):
raise TypeError("The value must be str type, but got value type is {}".format(type(value))) raise TypeError("The value must be str type, but got value type is {}".format(type(value)))


if value not in ('CPU', 'Ascend'):
raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(value))
if value not in ('CPU', 'Ascend', 'GPU'):
raise ValueError("The value must be 'CPU', 'Ascend' or 'GPU', but got value {}".format(value))


self._is_device = (value != 'CPU') self._is_device = (value != 'CPU')
self._target = value self._target = value


+ 2
- 2
mindspore/nn/optim/ftrl.py View File

@@ -192,8 +192,8 @@ class FTRL(Optimizer):
if not isinstance(value, str): if not isinstance(value, str):
raise TypeError("The value must be str type, but got value type is {}".format(type(value))) raise TypeError("The value must be str type, but got value type is {}".format(type(value)))


if value not in ('CPU', 'Ascend'):
raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(value))
if value not in ('CPU', 'Ascend', 'GPU'):
raise ValueError("The value must be 'CPU', 'Ascend' or 'GPU', but got value {}".format(value))


if value == 'CPU': if value == 'CPU':
self.sparse_opt = P.FusedSparseFtrl(self.lr, self.l1, self.l2, self.lr_power, self.use_locking) self.sparse_opt = P.FusedSparseFtrl(self.lr, self.l1, self.l2, self.lr_power, self.use_locking)


+ 2
- 2
mindspore/nn/optim/lazyadam.py View File

@@ -257,8 +257,8 @@ class LazyAdam(Optimizer):
if not isinstance(value, str): if not isinstance(value, str):
raise TypeError("The value must be str type, but got value type is {}".format(type(value))) raise TypeError("The value must be str type, but got value type is {}".format(type(value)))


if value not in ('CPU', 'Ascend'):
raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(value))
if value not in ('CPU', 'Ascend', 'GPU'):
raise ValueError("The value must be 'CPU', 'Ascend' or 'GPU', but got value {}".format(value))


self._is_device = (value != 'CPU') self._is_device = (value != 'CPU')
self._target = value self._target = value

+ 2
- 3
mindspore/nn/optim/optimizer.py View File

@@ -105,7 +105,7 @@ class Optimizer(Cell):
weight_decay = self._preprocess_weight_decay(weight_decay) weight_decay = self._preprocess_weight_decay(weight_decay)


self._unique = True self._unique = True
self._target = 'Ascend'
self._target = context.get_context("device_target")
self.dynamic_lr = False self.dynamic_lr = False
self.assignadd = None self.assignadd = None
self.global_step = None self.global_step = None
@@ -194,8 +194,7 @@ class Optimizer(Cell):
@property @property
def target(self): def target(self):
"""The method is used to determine whether the parameter is updated on host or device. The input type is str """The method is used to determine whether the parameter is updated on host or device. The input type is str
and can only be 'CPU' and 'Ascend'. In GPU environment, users can only configure value as 'CPU'.
The method is read-only."""
and can only be 'CPU', 'Ascend' or 'GPU'."""
return self._target return self._target


@target.setter @target.setter


+ 2
- 2
mindspore/nn/optim/proximal_ada_grad.py View File

@@ -161,8 +161,8 @@ class ProximalAdagrad(Optimizer):
if not isinstance(value, str): if not isinstance(value, str):
raise TypeError("The value must be str type, but got value type is {}".format(type(value))) raise TypeError("The value must be str type, but got value type is {}".format(type(value)))
if value not in ('CPU', 'Ascend'):
raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(value))
if value not in ('CPU', 'Ascend', 'GPU'):
raise ValueError("The value must be 'CPU', 'Ascend' or 'GPU', but got value {}".format(value))
if value == 'CPU': if value == 'CPU':
self.sparse_opt = P.FusedSparseProximalAdagrad(self.use_locking).add_prim_attr("primitive_target", "CPU") self.sparse_opt = P.FusedSparseProximalAdagrad(self.use_locking).add_prim_attr("primitive_target", "CPU")


Loading…
Cancel
Save