Browse Source

debug load_ckpt,rename the network parameters;move the raiseerror of dataset size from dataset.py to train.py, give the control switch to user

tags/v1.0.0
Payne 5 years ago
parent
commit
ac8f7734bb
6 changed files with 42 additions and 35 deletions
  1. +5
    -1
      model_zoo/official/cv/mobilenetv2/eval.py
  2. +3
    -3
      model_zoo/official/cv/mobilenetv2/scripts/run_eval.sh
  3. +6
    -7
      model_zoo/official/cv/mobilenetv2/src/dataset.py
  4. +9
    -9
      model_zoo/official/cv/mobilenetv2/src/mobilenetV2.py
  5. +1
    -12
      model_zoo/official/cv/mobilenetv2/src/models.py
  6. +18
    -3
      model_zoo/official/cv/mobilenetv2/train.py

+ 5
- 1
model_zoo/official/cv/mobilenetv2/eval.py View File

@@ -28,7 +28,7 @@ from src.utils import switch_precision, set_context
if __name__ == '__main__': if __name__ == '__main__':
args_opt = eval_parse_args() args_opt = eval_parse_args()
config = set_config(args_opt) config = set_config(args_opt)
backbone_net, head_net, net = define_net(args_opt, config)
backbone_net, head_net, net = define_net(config)


#load the trained checkpoint file to the net for evaluation #load the trained checkpoint file to the net for evaluation
if args_opt.head_ckpt: if args_opt.head_ckpt:
@@ -42,6 +42,10 @@ if __name__ == '__main__':


dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, config=config) dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, config=config)
step_size = dataset.get_dataset_size() step_size = dataset.get_dataset_size()
if step_size == 0:
raise ValueError("The step_size of dataset is zero. Check if the images count of train dataset is more \
than batch_size in config.py")

net.set_train(False) net.set_train(False)


loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')


+ 3
- 3
model_zoo/official/cv/mobilenetv2/scripts/run_eval.sh View File

@@ -103,9 +103,9 @@ run_cpu()
if [ $# -gt 4 ] || [ $# -lt 3 ] if [ $# -gt 4 ] || [ $# -lt 3 ]
then then
echo "Usage: echo "Usage:
Ascend: sh run_infer.sh [PLATFORM] [DATASET_PATH] [PRETRAIN_CKPT]
GPU: sh run_infer.sh [PLATFORM] [DATASET_PATH] [PRETRAIN_CKPT]
CPU: sh run_infer.sh [PLATFORM] [DATASET_PATH] [BACKBONE_CKPT] [HEAD_CKPT]"
Ascend: sh run_eval.sh [PLATFORM] [DATASET_PATH] [PRETRAIN_CKPT]
GPU: sh run_eval.sh [PLATFORM] [DATASET_PATH] [PRETRAIN_CKPT]
CPU: sh run_eval.sh [PLATFORM] [DATASET_PATH] [BACKBONE_CKPT] [HEAD_CKPT]"
exit 1 exit 1
fi fi




+ 6
- 7
model_zoo/official/cv/mobilenetv2/src/dataset.py View File

@@ -36,7 +36,6 @@ def create_dataset(dataset_path, do_train, config, repeat_num=1):
config(struct): the config of train and eval in diffirent platform. config(struct): the config of train and eval in diffirent platform.
repeat_num(int): the repeat times of dataset. Default: 1. repeat_num(int): the repeat times of dataset. Default: 1.



Returns: Returns:
dataset dataset
""" """
@@ -96,11 +95,7 @@ def create_dataset(dataset_path, do_train, config, repeat_num=1):
# apply dataset repeat operation # apply dataset repeat operation
ds = ds.repeat(repeat_num) ds = ds.repeat(repeat_num)


step_size = ds.get_dataset_size()
if step_size == 0:
raise ValueError("The step_size of dataset is zero. Check if the images of train dataset is more than batch_\
size in config.py")
return ds, step_size
return ds




def extract_features(net, dataset_path, config): def extract_features(net, dataset_path, config):
@@ -112,12 +107,16 @@ def extract_features(net, dataset_path, config):
config=config, config=config,
repeat_num=1) repeat_num=1)
step_size = dataset.get_dataset_size() step_size = dataset.get_dataset_size()
if step_size == 0:
raise ValueError("The step_size of dataset is zero. Check if the images count of train dataset is more \
than batch_size in config.py")

model = Model(net) model = Model(net)


for i, data in enumerate(dataset.create_dict_iterator(output_numpy=True)): for i, data in enumerate(dataset.create_dict_iterator(output_numpy=True)):
features_path = os.path.join(features_folder, f"feature_{i}.npy") features_path = os.path.join(features_folder, f"feature_{i}.npy")
label_path = os.path.join(features_folder, f"label_{i}.npy") label_path = os.path.join(features_folder, f"label_{i}.npy")
if not os.path.exists(features_path or not os.path.exists(label_path)):
if not os.path.exists(features_path) or not os.path.exists(label_path):
image = data["image"] image = data["image"]
label = data["label"] label = data["label"]
features = model.predict(Tensor(image)) features = model.predict(Tensor(image))


+ 9
- 9
model_zoo/official/cv/mobilenetv2/src/mobilenetV2.py View File

