Browse Source

fix a bug with add set_grad() in wide_and_deep network

tags/v1.0.0
lvchangquan 5 years ago
parent
commit
381d06549c
3 changed files with 5 additions and 3 deletions
  1. +1
    -1
      model_zoo/official/gnn/gat/src/utils.py
  2. +2
    -1
      model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py
  3. +2
    -1
      model_zoo/official/recommend/wide_and_deep_multitable/src/wide_and_deep.py

+ 1
- 1
model_zoo/official/gnn/gat/src/utils.py View File

@@ -138,6 +138,7 @@ class TrainOneStepCell(nn.Cell):
def __init__(self, network, optimizer, sens=1.0):
super(TrainOneStepCell, self).__init__(auto_prefix=True)
self.network = network
self.network.set_grad()
self.network.add_flags(defer_inline=True)
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer
@@ -167,7 +168,6 @@ class TrainGAT(nn.Cell):
def __init__(self, network, num_class, label, mask, learning_rate, l2_coeff):
super(TrainGAT, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
loss_net = LossNetWrapper(network, num_class, label, mask, l2_coeff)
optimizer = nn.Adam(loss_net.trainable_params(),
learning_rate=learning_rate)


+ 2
- 1
model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py View File

@@ -328,7 +328,6 @@ class TrainStepWrap(nn.Cell):
parallel_mode = context.get_auto_parallel_context("parallel_mode")
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
self.network = network
self.network.set_grad()
self.network.set_train()
self.trainable_params = network.trainable_params()
weights_w = []
@@ -361,6 +360,8 @@ class TrainStepWrap(nn.Cell):
self.sens = sens
self.loss_net_w = IthOutputCell(network, output_index=0)
self.loss_net_d = IthOutputCell(network, output_index=1)
self.loss_net_w.set_grad()
self.loss_net_d.set_grad()
self.reducer_flag = False
self.grad_reducer_w = None


+ 2
- 1
model_zoo/official/recommend/wide_and_deep_multitable/src/wide_and_deep.py View File

@@ -510,7 +510,6 @@ class TrainStepWrap(nn.Cell):
def __init__(self, network, config, sens=1000.0):
super(TrainStepWrap, self).__init__()
self.network = network
self.network.set_grad()
self.network.set_train()
self.trainable_params = network.trainable_params()
weights_w = []
@@ -546,6 +545,8 @@ class TrainStepWrap(nn.Cell):
self.sens = sens
self.loss_net_w = IthOutputCell(network, output_index=0)
self.loss_net_d = IthOutputCell(network, output_index=1)
self.loss_net_w.set_grad()
self.loss_net_w.set_grad()
self.reducer_flag = False
self.grad_reducer_w = None


Loading…
Cancel
Save