| @@ -193,10 +193,10 @@ Parameters for learning rate: | |||||
| - Set options in `config.py`, including loss_scale, learning rate and network hyperparameters. Click [here](https://www.mindspore.cn/tutorial/training/zh-CN/master/use/data_preparation.html) for more information about dataset. | - Set options in `config.py`, including loss_scale, learning rate and network hyperparameters. Click [here](https://www.mindspore.cn/tutorial/training/zh-CN/master/use/data_preparation.html) for more information about dataset. | ||||
| - Run `run_standalone_train_ascend.sh` for non-distributed training of Transformer model. | |||||
| - Run `run_standalone_train.sh` for non-distributed training of Transformer model. | |||||
| ``` bash | ``` bash | ||||
| sh scripts/run_standalone_train_ascend.sh DEVICE_ID EPOCH_SIZE DATA_PATH | |||||
| sh scripts/run_standalone_train.sh DEVICE_TARGET DEVICE_ID EPOCH_SIZE DATA_PATH | |||||
| ``` | ``` | ||||
| - Run `run_distribute_train_ascend.sh` for distributed training of Transformer model. | - Run `run_distribute_train_ascend.sh` for distributed training of Transformer model. | ||||
| @@ -100,7 +100,7 @@ def run_transformer_eval(): | |||||
| parser = argparse.ArgumentParser(description='tranformer') | parser = argparse.ArgumentParser(description='tranformer') | ||||
| parser.add_argument("--device_target", type=str, default="Ascend", | parser.add_argument("--device_target", type=str, default="Ascend", | ||||
| help="device where the code will be implemented, default is Ascend") | help="device where the code will be implemented, default is Ascend") | ||||
| parser.add_argument('--device_id', type=int, default=None, help='device id of GPU or Ascend. (Default: None)') | |||||
| parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend, default is 0') | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, reserve_class_name_in_scope=False, | context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, reserve_class_name_in_scope=False, | ||||