|
|
@@ -441,6 +441,7 @@ def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, devi |
|
|
hwc_to_chw = C.HWC2CHW() |
|
|
hwc_to_chw = C.HWC2CHW() |
|
|
normalize_op = C.Normalize((123.675, 116.28, 103.53), (58.395, 57.12, 57.375)) |
|
|
normalize_op = C.Normalize((123.675, 116.28, 103.53), (58.395, 57.12, 57.375)) |
|
|
horizontally_op = C.RandomHorizontalFlip(1) |
|
|
horizontally_op = C.RandomHorizontalFlip(1) |
|
|
|
|
|
type_cast0 = CC.TypeCast(mstype.float32) |
|
|
type_cast1 = CC.TypeCast(mstype.float16) |
|
|
type_cast1 = CC.TypeCast(mstype.float16) |
|
|
type_cast2 = CC.TypeCast(mstype.int32) |
|
|
type_cast2 = CC.TypeCast(mstype.int32) |
|
|
type_cast3 = CC.TypeCast(mstype.bool_) |
|
|
type_cast3 = CC.TypeCast(mstype.bool_) |
|
|
@@ -453,13 +454,15 @@ def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, devi |
|
|
|
|
|
|
|
|
flip = (np.random.rand() < config.flip_ratio) |
|
|
flip = (np.random.rand() < config.flip_ratio) |
|
|
if flip: |
|
|
if flip: |
|
|
ds = ds.map(input_columns=["image"], operations=[normalize_op, horizontally_op, hwc_to_chw, type_cast1], |
|
|
|
|
|
num_parallel_workers=24) |
|
|
|
|
|
|
|
|
ds = ds.map(input_columns=["image"], operations=[normalize_op, type_cast0, horizontally_op], |
|
|
|
|
|
num_parallel_workers=12) |
|
|
ds = ds.map(input_columns=["image", "image_shape", "box", "label", "valid_num"], |
|
|
ds = ds.map(input_columns=["image", "image_shape", "box", "label", "valid_num"], |
|
|
operations=flipped_generation, num_parallel_workers=num_parallel_workers) |
|
|
operations=flipped_generation, num_parallel_workers=num_parallel_workers) |
|
|
else: |
|
|
else: |
|
|
ds = ds.map(input_columns=["image"], operations=[normalize_op, hwc_to_chw, type_cast1], |
|
|
|
|
|
num_parallel_workers=24) |
|
|
|
|
|
|
|
|
ds = ds.map(input_columns=["image"], operations=[normalize_op, type_cast0], |
|
|
|
|
|
num_parallel_workers=12) |
|
|
|
|
|
ds = ds.map(input_columns=["image"], operations=[hwc_to_chw, type_cast1], |
|
|
|
|
|
num_parallel_workers=12) |
|
|
|
|
|
|
|
|
else: |
|
|
else: |
|
|
ds = ds.map(input_columns=["image", "annotation"], |
|
|
ds = ds.map(input_columns=["image", "annotation"], |
|
|
|