| @@ -277,8 +277,8 @@ class Lamb(Optimizer): | |||||
| self.global_step = Parameter(initializer(0, [1]), name='global_step') | self.global_step = Parameter(initializer(0, [1]), name='global_step') | ||||
| self.assignadd = P.AssignAdd() | self.assignadd = P.AssignAdd() | ||||
| self.hyper_map = C.HyperMap() | self.hyper_map = C.HyperMap() | ||||
| self.enable_graph_kernel = context.get_context("enable_graph_kernel") and \ | |||||
| context.get_context("device_target") == "Ascend" | |||||
| self.enable_graph_kernel = context.get_context("device_target") == "Ascend" and \ | |||||
| context.get_context("enable_graph_kernel") | |||||
| def construct(self, gradients): | def construct(self, gradients): | ||||
| lr = self.get_lr() | lr = self.get_lr() | ||||
| @@ -56,7 +56,7 @@ class _OpSelector: | |||||
| def __call__(self, *args, **kwargs): | def __call__(self, *args, **kwargs): | ||||
| _op_type = _OpSelector.DEFAULT_OP_TYPE | _op_type = _OpSelector.DEFAULT_OP_TYPE | ||||
| if context.get_context("enable_graph_kernel"): | |||||
| if context.get_context("device_target") in ['Ascend', 'GPU'] and context.get_context("enable_graph_kernel"): | |||||
| if _OpSelector.KW_STR in kwargs: | if _OpSelector.KW_STR in kwargs: | ||||
| _op_type = kwargs.get(_OpSelector.KW_STR) | _op_type = kwargs.get(_OpSelector.KW_STR) | ||||
| kwargs.pop(_OpSelector.KW_STR, None) | kwargs.pop(_OpSelector.KW_STR, None) | ||||