@@ -284,8 +284,11 @@ class MobileNetV2(nn.Cell):
MobileNetV2 architecture. MobileNetV2 architecture.


Args: Args:
backbone(nn.Cell):
head(nn.Cell):
class_num (Cell): number of classes.
width_mult (int): Channels multiplier for round to 8/16 and others. Default is 1.
has_dropout (bool): Is dropout used. Default is false
inverted_residual_setting (list): Inverted residual settings. Default is None
round_nearest (list): Channel round to . Default is 8
Returns: Returns:
Tensor, output tensor. Tensor, output tensor.


@@ -310,14 +313,11 @@ class MobileNetV2(nn.Cell):


class MobileNetV2Combine(nn.Cell): class MobileNetV2Combine(nn.Cell):
""" """
MobileNetV2 architecture.
MobileNetV2Combine architecture.


Args: Args:
class_num (Cell): number of classes.
width_mult (int): Channels multiplier for round to 8/16 and others. Default is 1.
has_dropout (bool): Is dropout used. Default is false
inverted_residual_setting (list): Inverted residual settings. Default is None
round_nearest (list): Channel round to . Default is 8
backbone (Cell): the features extract layers.
head (Cell): the fully connected layers.
Returns: Returns:
Tensor, output tensor. Tensor, output tensor.


@@ -326,7 +326,7 @@ class MobileNetV2Combine(nn.Cell):
""" """


def __init__(self, backbone, head): def __init__(self, backbone, head):
super(MobileNetV2Combine, self).__init__()
super(MobileNetV2Combine, self).__init__(auto_prefix=False)
self.backbone = backbone self.backbone = backbone
self.head = head self.head = head




+ 1
- 12
model_zoo/official/cv/mobilenetv2/src/models.py View File

@@ -119,20 +119,9 @@ def load_ckpt(network, pretrain_ckpt_path, trainable=True):
for param in network.get_parameters(): for param in network.get_parameters():
param.requires_grad = False param.requires_grad = False


def define_net(args, config):
def define_net(config):
backbone_net = MobileNetV2Backbone() backbone_net = MobileNetV2Backbone()
head_net = MobileNetV2Head(input_channel=backbone_net.out_channels, num_classes=config.num_classes) head_net = MobileNetV2Head(input_channel=backbone_net.out_channels, num_classes=config.num_classes)
net = mobilenet_v2(backbone_net, head_net) net = mobilenet_v2(backbone_net, head_net)


# load the ckpt file to the network for fine tune or incremental leaning
if args.pretrain_ckpt:
if args.train_method == "fine_tune":
load_ckpt(net, args.pretrain_ckpt)
elif args.train_method == "incremental_learn":
load_ckpt(backbone_net, args.pretrain_ckpt, trainable=False)
elif args.train_method == "train":
pass
else:
raise ValueError("must input the usage of pretrain_ckpt when the pretrain_ckpt isn't None")

return backbone_net, head_net, net return backbone_net, head_net, net

+ 18
- 3
model_zoo/official/cv/mobilenetv2/train.py View File

@@ -35,7 +35,7 @@ from src.config import set_config


from src.args import train_parse_args from src.args import train_parse_args
from src.utils import context_device_init, switch_precision, config_ckpoint from src.utils import context_device_init, switch_precision, config_ckpoint
from src.models import CrossEntropyWithLabelSmooth, define_net
from src.models import CrossEntropyWithLabelSmooth, define_net, load_ckpt


set_seed(1) set_seed(1)


@@ -50,7 +50,18 @@ if __name__ == '__main__':
context_device_init(config) context_device_init(config)


# define network # define network
backbone_net, head_net, net = define_net(args_opt, config)
backbone_net, head_net, net = define_net(config)

# load the ckpt file to the network for fine tune or incremental leaning
if args_opt.pretrain_ckpt:
if args_opt.train_method == "fine_tune":
load_ckpt(net, args_opt.pretrain_ckpt)
elif args_opt.train_method == "incremental_learn":
load_ckpt(backbone_net, args_opt.pretrain_ckpt, trainable=False)
elif args_opt.train_method == "train":
pass
else:
raise ValueError("must input the usage of pretrain_ckpt when the pretrain_ckpt isn't None")


# CPU only support "incremental_learn" # CPU only support "incremental_learn"
if args_opt.train_method == "incremental_learn": if args_opt.train_method == "incremental_learn":
@@ -60,7 +71,11 @@ if __name__ == '__main__':
elif args_opt.train_method in ("train", "fine_tune"): elif args_opt.train_method in ("train", "fine_tune"):
if args_opt.platform == "CPU": if args_opt.platform == "CPU":
raise ValueError("Currently, CPU only support \"incremental_learn\", not \"fine_tune\" or \"train\".") raise ValueError("Currently, CPU only support \"incremental_learn\", not \"fine_tune\" or \"train\".")
dataset, step_size = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, config=config)
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, config=config)
step_size = dataset.get_dataset_size()
if step_size == 0:
raise ValueError("The step_size of dataset is zero. Check if the images count of train dataset is more \
than batch_size in config.py")


# Currently, only Ascend support switch precision. # Currently, only Ascend support switch precision.
switch_precision(net, mstype.float16, config) switch_precision(net, mstype.float16, config)


Loading…
Cancel
Save