| @@ -26,9 +26,8 @@ tile_op_info = TBERegOp("Tile") \ | |||
| .attr("multiples", "optional", "listInt", "all")\ | |||
| .input(0, "x1", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .op_pattern("dynamicFormat") \ | |||
| .dtype_format(DataType.None_None, DataType.None_None) \ | |||
| .get_op_info() | |||
| @@ -30,7 +30,8 @@ class CropAndResize(PrimitiveWithInfer): | |||
| Args: | |||
| method (str): An optional string that specifies the sampling method for resizing. | |||
| It can be either "bilinear" or "nearest". Default: "bilinear" | |||
| It can be "bilinear", "nearest" or "bilinear_v2". The option "bilinear" stands for standard bilinear | |||
| interpolation algorithm, while "bilinear_v2" may result in better result in some cases. Default: "bilinear" | |||
| extrapolation_value (float): An optional float value used extrapolation, if applicable. Default: 0. | |||
| Inputs: | |||
| @@ -81,7 +82,7 @@ class CropAndResize(PrimitiveWithInfer): | |||
| """init CropAndResize""" | |||
| self.init_prim_io_names(inputs=['x', 'boxes', 'box_index', 'crop_size'], outputs=['y']) | |||
| validator.check_value_type("method", method, [str], self.name) | |||
| validator.check_string("method", method, ["bilinear", "nearest"], self.name) | |||
| validator.check_string("method", method, ["bilinear", "nearest", "bilinear_v2"], self.name) | |||
| self.method = method | |||
| validator.check_value_type("extrapolation_value", extrapolation_value, [float], self.name) | |||
| self.extrapolation_value = extrapolation_value | |||
| @@ -406,9 +406,14 @@ def create_coco_label(is_training): | |||
| image_anno_dict = {} | |||
| masks = {} | |||
| masks_shape = {} | |||
| for img_id in image_ids: | |||
| images_num = len(image_ids) | |||
| for ind, img_id in enumerate(image_ids): | |||
| image_info = coco.loadImgs(img_id) | |||
| file_name = image_info[0]["file_name"] | |||
| image_path = os.path.join(coco_root, data_type, file_name) | |||
| if not os.path.isfile(image_path): | |||
| print("{}/{}: {} is in annotations but not exist".format(ind + 1, images_num, image_path)) | |||
| continue | |||
| anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=None) | |||
| anno = coco.loadAnns(anno_ids) | |||
| image_path = os.path.join(coco_root, data_type, file_name) | |||
| @@ -416,7 +421,8 @@ def create_coco_label(is_training): | |||
| instance_masks = [] | |||
| image_height = coco.imgs[img_id]["height"] | |||
| image_width = coco.imgs[img_id]["width"] | |||
| print("image file name: ", file_name) | |||
| if (ind + 1) % 10 == 0: | |||
| print("{}/{}: parsing annotation for image={}".format(ind + 1, images_num, file_name)) | |||
| if not is_training: | |||
| image_files.append(image_path) | |||
| image_anno_dict[image_path] = np.array([0, 0, 0, 0, 0, 1]) | |||
| @@ -478,13 +484,16 @@ def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="mask | |||
| } | |||
| writer.add_schema(maskrcnn_json, "maskrcnn_json") | |||
| for image_name in image_files: | |||
| image_files_num = len(image_files) | |||
| for ind, image_name in enumerate(image_files): | |||
| with open(image_name, 'rb') as f: | |||
| img = f.read() | |||
| annos = np.array(image_anno_dict[image_name], dtype=np.int32) | |||
| mask = masks[image_name] | |||
| mask_shape = masks_shape[image_name] | |||
| row = {"image": img, "annotation": annos, "mask": mask, "mask_shape": mask_shape} | |||
| if (ind + 1) % 10 == 0: | |||
| print("writing {}/{} into mindrecord".format(ind + 1, image_files_num)) | |||
| writer.write_raw_data([row]) | |||
| writer.commit() | |||
| @@ -108,7 +108,7 @@ class BboxAssignSampleForRcnn(nn.Cell): | |||
| self.round = P.Round() | |||
| self.image_h_w = Tensor([cfg.img_height, cfg.img_width, cfg.img_height, cfg.img_width], dtype=mstype.float16) | |||
| self.range = nn.Range(start=0, limit=cfg.num_expected_pos_stage2) | |||
| self.crop_and_resize = P.CropAndResize() | |||
| self.crop_and_resize = P.CropAndResize(method="bilinear_v2") | |||
| self.mask_shape = (cfg.mask_shape[0], cfg.mask_shape[1]) | |||
| self.squeeze_mask_last = P.Squeeze(axis=-1) | |||
| def construct(self, gt_bboxes_i, gt_labels_i, valid_mask, bboxes, gt_valids, gt_masks_i): | |||
| @@ -84,9 +84,10 @@ class FeatPyramidNeck(nn.Cell): | |||
| self.fpn_convs_.append(fpn_conv) | |||
| self.lateral_convs_list = nn.layer.CellList(self.lateral_convs_list_) | |||
| self.fpn_convs_list = nn.layer.CellList(self.fpn_convs_) | |||
| self.interpolate1 = P.ResizeNearestNeighbor((48, 80)) | |||
| self.interpolate2 = P.ResizeNearestNeighbor((96, 160)) | |||
| self.interpolate3 = P.ResizeNearestNeighbor((192, 320)) | |||
| self.interpolate1 = P.ResizeBilinear((48, 80)) | |||
| self.interpolate2 = P.ResizeBilinear((96, 160)) | |||
| self.interpolate3 = P.ResizeBilinear((192, 320)) | |||
| self.cast = P.Cast() | |||
| self.maxpool = P.MaxPool(ksize=1, strides=2, padding="same") | |||
| def construct(self, inputs): | |||
| @@ -95,9 +96,9 @@ class FeatPyramidNeck(nn.Cell): | |||
| x += (self.lateral_convs_list[i](inputs[i]),) | |||
| y = (x[3],) | |||
| y = y + (x[2] + self.interpolate1(y[self.fpn_layer - 4]),) | |||
| y = y + (x[1] + self.interpolate2(y[self.fpn_layer - 3]),) | |||
| y = y + (x[0] + self.interpolate3(y[self.fpn_layer - 2]),) | |||
| y = y + (x[2] + self.cast(self.interpolate1(y[self.fpn_layer - 4]), mstype.float16),) | |||
| y = y + (x[1] + self.cast(self.interpolate2(y[self.fpn_layer - 3]), mstype.float16),) | |||
| y = y + (x[0] + self.cast(self.interpolate3(y[self.fpn_layer - 2]), mstype.float16),) | |||
| z = () | |||
| for i in range(self.fpn_layer - 1, -1, -1): | |||
| @@ -247,7 +247,6 @@ def get_seg_masks(mask_pred, det_bboxes, det_labels, img_meta, rescale, num_clas | |||
| else: | |||
| img_h = np.round(ori_shape[0] * scale_factor[0]).astype(np.int32) | |||
| img_w = np.round(ori_shape[1] * scale_factor[1]).astype(np.int32) | |||
| scale_factor = 1.0 | |||
| for i in range(bboxes.shape[0]): | |||
| bbox = (bboxes[i, :] / 1.0).astype(np.int32) | |||
| @@ -256,6 +255,10 @@ def get_seg_masks(mask_pred, det_bboxes, det_labels, img_meta, rescale, num_clas | |||
| h = max(bbox[3] - bbox[1] + 1, 1) | |||
| w = min(w, img_w - bbox[0]) | |||
| h = min(h, img_h - bbox[1]) | |||
| if w <= 0 or h <= 0: | |||
| print("there is invalid proposal bbox, index={} bbox={} w={} h={}".format(i, bbox, w, h)) | |||
| w = max(w, 1) | |||
| h = max(h, 1) | |||
| mask_pred_ = mask_pred[i, :, :] | |||
| im_mask = np.zeros((img_h, img_w), dtype=np.uint8) | |||
| bbox_mask = mmcv.imresize(mask_pred_, (w, h)) | |||
| @@ -16,6 +16,7 @@ | |||
| """train MaskRcnn and get checkpoint files.""" | |||
| import os | |||
| import time | |||
| import argparse | |||
| import ast | |||
| @@ -26,7 +27,7 @@ from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMoni | |||
| from mindspore.train import Model | |||
| from mindspore.context import ParallelMode | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore.nn import SGD | |||
| from mindspore.nn import Momentum | |||
| from mindspore.common import set_seed | |||
| from src.maskrcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50 | |||
| @@ -71,7 +72,7 @@ if __name__ == '__main__': | |||
| prefix = "MaskRcnn.mindrecord" | |||
| mindrecord_dir = config.mindrecord_dir | |||
| mindrecord_file = os.path.join(mindrecord_dir, prefix + "0") | |||
| if not os.path.exists(mindrecord_file): | |||
| if rank == 0 and not os.path.exists(mindrecord_file): | |||
| if not os.path.isdir(mindrecord_dir): | |||
| os.makedirs(mindrecord_dir) | |||
| if args_opt.dataset == "coco": | |||
| @@ -80,14 +81,16 @@ if __name__ == '__main__': | |||
| data_to_mindrecord_byte_image("coco", True, prefix) | |||
| print("Create Mindrecord Done, at {}".format(mindrecord_dir)) | |||
| else: | |||
| print("coco_root not exits.") | |||
| raise Exception("coco_root not exits.") | |||
| else: | |||
| if os.path.isdir(config.IMAGE_DIR) and os.path.exists(config.ANNO_PATH): | |||
| print("Create Mindrecord.") | |||
| data_to_mindrecord_byte_image("other", True, prefix) | |||
| print("Create Mindrecord Done, at {}".format(mindrecord_dir)) | |||
| else: | |||
| print("IMAGE_DIR or ANNO_PATH not exits.") | |||
| raise Exception("IMAGE_DIR or ANNO_PATH not exits.") | |||
| while not os.path.exists(mindrecord_file+".db"): | |||
| time.sleep(5) | |||
| if not args_opt.only_create_dataset: | |||
| loss_scale = float(config.loss_scale) | |||
| @@ -115,8 +118,8 @@ if __name__ == '__main__': | |||
| loss = LossNet() | |||
| lr = Tensor(dynamic_lr(config, rank_size=device_num, start_steps=config.pretrain_epoch_size * dataset_size), | |||
| mstype.float32) | |||
| opt = SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum, | |||
| weight_decay=config.weight_decay, loss_scale=config.loss_scale) | |||
| opt = Momentum(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum, | |||
| weight_decay=config.weight_decay, loss_scale=config.loss_scale) | |||
| net_with_loss = WithLossCell(net, loss) | |||
| if args_opt.run_distribute: | |||