Browse Source

!2792 fix fastrcnn accuracy error

Merge pull request !2792 from yanghaitao/yht_fastrcnn_accu
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
a1ea599461
1 changed files with 7 additions and 4 deletions
  1. +7
    -4
      model_zoo/faster_rcnn/src/dataset.py

+ 7
- 4
model_zoo/faster_rcnn/src/dataset.py View File

@@ -441,6 +441,7 @@ def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, devi
hwc_to_chw = C.HWC2CHW() hwc_to_chw = C.HWC2CHW()
normalize_op = C.Normalize((123.675, 116.28, 103.53), (58.395, 57.12, 57.375)) normalize_op = C.Normalize((123.675, 116.28, 103.53), (58.395, 57.12, 57.375))
horizontally_op = C.RandomHorizontalFlip(1) horizontally_op = C.RandomHorizontalFlip(1)
type_cast0 = CC.TypeCast(mstype.float32)
type_cast1 = CC.TypeCast(mstype.float16) type_cast1 = CC.TypeCast(mstype.float16)
type_cast2 = CC.TypeCast(mstype.int32) type_cast2 = CC.TypeCast(mstype.int32)
type_cast3 = CC.TypeCast(mstype.bool_) type_cast3 = CC.TypeCast(mstype.bool_)
@@ -453,13 +454,15 @@ def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, devi


flip = (np.random.rand() < config.flip_ratio) flip = (np.random.rand() < config.flip_ratio)
if flip: if flip:
ds = ds.map(input_columns=["image"], operations=[normalize_op, horizontally_op, hwc_to_chw, type_cast1],
num_parallel_workers=24)
ds = ds.map(input_columns=["image"], operations=[normalize_op, type_cast0, horizontally_op],
num_parallel_workers=12)
ds = ds.map(input_columns=["image", "image_shape", "box", "label", "valid_num"], ds = ds.map(input_columns=["image", "image_shape", "box", "label", "valid_num"],
operations=flipped_generation, num_parallel_workers=num_parallel_workers) operations=flipped_generation, num_parallel_workers=num_parallel_workers)
else: else:
ds = ds.map(input_columns=["image"], operations=[normalize_op, hwc_to_chw, type_cast1],
num_parallel_workers=24)
ds = ds.map(input_columns=["image"], operations=[normalize_op, type_cast0],
num_parallel_workers=12)
ds = ds.map(input_columns=["image"], operations=[hwc_to_chw, type_cast1],
num_parallel_workers=12)


else: else:
ds = ds.map(input_columns=["image", "annotation"], ds = ds.map(input_columns=["image", "annotation"],


Loading…
Cancel
Save