diff --git a/example/vgg16_cifar10/train.py b/example/vgg16_cifar10/train.py index 52ba0ecdf4..8cfcc5fd9c 100644 --- a/example/vgg16_cifar10/train.py +++ b/example/vgg16_cifar10/train.py @@ -71,7 +71,6 @@ if __name__ == '__main__': device_num = int(os.environ.get("DEVICE_NUM", 1)) if device_num > 1: context.reset_auto_parallel_context() - context.set_context(enable_hccl=True) context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True) init()