Browse Source

!10156 support CPU deepfm

From: @zhao_ting_v
Reviewed-by: @guoqi1024,@wuxuejian
Signed-off-by: @guoqi1024
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
b741fe46b3
5 changed files with 85 additions and 65 deletions
  1. +65
    -50
      model_zoo/official/recommend/deepfm/README.md
  2. +6
    -3
      model_zoo/official/recommend/deepfm/eval.py
  3. +1
    -0
      model_zoo/official/recommend/deepfm/src/config.py
  4. +6
    -5
      model_zoo/official/recommend/deepfm/src/deepfm.py
  5. +7
    -7
      model_zoo/official/recommend/deepfm/train.py

+ 65
- 50
model_zoo/official/recommend/deepfm/README.md View File

@@ -4,7 +4,7 @@
- [Model Architecture](#model-architecture) - [Model Architecture](#model-architecture)
- [Dataset](#dataset) - [Dataset](#dataset)
- [Environment Requirements](#environment-requirements) - [Environment Requirements](#environment-requirements)
- [Quick Start](#quick-start)
- [Quick Start](#quick-start)
- [Script Description](#script-description) - [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code) - [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters) - [Script Parameters](#script-parameters)
@@ -14,16 +14,15 @@
- [Evaluation Process](#evaluation-process) - [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation) - [Evaluation](#evaluation)
- [Model Description](#model-description) - [Model Description](#model-description)
- [Performance](#performance)
- [Performance](#performance)
- [Evaluation Performance](#evaluation-performance) - [Evaluation Performance](#evaluation-performance)
- [Inference Performance](#evaluation-performance) - [Inference Performance](#evaluation-performance)
- [Description of Random Situation](#description-of-random-situation) - [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage) - [ModelZoo Homepage](#modelzoo-homepage)



# [DeepFM Description](#contents) # [DeepFM Description](#contents)


Learning sophisticated feature interactions behind user behaviors is critical in maximizing CTR for recommender systems. Despite great progress, existing methods seem to have a strong bias towards low- or high-order interactions, or require expertise feature engineering. In this paper, we show that it is possible to derive an end-to-end learning model that emphasizes both low- and high-order feature interactions. The proposed model, DeepFM, combines the power of factorization machines for recommendation and deep learning for feature learning in a new neural network architecture.
Learning sophisticated feature interactions behind user behaviors is critical in maximizing CTR for recommender systems. Despite great progress, existing methods seem to have a strong bias towards low- or high-order interactions, or require expertise feature engineering. In this paper, we show that it is possible to derive an end-to-end learning model that emphasizes both low- and high-order feature interactions. The proposed model, DeepFM, combines the power of factorization machines for recommendation and deep learning for feature learning in a new neural network architecture.


[Paper](https://arxiv.org/abs/1703.04247): Huifeng Guo, Ruiming Tang, Yunming Ye, Zhenguo Li, Xiuqiang He. DeepFM: A Factorization-Machine based Neural Network for CTR Prediction [Paper](https://arxiv.org/abs/1703.04247): Huifeng Guo, Ruiming Tang, Yunming Ye, Zhenguo Li, Xiuqiang He. DeepFM: A Factorization-Machine based Neural Network for CTR Prediction


@@ -35,27 +34,24 @@ The FM and deep component share the same input raw feature vector, which enables
# [Dataset](#contents) # [Dataset](#contents)


- [1] A dataset used in Huifeng Guo, Ruiming Tang, Yunming Ye, Zhenguo Li, Xiuqiang He. DeepFM: A Factorization-Machine based Neural Network for CTR Prediction[J]. 2017. - [1] A dataset used in Huifeng Guo, Ruiming Tang, Yunming Ye, Zhenguo Li, Xiuqiang He. DeepFM: A Factorization-Machine based Neural Network for CTR Prediction[J]. 2017.


# [Environment Requirements](#contents) # [Environment Requirements](#contents)


- Hardware(Ascend/GPU)
- Prepare hardware environment with Ascend or GPU processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Hardware(Ascend/GPU/CPU
- Prepare hardware environment with Ascend, GPU, or CPU processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework - Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below: - For more information, please check the resources below:
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)


- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)


# [Quick Start](#contents) # [Quick Start](#contents)


After installing MindSpore via the official website, you can start training and evaluation as follows:
After installing MindSpore via the official website, you can start training and evaluation as follows:


- runing on Ascend - runing on Ascend


```
```shell
# run training example # run training example
python train.py \ python train.py \
--dataset_path='dataset/train' \ --dataset_path='dataset/train' \
@@ -64,10 +60,10 @@ After installing MindSpore via the official website, you can start training and
--loss_file_name='loss.log' \ --loss_file_name='loss.log' \
--device_target='Ascend' \ --device_target='Ascend' \
--do_eval=True > ms_log/output.log 2>&1 & --do_eval=True > ms_log/output.log 2>&1 &
# run distributed training example # run distributed training example
sh scripts/run_distribute_train.sh 8 /dataset_path /rank_table_8p.json sh scripts/run_distribute_train.sh 8 /dataset_path /rank_table_8p.json
# run evaluation example # run evaluation example
python eval.py \ python eval.py \
--dataset_path='dataset/test' \ --dataset_path='dataset/test' \
@@ -81,13 +77,13 @@ After installing MindSpore via the official website, you can start training and


Please follow the instructions in the link below: Please follow the instructions in the link below:


https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools.
[hccl tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools).


- running on GPU - running on GPU


For running on GPU, please change `device_target` from `Ascend` to `GPU` in configuration file src/config.py For running on GPU, please change `device_target` from `Ascend` to `GPU` in configuration file src/config.py


```
```shell
# run training example # run training example
python train.py \ python train.py \
--dataset_path='dataset/train' \ --dataset_path='dataset/train' \
@@ -96,10 +92,10 @@ After installing MindSpore via the official website, you can start training and
--loss_file_name='loss.log' \ --loss_file_name='loss.log' \
--device_target='GPU' \ --device_target='GPU' \
--do_eval=True > ms_log/output.log 2>&1 & --do_eval=True > ms_log/output.log 2>&1 &
# run distributed training example # run distributed training example
sh scripts/run_distribute_train.sh 8 /dataset_path sh scripts/run_distribute_train.sh 8 /dataset_path
# run evaluation example # run evaluation example
python eval.py \ python eval.py \
--dataset_path='dataset/test' \ --dataset_path='dataset/test' \
@@ -109,16 +105,35 @@ After installing MindSpore via the official website, you can start training and
sh scripts/run_eval.sh 0 GPU /dataset_path /checkpoint_path/deepfm.ckpt sh scripts/run_eval.sh 0 GPU /dataset_path /checkpoint_path/deepfm.ckpt
``` ```


- running on CPU

```shell
# run training example
python train.py \
--dataset_path='dataset/train' \
--ckpt_path='./checkpoint' \
--eval_file_name='auc.log' \
--loss_file_name='loss.log' \
--device_target='CPU' \
--do_eval=True > ms_log/output.log 2>&1 &

# run evaluation example
python eval.py \
--dataset_path='dataset/test' \
--checkpoint_path='./checkpoint/deepfm.ckpt' \
--device_target='CPU' > ms_log/eval_output.log 2>&1 &
```

# [Script Description](#contents) # [Script Description](#contents)


## [Script and Sample Code](#contents) ## [Script and Sample Code](#contents)


```
```path
. .
└─deepfm
└─deepfm
├─README.md ├─README.md
├─mindspore_hub_conf.md # config for mindspore hub ├─mindspore_hub_conf.md # config for mindspore hub
├─scripts
├─scripts
├─run_standalone_train.sh # launch standalone training(1p) in Ascend or GPU ├─run_standalone_train.sh # launch standalone training(1p) in Ascend or GPU
├─run_distribute_train.sh # launch distributed training(8p) in Ascend ├─run_distribute_train.sh # launch distributed training(8p) in Ascend
├─run_distribute_train_gpu.sh # launch distributed training(8p) in GPU ├─run_distribute_train_gpu.sh # launch distributed training(8p) in GPU
@@ -138,7 +153,8 @@ After installing MindSpore via the official website, you can start training and
Parameters for both training and evaluation can be set in config.py Parameters for both training and evaluation can be set in config.py


- train parameters - train parameters
```

```help
optional arguments: optional arguments:
-h, --help show this help message and exit -h, --help show this help message and exit
--dataset_path DATASET_PATH --dataset_path DATASET_PATH
@@ -153,8 +169,10 @@ Parameters for both training and evaluation can be set in config.py
--device_target DEVICE_TARGET --device_target DEVICE_TARGET
Ascend or GPU. Default: Ascend Ascend or GPU. Default: Ascend
``` ```

- eval parameters - eval parameters
```

```help
optional arguments: optional arguments:
-h, --help show this help message and exit -h, --help show this help message and exit
--checkpoint_path CHECKPOINT_PATH --checkpoint_path CHECKPOINT_PATH
@@ -165,14 +183,13 @@ Parameters for both training and evaluation can be set in config.py
Ascend or GPU. Default: Ascend Ascend or GPU. Default: Ascend
``` ```



## [Training Process](#contents) ## [Training Process](#contents)


### Training
### Training


- running on Ascend - running on Ascend


```
```shell
python train.py \ python train.py \
--dataset_path='dataset/train' \ --dataset_path='dataset/train' \
--ckpt_path='./checkpoint' \ --ckpt_path='./checkpoint' \
@@ -181,36 +198,36 @@ Parameters for both training and evaluation can be set in config.py
--device_target='Ascend' \ --device_target='Ascend' \
--do_eval=True > ms_log/output.log 2>&1 & --do_eval=True > ms_log/output.log 2>&1 &
``` ```
The python command above will run in the background, you can view the results through the file `ms_log/output.log`. The python command above will run in the background, you can view the results through the file `ms_log/output.log`.
After training, you'll get some checkpoint files under `./checkpoint` folder by default. The loss value are saved in loss.log file. After training, you'll get some checkpoint files under `./checkpoint` folder by default. The loss value are saved in loss.log file.
```
```log
2020-05-27 15:26:29 epoch: 1 step: 41257, loss is 0.498953253030777 2020-05-27 15:26:29 epoch: 1 step: 41257, loss is 0.498953253030777
2020-05-27 15:32:32 epoch: 2 step: 41257, loss is 0.45545706152915955 2020-05-27 15:32:32 epoch: 2 step: 41257, loss is 0.45545706152915955
... ...
``` ```
The model checkpoint will be saved in the current directory.
The model checkpoint will be saved in the current directory.


- running on GPU - running on GPU

To do. To do.


### Distributed Training ### Distributed Training


- running on Ascend - running on Ascend


```
```shell
sh scripts/run_distribute_train.sh 8 /dataset_path /rank_table_8p.json sh scripts/run_distribute_train.sh 8 /dataset_path /rank_table_8p.json
``` ```
The above shell script will run distribute training in the background. You can view the results through the file `log[X]/output.log`. The loss value are saved in loss.log file. The above shell script will run distribute training in the background. You can view the results through the file `log[X]/output.log`. The loss value are saved in loss.log file.


- running on GPU - running on GPU
To do.


To do.


## [Evaluation Process](#contents) ## [Evaluation Process](#contents)


@@ -219,8 +236,8 @@ Parameters for both training and evaluation can be set in config.py
- evaluation on dataset when running on Ascend - evaluation on dataset when running on Ascend


Before running the command below, please check the checkpoint path used for evaluation. Before running the command below, please check the checkpoint path used for evaluation.
```
```shell
python eval.py \ python eval.py \
--dataset_path='dataset/test' \ --dataset_path='dataset/test' \
--checkpoint_path='./checkpoint/deepfm.ckpt' \ --checkpoint_path='./checkpoint/deepfm.ckpt' \
@@ -228,22 +245,22 @@ Parameters for both training and evaluation can be set in config.py
OR OR
sh scripts/run_eval.sh 0 Ascend /dataset_path /checkpoint_path/deepfm.ckpt sh scripts/run_eval.sh 0 Ascend /dataset_path /checkpoint_path/deepfm.ckpt
``` ```
The above python command will run in the background. You can view the results through the file "eval_output.log". The accuracy is saved in auc.log file. The above python command will run in the background. You can view the results through the file "eval_output.log". The accuracy is saved in auc.log file.
```
```log
{'result': {'AUC': 0.8057789065281104, 'eval_time': 35.64779996871948}} {'result': {'AUC': 0.8057789065281104, 'eval_time': 35.64779996871948}}
``` ```



- evaluation on dataset when running on GPU - evaluation on dataset when running on GPU
To do.


To do.


# [Model Description](#contents) # [Model Description](#contents)

## [Performance](#contents) ## [Performance](#contents)


### Evaluation Performance
### Evaluation Performance


| Parameters | Ascend | GPU | | Parameters | Ascend | GPU |
| -------------------------- | ----------------------------------------------------------- | ---------------------- | | -------------------------- | ----------------------------------------------------------- | ---------------------- |
@@ -263,7 +280,6 @@ Parameters for both training and evaluation can be set in config.py
| Checkpoint for Fine tuning | 190M (.ckpt file) | To do | | Checkpoint for Fine tuning | 190M (.ckpt file) | To do |
| Scripts | [deepfm script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/recommend/deepfm) | To do | | Scripts | [deepfm script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/recommend/deepfm) | To do |



### Inference Performance ### Inference Performance


| Parameters | Ascend | GPU | | Parameters | Ascend | GPU |
@@ -278,11 +294,10 @@ Parameters for both training and evaluation can be set in config.py
| Accuracy | 1pc: 80.55%; | To do | | Accuracy | 1pc: 80.55%; | To do |
| Model for inference | 190M (.ckpt file) | To do | | Model for inference | 190M (.ckpt file) | To do |



# [Description of Random Situation](#contents) # [Description of Random Situation](#contents)


We set the random seed before training in train.py. We set the random seed before training in train.py.


# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
# [ModelZoo Homepage](#contents)


Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

+ 6
- 3
model_zoo/official/recommend/deepfm/eval.py View File

@@ -30,9 +30,10 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
parser = argparse.ArgumentParser(description='CTR Prediction') parser = argparse.ArgumentParser(description='CTR Prediction')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--device_target', type=str, default="Ascend", help='Ascend or GPU. Default: Ascend')
parser.add_argument('--device_target', type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"),
help="device target, support Ascend, GPU and CPU.")
args_opt, _ = parser.parse_known_args() args_opt, _ = parser.parse_known_args()
device_id = int(os.getenv('DEVICE_ID'))
device_id = int(os.getenv('DEVICE_ID', '0'))
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id) context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id)




@@ -49,7 +50,9 @@ if __name__ == '__main__':
ds_eval = create_dataset(args_opt.dataset_path, train_mode=False, ds_eval = create_dataset(args_opt.dataset_path, train_mode=False,
epochs=1, batch_size=train_config.batch_size, epochs=1, batch_size=train_config.batch_size,
data_type=DataType(data_config.data_format)) data_type=DataType(data_config.data_format))
model_builder = ModelBuilder(ModelConfig, TrainConfig)
if model_config.convert_dtype:
model_config.convert_dtype = args_opt.device_target != "CPU"
model_builder = ModelBuilder(model_config, train_config)
train_net, eval_net = model_builder.get_train_eval_net() train_net, eval_net = model_builder.get_train_eval_net()
train_net.set_train() train_net.set_train()
eval_net.set_train(False) eval_net.set_train(False)


+ 1
- 0
model_zoo/official/recommend/deepfm/src/config.py View File

@@ -35,6 +35,7 @@ class ModelConfig:
init_args = [-0.01, 0.01] init_args = [-0.01, 0.01]
weight_bias_init = ['normal', 'normal'] weight_bias_init = ['normal', 'normal']
keep_prob = 0.9 keep_prob = 0.9
convert_dtype = True
class TrainConfig: class TrainConfig:
"""train config""" """train config"""


+ 6
- 5
model_zoo/official/recommend/deepfm/src/deepfm.py View File

@@ -175,6 +175,7 @@ class DeepFMModel(nn.Cell):
self.init_args = config.init_args self.init_args = config.init_args
self.weight_bias_init = config.weight_bias_init self.weight_bias_init = config.weight_bias_init
self.keep_prob = config.keep_prob self.keep_prob = config.keep_prob
convert_dtype = config.convert_dtype
init_acts = [('W_l2', [self.vocab_size, 1], 'normal'), init_acts = [('W_l2', [self.vocab_size, 1], 'normal'),
('V_l2', [self.vocab_size, self.emb_dim], 'normal')] ('V_l2', [self.vocab_size, self.emb_dim], 'normal')]
var_map = init_var_dict(self.init_args, init_acts) var_map = init_var_dict(self.init_args, init_acts)
@@ -184,15 +185,15 @@ class DeepFMModel(nn.Cell):
self.deep_input_dims = self.field_size * self.emb_dim self.deep_input_dims = self.field_size * self.emb_dim
self.all_dim_list = [self.deep_input_dims] + self.deep_layer_dims_list + [1] self.all_dim_list = [self.deep_input_dims] + self.deep_layer_dims_list + [1]
self.dense_layer_1 = DenseLayer(self.all_dim_list[0], self.all_dim_list[1], self.weight_bias_init, self.dense_layer_1 = DenseLayer(self.all_dim_list[0], self.all_dim_list[1], self.weight_bias_init,
self.deep_layer_act, self.keep_prob, convert_dtype=True)
self.deep_layer_act, self.keep_prob, convert_dtype=convert_dtype)
self.dense_layer_2 = DenseLayer(self.all_dim_list[1], self.all_dim_list[2], self.weight_bias_init, self.dense_layer_2 = DenseLayer(self.all_dim_list[1], self.all_dim_list[2], self.weight_bias_init,
self.deep_layer_act, self.keep_prob, convert_dtype=True)
self.deep_layer_act, self.keep_prob, convert_dtype=convert_dtype)
self.dense_layer_3 = DenseLayer(self.all_dim_list[2], self.all_dim_list[3], self.weight_bias_init, self.dense_layer_3 = DenseLayer(self.all_dim_list[2], self.all_dim_list[3], self.weight_bias_init,
self.deep_layer_act, self.keep_prob, convert_dtype=True)
self.deep_layer_act, self.keep_prob, convert_dtype=convert_dtype)
self.dense_layer_4 = DenseLayer(self.all_dim_list[3], self.all_dim_list[4], self.weight_bias_init, self.dense_layer_4 = DenseLayer(self.all_dim_list[3], self.all_dim_list[4], self.weight_bias_init,
self.deep_layer_act, self.keep_prob, convert_dtype=True)
self.deep_layer_act, self.keep_prob, convert_dtype=convert_dtype)
self.dense_layer_5 = DenseLayer(self.all_dim_list[4], self.all_dim_list[5], self.weight_bias_init, self.dense_layer_5 = DenseLayer(self.all_dim_list[4], self.all_dim_list[5], self.weight_bias_init,
self.deep_layer_act, self.keep_prob, convert_dtype=True, use_act=False)
self.deep_layer_act, self.keep_prob, convert_dtype=convert_dtype, use_act=False)
" FM, linear Layers " " FM, linear Layers "
self.Gatherv2 = P.GatherV2() self.Gatherv2 = P.GatherV2()
self.Mul = P.Mul() self.Mul = P.Mul()


+ 7
- 7
model_zoo/official/recommend/deepfm/train.py View File

@@ -39,7 +39,8 @@ parser.add_argument('--loss_file_name', type=str, default="./loss.log",
help='Loss log file path. Default: "./loss.log"') help='Loss log file path. Default: "./loss.log"')
parser.add_argument('--do_eval', type=str, default='True', parser.add_argument('--do_eval', type=str, default='True',
help='Do evaluation or not, only support "True" or "False". Default: "True"') help='Do evaluation or not, only support "True" or "False". Default: "True"')
parser.add_argument('--device_target', type=str, default="Ascend", help='Ascend or GPU. Default: Ascend')
parser.add_argument('--device_target', type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"),
help="device target, support Ascend, GPU and CPU.")
args_opt, _ = parser.parse_known_args() args_opt, _ = parser.parse_known_args()
args_opt.do_eval = args_opt.do_eval == 'True' args_opt.do_eval = args_opt.do_eval == 'True'
rank_size = int(os.environ.get("RANK_SIZE", 1)) rank_size = int(os.environ.get("RANK_SIZE", 1))
@@ -74,11 +75,8 @@ if __name__ == '__main__':
if args_opt.device_target == "Ascend": if args_opt.device_target == "Ascend":
device_id = int(os.getenv('DEVICE_ID')) device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id) context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id)
elif args_opt.device_target == "GPU":
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
else: else:
print("Unsupported device_target ", args_opt.device_target)
exit()
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
rank_size = None rank_size = None
rank_id = None rank_id = None


@@ -92,7 +90,9 @@ if __name__ == '__main__':


steps_size = ds_train.get_dataset_size() steps_size = ds_train.get_dataset_size()


model_builder = ModelBuilder(ModelConfig, TrainConfig)
if model_config.convert_dtype:
model_config.convert_dtype = args_opt.device_target != "CPU"
model_builder = ModelBuilder(model_config, train_config)
train_net, eval_net = model_builder.get_train_eval_net() train_net, eval_net = model_builder.get_train_eval_net()
auc_metric = AUCMetric() auc_metric = AUCMetric()
model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})
@@ -105,7 +105,7 @@ if __name__ == '__main__':
if rank_size: if rank_size:
train_config.ckpt_file_name_prefix = train_config.ckpt_file_name_prefix + str(get_rank()) train_config.ckpt_file_name_prefix = train_config.ckpt_file_name_prefix + str(get_rank())
args_opt.ckpt_path = os.path.join(args_opt.ckpt_path, 'ckpt_' + str(get_rank()) + '/') args_opt.ckpt_path = os.path.join(args_opt.ckpt_path, 'ckpt_' + str(get_rank()) + '/')
if args_opt.device_target == "GPU":
if args_opt.device_target != "Ascend":
config_ck = CheckpointConfig(save_checkpoint_steps=steps_size, config_ck = CheckpointConfig(save_checkpoint_steps=steps_size,
keep_checkpoint_max=train_config.keep_checkpoint_max) keep_checkpoint_max=train_config.keep_checkpoint_max)
else: else:


Loading…
Cancel
Save