|
|
|
@@ -107,7 +107,7 @@ class TrainStepWrapForAdamDynamicLr(nn.Cell): |
|
|
|
|
|
|
|
|
|
|
|
class TempC2Wrap(nn.Cell): |
|
|
|
def __init__(self, op, c1=None, c2=None, ): |
|
|
|
def __init__(self, op, c1=None, c2=None,): |
|
|
|
super(TempC2Wrap, self).__init__() |
|
|
|
self.op = op |
|
|
|
self.c1 = c1 |
|
|
|
@@ -387,7 +387,7 @@ test_case_cell_ops = [ |
|
|
|
'block': set_train(nn.Dense(in_channels=768, |
|
|
|
out_channels=3072, |
|
|
|
activation='gelu', |
|
|
|
weight_init=TruncatedNormal(0.02), )), |
|
|
|
weight_init=TruncatedNormal(0.02),)), |
|
|
|
'desc_inputs': [[3, 768]], |
|
|
|
'desc_bprop': [[3, 3072]]}), |
|
|
|
('GetNextSentenceOutput', { |
|
|
|
|