Merge pull request !6498 from zhaoting/hubtags/v1.0.0
| @@ -0,0 +1,32 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """hub config.""" | |||||
| from src.mobilenetV2 import MobileNetV2Backbone, MobileNetV2Head, mobilenet_v2 | |||||
| def create_network(name, *args, **kwargs): | |||||
| if name == "mobilenetv2": | |||||
| backbone_net = MobileNetV2Backbone() | |||||
| include_top = kwargs["include_top"] | |||||
| if include_top is None: | |||||
| include_top = True | |||||
| if include_top: | |||||
| activation = kwargs["activation"] | |||||
| head_net = MobileNetV2Head(input_channel=backbone_net.out_channels, | |||||
| num_classes=int(kwargs["num_classes"]), | |||||
| activation=activation) | |||||
| net = mobilenet_v2(backbone_net, head_net) | |||||
| return net | |||||
| return backbone_net | |||||
| raise NotImplementedError(f"{name} is not implemented in the repo") | |||||
| @@ -48,6 +48,7 @@ def train_parse_args(): | |||||
| for fine tune or incremental learning') | for fine tune or incremental learning') | ||||
| train_parser.add_argument('--run_distribute', type=ast.literal_eval, default=True, help='Run distribute') | train_parser.add_argument('--run_distribute', type=ast.literal_eval, default=True, help='Run distribute') | ||||
| train_args = train_parser.parse_args() | train_args = train_parser.parse_args() | ||||
| train_args.is_training = True | |||||
| return train_args | return train_args | ||||
| def eval_parse_args(): | def eval_parse_args(): | ||||
| @@ -61,5 +62,6 @@ def eval_parse_args(): | |||||
| for incremental learning') | for incremental learning') | ||||
| eval_parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='If run distribute in GPU.') | eval_parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='If run distribute in GPU.') | ||||
| eval_args = eval_parser.parse_args() | eval_args = eval_parser.parse_args() | ||||
| eval_args.is_training = False | |||||
| return eval_args | return eval_args | ||||
| @@ -38,7 +38,8 @@ def set_config(args): | |||||
| "keep_checkpoint_max": 20, | "keep_checkpoint_max": 20, | ||||
| "save_checkpoint_path": "./", | "save_checkpoint_path": "./", | ||||
| "platform": args.platform, | "platform": args.platform, | ||||
| "run_distribute": False | |||||
| "run_distribute": False, | |||||
| "activation": "Softmax" | |||||
| }) | }) | ||||
| config_gpu = ed({ | config_gpu = ed({ | ||||
| "num_classes": 1000, | "num_classes": 1000, | ||||
| @@ -60,7 +61,8 @@ def set_config(args): | |||||
| "save_checkpoint_path": "./", | "save_checkpoint_path": "./", | ||||
| "platform": args.platform, | "platform": args.platform, | ||||
| "ccl": "nccl", | "ccl": "nccl", | ||||
| "run_distribute": args.run_distribute | |||||
| "run_distribute": args.run_distribute, | |||||
| "activation": "Softmax" | |||||
| }) | }) | ||||
| config_ascend = ed({ | config_ascend = ed({ | ||||
| "num_classes": 1000, | "num_classes": 1000, | ||||
| @@ -85,7 +87,8 @@ def set_config(args): | |||||
| "device_id": int(os.getenv('DEVICE_ID', '0')), | "device_id": int(os.getenv('DEVICE_ID', '0')), | ||||
| "rank_id": int(os.getenv('RANK_ID', '0')), | "rank_id": int(os.getenv('RANK_ID', '0')), | ||||
| "rank_size": int(os.getenv('RANK_SIZE', '1')), | "rank_size": int(os.getenv('RANK_SIZE', '1')), | ||||
| "run_distribute": int(os.getenv('RANK_SIZE', '1')) > 1. | |||||
| "run_distribute": int(os.getenv('RANK_SIZE', '1')) > 1., | |||||
| "activation": "Softmax" | |||||
| }) | }) | ||||
| config = ed({"CPU": config_cpu, | config = ed({"CPU": config_cpu, | ||||
| "GPU": config_gpu, | "GPU": config_gpu, | ||||
| @@ -242,16 +242,25 @@ class MobileNetV2Head(nn.Cell): | |||||
| >>> MobileNetV2(num_classes=1000) | >>> MobileNetV2(num_classes=1000) | ||||
| """ | """ | ||||
| def __init__(self, input_channel=1280, num_classes=1000, has_dropout=False): | |||||
| def __init__(self, input_channel=1280, num_classes=1000, has_dropout=False, activation="None"): | |||||
| super(MobileNetV2Head, self).__init__() | super(MobileNetV2Head, self).__init__() | ||||
| # mobilenet head | # mobilenet head | ||||
| head = ([GlobalAvgPooling(), nn.Dense(input_channel, num_classes, has_bias=True)] if not has_dropout else | head = ([GlobalAvgPooling(), nn.Dense(input_channel, num_classes, has_bias=True)] if not has_dropout else | ||||
| [GlobalAvgPooling(), nn.Dropout(0.2), nn.Dense(input_channel, num_classes, has_bias=True)]) | [GlobalAvgPooling(), nn.Dropout(0.2), nn.Dense(input_channel, num_classes, has_bias=True)]) | ||||
| self.head = nn.SequentialCell(head) | self.head = nn.SequentialCell(head) | ||||
| self.need_activation = True | |||||
| if activation == "Sigmoid": | |||||
| self.activation = P.Sigmoid() | |||||
| elif activation == "Softmax": | |||||
| self.activation = P.Softmax() | |||||
| else: | |||||
| self.need_activation = False | |||||
| self._initialize_weights() | self._initialize_weights() | ||||
| def construct(self, x): | def construct(self, x): | ||||
| x = self.head(x) | x = self.head(x) | ||||
| if self.need_activation: | |||||
| x = self.activation(x) | |||||
| return x | return x | ||||
| def _initialize_weights(self): | def _initialize_weights(self): | ||||
| @@ -121,7 +121,7 @@ def load_ckpt(network, pretrain_ckpt_path, trainable=True): | |||||
| def define_net(config): | def define_net(config): | ||||
| backbone_net = MobileNetV2Backbone() | backbone_net = MobileNetV2Backbone() | ||||
| activation = config.activation if not args.is_training else "None" | |||||
| head_net = MobileNetV2Head(input_channel=backbone_net.out_channels, num_classes=config.num_classes) | head_net = MobileNetV2Head(input_channel=backbone_net.out_channels, num_classes=config.num_classes) | ||||
| net = mobilenet_v2(backbone_net, head_net) | |||||
| net = mobilenet_v2(backbone_net, head_net, activation=activation) | |||||
| return backbone_net, head_net, net | return backbone_net, head_net, net | ||||
| @@ -158,7 +158,7 @@ Parameters for both training and evaluation can be set in config.py. | |||||
| "epoch_size": 90, # only valid for taining, which is always 1 for inference | "epoch_size": 90, # only valid for taining, which is always 1 for inference | ||||
| "pretrain_epoch_size": 0, # epoch size that model has been trained before loading pretrained checkpoint, actual training epoch size is equal to epoch_size minus pretrain_epoch_size | "pretrain_epoch_size": 0, # epoch size that model has been trained before loading pretrained checkpoint, actual training epoch size is equal to epoch_size minus pretrain_epoch_size | ||||
| "save_checkpoint": True, # whether save checkpoint or not | "save_checkpoint": True, # whether save checkpoint or not | ||||
| "save_checkpoint_steps": 195, # the step interval between two checkpoints. By default, the last checkpoint will be saved after the last step | |||||
| "save_checkpoint_epochs": 5, # the epoch interval between two checkpoints. By default, the last checkpoint will be saved after the last step | |||||
| "keep_checkpoint_max": 10, # only keep the last keep_checkpoint_max checkpoint | "keep_checkpoint_max": 10, # only keep the last keep_checkpoint_max checkpoint | ||||
| "save_checkpoint_path": "./", # path to save checkpoint | "save_checkpoint_path": "./", # path to save checkpoint | ||||
| "warmup_epochs": 5, # number of warmup epoch | "warmup_epochs": 5, # number of warmup epoch | ||||
| @@ -179,15 +179,16 @@ Parameters for both training and evaluation can be set in config.py. | |||||
| "epoch_size": 90, # only valid for taining, which is always 1 for inference | "epoch_size": 90, # only valid for taining, which is always 1 for inference | ||||
| "pretrain_epoch_size": 0, # epoch size that model has been trained before loading pretrained checkpoint, actual training epoch size is equal to epoch_size minus pretrain_epoch_size | "pretrain_epoch_size": 0, # epoch size that model has been trained before loading pretrained checkpoint, actual training epoch size is equal to epoch_size minus pretrain_epoch_size | ||||
| "save_checkpoint": True, # whether save checkpoint or not | "save_checkpoint": True, # whether save checkpoint or not | ||||
| "save_checkpoint_epochs": 1, # the epoch interval between two checkpoints. By default, the last checkpoint will be saved after the last epoch | |||||
| "save_checkpoint_epochs": 5, # the epoch interval between two checkpoints. By default, the last checkpoint will be saved after the last epoch | |||||
| "keep_checkpoint_max": 10, # only keep the last keep_checkpoint_max checkpoint | "keep_checkpoint_max": 10, # only keep the last keep_checkpoint_max checkpoint | ||||
| "save_checkpoint_path": "./", # path to save checkpoint relative to the executed path | "save_checkpoint_path": "./", # path to save checkpoint relative to the executed path | ||||
| "warmup_epochs": 0, # number of warmup epoch | "warmup_epochs": 0, # number of warmup epoch | ||||
| "lr_decay_mode": "cosine", # decay mode for generating learning rate | |||||
| "lr_decay_mode": "Linear", # decay mode for generating learning rate | |||||
| "label_smooth": True, # label smooth | "label_smooth": True, # label smooth | ||||
| "label_smooth_factor": 0.1, # label smooth factor | "label_smooth_factor": 0.1, # label smooth factor | ||||
| "lr_init": 0, # initial learning rate | "lr_init": 0, # initial learning rate | ||||
| "lr_max": 0.1, # maximum learning rate | "lr_max": 0.1, # maximum learning rate | ||||
| "lr_end": 0.0, # minimum learning rate | |||||
| ``` | ``` | ||||
| - Config for ResNet101, ImageNet2012 dataset | - Config for ResNet101, ImageNet2012 dataset | ||||
| @@ -201,7 +202,7 @@ Parameters for both training and evaluation can be set in config.py. | |||||
| "epoch_size": 120, # epoch size for training | "epoch_size": 120, # epoch size for training | ||||
| "pretrain_epoch_size": 0, # epoch size that model has been trained before loading pretrained checkpoint, actual training epoch size is equal to epoch_size minus pretrain_epoch_size | "pretrain_epoch_size": 0, # epoch size that model has been trained before loading pretrained checkpoint, actual training epoch size is equal to epoch_size minus pretrain_epoch_size | ||||
| "save_checkpoint": True, # whether save checkpoint or not | "save_checkpoint": True, # whether save checkpoint or not | ||||
| "save_checkpoint_epochs": 1, # the epoch interval between two checkpoints. By default, the last checkpoint will be saved after the last epoch | |||||
| "save_checkpoint_epochs": 5, # the epoch interval between two checkpoints. By default, the last checkpoint will be saved after the last epoch | |||||
| "keep_checkpoint_max": 10, # only keep the last keep_checkpoint_max checkpoint | "keep_checkpoint_max": 10, # only keep the last keep_checkpoint_max checkpoint | ||||
| "save_checkpoint_path": "./", # path to save checkpoint relative to the executed path | "save_checkpoint_path": "./", # path to save checkpoint relative to the executed path | ||||
| "warmup_epochs": 0, # number of warmup epoch | "warmup_epochs": 0, # number of warmup epoch | ||||
| @@ -0,0 +1,24 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """hub config.""" | |||||
| from src.ssd import SSD300, ssd_mobilenet_v2 | |||||
| from src.config import config | |||||
| def create_network(name, *args, **kwargs): | |||||
| if name == "ssd300": | |||||
| backbone = ssd_mobilenet_v2() | |||||
| ssd = SSD300(backbone=backbone, config=config, *args, **kwargs) | |||||
| return ssd | |||||
| raise NotImplementedError(f"{name} is not implemented in the repo") | |||||