|
|
|
@@ -14,7 +14,6 @@ |
|
|
|
# ============================================================================ |
|
|
|
"""train_imagenet.""" |
|
|
|
import os |
|
|
|
import math |
|
|
|
import argparse |
|
|
|
import random |
|
|
|
import numpy as np |
|
|
|
@@ -64,7 +63,6 @@ if __name__ == '__main__': |
|
|
|
epoch_size = config.epoch_size |
|
|
|
net = resnet101(class_num=config.class_num) |
|
|
|
# weight init |
|
|
|
default_recurisive_init(net) |
|
|
|
for _, cell in net.cells_and_names(): |
|
|
|
if isinstance(cell, nn.Conv2d): |
|
|
|
cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(), |
|
|
|
|