Browse Source

modify the ckpt path

tags/v1.0.0
root 5 years ago
parent
commit
c8200af525
8 changed files with 17 additions and 15 deletions
  1. +3
    -3
      model_zoo/official/cv/mobilenetv2/src/config.py
  2. +6
    -5
      model_zoo/official/cv/mobilenetv2/train.py
  3. +1
    -1
      model_zoo/official/nlp/bert/scripts/ascend_distributed_launcher/hyper_parameter_config.ini
  4. +1
    -0
      model_zoo/official/nlp/transformer/train.py
  5. +1
    -1
      model_zoo/official/recommend/deepfm/scripts/run_distribute_train.sh
  6. +1
    -1
      model_zoo/official/recommend/deepfm/scripts/run_distribute_train_gpu.sh
  7. +2
    -2
      model_zoo/official/recommend/wide_and_deep/src/config.py
  8. +2
    -2
      model_zoo/official/recommend/wide_and_deep_multitable/src/config.py

+ 3
- 3
model_zoo/official/cv/mobilenetv2/src/config.py View File

@@ -36,7 +36,7 @@ def set_config(args):
"save_checkpoint": True, "save_checkpoint": True,
"save_checkpoint_epochs": 1, "save_checkpoint_epochs": 1,
"keep_checkpoint_max": 20, "keep_checkpoint_max": 20,
"save_checkpoint_path": "./checkpoint",
"save_checkpoint_path": "./",
"platform": args.platform, "platform": args.platform,
"run_distribute": False "run_distribute": False
}) })
@@ -57,7 +57,7 @@ def set_config(args):
"save_checkpoint": True, "save_checkpoint": True,
"save_checkpoint_epochs": 1, "save_checkpoint_epochs": 1,
"keep_checkpoint_max": 200, "keep_checkpoint_max": 200,
"save_checkpoint_path": "./checkpoint",
"save_checkpoint_path": "./",
"platform": args.platform, "platform": args.platform,
"ccl": "nccl", "ccl": "nccl",
"run_distribute": args.run_distribute "run_distribute": args.run_distribute
@@ -79,7 +79,7 @@ def set_config(args):
"save_checkpoint": True, "save_checkpoint": True,
"save_checkpoint_epochs": 1, "save_checkpoint_epochs": 1,
"keep_checkpoint_max": 200, "keep_checkpoint_max": 200,
"save_checkpoint_path": "./checkpoint",
"save_checkpoint_path": "./",
"platform": args.platform, "platform": args.platform,
"ccl": "hccl", "ccl": "hccl",
"device_id": int(os.getenv('DEVICE_ID', '0')), "device_id": int(os.getenv('DEVICE_ID', '0')),


+ 6
- 5
model_zoo/official/cv/mobilenetv2/train.py View File

@@ -24,6 +24,7 @@ from mindspore.nn import WithLossCell, TrainOneStepCell
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.communication.management import get_rank
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import save_checkpoint from mindspore.train.serialization import save_checkpoint
@@ -94,10 +95,6 @@ if __name__ == '__main__':
features_path = args_opt.dataset_path + '_features' features_path = args_opt.dataset_path + '_features'
idx_list = list(range(step_size)) idx_list = list(range(step_size))


if os.path.isdir(config.save_checkpoint_path):
os.rename(config.save_checkpoint_path, "{}_{}".format(config.save_checkpoint_path, time.time()))
os.mkdir(config.save_checkpoint_path)

for epoch in range(epoch_size): for epoch in range(epoch_size):
random.shuffle(idx_list) random.shuffle(idx_list)
epoch_start = time.time() epoch_start = time.time()
@@ -112,7 +109,11 @@ if __name__ == '__main__':
.format(epoch + 1, step_size, epoch_mseconds, per_step_mseconds, np.mean(np.array(losses))), \ .format(epoch + 1, step_size, epoch_mseconds, per_step_mseconds, np.mean(np.array(losses))), \
end="") end="")
if (epoch + 1) % config.save_checkpoint_epochs == 0: if (epoch + 1) % config.save_checkpoint_epochs == 0:
save_checkpoint(network, os.path.join(config.save_checkpoint_path, \
rank = 0
if config.run_distribute:
rank = get_rank()
save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/')
save_checkpoint(network, os.path.join(save_ckpt_path, \
f"mobilenetv2_head_{epoch+1}.ckpt")) f"mobilenetv2_head_{epoch+1}.ckpt"))
print("total cost {:5.4f} s".format(time.time() - start)) print("total cost {:5.4f} s".format(time.time() - start))




+ 1
- 1
model_zoo/official/nlp/bert/scripts/ascend_distributed_launcher/hyper_parameter_config.ini View File

@@ -7,6 +7,6 @@ do_shuffle=true
enable_data_sink=true enable_data_sink=true
data_sink_steps=100 data_sink_steps=100
accumulation_steps=1 accumulation_steps=1
save_checkpoint_path=./checkpoint/
save_checkpoint_path=./
save_checkpoint_steps=10000 save_checkpoint_steps=10000
save_checkpoint_num=1 save_checkpoint_num=1

+ 1
- 0
model_zoo/official/nlp/transformer/train.py View File

