Browse Source

adapt to weight initializer modification

tags/v0.3.0-alpha
gengdongjie 5 years ago
parent
commit
ae9ce1629b
2 changed files with 4 additions and 4 deletions
  1. +2
    -2
      example/resnet101_imagenet2012/train.py
  2. +2
    -2
      example/resnet50_imagenet2012/train.py

+ 2
- 2
example/resnet101_imagenet2012/train.py View File

@@ -64,11 +64,11 @@ if __name__ == '__main__':
if isinstance(cell, nn.Conv2d):
cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(),
cell.weight.default_input.shape(),
cell.weight.default_input.dtype())
cell.weight.default_input.dtype()).to_tensor()
if isinstance(cell, nn.Dense):
cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(),
cell.weight.default_input.shape(),
cell.weight.default_input.dtype())
cell.weight.default_input.dtype()).to_tensor()
if not config.label_smooth:
config.label_smooth_factor = 0.0
loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)


+ 2
- 2
example/resnet50_imagenet2012/train.py View File

@@ -61,11 +61,11 @@ if __name__ == '__main__':
if isinstance(cell, nn.Conv2d):
cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(),
cell.weight.default_input.shape(),
cell.weight.default_input.dtype())
cell.weight.default_input.dtype()).to_tensor()
if isinstance(cell, nn.Dense):
cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(),
cell.weight.default_input.shape(),
cell.weight.default_input.dtype())
cell.weight.default_input.dtype()).to_tensor()
if not config.use_label_smooth:
config.label_smooth_factor = 0.0



Loading…
Cancel
Save