Browse Source

!9529 fix gnmt doc & code bugs

From: @zhaojichen
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
26f4fc23e2
12 changed files with 173 additions and 204 deletions
  1. +73
    -41
      model_zoo/official/nlp/gnmt_v2/README.md
  2. +1
    -6
      model_zoo/official/nlp/gnmt_v2/config/config.json
  3. +55
    -48
      model_zoo/official/nlp/gnmt_v2/config/config.py
  4. +3
    -7
      model_zoo/official/nlp/gnmt_v2/config/config_test.json
  5. +8
    -6
      model_zoo/official/nlp/gnmt_v2/create_dataset.py
  6. +7
    -40
      model_zoo/official/nlp/gnmt_v2/src/dataset/base.py
  7. +14
    -20
      model_zoo/official/nlp/gnmt_v2/src/dataset/load_dataset.py
  8. +5
    -0
      model_zoo/official/nlp/gnmt_v2/src/dataset/schema.py
  9. +0
    -1
      model_zoo/official/nlp/gnmt_v2/src/gnmt_model/gnmt_for_infer.py
  10. +0
    -13
      model_zoo/official/nlp/gnmt_v2/src/utils/loss_monitor.py
  11. +1
    -4
      model_zoo/official/nlp/gnmt_v2/src/utils/optimizer.py
  12. +6
    -18
      model_zoo/official/nlp/gnmt_v2/train.py

+ 73
- 41
model_zoo/official/nlp/gnmt_v2/README.md View File

