| @@ -26,9 +26,8 @@ tile_op_info = TBERegOp("Tile") \ | |||||
| .attr("multiples", "optional", "listInt", "all")\ | .attr("multiples", "optional", "listInt", "all")\ | ||||
| .input(0, "x1", False, "required", "all") \ | .input(0, "x1", False, "required", "all") \ | ||||
| .output(0, "y", False, "required", "all") \ | .output(0, "y", False, "required", "all") \ | ||||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||||
| .op_pattern("dynamicFormat") \ | |||||
| .dtype_format(DataType.None_None, DataType.None_None) \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @@ -30,7 +30,8 @@ class CropAndResize(PrimitiveWithInfer): | |||||
| Args: | Args: | ||||
| method (str): An optional string that specifies the sampling method for resizing. | method (str): An optional string that specifies the sampling method for resizing. | ||||
| It can be either "bilinear" or "nearest". Default: "bilinear" | |||||
| It can be "bilinear", "nearest" or "bilinear_v2". The option "bilinear" stands for standard bilinear | |||||
| interpolation algorithm, while "bilinear_v2" may result in better result in some cases. Default: "bilinear" | |||||
| extrapolation_value (float): An optional float value used extrapolation, if applicable. Default: 0. | extrapolation_value (float): An optional float value used extrapolation, if applicable. Default: 0. | ||||
| Inputs: | Inputs: | ||||
| @@ -81,7 +82,7 @@ class CropAndResize(PrimitiveWithInfer): | |||||
| """init CropAndResize""" | """init CropAndResize""" | ||||
| self.init_prim_io_names(inputs=['x', 'boxes', 'box_index', 'crop_size'], outputs=['y']) | self.init_prim_io_names(inputs=['x', 'boxes', 'box_index', 'crop_size'], outputs=['y']) | ||||
| validator.check_value_type("method", method, [str], self.name) | validator.check_value_type("method", method, [str], self.name) | ||||
| validator.check_string("method", method, ["bilinear", "nearest"], self.name) | |||||
| validator.check_string("method", method, ["bilinear", "nearest", "bilinear_v2"], self.name) | |||||
| self.method = method | self.method = method | ||||
| validator.check_value_type("extrapolation_value", extrapolation_value, [float], self.name) | validator.check_value_type("extrapolation_value", extrapolation_value, [float], self.name) | ||||
| self.extrapolation_value = extrapolation_value | self.extrapolation_value = extrapolation_value | ||||
| @@ -406,9 +406,14 @@ def create_coco_label(is_training): | |||||
| image_anno_dict = {} | image_anno_dict = {} | ||||
| masks = {} | masks = {} | ||||
| masks_shape = {} | masks_shape = {} | ||||
| for img_id in image_ids: | |||||
| images_num = len(image_ids) | |||||
| for ind, img_id in enumerate(image_ids): | |||||
| image_info = coco.loadImgs(img_id) | image_info = coco.loadImgs(img_id) | ||||
| file_name = image_info[0]["file_name"] | file_name = image_info[0]["file_name"] | ||||
| image_path = os.path.join(coco_root, data_type, file_name) | |||||
| if not os.path.isfile(image_path): | |||||
| print("{}/{}: {} is in annotations but not exist".format(ind + 1, images_num, image_path)) | |||||
| continue | |||||
| anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=None) | anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=None) | ||||
| anno = coco.loadAnns(anno_ids) | anno = coco.loadAnns(anno_ids) | ||||
| image_path = os.path.join(coco_root, data_type, file_name) | image_path = os.path.join(coco_root, data_type, file_name) | ||||
| @@ -416,7 +421,8 @@ def create_coco_label(is_training): | |||||
| instance_masks = [] | instance_masks = [] | ||||
| image_height = coco.imgs[img_id]["height"] | image_height = coco.imgs[img_id]["height"] | ||||
| image_width = coco.imgs[img_id]["width"] | image_width = coco.imgs[img_id]["width"] | ||||
| print("image file name: ", file_name) | |||||
| if (ind + 1) % 10 == 0: | |||||
| print("{}/{}: parsing annotation for image={}".format(ind + 1, images_num, file_name)) | |||||
| if not is_training: | if not is_training: | ||||
| image_files.append(image_path) | image_files.append(image_path) | ||||
| image_anno_dict[image_path] = np.array([0, 0, 0, 0, 0, 1]) | image_anno_dict[image_path] = np.array([0, 0, 0, 0, 0, 1]) | ||||
| @@ -478,13 +484,16 @@ def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="mask | |||||
| } | } | ||||
| writer.add_schema(maskrcnn_json, "maskrcnn_json") | writer.add_schema(maskrcnn_json, "maskrcnn_json") | ||||
| for image_name in image_files: | |||||
| image_files_num = len(image_files) | |||||
| for ind, image_name in enumerate(image_files): | |||||
| with open(image_name, 'rb') as f: | with open(image_name, 'rb') as f: | ||||
| img = f.read() | img = f.read() | ||||
| annos = np.array(image_anno_dict[image_name], dtype=np.int32) | annos = np.array(image_anno_dict[image_name], dtype=np.int32) | ||||
| mask = masks[image_name] | mask = masks[image_name] | ||||
| mask_shape = masks_shape[image_name] | mask_shape = masks_shape[image_name] | ||||
| row = {"image": img, "annotation": annos, "mask": mask, "mask_shape": mask_shape} | row = {"image": img, "annotation": annos, "mask": mask, "mask_shape": mask_shape} | ||||
| if (ind + 1) % 10 == 0: | |||||
| print("writing {}/{} into mindrecord".format(ind + 1, image_files_num)) | |||||
| writer.write_raw_data([row]) | writer.write_raw_data([row]) | ||||
| writer.commit() | writer.commit() | ||||
| @@ -108,7 +108,7 @@ class BboxAssignSampleForRcnn(nn.Cell): | |||||
| self.round = P.Round() | self.round = P.Round() | ||||
| self.image_h_w = Tensor([cfg.img_height, cfg.img_width, cfg.img_height, cfg.img_width], dtype=mstype.float16) | self.image_h_w = Tensor([cfg.img_height, cfg.img_width, cfg.img_height, cfg.img_width], dtype=mstype.float16) | ||||
| self.range = nn.Range(start=0, limit=cfg.num_expected_pos_stage2) | self.range = nn.Range(start=0, limit=cfg.num_expected_pos_stage2) | ||||
| self.crop_and_resize = P.CropAndResize() | |||||
| self.crop_and_resize = P.CropAndResize(method="bilinear_v2") | |||||
| self.mask_shape = (cfg.mask_shape[0], cfg.mask_shape[1]) | self.mask_shape = (cfg.mask_shape[0], cfg.mask_shape[1]) | ||||
| self.squeeze_mask_last = P.Squeeze(axis=-1) | self.squeeze_mask_last = P.Squeeze(axis=-1) | ||||
| def construct(self, gt_bboxes_i, gt_labels_i, valid_mask, bboxes, gt_valids, gt_masks_i): | def construct(self, gt_bboxes_i, gt_labels_i, valid_mask, bboxes, gt_valids, gt_masks_i): | ||||
| @@ -84,9 +84,10 @@ class FeatPyramidNeck(nn.Cell): | |||||
| self.fpn_convs_.append(fpn_conv) | self.fpn_convs_.append(fpn_conv) | ||||
| self.lateral_convs_list = nn.layer.CellList(self.lateral_convs_list_) | self.lateral_convs_list = nn.layer.CellList(self.lateral_convs_list_) | ||||
| self.fpn_convs_list = nn.layer.CellList(self.fpn_convs_) | self.fpn_convs_list = nn.layer.CellList(self.fpn_convs_) | ||||
| self.interpolate1 = P.ResizeNearestNeighbor((48, 80)) | |||||
| self.interpolate2 = P.ResizeNearestNeighbor((96, 160)) | |||||
| self.interpolate3 = P.ResizeNearestNeighbor((192, 320)) | |||||
| self.interpolate1 = P.ResizeBilinear((48, 80)) | |||||
| self.interpolate2 = P.ResizeBilinear((96, 160)) | |||||
| self.interpolate3 = P.ResizeBilinear((192, 320)) | |||||
| self.cast = P.Cast() | |||||
| self.maxpool = P.MaxPool(ksize=1, strides=2, padding="same") | self.maxpool = P.MaxPool(ksize=1, strides=2, padding="same") | ||||
| def construct(self, inputs): | def construct(self, inputs): | ||||
| @@ -95,9 +96,9 @@ class FeatPyramidNeck(nn.Cell): | |||||
| x += (self.lateral_convs_list[i](inputs[i]),) | x += (self.lateral_convs_list[i](inputs[i]),) | ||||
| y = (x[3],) | y = (x[3],) | ||||
| y = y + (x[2] + self.interpolate1(y[self.fpn_layer - 4]),) | |||||
| y = y + (x[1] + self.interpolate2(y[self.fpn_layer - 3]),) | |||||
| y = y + (x[0] + self.interpolate3(y[self.fpn_layer - 2]),) | |||||
| y = y + (x[2] + self.cast(self.interpolate1(y[self.fpn_layer - 4]), mstype.float16),) | |||||
| y = y + (x[1] + self.cast(self.interpolate2(y[self.fpn_layer - 3]), mstype.float16),) | |||||
| y = y + (x[0] + self.cast(self.interpolate3(y[self.fpn_layer - 2]), mstype.float16),) | |||||
| z = () | z = () | ||||
| for i in range(self.fpn_layer - 1, -1, -1): | for i in range(self.fpn_layer - 1, -1, -1): | ||||
| @@ -247,7 +247,6 @@ def get_seg_masks(mask_pred, det_bboxes, det_labels, img_meta, rescale, num_clas | |||||
| else: | else: | ||||
| img_h = np.round(ori_shape[0] * scale_factor[0]).astype(np.int32) | img_h = np.round(ori_shape[0] * scale_factor[0]).astype(np.int32) | ||||
| img_w = np.round(ori_shape[1] * scale_factor[1]).astype(np.int32) | img_w = np.round(ori_shape[1] * scale_factor[1]).astype(np.int32) | ||||
| scale_factor = 1.0 | |||||
| for i in range(bboxes.shape[0]): | for i in range(bboxes.shape[0]): | ||||
| bbox = (bboxes[i, :] / 1.0).astype(np.int32) | bbox = (bboxes[i, :] / 1.0).astype(np.int32) | ||||
| @@ -256,6 +255,10 @@ def get_seg_masks(mask_pred, det_bboxes, det_labels, img_meta, rescale, num_clas | |||||
| h = max(bbox[3] - bbox[1] + 1, 1) | h = max(bbox[3] - bbox[1] + 1, 1) | ||||
| w = min(w, img_w - bbox[0]) | w = min(w, img_w - bbox[0]) | ||||
| h = min(h, img_h - bbox[1]) | h = min(h, img_h - bbox[1]) | ||||
| if w <= 0 or h <= 0: | |||||
| print("there is invalid proposal bbox, index={} bbox={} w={} h={}".format(i, bbox, w, h)) | |||||
| w = max(w, 1) | |||||
| h = max(h, 1) | |||||
| mask_pred_ = mask_pred[i, :, :] | mask_pred_ = mask_pred[i, :, :] | ||||
| im_mask = np.zeros((img_h, img_w), dtype=np.uint8) | im_mask = np.zeros((img_h, img_w), dtype=np.uint8) | ||||
| bbox_mask = mmcv.imresize(mask_pred_, (w, h)) | bbox_mask = mmcv.imresize(mask_pred_, (w, h)) | ||||
| @@ -16,6 +16,7 @@ | |||||
| """train MaskRcnn and get checkpoint files.""" | """train MaskRcnn and get checkpoint files.""" | ||||
| import os | import os | ||||
| import time | |||||
| import argparse | import argparse | ||||
| import ast | import ast | ||||
| @@ -26,7 +27,7 @@ from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMoni | |||||
| from mindspore.train import Model | from mindspore.train import Model | ||||
| from mindspore.context import ParallelMode | from mindspore.context import ParallelMode | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| from mindspore.nn import SGD | |||||
| from mindspore.nn import Momentum | |||||
| from mindspore.common import set_seed | from mindspore.common import set_seed | ||||
| from src.maskrcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50 | from src.maskrcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50 | ||||
| @@ -71,7 +72,7 @@ if __name__ == '__main__': | |||||
| prefix = "MaskRcnn.mindrecord" | prefix = "MaskRcnn.mindrecord" | ||||
| mindrecord_dir = config.mindrecord_dir | mindrecord_dir = config.mindrecord_dir | ||||
| mindrecord_file = os.path.join(mindrecord_dir, prefix + "0") | mindrecord_file = os.path.join(mindrecord_dir, prefix + "0") | ||||
| if not os.path.exists(mindrecord_file): | |||||
| if rank == 0 and not os.path.exists(mindrecord_file): | |||||
| if not os.path.isdir(mindrecord_dir): | if not os.path.isdir(mindrecord_dir): | ||||
| os.makedirs(mindrecord_dir) | os.makedirs(mindrecord_dir) | ||||
| if args_opt.dataset == "coco": | if args_opt.dataset == "coco": | ||||
| @@ -80,14 +81,16 @@ if __name__ == '__main__': | |||||
| data_to_mindrecord_byte_image("coco", True, prefix) | data_to_mindrecord_byte_image("coco", True, prefix) | ||||
| print("Create Mindrecord Done, at {}".format(mindrecord_dir)) | print("Create Mindrecord Done, at {}".format(mindrecord_dir)) | ||||
| else: | else: | ||||
| print("coco_root not exits.") | |||||
| raise Exception("coco_root not exits.") | |||||
| else: | else: | ||||
| if os.path.isdir(config.IMAGE_DIR) and os.path.exists(config.ANNO_PATH): | if os.path.isdir(config.IMAGE_DIR) and os.path.exists(config.ANNO_PATH): | ||||
| print("Create Mindrecord.") | print("Create Mindrecord.") | ||||
| data_to_mindrecord_byte_image("other", True, prefix) | data_to_mindrecord_byte_image("other", True, prefix) | ||||
| print("Create Mindrecord Done, at {}".format(mindrecord_dir)) | print("Create Mindrecord Done, at {}".format(mindrecord_dir)) | ||||
| else: | else: | ||||
| print("IMAGE_DIR or ANNO_PATH not exits.") | |||||
| raise Exception("IMAGE_DIR or ANNO_PATH not exits.") | |||||
| while not os.path.exists(mindrecord_file+".db"): | |||||
| time.sleep(5) | |||||
| if not args_opt.only_create_dataset: | if not args_opt.only_create_dataset: | ||||
| loss_scale = float(config.loss_scale) | loss_scale = float(config.loss_scale) | ||||
| @@ -115,8 +118,8 @@ if __name__ == '__main__': | |||||
| loss = LossNet() | loss = LossNet() | ||||
| lr = Tensor(dynamic_lr(config, rank_size=device_num, start_steps=config.pretrain_epoch_size * dataset_size), | lr = Tensor(dynamic_lr(config, rank_size=device_num, start_steps=config.pretrain_epoch_size * dataset_size), | ||||
| mstype.float32) | mstype.float32) | ||||
| opt = SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum, | |||||
| weight_decay=config.weight_decay, loss_scale=config.loss_scale) | |||||
| opt = Momentum(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum, | |||||
| weight_decay=config.weight_decay, loss_scale=config.loss_scale) | |||||
| net_with_loss = WithLossCell(net, loss) | net_with_loss = WithLossCell(net, loss) | ||||
| if args_opt.run_distribute: | if args_opt.run_distribute: | ||||