Merge pull request !3949 from meixiaowei/mastertags/v0.7.0-beta
| @@ -35,7 +35,7 @@ MaskRcnn is a two-stage target detection network,This network uses a region prop | |||||
| └─train2017 | └─train2017 | ||||
| ``` | ``` | ||||
| Notice that the coco2017 dataset will be converted to MindRecord which is a data format in MindSpore. The dataset conversion may take about 4 hours. | |||||
| 2. If your own dataset is used. **Select dataset to other when run script.** | 2. If your own dataset is used. **Select dataset to other when run script.** | ||||
| Organize the dataset infomation into a TXT file, each row in the file is as follows: | Organize the dataset infomation into a TXT file, each row in the file is as follows: | ||||
| @@ -134,6 +134,7 @@ config = ed({ | |||||
| "loss_scale": 1, | "loss_scale": 1, | ||||
| "momentum": 0.91, | "momentum": 0.91, | ||||
| "weight_decay": 1e-4, | "weight_decay": 1e-4, | ||||
| "pretrain_epoch_size": 0, | |||||
| "epoch_size": 12, | "epoch_size": 12, | ||||
| "save_checkpoint": True, | "save_checkpoint": True, | ||||
| "save_checkpoint_epochs": 1, | "save_checkpoint_epochs": 1, | ||||
| @@ -25,7 +25,7 @@ def a_cosine_learning_rate(current_step, base_lr, warmup_steps, decay_steps): | |||||
| learning_rate = (1 + math.cos(base * math.pi)) / 2 * base_lr | learning_rate = (1 + math.cos(base * math.pi)) / 2 * base_lr | ||||
| return learning_rate | return learning_rate | ||||
| def dynamic_lr(config, rank_size=1): | |||||
| def dynamic_lr(config, rank_size=1, start_steps=0): | |||||
| """dynamic learning rate generator""" | """dynamic learning rate generator""" | ||||
| base_lr = config.base_lr | base_lr = config.base_lr | ||||
| @@ -38,5 +38,5 @@ def dynamic_lr(config, rank_size=1): | |||||
| lr.append(linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * config.warmup_ratio)) | lr.append(linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * config.warmup_ratio)) | ||||
| else: | else: | ||||
| lr.append(a_cosine_learning_rate(i, base_lr, warmup_steps, total_steps)) | lr.append(a_cosine_learning_rate(i, base_lr, warmup_steps, total_steps)) | ||||
| return lr | |||||
| learning_rate = lr[start_steps:] | |||||
| return learning_rate | |||||
| @@ -108,13 +108,15 @@ if __name__ == '__main__': | |||||
| load_path = args_opt.pre_trained | load_path = args_opt.pre_trained | ||||
| if load_path != "": | if load_path != "": | ||||
| param_dict = load_checkpoint(load_path) | param_dict = load_checkpoint(load_path) | ||||
| for item in list(param_dict.keys()): | |||||
| if not (item.startswith('backbone') or item.startswith('rcnn_mask')): | |||||
| param_dict.pop(item) | |||||
| if config.pretrain_epoch_size == 0: | |||||
| for item in list(param_dict.keys()): | |||||
| if not (item.startswith('backbone') or item.startswith('rcnn_mask')): | |||||
| param_dict.pop(item) | |||||
| load_param_into_net(net, param_dict) | load_param_into_net(net, param_dict) | ||||
| loss = LossNet() | loss = LossNet() | ||||
| lr = Tensor(dynamic_lr(config, rank_size=device_num), mstype.float32) | |||||
| lr = Tensor(dynamic_lr(config, rank_size=device_num, start_steps=config.pretrain_epoch_size * dataset_size), | |||||
| mstype.float32) | |||||
| opt = SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum, | opt = SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum, | ||||
| weight_decay=config.weight_decay, loss_scale=config.loss_scale) | weight_decay=config.weight_decay, loss_scale=config.loss_scale) | ||||