@@ -1,6 +1,7 @@
![](https://www.mindspore.cn/static/img/logo.a3e472c9.png) ![](https://www.mindspore.cn/static/img/logo.a3e472c9.png)


<!-- TOC --> <!-- TOC -->

- [GNMT v2 For MindSpore](#gnmt-v2-for-mindspore) - [GNMT v2 For MindSpore](#gnmt-v2-for-mindspore)
- [Model Structure](#model-structure) - [Model Structure](#model-structure)
- [Dataset](#dataset) - [Dataset](#dataset)
@@ -15,41 +16,46 @@
- [Inference Process](#inference-process) - [Inference Process](#inference-process)
- [Model Description](#model-description) - [Model Description](#model-description)
- [Performance](#performance) - [Performance](#performance)
- [Result](#result)
- [Training Performance](#training-performance)
- [Inference Performance](#inference-performance)
- [Training Performance](#training-performance)
- [Inference Performance](#inference-performance)
- [Random Situation Description](#random-situation-description) - [Random Situation Description](#random-situation-description)
- [Others](#others) - [Others](#others)
- [ModelZoo](#modelzoo)
- [ModelZoo HomePage](#modelzoo-homepage)

<!-- /TOC --> <!-- /TOC -->


# [GNMT v2 For MindSpore](#contents)


# GNMT v2 For MindSpore
The GNMT v2 model is similar to the model described in [Google's Neural Machine Translation System: Bridging the Gap between Human and Machine Translation](https://arxiv.org/abs/1609.08144), which is mainly used for corpus translation. The GNMT v2 model is similar to the model described in [Google's Neural Machine Translation System: Bridging the Gap between Human and Machine Translation](https://arxiv.org/abs/1609.08144), which is mainly used for corpus translation.


# Model Structure
# [Model Structure](#contents)

The GNMTv2 model mainly consists of an encoder, a decoder, and an attention mechanism, where the encoder and the decoder use a shared word embedding vector. The GNMTv2 model mainly consists of an encoder, a decoder, and an attention mechanism, where the encoder and the decoder use a shared word embedding vector.
Encoder: consists of four long short-term memory (LSTM) layers. The first LSTM layer is bidirectional, while the other three layers are unidirectional. Encoder: consists of four long short-term memory (LSTM) layers. The first LSTM layer is bidirectional, while the other three layers are unidirectional.
Decoder: consists of four unidirectional LSTM layers and a fully connected classifier. The output embedding dimension of LSTM is 1024. Decoder: consists of four unidirectional LSTM layers and a fully connected classifier. The output embedding dimension of LSTM is 1024.
Attention mechanism: uses the standardized Bahdanau attention mechanism. First, the first layer output of the decoder is used as the input of the attention mechanism. Then, the computing result of the attention mechanism is connected to the input of the decoder LSTM, which is used as the input of the subsequent LSTM layer. Attention mechanism: uses the standardized Bahdanau attention mechanism. First, the first layer output of the decoder is used as the input of the attention mechanism. Then, the computing result of the attention mechanism is connected to the input of the decoder LSTM, which is used as the input of the subsequent LSTM layer.


# Dataset
# [Dataset](#contents)

Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below. Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.


- *WMT Englis-German* for training.
- *WMT newstest2014* for evaluation.
- WMT Englis-German for training.
- WMT newstest2014 for evaluation.

# [Environment Requirements](#contents)


# Environment Requirements
## Platform ## Platform

- Hardware (Ascend) - Hardware (Ascend)
- Prepare hardware environment with Ascend processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you could get the resources for trial.
- Prepare hardware environment with Ascend processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you could get the resources for trial.
- Framework - Framework
- Install [MindSpore](https://www.mindspore.cn/install/en).
- Install [MindSpore](https://www.mindspore.cn/install/en).
- For more information, please check the resources below: - For more information, please check the resources below:
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore API](https://www.mindspore.cn/doc/api_python/en/master/index.html)


## Software ## Software

```txt ```txt
numpy numpy
sacrebleu==1.2.10 sacrebleu==1.2.10
@@ -58,13 +64,16 @@ subword_nmt==0.3.7
``` ```


# [Quick Start](#contents) # [Quick Start](#contents)

The process of GNMTv2 performing the text translation task is as follows: The process of GNMTv2 performing the text translation task is as follows:

1. Download the wmt16 data corpus and extract the dataset. For details, see the chapter "_Dataset_" above. 1. Download the wmt16 data corpus and extract the dataset. For details, see the chapter "_Dataset_" above.
2. Dataset preparation and configuration. 2. Dataset preparation and configuration.
3. Training. 3. Training.
4. Inference. 4. Inference.


After dataset preparation, you can start training and evaluation as follows:
After dataset preparation, you can start training and evaluation as follows:

```bash ```bash
# run training example # run training example
cd ./scripts cd ./scripts
@@ -80,8 +89,10 @@ sh run_standalone_eval_ascend.sh DATASET_SCHEMA_TEST TEST_DATASET EXISTED_CKPT_P
VOCAB_ADDR BPE_CODE_ADDR TEST_TARGET VOCAB_ADDR BPE_CODE_ADDR TEST_TARGET
``` ```


# Script Description
# [Script Description](#contents)

The GNMT network script and code result are as follows: The GNMT network script and code result are as follows:

```text ```text
├── gnmt ├── gnmt
├── README.md // Introduction of GNMTv2 model. ├── README.md // Introduction of GNMTv2 model.
@@ -92,9 +103,9 @@ The GNMT network script and code result are as follows:
│ ├──config_test.json // Configuration file for test. │ ├──config_test.json // Configuration file for test.
├── src ├── src
│ ├──__init__.py // User interface. │ ├──__init__.py // User interface.
│ ├──dataset
│ ├──__init__.py // User interface.
│ ├──base.py // Base class of data loader.
│ ├──dataset
│ ├──__init__.py // User interface.
│ ├──base.py // Base class of data loader.
│ ├──bi_data_loader.py // Bilingual data loader. │ ├──bi_data_loader.py // Bilingual data loader.
│ ├──load_dataset.py // Dataset loader to feed into model. │ ├──load_dataset.py // Dataset loader to feed into model.
│ ├──schema.py // Define schema of mindrecord. │ ├──schema.py // Define schema of mindrecord.
@@ -134,24 +145,29 @@ The GNMT network script and code result are as follows:
``` ```


## Dataset Preparation ## Dataset Preparation

You may use this [shell script](https://github.com/NVIDIA/DeepLearningExamples/blob/master/TensorFlow/Translation/GNMT/scripts/wmt16_en_de.sh) to download and preprocess WMT English-German dataset. Assuming you get the following files: You may use this [shell script](https://github.com/NVIDIA/DeepLearningExamples/blob/master/TensorFlow/Translation/GNMT/scripts/wmt16_en_de.sh) to download and preprocess WMT English-German dataset. Assuming you get the following files:
- train.tok.clean.bpe.32000.en
- train.tok.clean.bpe.32000.de
- vocab.bpe.32000
- bpe.32000
- newstest2014.en
- newstest2014.de


- Convert the original data to tfrecord for training and evaluation:
- train.tok.clean.bpe.32000.en
- train.tok.clean.bpe.32000.de
- vocab.bpe.32000
- bpe.32000
- newstest2014.en
- newstest2014.de

- Convert the original data to mindrecord for training and evaluation:


``` bash ``` bash
python create_dataset.py --src_folder /home/workspace/wmt16_de_en --output_folder /home/workspace/dataset_menu python create_dataset.py --src_folder /home/workspace/wmt16_de_en --output_folder /home/workspace/dataset_menu
``` ```


## Configuration File ## Configuration File

The JSON file in the `config/` directory is the template configuration file. The JSON file in the `config/` directory is the template configuration file.
Almost all required options and parameters can be easily assigned, including the training platform, model configuration, and optimizer parameters. Almost all required options and parameters can be easily assigned, including the training platform, model configuration, and optimizer parameters.

- config for GNMTv2 - config for GNMTv2

```python ```python
'random_seed': 50 # global random seed 'random_seed': 50 # global random seed
'epochs':6 # total training epochs 'epochs':6 # total training epochs
@@ -159,63 +175,74 @@ Almost all required options and parameters can be easily assigned, including the
'dataset_sink_mode': true # whether use dataset sink mode 'dataset_sink_mode': true # whether use dataset sink mode
'seq_length': 51 # max length of source sentences 'seq_length': 51 # max length of source sentences
'vocab_size': 32320 # vocabulary size 'vocab_size': 32320 # vocabulary size
'hidden_size': 125 # the output's last dimension of dynamicRNN
'hidden_size': 1024 # the output's last dimension of dynamicRNN
'initializer_range': 0.1 # initializer range 'initializer_range': 0.1 # initializer range
'max_decode_length': 125 # max length of decoder
'lr': 0.1 # initial learning rate
'max_decode_length': 50 # max length of decoder
'lr': 2e-1 # initial learning rate
'lr_scheduler': 'WarmupMultiStepLR' # learning rate scheduler 'lr_scheduler': 'WarmupMultiStepLR' # learning rate scheduler
'existed_ckpt': '' # the absolute full path to save the checkpoint file
'existed_ckpt': "" # the absolute full path to save the checkpoint file
``` ```

For more configuration details, please refer the script `config/config.py` file. For more configuration details, please refer the script `config/config.py` file.


## Training Process ## Training Process

For a pre-trained model, configure the following options in the `scripts/run_standalone_train_ascend.json` file: For a pre-trained model, configure the following options in the `scripts/run_standalone_train_ascend.json` file:

- Select an optimizer ('momentum/adam/lamb' is available). - Select an optimizer ('momentum/adam/lamb' is available).
- Specify `ckpt_prefix` and `ckpt_path` in `checkpoint_path` to save the model file. - Specify `ckpt_prefix` and `ckpt_path` in `checkpoint_path` to save the model file.
- Set other parameters, including dataset configuration and network configuration. - Set other parameters, including dataset configuration and network configuration.
- If a pre-trained model exists, assign `existed_ckpt` to the path of the existing model during fine-tuning. - If a pre-trained model exists, assign `existed_ckpt` to the path of the existing model during fine-tuning.


Start task training on a single device and run the shell script `scripts/run_standalone_train_ascend.sh`: Start task training on a single device and run the shell script `scripts/run_standalone_train_ascend.sh`:

```bash ```bash
cd ./scripts cd ./scripts
sh run_standalone_train_ascend.sh DATASET_SCHEMA_TRAIN PRE_TRAIN_DATASET sh run_standalone_train_ascend.sh DATASET_SCHEMA_TRAIN PRE_TRAIN_DATASET
``` ```

In this script, the `DATASET_SCHEMA_TRAIN` and `PRE_TRAIN_DATASET` are the dataset schema and dataset address. In this script, the `DATASET_SCHEMA_TRAIN` and `PRE_TRAIN_DATASET` are the dataset schema and dataset address.


Run `scripts/run_distributed_train_ascend.sh` for distributed training of GNMTv2 model. Run `scripts/run_distributed_train_ascend.sh` for distributed training of GNMTv2 model.
Task training on multiple devices and run the following command in bash to be executed in `scripts/`.: Task training on multiple devices and run the following command in bash to be executed in `scripts/`.:

```bash ```bash
cd ./scripts cd ./scripts
sh run_distributed_train_ascend.sh RANK_TABLE_ADDR DATASET_SCHEMA_TRAIN PRE_TRAIN_DATASET sh run_distributed_train_ascend.sh RANK_TABLE_ADDR DATASET_SCHEMA_TRAIN PRE_TRAIN_DATASET
``` ```
Note: the `RANK_TABLE_ADDR` is the hccl_json file assigned when distributed training is running.

Note: the `RANK_TABLE_ADDR` is the hccl_json file assigned when distributed training is running.
Currently, inconsecutive device IDs are not supported in `scripts/run_distributed_train_ascend.sh`. The device ID must start from 0 in the `RANK_TABLE_ADDR` file. Currently, inconsecutive device IDs are not supported in `scripts/run_distributed_train_ascend.sh`. The device ID must start from 0 in the `RANK_TABLE_ADDR` file.


## Inference Process ## Inference Process

For inference using a trained model on multiple hardware platforms, such as Ascend 910. For inference using a trained model on multiple hardware platforms, such as Ascend 910.
Set options in `config/config_test.json`. Set options in `config/config_test.json`.
Run the shell script `scripts/run_standalone_eval_ascend.sh` to process the output token ids to get the BLEU scores. Run the shell script `scripts/run_standalone_eval_ascend.sh` to process the output token ids to get the BLEU scores.

```bash ```bash
cd ./scripts cd ./scripts
sh run_standalone_eval_ascend.sh sh run_standalone_eval_ascend.sh
sh run_standalone_eval_ascend.sh DATASET_SCHEMA_TEST TEST_DATASET EXISTED_CKPT_PATH \ sh run_standalone_eval_ascend.sh DATASET_SCHEMA_TEST TEST_DATASET EXISTED_CKPT_PATH \
VOCAB_ADDR BPE_CODE_ADDR TEST_TARGET VOCAB_ADDR BPE_CODE_ADDR TEST_TARGET
``` ```

The `DATASET_SCHEMA_TEST` and the `TEST_DATASET` are the schema and address of inference dataset respectively, and `EXISTED_CKPT_PATH` is the path of the model file generated during training process. The `DATASET_SCHEMA_TEST` and the `TEST_DATASET` are the schema and address of inference dataset respectively, and `EXISTED_CKPT_PATH` is the path of the model file generated during training process.
The `VOCAB_ADDR` is the vocabulary address, `BPE_CODE_ADDR` is the bpe code address and the `TEST_TARGET` are the path of answers. The `VOCAB_ADDR` is the vocabulary address, `BPE_CODE_ADDR` is the bpe code address and the `TEST_TARGET` are the path of answers.


# Model Description
# [Model Description](#contents)

## Performance ## Performance
### Result
#### Training Performance
### Training Performance


| Parameters | Ascend | | Parameters | Ascend |
| -------------------------- | -------------------------------------------------------------- | | -------------------------- | -------------------------------------------------------------- |
| Resource | Ascend 910 | | Resource | Ascend 910 |
| uploaded Date | 11/06/2020 (month/day/year) | | uploaded Date | 11/06/2020 (month/day/year) |
| MindSpore Version | 1.0.0 | | MindSpore Version | 1.0.0 |
| Dataset | WMT Englis-German |
| Dataset | WMT Englis-German for training |
| Training Parameters | epoch=6, batch_size=128 | | Training Parameters | epoch=6, batch_size=128 |
| Optimizer | Adam | | Optimizer | Adam |
| Loss Function | Softmax Cross Entropy | | Loss Function | Softmax Cross Entropy |
@@ -227,7 +254,7 @@ The `VOCAB_ADDR` is the vocabulary address, `BPE_CODE_ADDR` is the bpe code addr
| Checkpoint for inference | 1.8G (.ckpt file) | | Checkpoint for inference | 1.8G (.ckpt file) |
| Scripts | [gnmt_v2](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/gnmt_v2) | | Scripts | [gnmt_v2](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/gnmt_v2) |


#### Inference Performance
### Inference Performance


| Parameters | Ascend | | Parameters | Ascend |
| ------------------- | --------------------------- | | ------------------- | --------------------------- |
@@ -241,15 +268,20 @@ The `VOCAB_ADDR` is the vocabulary address, `BPE_CODE_ADDR` is the bpe code addr
| Accuracy | BLEU Score= 24.05 | | Accuracy | BLEU Score= 24.05 |
| Model for inference | 1.8G (.ckpt file) | | Model for inference | 1.8G (.ckpt file) |


# Random Situation Description
# [Random Situation Description](#contents)

There are three random situations: There are three random situations:

- Shuffle of the dataset. - Shuffle of the dataset.
- Initialization of some model weights. - Initialization of some model weights.
- Dropout operations. - Dropout operations.

Some seeds have already been set in train.py to avoid the randomness of dataset shuffle and weight initialization. If you want to disable dropout, please set the corresponding dropout_prob parameter to 0 in config/config.json. Some seeds have already been set in train.py to avoid the randomness of dataset shuffle and weight initialization. If you want to disable dropout, please set the corresponding dropout_prob parameter to 0 in config/config.json.


# Others
# [Others](#contents)

This model has been validated in the Ascend environment and is not validated on the CPU and GPU. This model has been validated in the Ascend environment and is not validated on the CPU and GPU.


# ModelZoo 主页
[链接](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)
# [ModelZoo HomePage](#contents)

Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)

+ 1
- 6
model_zoo/official/nlp/gnmt_v2/config/config.json View File

@@ -1,7 +1,4 @@
{ {
"training_platform": {
"modelarts": false
},
"dataset_config": { "dataset_config": {
"random_seed": 50, "random_seed": 50,
"epochs": 6, "epochs": 6,
@@ -9,10 +6,8 @@
"dataset_schema": "/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.json", "dataset_schema": "/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.json",
"pre_train_dataset": "/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.tfrecord-001-of-001", "pre_train_dataset": "/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.tfrecord-001-of-001",
"fine_tune_dataset": null, "fine_tune_dataset": null,
"test_dataset": null,
"valid_dataset": null, "valid_dataset": null,
"dataset_sink_mode": true,
"dataset_sink_step": 2
"dataset_sink_mode": true
}, },
"model_config": { "model_config": {
"seq_length": 51, "seq_length": 51,


+ 55
- 48
model_zoo/official/nlp/gnmt_v2/config/config.py View File

@@ -53,7 +53,6 @@ def get_source_list(folder: str) -> List:




PARAM_NODES = {"dataset_config", PARAM_NODES = {"dataset_config",
"training_platform",
"model_config", "model_config",
"loss_scale_config", "loss_scale_config",
"learn_rate_config", "learn_rate_config",
@@ -65,88 +64,99 @@ class GNMTConfig:
Configuration for `GNMT`. Configuration for `GNMT`.


Args: Args:
random_seed (int): Random seed.
batch_size (int): Batch size of input dataset.
random_seed (int): Random seed, it can be changed.
epochs (int): Epoch number. epochs (int): Epoch number.
dataset_sink_mode (bool): Whether enable dataset sink mode.
dataset_sink_step (int): Dataset sink step.
lr_scheduler (str): Whether use lr_scheduler, only support "ISR" now.
lr (float): Initial learning rate.
min_lr (float): Minimum learning rate.
decay_start_step (int): Step to decay.
warmup_steps (int): Warm up steps.
batch_size (int): Batch size of input dataset.
dataset_schema (str): Path of dataset schema file. dataset_schema (str): Path of dataset schema file.
pre_train_dataset (str): Path of pre-training dataset file or folder. pre_train_dataset (str): Path of pre-training dataset file or folder.
fine_tune_dataset (str): Path of fine-tune dataset file or folder. fine_tune_dataset (str): Path of fine-tune dataset file or folder.
test_dataset (str): Path of test dataset file or folder. test_dataset (str): Path of test dataset file or folder.
valid_dataset (str): Path of validation dataset file or folder. valid_dataset (str): Path of validation dataset file or folder.
ckpt_path (str): Checkpoints save path.
save_ckpt_steps (int): Interval of saving ckpt.
ckpt_prefix (str): Prefix of ckpt file.
keep_ckpt_max (int): Max ckpt files number.
seq_length (int): Length of input sequence. Default: 64.
vocab_size (int): The shape of each embedding vector. Default: 46192.
hidden_size (int): Size of embedding, attention, dim. Default: 512.
dataset_sink_mode (bool): Whether enable dataset sink mode.
seq_length (int): Length of input sequence.
vocab_size (int): The shape of each embedding vector.
hidden_size (int): Size of embedding, attention, dim.
num_hidden_layers (int): Encoder, Decoder layers. num_hidden_layers (int): Encoder, Decoder layers.

intermediate_size (int): Size of intermediate layer in the Transformer intermediate_size (int): Size of intermediate layer in the Transformer
encoder/decoder cell. Default: 4096.
encoder/decoder cell.
hidden_act (str): Activation function used in the Transformer encoder/decoder hidden_act (str): Activation function used in the Transformer encoder/decoder
cell. Default: "relu".
cell.
hidden_dropout_prob (float): The dropout probability for hidden outputs.
attention_dropout_prob (float): The dropout probability for Attention module.
initializer_range (float): Initialization value of TruncatedNormal.
label_smoothing (float): Label smoothing setting.
beam_width (int): Beam width for beam search in inferring.
length_penalty_weight (float): Penalty for sentence length.
max_decode_length (int): Max decode length for inferring.
input_mask_from_dataset (bool): Specifies whether to use the input mask that loaded from
dataset.
init_loss_scale (int): Initialized loss scale. init_loss_scale (int): Initialized loss scale.
loss_scale_factor (int): Loss scale factor. loss_scale_factor (int): Loss scale factor.
scale_window (int): Window size of loss scale. scale_window (int): Window size of loss scale.
beam_width (int): Beam width for beam search in inferring. Default: 4.
length_penalty_weight (float): Penalty for sentence length. Default: 1.0.
label_smoothing (float): Label smoothing setting. Default: 0.1.
input_mask_from_dataset (bool): Specifies whether to use the input mask that loaded from
dataset. Default: True.
lr_scheduler (str): Whether use lr_scheduler, only support "ISR" now.
optimizer (str): Optimizer for training, e.g. Adam, Lamb, momentum. Default: Adam.
lr (float): Initial learning rate.
min_lr (float): Minimum learning rate.
decay_steps (int): Decay steps.
lr_scheduler_power(float): A value used to calculate decayed learning rate.
warmup_lr_remain_steps (int or float): Start decay at 'remain_steps' iteration.
warmup_lr_decay_interval (int):interval between LR decay steps.
decay_start_step (int): Step to decay.
warmup_steps (int): Warm up steps.
existed_ckpt (str): Using existed checkpoint to keep training or not.
save_ckpt_steps (int): Interval of saving ckpt.
keep_ckpt_max (int): Max ckpt files number.
ckpt_prefix (str): Prefix of ckpt file.
ckpt_path (str): Checkpoints save path.
save_graphs (bool): Whether to save graphs, please set to True if mindinsight save_graphs (bool): Whether to save graphs, please set to True if mindinsight
is wanted. is wanted.
dtype (mstype): Data type of the input. Default: mstype.float32.
max_decode_length (int): Max decode length for inferring. Default: 64.
hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.1.
attention_dropout_prob (float): The dropout probability for
Multi-head Self-Attention. Default: 0.1.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
dtype (mstype): Data type of the input.

Note:
There are three types of learning rate scheduler, square root scheduler, polynomial
decay scheduler and warmup multistep learning rate scheduler.
In square root scheduler, the following parameters can be used, lr, decay_start_step,
warmup_steps and min_lr.
In polynomial decay scheduler, the following parameters can be used, lr, min_lr, decay_steps,
warmup_steps, lr_scheduler_power.
In warmmup multistep learning rate scheduler, the following parameters can be used, lr, warmup_steps,
warmup_lr_remain_steps, warmup_lr_decay_interval, decay_steps, lr_scheduler_power.
""" """


def __init__(self, def __init__(self,
modelarts=False, random_seed=74,
epochs=6, batch_size=64,
random_seed=50,
epochs=6, batch_size=128,
dataset_schema: str = None, dataset_schema: str = None,
pre_train_dataset: str = None, pre_train_dataset: str = None,
fine_tune_dataset: str = None, fine_tune_dataset: str = None,
test_dataset: str = None, test_dataset: str = None,
valid_dataset: str = None, valid_dataset: str = None,
dataset_sink_mode=True, dataset_sink_step=1,
dataset_sink_mode=True,
seq_length=51, vocab_size=32320, hidden_size=1024, seq_length=51, vocab_size=32320, hidden_size=1024,
num_hidden_layers=4, intermediate_size=4096, num_hidden_layers=4, intermediate_size=4096,
hidden_act="tanh", hidden_act="tanh",
hidden_dropout_prob=0.2, attention_dropout_prob=0.2, hidden_dropout_prob=0.2, attention_dropout_prob=0.2,
initializer_range=0.1, initializer_range=0.1,
label_smoothing=0.1, label_smoothing=0.1,
beam_width=5,
length_penalty_weight=1.0,
beam_width=2,
length_penalty_weight=0.6,
max_decode_length=50, max_decode_length=50,
input_mask_from_dataset=False, input_mask_from_dataset=False,
init_loss_scale=2 ** 10,
loss_scale_factor=2, scale_window=128,
lr_scheduler="", optimizer="adam",
lr=1e-4, min_lr=1e-6,
decay_steps=4, lr_scheduler_power=1,
init_loss_scale=65536,
loss_scale_factor=2, scale_window=1000,
lr_scheduler="WarmupMultiStepLR",
optimizer="adam",
lr=2e-3, min_lr=1e-6,
decay_steps=4, lr_scheduler_power=0.5,
warmup_lr_remain_steps=0.666, warmup_lr_decay_interval=-1, warmup_lr_remain_steps=0.666, warmup_lr_decay_interval=-1,
decay_start_step=-1, warmup_steps=200, decay_start_step=-1, warmup_steps=200,
existed_ckpt="", save_ckpt_steps=2000, keep_ckpt_max=20,
existed_ckpt="", save_ckpt_steps=3452, keep_ckpt_max=6,
ckpt_prefix="gnmt", ckpt_path: str = None, ckpt_prefix="gnmt", ckpt_path: str = None,
save_step=10000,
save_graphs=False, save_graphs=False,
dtype=mstype.float32): dtype=mstype.float32):


self.save_graphs = save_graphs self.save_graphs = save_graphs
self.random_seed = random_seed self.random_seed = random_seed
self.modelarts = modelarts
self.save_step = save_step
self.dataset_schema = dataset_schema self.dataset_schema = dataset_schema
self.pre_train_dataset = get_source_list(pre_train_dataset) # type: List[str] self.pre_train_dataset = get_source_list(pre_train_dataset) # type: List[str]
self.fine_tune_dataset = get_source_list(fine_tune_dataset) # type: List[str] self.fine_tune_dataset = get_source_list(fine_tune_dataset) # type: List[str]
@@ -158,7 +168,6 @@ class GNMTConfig:


self.epochs = epochs self.epochs = epochs
self.dataset_sink_mode = dataset_sink_mode self.dataset_sink_mode = dataset_sink_mode
self.dataset_sink_step = dataset_sink_step


self.ckpt_path = ckpt_path self.ckpt_path = ckpt_path
self.keep_ckpt_max = keep_ckpt_max self.keep_ckpt_max = keep_ckpt_max
@@ -201,8 +210,6 @@ class GNMTConfig:
self.decay_start_step = decay_start_step self.decay_start_step = decay_start_step
self.warmup_steps = warmup_steps self.warmup_steps = warmup_steps


self.train_url = ""

@classmethod @classmethod
def from_dict(cls, json_object: dict): def from_dict(cls, json_object: dict):
"""Constructs a `TransformerConfig` from a Python dictionary of parameters.""" """Constructs a `TransformerConfig` from a Python dictionary of parameters."""


+ 3
- 7
model_zoo/official/nlp/gnmt_v2/config/config_test.json View File

@@ -1,7 +1,4 @@
{ {
"training_platform": {
"modelarts": false
},
"dataset_config": { "dataset_config": {
"random_seed": 50, "random_seed": 50,
"epochs": 6, "epochs": 6,
@@ -11,8 +8,7 @@
"fine_tune_dataset": null, "fine_tune_dataset": null,
"test_dataset": "/home/workspace/dataset_menu/newstest2014.en.tfrecord-001-of-001", "test_dataset": "/home/workspace/dataset_menu/newstest2014.en.tfrecord-001-of-001",
"valid_dataset": null, "valid_dataset": null,
"dataset_sink_mode": true,
"dataset_sink_step": 2
"dataset_sink_mode": true
}, },
"model_config": { "model_config": {
"seq_length": 107, "seq_length": 107,
@@ -29,9 +25,9 @@
"max_decode_length": 80 "max_decode_length": 80
}, },
"loss_scale_config": { "loss_scale_config": {
"init_loss_scale": 8192,
"init_loss_scale": 65536,
"loss_scale_factor": 2, "loss_scale_factor": 2,
"scale_window": 128
"scale_window": 1000
}, },
"learn_rate_config": { "learn_rate_config": {
"optimizer": "adam", "optimizer": "adam",


+ 8
- 6
model_zoo/official/nlp/gnmt_v2/create_dataset.py View File

@@ -49,11 +49,12 @@ if __name__ == '__main__':
schema_address=args.output_folder + "/" + test_src_file + ".json" schema_address=args.output_folder + "/" + test_src_file + ".json"
) )
print(f" | It's writing, please wait a moment.") print(f" | It's writing, please wait a moment.")
test.write_to_tfrecord(
test.write_to_mindrecord(
path=os.path.join( path=os.path.join(
args.output_folder, args.output_folder,
os.path.basename(test_src_file) + ".tfrecord"
)
os.path.basename(test_src_file) + ".mindrecord"
),
train_mode=False
) )
train = BiLingualDataLoader( train = BiLingualDataLoader(
@@ -65,11 +66,12 @@ if __name__ == '__main__':
schema_address=args.output_folder + "/" + train_src_file + ".json" schema_address=args.output_folder + "/" + train_src_file + ".json"
) )
print(f" | It's writing, please wait a moment.") print(f" | It's writing, please wait a moment.")
train.write_to_tfrecord(
train.write_to_mindrecord(
path=os.path.join( path=os.path.join(
args.output_folder, args.output_folder,
os.path.basename(train_src_file) + ".tfrecord"
)
os.path.basename(train_src_file) + ".mindrecord"
),
train_mode=True
) )
print(f" | Vocabulary size: {tokenizer.vocab_size}.") print(f" | Vocabulary size: {tokenizer.vocab_size}.")

+ 7
- 40
model_zoo/official/nlp/gnmt_v2/src/dataset/base.py View File

@@ -14,16 +14,16 @@
# ============================================================================ # ============================================================================
"""Base class of data loader.""" """Base class of data loader."""
import os import os
import collections
import numpy as np import numpy as np


from mindspore.mindrecord import FileWriter from mindspore.mindrecord import FileWriter
from .schema import SCHEMA
from .schema import SCHEMA, TEST_SCHEMA




class DataLoader: class DataLoader:
"""Data loader for dataset.""" """Data loader for dataset."""
_SCHEMA = SCHEMA _SCHEMA = SCHEMA
_TEST_SCHEMA = TEST_SCHEMA


def __init__(self): def __init__(self):
self._examples = [] self._examples = []
@@ -41,7 +41,7 @@ class DataLoader:
new_sen[:sen.shape[0]] = sen[:] new_sen[:sen.shape[0]] = sen[:]
return new_sen return new_sen


def write_to_mindrecord(self, path, shard_num=1, desc=""):
def write_to_mindrecord(self, path, train_mode, shard_num=1, desc="gnmt"):
""" """
Write mindrecord file. Write mindrecord file.


@@ -54,7 +54,10 @@ class DataLoader:
path = os.path.abspath(path) path = os.path.abspath(path)


writer = FileWriter(file_name=path, shard_num=shard_num) writer = FileWriter(file_name=path, shard_num=shard_num)
writer.add_schema(self._SCHEMA, desc)
if train_mode:
writer.add_schema(self._SCHEMA, desc)
else:
writer.add_schema(self._TEST_SCHEMA, desc)
if not self._examples: if not self._examples:
self._load() self._load()


@@ -62,41 +65,5 @@ class DataLoader:
writer.commit() writer.commit()
print(f"| Wrote to {path}.") print(f"| Wrote to {path}.")


def write_to_tfrecord(self, path, shard_num=1):
"""
Write to tfrecord.

Args:
path (str): Output file path.
shard_num (int): Shard num.
"""
import tensorflow as tf
if not os.path.isabs(path):
path = os.path.abspath(path)
output_files = []
for i in range(shard_num):
output_file = path + "-%03d-of-%03d" % (i + 1, shard_num)
output_files.append(output_file)
# create writers
writers = []
for output_file in output_files:
writers.append(tf.io.TFRecordWriter(output_file))

if not self._examples:
self._load()

# create feature
features = collections.OrderedDict()
for example in self._examples:
for key in example:
features[key] = tf.train.Feature(int64_list=tf.train.Int64List(value=example[key].tolist()))
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
for writer in writers:
writer.write(tf_example.SerializeToString())
for writer in writers:
writer.close()
for p in output_files:
print(f" | Write to {p}.")

def _add_example(self, example): def _add_example(self, example):
self._examples.append(example) self._examples.append(example)

+ 14
- 20
model_zoo/official/nlp/gnmt_v2/src/dataset/load_dataset.py View File

@@ -19,9 +19,9 @@ import mindspore.dataset.engine as de
import mindspore.dataset.transforms.c_transforms as deC import mindspore.dataset.transforms.c_transforms as deC




def _load_dataset(input_files, schema_file, batch_size, epoch_count=1,
sink_mode=False, sink_step=1, rank_size=1, rank_id=0, shuffle=True,
drop_remainder=True, is_translate=False):
def _load_dataset(input_files, schema_file, batch_size, sink_mode=False,
rank_size=1, rank_id=0, shuffle=True, drop_remainder=True,
is_translate=False):
""" """
Load dataset according to passed in params. Load dataset according to passed in params.


@@ -29,9 +29,7 @@ def _load_dataset(input_files, schema_file, batch_size, epoch_count=1,
input_files (list): Data files. input_files (list): Data files.
schema_file (str): Schema file path. schema_file (str): Schema file path.
batch_size (int): Batch size. batch_size (int): Batch size.
epoch_count (int): Epoch count.
sink_mode (bool): Whether enable sink mode. sink_mode (bool): Whether enable sink mode.
sink_step (int): Step to sink.
rank_size (int): Rank size. rank_size (int): Rank size.
rank_id (int): Rank id. rank_id (int): Rank id.
shuffle (bool): Whether shuffle dataset. shuffle (bool): Whether shuffle dataset.
@@ -57,15 +55,14 @@ def _load_dataset(input_files, schema_file, batch_size, epoch_count=1,
print(f" | Loading {datafile}.") print(f" | Loading {datafile}.")


if not is_translate: if not is_translate:
ds = de.TFRecordDataset(
input_files, schema_file,
columns_list=[
ds = de.MindDataset(
input_files, columns_list=[
"src", "src_padding", "src", "src_padding",
"prev_opt", "prev_opt",
"target", "tgt_padding" "target", "tgt_padding"
],
shuffle=False, num_shards=rank_size, shard_id=rank_id,
shard_equal_rows=True, num_parallel_workers=8)
], shuffle=False, num_shards=rank_size, shard_id=rank_id,
num_parallel_workers=8
)


ori_dataset_size = ds.get_dataset_size() ori_dataset_size = ds.get_dataset_size()
print(f" | Dataset size: {ori_dataset_size}.") print(f" | Dataset size: {ori_dataset_size}.")
@@ -92,13 +89,13 @@ def _load_dataset(input_files, schema_file, batch_size, epoch_count=1,
) )
ds = ds.batch(batch_size, drop_remainder=drop_remainder) ds = ds.batch(batch_size, drop_remainder=drop_remainder)
else: else:
ds = de.TFRecordDataset(
input_files, schema_file,
columns_list=[
ds = de.MindDataset(
input_files, columns_list=[
"src", "src_padding" "src", "src_padding"
], ],
shuffle=False, num_shards=rank_size, shard_id=rank_id, shuffle=False, num_shards=rank_size, shard_id=rank_id,
shard_equal_rows=True, num_parallel_workers=8)
num_parallel_workers=8
)


ori_dataset_size = ds.get_dataset_size() ori_dataset_size = ds.get_dataset_size()
print(f" | Dataset size: {ori_dataset_size}.") print(f" | Dataset size: {ori_dataset_size}.")
@@ -119,7 +116,7 @@ def _load_dataset(input_files, schema_file, batch_size, epoch_count=1,
return ds return ds




def load_dataset(data_files: list, schema: str, batch_size: int, epoch_count: int, sink_mode: bool, sink_step: int = 1,
def load_dataset(data_files: list, schema: str, batch_size: int, sink_mode: bool,
rank_size: int = 1, rank_id: int = 0, shuffle=True, drop_remainder=True, is_translate=False): rank_size: int = 1, rank_id: int = 0, shuffle=True, drop_remainder=True, is_translate=False):
""" """
Load dataset. Load dataset.
@@ -128,9 +125,7 @@ def load_dataset(data_files: list, schema: str, batch_size: int, epoch_count: in
data_files (list): Data files. data_files (list): Data files.
schema (str): Schema file path. schema (str): Schema file path.
batch_size (int): Batch size. batch_size (int): Batch size.
epoch_count (int): Epoch count.
sink_mode (bool): Whether enable sink mode. sink_mode (bool): Whether enable sink mode.
sink_step (int): Step to sink.
rank_size (int): Rank size. rank_size (int): Rank size.
rank_id (int): Rank id. rank_id (int): Rank id.
shuffle (bool): Whether shuffle dataset. shuffle (bool): Whether shuffle dataset.
@@ -138,6 +133,5 @@ def load_dataset(data_files: list, schema: str, batch_size: int, epoch_count: in
Returns: Returns:
Dataset, dataset instance. Dataset, dataset instance.
""" """
return _load_dataset(data_files, schema, batch_size, epoch_count, sink_mode,
sink_step, rank_size, rank_id, shuffle=shuffle,
return _load_dataset(data_files, schema, batch_size, sink_mode, rank_size, rank_id, shuffle=shuffle,
drop_remainder=drop_remainder, is_translate=is_translate) drop_remainder=drop_remainder, is_translate=is_translate)

+ 5
- 0
model_zoo/official/nlp/gnmt_v2/src/dataset/schema.py View File

@@ -21,3 +21,8 @@ SCHEMA = {
"target": {"type": "int64", "shape": [-1]}, "target": {"type": "int64", "shape": [-1]},
"tgt_padding": {"type": "int64", "shape": [-1]}, "tgt_padding": {"type": "int64", "shape": [-1]},
} }

TEST_SCHEMA = {
"src": {"type": "int64", "shape": [-1]},
"src_padding": {"type": "int64", "shape": [-1]},
}

+ 0
- 1
model_zoo/official/nlp/gnmt_v2/src/gnmt_model/gnmt_for_infer.py View File

@@ -189,7 +189,6 @@ def infer(config):
eval_dataset = load_dataset(data_files=config.test_dataset, eval_dataset = load_dataset(data_files=config.test_dataset,
schema=config.dataset_schema, schema=config.dataset_schema,
batch_size=config.batch_size, batch_size=config.batch_size,
epoch_count=1,
sink_mode=config.dataset_sink_mode, sink_mode=config.dataset_sink_mode,
drop_remainder=False, drop_remainder=False,
is_translate=True, is_translate=True,


+ 0
- 13
model_zoo/official/nlp/gnmt_v2/src/utils/loss_monitor.py View File

@@ -16,8 +16,6 @@
import time import time


from mindspore.train.callback import Callback from mindspore.train.callback import Callback
from mindspore.communication.management import get_rank

from config import GNMTConfig from config import GNMTConfig




@@ -51,9 +49,6 @@ class LossCallBack(Callback):
"""step end.""" """step end."""
cb_params = run_context.original_args() cb_params = run_context.original_args()
file_name = "./loss.log" file_name = "./loss.log"
if self.config.modelarts:
import os
file_name = "/home/work/workspace/loss/loss_{}.log".format(os.getenv('DEVICE_ID'))
with open(file_name, "a+") as f: with open(file_name, "a+") as f:
time_stamp_current = self._get_ms_timestamp() time_stamp_current = self._get_ms_timestamp()
f.write("time: {}, epoch: {}, step: {}, outputs: [loss: {}, overflow: {}, loss scale value: {} ].\n".format( f.write("time: {}, epoch: {}, step: {}, outputs: [loss: {}, overflow: {}, loss scale value: {} ].\n".format(
@@ -65,14 +60,6 @@ class LossCallBack(Callback):
str(cb_params.net_outputs[2].asnumpy()) str(cb_params.net_outputs[2].asnumpy())
)) ))


if self.config.modelarts:
from modelarts.data_util import upload_output
rank_id = get_rank()
if cb_params.cur_step_num % self.config.save_step == 1 \
and cb_params.cur_step_num != 1 and rank_id in [0, 8]:
upload_output("/home/work/workspace/loss", self.config.train_url)
upload_output("/cache/ckpt_0", self.config.train_url)

@staticmethod @staticmethod
def _get_ms_timestamp(): def _get_ms_timestamp():
t = time.time() t = time.time()


+ 1
- 4
model_zoo/official/nlp/gnmt_v2/src/utils/optimizer.py View File

@@ -87,10 +87,7 @@ def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):
validator.check_value_type("beta2", beta2, [float], prim_name) validator.check_value_type("beta2", beta2, [float], prim_name)
validator.check_value_type("eps", eps, [float], prim_name) validator.check_value_type("eps", eps, [float], prim_name)
validator.check_value_type("weight_dacay", weight_decay, [float], prim_name) validator.check_value_type("weight_dacay", weight_decay, [float], prim_name)
# validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
# validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
# validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name)
# validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name)





def _check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, power, prim_name): def _check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, power, prim_name):


+ 6
- 18
model_zoo/official/nlp/gnmt_v2/train.py View File

@@ -169,9 +169,7 @@ def _get_optimizer(config, network, lr):
if config.optimizer.lower() == "adam": if config.optimizer.lower() == "adam":
optimizer = Adam(network.trainable_params(), lr, beta1=0.9, beta2=0.98) optimizer = Adam(network.trainable_params(), lr, beta1=0.9, beta2=0.98)
elif config.optimizer.lower() == "lamb": elif config.optimizer.lower() == "lamb":
optimizer = Lamb(network.trainable_params(), decay_steps=12000,
start_learning_rate=config.lr, end_learning_rate=config.min_lr,
power=10.0, warmup_steps=config.warmup_steps, weight_decay=0.01,
optimizer = Lamb(network.trainable_params(), learning_rate=lr,
eps=1e-6) eps=1e-6)
elif config.optimizer.lower() == "momentum": elif config.optimizer.lower() == "momentum":
optimizer = Momentum(network.trainable_params(), lr, momentum=0.9) optimizer = Momentum(network.trainable_params(), lr, momentum=0.9)
@@ -277,25 +275,21 @@ def train_parallel(config: GNMTConfig):
data_files=config.pre_train_dataset, data_files=config.pre_train_dataset,
schema=config.dataset_schema, schema=config.dataset_schema,
batch_size=config.batch_size, batch_size=config.batch_size,
epoch_count=config.epochs,
sink_mode=config.dataset_sink_mode, sink_mode=config.dataset_sink_mode,
sink_step=config.dataset_sink_step,
rank_size=MultiAscend.get_group_size(), rank_size=MultiAscend.get_group_size(),
rank_id=MultiAscend.get_rank() rank_id=MultiAscend.get_rank()
) if config.pre_train_dataset else None ) if config.pre_train_dataset else None
fine_tune_dataset = load_dataset( fine_tune_dataset = load_dataset(
data_files=config.fine_tune_dataset, schema=config.dataset_schema, data_files=config.fine_tune_dataset, schema=config.dataset_schema,
batch_size=config.batch_size, epoch_count=config.epochs,
batch_size=config.batch_size,
sink_mode=config.dataset_sink_mode, sink_mode=config.dataset_sink_mode,
sink_step=config.dataset_sink_step,
rank_size=MultiAscend.get_group_size(), rank_size=MultiAscend.get_group_size(),
rank_id=MultiAscend.get_rank() rank_id=MultiAscend.get_rank()
) if config.fine_tune_dataset else None ) if config.fine_tune_dataset else None
test_dataset = load_dataset( test_dataset = load_dataset(
data_files=config.test_dataset, schema=config.dataset_schema, data_files=config.test_dataset, schema=config.dataset_schema,
batch_size=config.batch_size, epoch_count=config.epochs,
batch_size=config.batch_size,
sink_mode=config.dataset_sink_mode, sink_mode=config.dataset_sink_mode,
sink_step=config.dataset_sink_step,
rank_size=MultiAscend.get_group_size(), rank_size=MultiAscend.get_group_size(),
rank_id=MultiAscend.get_rank() rank_id=MultiAscend.get_rank()
) if config.test_dataset else None ) if config.test_dataset else None
@@ -318,21 +312,15 @@ def train_single(config: GNMTConfig):
pre_train_dataset = load_dataset(data_files=config.pre_train_dataset, pre_train_dataset = load_dataset(data_files=config.pre_train_dataset,
schema=config.dataset_schema, schema=config.dataset_schema,
batch_size=config.batch_size, batch_size=config.batch_size,
epoch_count=config.epochs,
sink_mode=config.dataset_sink_mode,
sink_step=config.dataset_sink_step) if config.pre_train_dataset else None
sink_mode=config.dataset_sink_mode) if config.pre_train_dataset else None
fine_tune_dataset = load_dataset(data_files=config.fine_tune_dataset, fine_tune_dataset = load_dataset(data_files=config.fine_tune_dataset,
schema=config.dataset_schema, schema=config.dataset_schema,
batch_size=config.batch_size, batch_size=config.batch_size,
epoch_count=config.epochs,
sink_mode=config.dataset_sink_mode,
sink_step=config.dataset_sink_step) if config.fine_tune_dataset else None
sink_mode=config.dataset_sink_mode) if config.fine_tune_dataset else None
test_dataset = load_dataset(data_files=config.test_dataset, test_dataset = load_dataset(data_files=config.test_dataset,
schema=config.dataset_schema, schema=config.dataset_schema,
batch_size=config.batch_size, batch_size=config.batch_size,
epoch_count=config.epochs,
sink_mode=config.dataset_sink_mode,
sink_step=config.dataset_sink_step) if config.test_dataset else None
sink_mode=config.dataset_sink_mode) if config.test_dataset else None


_build_training_pipeline(config=config, _build_training_pipeline(config=config,
pre_training_dataset=pre_train_dataset, pre_training_dataset=pre_train_dataset,


Loading…
Cancel
Save