Browse Source

!5914 edit bert hub config file and tinybert script bugfix

Merge pull request !5914 from yoonlee666/master
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
994cfcb8a0
2 changed files with 5 additions and 3 deletions
  1. +4
    -2
      model_zoo/official/nlp/bert/mindspore_hub_conf.py
  2. +1
    -1
      model_zoo/official/nlp/tinybert/scripts/run_distributed_gd_gpu.sh

+ 4
- 2
model_zoo/official/nlp/bert/mindspore_hub_conf.py View File

@@ -67,11 +67,13 @@ def create_network(name, *args, **kwargs):
bert_net_cfg_base.batch_size = kwargs["batch_size"]
if "seq_length" in kwargs:
bert_net_cfg_base.seq_length = kwargs["seq_length"]
return BertModel(bert_net_cfg_base, *args)
is_training = kwargs.get("is_training", default=False)
return BertModel(bert_net_cfg_base, is_training, *args)
if name == 'bert_nezha':
if "batch_size" in kwargs:
bert_net_cfg_nezha.batch_size = kwargs["batch_size"]
if "seq_length" in kwargs:
bert_net_cfg_nezha.seq_length = kwargs["seq_length"]
return BertModel(bert_net_cfg_nezha, *args)
is_training = kwargs.get("is_training", default=False)
return BertModel(bert_net_cfg_nezha, is_training, *args)
raise NotImplementedError(f"{name} is not implemented in the repo")

+ 1
- 1
model_zoo/official/nlp/tinybert/scripts/run_distributed_gd_gpu.sh View File

@@ -38,5 +38,5 @@ mpirun --allow-run-as-root -n $RANK_SIZE \
--data_dir=$DATA_DIR \
--schema_dir=$SCHEMA_DIR \
--dataset_type="tfrecord" \
--enable_data_sink=False \
--enable_data_sink="false" \
--load_teacher_ckpt_path=$TEACHER_CKPT_PATH > log.txt 2>&1 &

Loading…
Cancel
Save