|
|
|
@@ -64,11 +64,11 @@ if __name__ == '__main__': |
|
|
|
if isinstance(cell, nn.Conv2d): |
|
|
|
cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(), |
|
|
|
cell.weight.default_input.shape(), |
|
|
|
cell.weight.default_input.dtype()) |
|
|
|
cell.weight.default_input.dtype()).to_tensor() |
|
|
|
if isinstance(cell, nn.Dense): |
|
|
|
cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(), |
|
|
|
cell.weight.default_input.shape(), |
|
|
|
cell.weight.default_input.dtype()) |
|
|
|
cell.weight.default_input.dtype()).to_tensor() |
|
|
|
if not config.label_smooth: |
|
|
|
config.label_smooth_factor = 0.0 |
|
|
|
loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num) |
|
|
|
|