You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

run_distribute_train.sh 1.5 kB

5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. #!/bin/bash
  2. # Copyright 2020 Huawei Technologies Co., Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # ============================================================================
  16. echo "Please run the script as: "
  17. echo "sh scripts/run_distribute_train.sh DEVICE_NUM DATASET_PATH MINDSPORE_HCCL_CONFIG_PAHT"
  18. echo "for example: sh scripts/run_distribute_train.sh 8 /dataset_path /rank_table_8p.json"
  19. echo "After running the script, the network runs in the background, The log will be generated in logx/output.log"
  20. export RANK_SIZE=$1
  21. DATA_URL=$2
  22. export MINDSPORE_HCCL_CONFIG_PATH=$3
  23. for ((i=0; i<RANK_SIZE;i++))
  24. do
  25. export DEVICE_ID=$i
  26. export RANK_ID=$i
  27. rm -rf log$i
  28. mkdir ./log$i
  29. cp *.py ./log$i
  30. cp -r src ./log$i
  31. cd ./log$i || exit
  32. echo "start training for rank $i, device $DEVICE_ID"
  33. env > env.log
  34. python -u train.py \
  35. --dataset_path=$DATA_URL \
  36. --ckpt_path="checkpoint" \
  37. --eval_file_name='auc.log' \
  38. --loss_file_name='loss.log' \
  39. --do_eval=True > output.log 2>&1 &
  40. cd ../
  41. done