Browse Source

!5849 mobilenetv2 modify api and debug

Merge pull request !5849 from yepei6/mobilenetv2
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
d14221ea8b
6 changed files with 11 additions and 10 deletions
  1. +2
    -2
      model_zoo/official/cv/mobilenetv2/eval.py
  2. +2
    -2
      model_zoo/official/cv/mobilenetv2/scripts/run_train.sh
  3. +2
    -2
      model_zoo/official/cv/mobilenetv2/src/args.py
  4. +2
    -2
      model_zoo/official/cv/mobilenetv2/src/launch.py
  5. +2
    -1
      model_zoo/official/cv/mobilenetv2/src/utils.py
  6. +1
    -1
      model_zoo/official/cv/mobilenetv2/train.py

+ 2
- 2
model_zoo/official/cv/mobilenetv2/eval.py View File

@@ -32,7 +32,7 @@ if __name__ == '__main__':


backbone_net = MobileNetV2Backbone(platform=args_opt.platform) backbone_net = MobileNetV2Backbone(platform=args_opt.platform)
head_net = MobileNetV2Head(input_channel=backbone_net.out_channels, num_classes=config.num_classes) head_net = MobileNetV2Head(input_channel=backbone_net.out_channels, num_classes=config.num_classes)
net = mobilenet_v2(feature_net, head_net)
net = mobilenet_v2(backbone_net, head_net)


#load the trained checkpoint file to the net for evaluation #load the trained checkpoint file to the net for evaluation
if args_opt.head_ckpt: if args_opt.head_ckpt:
@@ -51,7 +51,7 @@ if __name__ == '__main__':
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
model = Model(net, loss_fn=loss, metrics={'acc'}) model = Model(net, loss_fn=loss, metrics={'acc'})


res = model.eval(dataset)
res = model.eval(dataset, dataset_sink_mode=False)
print(f"result:{res}\npretrain_ckpt={args_opt.pretrain_ckpt}") print(f"result:{res}\npretrain_ckpt={args_opt.pretrain_ckpt}")
if args_opt.head_ckpt: if args_opt.head_ckpt:
print(f"head_ckpt={args_opt.head_ckpt}") print(f"head_ckpt={args_opt.head_ckpt}")

+ 2
- 2
model_zoo/official/cv/mobilenetv2/scripts/run_train.sh View File

@@ -84,9 +84,9 @@ run_gpu()
run_cpu() run_cpu()
{ {


if [ ! -d $4 ]
if [ ! -d $2 ]
then then
echo "error: DATASET_PATH=$4 is not a directory"
echo "error: DATASET_PATH=$2 is not a directory"
exit 1 exit 1
fi fi




+ 2
- 2
model_zoo/official/cv/mobilenetv2/src/args.py View File

@@ -22,7 +22,7 @@ def launch_parse_args():
that will spawn up multiple distributed processes") that will spawn up multiple distributed processes")
launch_parser.add_argument('--platform', type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"), \ launch_parser.add_argument('--platform', type=str, default="Ascend", choices=("Ascend", "GPU", "CPU"), \
help='run platform, only support GPU, CPU and Ascend') help='run platform, only support GPU, CPU and Ascend')
launch_parser.add_argument("--nproc_per_node", type=int, default=1, choices=(0, 1, 2, 3, 4, 5, 6, 7), \
launch_parser.add_argument("--nproc_per_node", type=int, default=1, choices=(1, 2, 3, 4, 5, 6, 7, 8), \
help="The number of processes to launch on each node, for D training, this is recommended to be set \ help="The number of processes to launch on each node, for D training, this is recommended to be set \
to the number of D in your system so that each process can be bound to a single D.") to the number of D in your system so that each process can be bound to a single D.")
launch_parser.add_argument("--visible_devices", type=str, default="0,1,2,3,4,5,6,7", help="will use the \ launch_parser.add_argument("--visible_devices", type=str, default="0,1,2,3,4,5,6,7", help="will use the \
@@ -32,7 +32,7 @@ def launch_parse_args():
the training script") the training script")


launch_args, unknown = launch_parser.parse_known_args() launch_args, unknown = launch_parser.parse_known_args()
launch_args.train_script_args = unknown
launch_args.training_script_args = unknown
launch_args.training_script_args += ["--platform", launch_args.platform] launch_args.training_script_args += ["--platform", launch_args.platform]
return launch_args return launch_args




+ 2
- 2
model_zoo/official/cv/mobilenetv2/src/launch.py View File

@@ -46,8 +46,8 @@ def main():
os.mkdir(device_dir) os.mkdir(device_dir)
os.chdir(device_dir) os.chdir(device_dir)
cmd = [sys.executable, '-u'] cmd = [sys.executable, '-u']
cmd.append(args.train_script)
cmd.extend(args.train_script_args)
cmd.append(args.training_script)
cmd.extend(args.training_script_args)
log_file = open('{dir}/log{id}.log'.format(dir=device_dir, id=rank_id), 'w') log_file = open('{dir}/log{id}.log'.format(dir=device_dir, id=rank_id), 'w')
process = subprocess.Popen(cmd, stdout=log_file, stderr=log_file, env=env) process = subprocess.Popen(cmd, stdout=log_file, stderr=log_file, env=env)
processes.append(process) processes.append(process)


+ 2
- 1
model_zoo/official/cv/mobilenetv2/src/utils.py View File

@@ -17,8 +17,9 @@ from mindspore import context
from mindspore import nn from mindspore import nn
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.train.model import ParallelMode from mindspore.train.model import ParallelMode
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.communication.management import get_rank, init
from mindspore.communication.management import get_rank, init, get_group_size
from src.models import Monitor from src.models import Monitor


+ 1
- 1
model_zoo/official/cv/mobilenetv2/train.py View File

@@ -88,7 +88,7 @@ if __name__ == '__main__':
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay) opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay)


network = WithLossCell(net, loss) network = WithLossCell(net, loss)
network = TrainOneStepCell(net, opt)
network = TrainOneStepCell(network, opt)
network.set_train() network.set_train()


features_path = args_opt.dataset_path + '_features' features_path = args_opt.dataset_path + '_features'


Loading…
Cancel
Save