Browse Source

!2610 optimize fastrcnn training process

Merge pull request !2610 from yanghaitao/yht_fastrcnn
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
f0d40475b0
1 changed files with 14 additions and 18 deletions
  1. +14
    -18
      model_zoo/faster_rcnn/src/dataset.py

+ 14
- 18
model_zoo/faster_rcnn/src/dataset.py View File

@@ -318,10 +318,6 @@ def preprocess_fn(image, box, is_training):
else: else:
input_data = resize_column(*input_data) input_data = resize_column(*input_data)


photo = (np.random.rand() < config.photo_ratio)
if photo:
input_data = photo_crop_column(*input_data)

input_data = image_bgr_rgb(*input_data) input_data = image_bgr_rgb(*input_data)


output_data = input_data output_data = input_data
@@ -432,19 +428,19 @@ def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="fast
writer.write_raw_data([row]) writer.write_raw_data([row])
writer.commit() writer.commit()



def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, device_num=1, rank_id=0, def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, device_num=1, rank_id=0,
is_training=True, num_parallel_workers=8):
is_training=True, num_parallel_workers=4):
"""Creatr FasterRcnn dataset with MindDataset.""" """Creatr FasterRcnn dataset with MindDataset."""
ds = de.MindDataset(mindrecord_file, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank_id, ds = de.MindDataset(mindrecord_file, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank_id,
num_parallel_workers=num_parallel_workers, shuffle=is_training)
num_parallel_workers=1, shuffle=is_training)
decode = C.Decode() decode = C.Decode()
ds = ds.map(input_columns=["image"], operations=decode)
ds = ds.map(input_columns=["image"], operations=decode, num_parallel_workers=1)
compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training)) compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training))


hwc_to_chw = C.HWC2CHW() hwc_to_chw = C.HWC2CHW()
normalize_op = C.Normalize((123.675, 116.28, 103.53), (58.395, 57.12, 57.375)) normalize_op = C.Normalize((123.675, 116.28, 103.53), (58.395, 57.12, 57.375))
horizontally_op = C.RandomHorizontalFlip(1) horizontally_op = C.RandomHorizontalFlip(1)
type_cast0 = CC.TypeCast(mstype.float32)
type_cast1 = CC.TypeCast(mstype.float16) type_cast1 = CC.TypeCast(mstype.float16)
type_cast2 = CC.TypeCast(mstype.int32) type_cast2 = CC.TypeCast(mstype.int32)
type_cast3 = CC.TypeCast(mstype.bool_) type_cast3 = CC.TypeCast(mstype.bool_)
@@ -453,17 +449,18 @@ def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, devi
ds = ds.map(input_columns=["image", "annotation"], ds = ds.map(input_columns=["image", "annotation"],
output_columns=["image", "image_shape", "box", "label", "valid_num"], output_columns=["image", "image_shape", "box", "label", "valid_num"],
columns_order=["image", "image_shape", "box", "label", "valid_num"], columns_order=["image", "image_shape", "box", "label", "valid_num"],
operations=compose_map_func, num_parallel_workers=4)

ds = ds.map(input_columns=["image"], operations=[normalize_op, type_cast0],
num_parallel_workers=num_parallel_workers)
operations=compose_map_func, num_parallel_workers=num_parallel_workers)


flip = (np.random.rand() < config.flip_ratio) flip = (np.random.rand() < config.flip_ratio)
if flip: if flip:
ds = ds.map(input_columns=["image"], operations=[horizontally_op],
num_parallel_workers=num_parallel_workers)
ds = ds.map(input_columns=["image"], operations=[normalize_op, horizontally_op, hwc_to_chw, type_cast1],
num_parallel_workers=24)
ds = ds.map(input_columns=["image", "image_shape", "box", "label", "valid_num"], ds = ds.map(input_columns=["image", "image_shape", "box", "label", "valid_num"],
operations=flipped_generation, num_parallel_workers=4)
operations=flipped_generation, num_parallel_workers=num_parallel_workers)
else:
ds = ds.map(input_columns=["image"], operations=[normalize_op, hwc_to_chw, type_cast1],
num_parallel_workers=24)

else: else:
ds = ds.map(input_columns=["image", "annotation"], ds = ds.map(input_columns=["image", "annotation"],
output_columns=["image", "image_shape", "box", "label", "valid_num"], output_columns=["image", "image_shape", "box", "label", "valid_num"],
@@ -471,11 +468,10 @@ def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, devi
operations=compose_map_func, operations=compose_map_func,
num_parallel_workers=num_parallel_workers) num_parallel_workers=num_parallel_workers)


ds = ds.map(input_columns=["image"], operations=[normalize_op, type_cast0],
num_parallel_workers=num_parallel_workers)
ds = ds.map(input_columns=["image"], operations=[normalize_op, hwc_to_chw, type_cast1],
num_parallel_workers=24)


# transpose_column from python to c # transpose_column from python to c
ds = ds.map(input_columns=["image"], operations=[hwc_to_chw, type_cast1])
ds = ds.map(input_columns=["image_shape"], operations=[type_cast1]) ds = ds.map(input_columns=["image_shape"], operations=[type_cast1])
ds = ds.map(input_columns=["box"], operations=[type_cast1]) ds = ds.map(input_columns=["box"], operations=[type_cast1])
ds = ds.map(input_columns=["label"], operations=[type_cast2]) ds = ds.map(input_columns=["label"], operations=[type_cast2])


Loading…
Cancel
Save