|
|
|
@@ -216,6 +216,7 @@ class BertTrainWithLossScaleCell(nn.Cell): |
|
|
|
def __init__(self, network, optimizer, scale_update_cell=None): |
|
|
|
super(BertTrainWithLossScaleCell, self).__init__(auto_prefix=False) |
|
|
|
self.network = network |
|
|
|
self.network.set_grad() |
|
|
|
self.weights = optimizer.parameters |
|
|
|
self.optimizer = optimizer |
|
|
|
self.grad = C.GradOperation(get_by_list=True, |
|
|
|
@@ -306,6 +307,7 @@ class BertTrainCell(nn.Cell): |
|
|
|
def __init__(self, network, optimizer, sens=1.0): |
|
|
|
super(BertTrainCell, self).__init__(auto_prefix=False) |
|
|
|
self.network = network |
|
|
|
self.network.set_grad() |
|
|
|
self.weights = optimizer.parameters |
|
|
|
self.optimizer = optimizer |
|
|
|
self.sens = sens |
|
|
|
@@ -470,6 +472,7 @@ class BertEvaluationWithLossScaleCell(nn.Cell): |
|
|
|
def __init__(self, network, optimizer, scale_update_cell=None): |
|
|
|
super(BertEvaluationWithLossScaleCell, self).__init__(auto_prefix=False) |
|
|
|
self.network = network |
|
|
|
self.network.set_grad() |
|
|
|
self.weights = optimizer.parameters |
|
|
|
self.optimizer = optimizer |
|
|
|
self.grad = C.GradOperation(get_by_list=True, |
|
|
|
@@ -556,6 +559,7 @@ class BertEvaluationCell(nn.Cell): |
|
|
|
def __init__(self, network, optimizer, sens=1.0): |
|
|
|
super(BertEvaluationCell, self).__init__(auto_prefix=False) |
|
|
|
self.network = network |
|
|
|
self.network.set_grad() |
|
|
|
self.weights = optimizer.parameters |
|
|
|
self.optimizer = optimizer |
|
|
|
self.sens = sens |
|
|
|
|