Browse Source

remove ControlDepend from ternarybert

pull/13877/head
huangbingjian 4 years ago
parent
commit
4ef7251e7f
2 changed files with 15 additions and 27 deletions
  1. +3
    -3
      mindspore/ops/operations/other_ops.py
  2. +12
    -24
      model_zoo/research/nlp/ternarybert/src/cell_wrapper.py

+ 3
- 3
mindspore/ops/operations/other_ops.py View File

@@ -423,9 +423,9 @@ class Depend(Primitive):
In order to ensure that operator A is executed before operator B, it is recommended to
insert the Depend operator between operators A and B. The usage method is as follows::

out_a = A(in_a)
in_b = Depend(in_b, out_a)
out_b = B(in_b)
a = A(x) ---> a = A(x)
b = B(y) ---> y = Depend(y, a)
---> b = B(y)

Inputs:
- **value** (Tensor) - the real value to return for depend operator.


+ 12
- 24
model_zoo/research/nlp/ternarybert/src/cell_wrapper.py View File

@@ -377,21 +377,16 @@ class BertTrainWithLossScaleCell(nn.Cell):
sens=None):
"""Defines the computation performed."""
weights = self.weights
saved = ()
for i in range(self.length):
saved = saved + (F.assign(self.saved_params[i], weights[i]),)
F.assign(self.saved_params[i], weights[i])

for i in range(self.quant_embedding_list_length):
quant_embedding = self.quantize_embedding(weights[self.quant_embedding_list[i]])
quant_embedding = F.depend(quant_embedding, saved)
assign_embedding = F.assign(weights[self.quant_embedding_list[i]], quant_embedding)
input_ids = F.depend(input_ids, assign_embedding)
F.assign(weights[self.quant_embedding_list[i]], quant_embedding)

for i in range(self.quant_weight_list_length):
quant_weight = self.quantize_weight(weights[self.quant_weight_list[i]])
quant_weight = F.depend(quant_weight, saved)
assign_weight = F.assign(weights[self.quant_weight_list[i]], quant_weight)
input_ids = F.depend(input_ids, assign_weight)
F.assign(weights[self.quant_weight_list[i]], quant_weight)

if sens is None:
scaling_sens = self.loss_scale
@@ -411,10 +406,10 @@ class BertTrainWithLossScaleCell(nn.Cell):
grads = self.grad_reducer(grads)
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
grads = self.hyper_map(F.partial(clip_grad, self.clip_type, self.clip_value), grads)
restore = ()
for i in range(self.length):
weights[i] = F.depend(weights[i], grads)
restore = restore + (F.assign(weights[i], self.saved_params[i]),)
param = F.depend(self.saved_params[i], grads)
F.assign(weights[i], param)

self.get_status(init)
flag_sum = self.reduce_sum(init, (0,))
@@ -431,7 +426,6 @@ class BertTrainWithLossScaleCell(nn.Cell):
succ = False
else:
succ = self.optimizer(grads)
succ = F.depend(succ, restore)
return succ


@@ -490,21 +484,16 @@ class BertTrainCell(nn.Cell):
label_ids):
"""Defines the computation performed."""
weights = self.weights
saved = ()
for i in range(self.length):
saved = saved + (F.assign(self.saved_params[i], weights[i]),)
F.assign(self.saved_params[i], weights[i])

for i in range(self.quant_embedding_list_length):
quant_embedding = self.quantize_embedding(weights[self.quant_embedding_list[i]])
quant_embedding = F.depend(quant_embedding, saved)
assign_embedding = F.assign(weights[self.quant_embedding_list[i]], quant_embedding)
input_ids = F.depend(input_ids, assign_embedding)
F.assign(weights[self.quant_embedding_list[i]], quant_embedding)

for i in range(self.quant_weight_list_length):
quant_weight = self.quantize_weight(weights[self.quant_weight_list[i]])
quant_weight = F.depend(quant_weight, saved)
assign_weight = F.assign(weights[self.quant_weight_list[i]], quant_weight)
input_ids = F.depend(input_ids, assign_weight)
F.assign(weights[self.quant_weight_list[i]], quant_weight)

grads = self.grad(self.network, weights)(input_ids,
input_mask,
@@ -515,11 +504,10 @@ class BertTrainCell(nn.Cell):
# apply grad reducer on grads
grads = self.grad_reducer(grads)
grads = self.hyper_map(F.partial(clip_grad, self.clip_type, self.clip_value), grads)
restore = ()
for i in range(self.length):
weights[i] = F.depend(weights[i], grads)
restore = restore + (F.assign(weights[i], self.saved_params[i]),)
param = F.depend(self.saved_params[i], grads)
F.assign(weights[i], param)

succ = self.optimizer(grads)
succ = F.depend(succ, restore)
return succ

Loading…
Cancel
Save