| @@ -19,6 +19,7 @@ python run_pretrain.py | |||||
| import os | import os | ||||
| import argparse | import argparse | ||||
| import numpy | |||||
| import mindspore.communication.management as D | import mindspore.communication.management as D | ||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore.train.model import Model | from mindspore.train.model import Model | ||||
| @@ -142,4 +143,5 @@ def run_pretrain(): | |||||
| model = Model(netwithgrads) | model = Model(netwithgrads) | ||||
| model.train(new_repeat_count, ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true")) | model.train(new_repeat_count, ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true")) | ||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| numpy.random.seed(0) | |||||
| run_pretrain() | run_pretrain() | ||||
| @@ -39,6 +39,7 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e | |||||
| shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank, | shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank, | ||||
| shard_equal_rows=True) | shard_equal_rows=True) | ||||
| ori_dataset_size = ds.get_dataset_size() | ori_dataset_size = ds.get_dataset_size() | ||||
| print('origin dataset size: ', ori_dataset_size) | |||||
| new_size = ori_dataset_size | new_size = ori_dataset_size | ||||
| if enable_data_sink == "true": | if enable_data_sink == "true": | ||||
| new_size = data_sink_steps * bert_net_cfg.batch_size | new_size = data_sink_steps * bert_net_cfg.batch_size | ||||
| @@ -53,7 +54,7 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e | |||||
| ds = ds.map(input_columns="input_ids", operations=type_cast_op) | ds = ds.map(input_columns="input_ids", operations=type_cast_op) | ||||
| # apply batch operations | # apply batch operations | ||||
| ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True) | ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True) | ||||
| ds = ds.repeat(new_repeat_count) | |||||
| ds = ds.repeat(max(new_repeat_count, repeat_count)) | |||||
| logger.info("data size: {}".format(ds.get_dataset_size())) | logger.info("data size: {}".format(ds.get_dataset_size())) | ||||
| logger.info("repeatcount: {}".format(ds.get_repeat_count())) | logger.info("repeatcount: {}".format(ds.get_repeat_count())) | ||||
| return ds, new_repeat_count | return ds, new_repeat_count | ||||
| @@ -32,7 +32,6 @@ from .bert_model import BertModel | |||||
| GRADIENT_CLIP_TYPE = 1 | GRADIENT_CLIP_TYPE = 1 | ||||
| GRADIENT_CLIP_VALUE = 1.0 | GRADIENT_CLIP_VALUE = 1.0 | ||||
| _nn_clip_by_norm = nn.ClipByNorm() | |||||
| clip_grad = C.MultitypeFuncGraph("clip_grad") | clip_grad = C.MultitypeFuncGraph("clip_grad") | ||||
| @@ -57,7 +56,7 @@ def _clip_grad(clip_type, clip_value, grad): | |||||
| new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt), | new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt), | ||||
| F.cast(F.tuple_to_array((clip_value,)), dt)) | F.cast(F.tuple_to_array((clip_value,)), dt)) | ||||
| else: | else: | ||||
| new_grad = _nn_clip_by_norm(grad, F.cast(F.tuple_to_array((clip_value,)), dt)) | |||||
| new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt)) | |||||
| return new_grad | return new_grad | ||||
| @@ -56,7 +56,7 @@ if cfg.bert_network == 'base': | |||||
| bert_net_cfg = BertConfig( | bert_net_cfg = BertConfig( | ||||
| batch_size=32, | batch_size=32, | ||||
| seq_length=128, | seq_length=128, | ||||
| vocab_size=21136, | |||||
| vocab_size=21128, | |||||
| hidden_size=768, | hidden_size=768, | ||||
| num_hidden_layers=12, | num_hidden_layers=12, | ||||
| num_attention_heads=12, | num_attention_heads=12, | ||||
| @@ -77,7 +77,7 @@ if cfg.bert_network == 'nezha': | |||||
| bert_net_cfg = BertConfig( | bert_net_cfg = BertConfig( | ||||
| batch_size=32, | batch_size=32, | ||||
| seq_length=128, | seq_length=128, | ||||
| vocab_size=21136, | |||||
| vocab_size=21128, | |||||
| hidden_size=1024, | hidden_size=1024, | ||||
| num_hidden_layers=24, | num_hidden_layers=24, | ||||
| num_attention_heads=16, | num_attention_heads=16, | ||||
| @@ -98,7 +98,7 @@ if cfg.bert_network == 'large': | |||||
| bert_net_cfg = BertConfig( | bert_net_cfg = BertConfig( | ||||
| batch_size=16, | batch_size=16, | ||||
| seq_length=512, | seq_length=512, | ||||
| vocab_size=30528, | |||||
| vocab_size=30522, | |||||
| hidden_size=1024, | hidden_size=1024, | ||||
| num_hidden_layers=24, | num_hidden_layers=24, | ||||
| num_attention_heads=16, | num_attention_heads=16, | ||||
| @@ -39,6 +39,7 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e | |||||
| shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank, | shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank, | ||||
| shard_equal_rows=True) | shard_equal_rows=True) | ||||
| ori_dataset_size = ds.get_dataset_size() | ori_dataset_size = ds.get_dataset_size() | ||||
| print('origin dataset size: ', ori_dataset_size) | |||||
| new_size = ori_dataset_size | new_size = ori_dataset_size | ||||
| if enable_data_sink == "true": | if enable_data_sink == "true": | ||||
| new_size = data_sink_steps * bert_net_cfg.batch_size | new_size = data_sink_steps * bert_net_cfg.batch_size | ||||
| @@ -53,7 +54,7 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e | |||||
| ds = ds.map(input_columns="input_ids", operations=type_cast_op) | ds = ds.map(input_columns="input_ids", operations=type_cast_op) | ||||
| # apply batch operations | # apply batch operations | ||||
| ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True) | ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True) | ||||
| ds = ds.repeat(new_repeat_count) | |||||
| ds = ds.repeat(max(new_repeat_count, repeat_count)) | |||||
| logger.info("data size: {}".format(ds.get_dataset_size())) | logger.info("data size: {}".format(ds.get_dataset_size())) | ||||
| logger.info("repeatcount: {}".format(ds.get_repeat_count())) | logger.info("repeatcount: {}".format(ds.get_repeat_count())) | ||||
| return ds, new_repeat_count | return ds, new_repeat_count | ||||