@@ -131,6 +131,7 @@ def run_transformer_train():
else: else:
device_num = 1 device_num = 1
rank_id = 0 rank_id = 0
save_ckpt_path = os.path.join(args.save_checkpoint_path, 'ckpt_0/')
dataset = create_transformer_dataset(epoch_count=1, rank_size=device_num, dataset = create_transformer_dataset(epoch_count=1, rank_size=device_num,
rank_id=rank_id, do_shuffle=args.do_shuffle, rank_id=rank_id, do_shuffle=args.do_shuffle,
dataset_path=args.data_path, dataset_path=args.data_path,


+ 1
- 1
model_zoo/official/recommend/deepfm/scripts/run_distribute_train.sh View File

@@ -36,7 +36,7 @@ do
env > env.log env > env.log
python -u train.py \ python -u train.py \
--dataset_path=$DATA_URL \ --dataset_path=$DATA_URL \
--ckpt_path="checkpoint" \
--ckpt_path="./" \
--eval_file_name='auc.log' \ --eval_file_name='auc.log' \
--loss_file_name='loss.log' \ --loss_file_name='loss.log' \
--do_eval=True > output.log 2>&1 & --do_eval=True > output.log 2>&1 &


+ 1
- 1
model_zoo/official/recommend/deepfm/scripts/run_distribute_train_gpu.sh View File

@@ -31,7 +31,7 @@ env > env.log
mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \ mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \
python -u train.py \ python -u train.py \
--dataset_path=$DATA_URL \ --dataset_path=$DATA_URL \
--ckpt_path="checkpoint" \
--ckpt_path="./" \
--eval_file_name='auc.log' \ --eval_file_name='auc.log' \
--loss_file_name='loss.log' \ --loss_file_name='loss.log' \
--device_target='GPU' \ --device_target='GPU' \


+ 2
- 2
model_zoo/official/recommend/wide_and_deep/src/config.py View File

@@ -38,7 +38,7 @@ def argparse_init():
parser.add_argument("--keep_prob", type=float, default=1.0, help="The keep rate in dropout layer.") parser.add_argument("--keep_prob", type=float, default=1.0, help="The keep rate in dropout layer.")
parser.add_argument("--dropout_flag", type=int, default=0, help="Enable dropout") parser.add_argument("--dropout_flag", type=int, default=0, help="Enable dropout")
parser.add_argument("--output_path", type=str, default="./output/") parser.add_argument("--output_path", type=str, default="./output/")
parser.add_argument("--ckpt_path", type=str, default="./checkpoints/", help="The location of the checkpoint file.")
parser.add_argument("--ckpt_path", type=str, default="./", help="The location of the checkpoint file.")
parser.add_argument("--stra_ckpt", type=str, default="./checkpoints/strategy.ckpt", parser.add_argument("--stra_ckpt", type=str, default="./checkpoints/strategy.ckpt",
help="The strategy checkpoint file.") help="The strategy checkpoint file.")
parser.add_argument("--eval_file_name", type=str, default="eval.log", help="Eval output file.") parser.add_argument("--eval_file_name", type=str, default="eval.log", help="Eval output file.")
@@ -77,7 +77,7 @@ class WideDeepConfig():
self.output_path = "./output" self.output_path = "./output"
self.eval_file_name = "eval.log" self.eval_file_name = "eval.log"
self.loss_file_name = "loss.log" self.loss_file_name = "loss.log"
self.ckpt_path = "./checkpoints/"
self.ckpt_path = "./"
self.stra_ckpt = './checkpoints/strategy.ckpt' self.stra_ckpt = './checkpoints/strategy.ckpt'
self.host_device_mix = 0 self.host_device_mix = 0
self.dataset_type = "tfrecord" self.dataset_type = "tfrecord"


+ 2
- 2
model_zoo/official/recommend/wide_and_deep_multitable/src/config.py View File

@@ -35,7 +35,7 @@ def argparse_init():
parser.add_argument("--dropout_flag", type=int, default=1) # The dropout rate parser.add_argument("--dropout_flag", type=int, default=1) # The dropout rate
parser.add_argument("--output_path", type=str, default="./output/") # The location of the output file. parser.add_argument("--output_path", type=str, default="./output/") # The location of the output file.
parser.add_argument("--ckpt_path", type=str, default="./checkpoints/") # The location of the checkpoints file.
parser.add_argument("--ckpt_path", type=str, default="./") # The location of the checkpoints file.
parser.add_argument("--eval_file_name", type=str, default="eval.log") # Eval output file. parser.add_argument("--eval_file_name", type=str, default="eval.log") # Eval output file.
parser.add_argument("--loss_file_name", type=str, default="loss.log") # Loss output file. parser.add_argument("--loss_file_name", type=str, default="loss.log") # Loss output file.
return parser return parser
@@ -67,7 +67,7 @@ class WideDeepConfig():
self.output_path = "./output/" self.output_path = "./output/"
self.eval_file_name = "eval.log" self.eval_file_name = "eval.log"
self.loss_file_name = "loss.log" self.loss_file_name = "loss.log"
self.ckpt_path = "./checkpoints/"
self.ckpt_path = "./"
def argparse_init(self): def argparse_init(self):
""" """


Loading…
Cancel
Save