Browse Source

train with ascend, modify api and debug

tags/v1.0.0
Payne 5 years ago
parent
commit
d6278c2bc6
6 changed files with 13 additions and 12 deletions
  1. +2
    -2
      model_zoo/official/cv/mobilenetv2/eval.py
  2. +2
    -2
      model_zoo/official/cv/mobilenetv2/scripts/run_train.sh
  3. +2
    -2
      model_zoo/official/cv/mobilenetv2/src/args.py
  4. +2
    -2
      model_zoo/official/cv/mobilenetv2/src/launch.py
  5. +2
    -1
      model_zoo/official/cv/mobilenetv2/src/utils.py
  6. +3
    -3
      model_zoo/official/cv/mobilenetv2/train.py

+ 2
- 2
model_zoo/official/cv/mobilenetv2/eval.py View File

@@ -32,7 +32,7 @@ if __name__ == '__main__':


backbone_net = MobileNetV2Backbone(platform=args_opt.platform) backbone_net = MobileNetV2Backbone(platform=args_opt.platform)
head_net = MobileNetV2Head(input_channel=backbone_net.out_channels, num_classes=config.num_classes) head_net = MobileNetV2Head(input_channel=backbone_net.out_channels, num_classes=config.num_classes)
net = mobilenet_v2(feature_net, head_net)
net = mobilenet_v2(backbone_net, head_net)


#load the trained checkpoint file to the net for evaluation #load the trained checkpoint file to the net for evaluation
if args_opt.head_ckpt: if args_opt.head_ckpt:
@@ -51,7 +51,7 @@ if __name__ == '__main__':
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
model = Model(net, loss_fn=loss, metrics={'acc'}) model = Model(net, loss_fn=loss, metrics={'acc'})


res = model.eval(dataset)
res = model.eval(dataset, dataset_sink_mode=False)
print(f"result:{res}\npretrain_ckpt={args_opt.pretrain_ckpt}") print(f"result:{res}\npretrain_ckpt={args_opt.pretrain_ckpt}")
if args_opt.head_ckpt: if args_opt.head_ckpt:
print(f"head_ckpt={args_opt.head_ckpt}") print(f"head_ckpt={args_opt.head_ckpt}")

+ 2
- 2
model_zoo/official/cv/mobilenetv2/scripts/run_train.sh View File

@@ -84,9 +84,9 @@ run_gpu()
run_cpu() run_cpu()
{ {


if [ ! -d $4 ]
if [ ! -d $2 ]
then then
echo "error: DATASET_PATH=$4 is not a directory"
echo "error: DATASET_PATH=$2 is not a directory"
exit 1 exit 1
fi fi




+ 2
- 2
model_zoo/official/cv/mobilenetv2/src/args.py View File

@@ -22,7 +22,7 @@ def launch_parse_args():
that will spawn up multiple distributed processes") that will spawn up multiple distributed processes")
launch_parser.add_argument('--platform', type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"), \ launch_parser.add_argument('--platform', type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"), \
help='run platform, only support GPU, CPU and Ascend') help='run platform, only support GPU, CPU and Ascend')
launch_parser.add_argument("--nproc_per_node", type=int, default=1, choices=(0, 1, 2, 3, 4, 5, 6, 7), \
launch_parser.add_argument("--nproc_per_node", type=int, default=1, choices=(1, 2, 3, 4, 5, 6, 7, 8), \
help="The number of processes to launch on each node, for D training, this is recommended to be set \ help="The number of processes to launch on each node, for D training, this is recommended to be set \
to the number of D in your system so that each process can be bound to a single D.") to the number of D in your system so that each process can be bound to a single D.")
launch_parser.add_argument("--visible_devices", type=str, default="0,1,2,3,4,5,6,7", help="will use the \ launch_parser.add_argument("--visible_devices", type=str, default="0,1,2,3,4,5,6,7", help="will use the \
@@ -32,7 +32,7 @@ def launch_parse_args():
the training script") the training script")


launch_args, unknown = launch_parser.parse_known_args() launch_args, unknown = launch_parser.parse_known_args()
launch_args.train_script_args = unknown
launch_args.training_script_args = unknown
launch_args.training_script_args += ["--platform", launch_args.platform] launch_args.training_script_args += ["--platform", launch_args.platform]
return launch_args return launch_args




+ 2
- 2
model_zoo/official/cv/mobilenetv2/src/launch.py View File

@@ -46,8 +46,8 @@ def main():
os.mkdir(device_dir) os.mkdir(device_dir)
os.chdir(device_dir) os.chdir(device_dir)
cmd = [sys.executable, '-u'] cmd = [sys.executable, '-u']
cmd.append(args.train_script)
cmd.extend(args.train_script_args)
cmd.append(args.training_script)
cmd.extend(args.training_script_args)
log_file = open('{dir}/log{id}.log'.format(dir=device_dir, id=rank_id), 'w') log_file = open('{dir}/log{id}.log'.format(dir=device_dir, id=rank_id), 'w')
process = subprocess.Popen(cmd, stdout=log_file, stderr=log_file, env=env) process = subprocess.Popen(cmd, stdout=log_file, stderr=log_file, env=env)
processes.append(process) processes.append(process)


+ 2
- 1
model_zoo/official/cv/mobilenetv2/src/utils.py View File

@@ -17,8 +17,9 @@ from mindspore import context
from mindspore import nn from mindspore import nn
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.train.model import ParallelMode from mindspore.train.model import ParallelMode
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.communication.management import get_rank, init
from mindspore.communication.management import get_rank, init, get_group_size
from src.models import Monitor from src.models import Monitor


+ 3
- 3
model_zoo/official/cv/mobilenetv2/train.py View File

@@ -26,7 +26,7 @@ from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import _exec_save_checkpoint
from mindspore.train.serialization import save_checkpoint
from mindspore.common import set_seed from mindspore.common import set_seed


from src.dataset import create_dataset, extract_features from src.dataset import create_dataset, extract_features
@@ -88,7 +88,7 @@ if __name__ == '__main__':
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay) opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay)


network = WithLossCell(net, loss) network = WithLossCell(net, loss)
network = TrainOneStepCell(net, opt)
network = TrainOneStepCell(network, opt)
network.set_train() network.set_train()


features_path = args_opt.dataset_path + '_features' features_path = args_opt.dataset_path + '_features'
@@ -116,7 +116,7 @@ if __name__ == '__main__':
.format(epoch + 1, step_size, epoch_mseconds, per_step_mseconds, np.mean(np.array(losses))), \ .format(epoch + 1, step_size, epoch_mseconds, per_step_mseconds, np.mean(np.array(losses))), \
end="") end="")
if (epoch + 1) % config.save_checkpoint_epochs == 0: if (epoch + 1) % config.save_checkpoint_epochs == 0:
_exec_save_checkpoint(network, os.path.join(config.save_checkpoint_path, \
save_checkpoint(network, os.path.join(config.save_checkpoint_path, \
f"mobilenetv2_head_{epoch+1}.ckpt")) f"mobilenetv2_head_{epoch+1}.ckpt"))
print("total cost {:5.4f} s".format(time.time() - start)) print("total cost {:5.4f} s".format(time.time() - start))




Loading…
Cancel
Save