|
|
|
@@ -48,6 +48,6 @@ if __name__ == '__main__': |
|
|
|
|
|
|
|
load_param_into_net(net, param_dict_new) |
|
|
|
|
|
|
|
img_data = Tensor(np.zeros([config.test_batch_size, 3, config.img_height, config.img_width]), ms.float16) |
|
|
|
img_data = Tensor(np.zeros([config.test_batch_size, 3, config.img_height, config.img_width]), ms.float32) |
|
|
|
|
|
|
|
export(net, img_data, file_name=args.file_name, file_format=args.file_format) |