|
|
|
@@ -277,8 +277,8 @@ class Lamb(Optimizer): |
|
|
|
self.global_step = Parameter(initializer(0, [1]), name='global_step') |
|
|
|
self.assignadd = P.AssignAdd() |
|
|
|
self.hyper_map = C.HyperMap() |
|
|
|
self.enable_graph_kernel = context.get_context("enable_graph_kernel") and \ |
|
|
|
context.get_context("device_target") == "Ascend" |
|
|
|
self.enable_graph_kernel = context.get_context("device_target") == "Ascend" and \ |
|
|
|
context.get_context("enable_graph_kernel") |
|
|
|
|
|
|
|
def construct(self, gradients): |
|
|
|
lr = self.get_lr() |
|
|
|
|