Browse Source

fix warning for enable_graph_kernel context in CPU device

tags/v1.2.0-rc1
dayschan 5 years ago
parent
commit
a661c3dd40
2 changed files with 3 additions and 3 deletions
  1. +2
    -2
      mindspore/nn/optim/lamb.py
  2. +1
    -1
      mindspore/ops/op_selector.py

+ 2
- 2
mindspore/nn/optim/lamb.py View File

@@ -277,8 +277,8 @@ class Lamb(Optimizer):
self.global_step = Parameter(initializer(0, [1]), name='global_step')
self.assignadd = P.AssignAdd()
self.hyper_map = C.HyperMap()
self.enable_graph_kernel = context.get_context("enable_graph_kernel") and \
context.get_context("device_target") == "Ascend"
self.enable_graph_kernel = context.get_context("device_target") == "Ascend" and \
context.get_context("enable_graph_kernel")

def construct(self, gradients):
lr = self.get_lr()


+ 1
- 1
mindspore/ops/op_selector.py View File

@@ -56,7 +56,7 @@ class _OpSelector:

def __call__(self, *args, **kwargs):
_op_type = _OpSelector.DEFAULT_OP_TYPE
if context.get_context("enable_graph_kernel"):
if context.get_context("device_target") in ['Ascend', 'GPU'] and context.get_context("enable_graph_kernel"):
if _OpSelector.KW_STR in kwargs:
_op_type = kwargs.get(_OpSelector.KW_STR)
kwargs.pop(_OpSelector.KW_STR, None)


Loading…
Cancel
Save