|
|
|
@@ -138,6 +138,7 @@ class TrainOneStepCell(nn.Cell): |
|
|
|
def __init__(self, network, optimizer, sens=1.0): |
|
|
|
super(TrainOneStepCell, self).__init__(auto_prefix=True) |
|
|
|
self.network = network |
|
|
|
self.network.set_grad() |
|
|
|
self.network.add_flags(defer_inline=True) |
|
|
|
self.weights = ParameterTuple(network.trainable_params()) |
|
|
|
self.optimizer = optimizer |
|
|
|
@@ -167,7 +168,6 @@ class TrainGAT(nn.Cell): |
|
|
|
def __init__(self, network, num_class, label, mask, learning_rate, l2_coeff): |
|
|
|
super(TrainGAT, self).__init__(auto_prefix=False) |
|
|
|
self.network = network |
|
|
|
self.network.set_grad() |
|
|
|
loss_net = LossNetWrapper(network, num_class, label, mask, l2_coeff) |
|
|
|
optimizer = nn.Adam(loss_net.trainable_params(), |
|
|
|
learning_rate=learning_rate) |
|
|
|
|