|
|
|
@@ -377,21 +377,16 @@ class BertTrainWithLossScaleCell(nn.Cell): |
|
|
|
sens=None): |
|
|
|
"""Defines the computation performed.""" |
|
|
|
weights = self.weights |
|
|
|
saved = () |
|
|
|
for i in range(self.length): |
|
|
|
saved = saved + (F.assign(self.saved_params[i], weights[i]),) |
|
|
|
F.assign(self.saved_params[i], weights[i]) |
|
|
|
|
|
|
|
for i in range(self.quant_embedding_list_length): |
|
|
|
quant_embedding = self.quantize_embedding(weights[self.quant_embedding_list[i]]) |
|
|
|
quant_embedding = F.depend(quant_embedding, saved) |
|
|
|
assign_embedding = F.assign(weights[self.quant_embedding_list[i]], quant_embedding) |
|
|
|
input_ids = F.depend(input_ids, assign_embedding) |
|
|
|
F.assign(weights[self.quant_embedding_list[i]], quant_embedding) |
|
|
|
|
|
|
|
for i in range(self.quant_weight_list_length): |
|
|
|
quant_weight = self.quantize_weight(weights[self.quant_weight_list[i]]) |
|
|
|
quant_weight = F.depend(quant_weight, saved) |
|
|
|
assign_weight = F.assign(weights[self.quant_weight_list[i]], quant_weight) |
|
|
|
input_ids = F.depend(input_ids, assign_weight) |
|
|
|
F.assign(weights[self.quant_weight_list[i]], quant_weight) |
|
|
|
|
|
|
|
if sens is None: |
|
|
|
scaling_sens = self.loss_scale |
|
|
|
@@ -411,10 +406,10 @@ class BertTrainWithLossScaleCell(nn.Cell): |
|
|
|
grads = self.grad_reducer(grads) |
|
|
|
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) |
|
|
|
grads = self.hyper_map(F.partial(clip_grad, self.clip_type, self.clip_value), grads) |
|
|
|
restore = () |
|
|
|
|
|
|
|
for i in range(self.length): |
|
|
|
weights[i] = F.depend(weights[i], grads) |
|
|
|
restore = restore + (F.assign(weights[i], self.saved_params[i]),) |
|
|
|
param = F.depend(self.saved_params[i], grads) |
|
|
|
F.assign(weights[i], param) |
|
|
|
|
|
|
|
self.get_status(init) |
|
|
|
flag_sum = self.reduce_sum(init, (0,)) |
|
|
|
@@ -431,7 +426,6 @@ class BertTrainWithLossScaleCell(nn.Cell): |
|
|
|
succ = False |
|
|
|
else: |
|
|
|
succ = self.optimizer(grads) |
|
|
|
succ = F.depend(succ, restore) |
|
|
|
return succ |
|
|
|
|
|
|
|
|
|
|
|
@@ -490,21 +484,16 @@ class BertTrainCell(nn.Cell): |
|
|
|
label_ids): |
|
|
|
"""Defines the computation performed.""" |
|
|
|
weights = self.weights |
|
|
|
saved = () |
|
|
|
for i in range(self.length): |
|
|
|
saved = saved + (F.assign(self.saved_params[i], weights[i]),) |
|
|
|
F.assign(self.saved_params[i], weights[i]) |
|
|
|
|
|
|
|
for i in range(self.quant_embedding_list_length): |
|
|
|
quant_embedding = self.quantize_embedding(weights[self.quant_embedding_list[i]]) |
|
|
|
quant_embedding = F.depend(quant_embedding, saved) |
|
|
|
assign_embedding = F.assign(weights[self.quant_embedding_list[i]], quant_embedding) |
|
|
|
input_ids = F.depend(input_ids, assign_embedding) |
|
|
|
F.assign(weights[self.quant_embedding_list[i]], quant_embedding) |
|
|
|
|
|
|
|
for i in range(self.quant_weight_list_length): |
|
|
|
quant_weight = self.quantize_weight(weights[self.quant_weight_list[i]]) |
|
|
|
quant_weight = F.depend(quant_weight, saved) |
|
|
|
assign_weight = F.assign(weights[self.quant_weight_list[i]], quant_weight) |
|
|
|
input_ids = F.depend(input_ids, assign_weight) |
|
|
|
F.assign(weights[self.quant_weight_list[i]], quant_weight) |
|
|
|
|
|
|
|
grads = self.grad(self.network, weights)(input_ids, |
|
|
|
input_mask, |
|
|
|
@@ -515,11 +504,10 @@ class BertTrainCell(nn.Cell): |
|
|
|
# apply grad reducer on grads |
|
|
|
grads = self.grad_reducer(grads) |
|
|
|
grads = self.hyper_map(F.partial(clip_grad, self.clip_type, self.clip_value), grads) |
|
|
|
restore = () |
|
|
|
|
|
|
|
for i in range(self.length): |
|
|
|
weights[i] = F.depend(weights[i], grads) |
|
|
|
restore = restore + (F.assign(weights[i], self.saved_params[i]),) |
|
|
|
param = F.depend(self.saved_params[i], grads) |
|
|
|
F.assign(weights[i], param) |
|
|
|
|
|
|
|
succ = self.optimizer(grads) |
|
|
|
succ = F.depend(succ, restore) |
|
|
|
return succ |