|
|
|
@@ -51,9 +51,9 @@ def get_lr(learning_rate, start_step, global_step, decay_step, decay_rate, steps |
|
|
|
return lr_each_step |
|
|
|
|
|
|
|
|
|
|
|
def init_net_param(net, init_value='ones'): |
|
|
|
"""Init:wq the parameters in net.""" |
|
|
|
params = net.trainable_params() |
|
|
|
def init_net_param(network, init_value='ones'): |
|
|
|
"""Init:wq the parameters in network.""" |
|
|
|
params = network.trainable_params() |
|
|
|
for p in params: |
|
|
|
if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name: |
|
|
|
p.set_parameter_data(initializer(init_value, p.data.shape(), p.data.dtype())) |
|
|
|
|