Browse Source

adopt BoundingBoxEncode modification and disable save graph for maskrcnn

tags/v1.1.0
gengdongjie 5 years ago
parent
commit
41fd8608f8
4 changed files with 4 additions and 4 deletions
  1. +1
    -1
      model_zoo/official/cv/faster_rcnn/src/FasterRcnn/bbox_assign_sample_stage2.py
  2. +1
    -1
      model_zoo/official/cv/maskrcnn/eval.py
  3. +1
    -1
      model_zoo/official/cv/maskrcnn/src/maskrcnn/bbox_assign_sample_stage2.py
  4. +1
    -1
      model_zoo/official/cv/maskrcnn/train.py

+ 1
- 1
model_zoo/official/cv/faster_rcnn/src/FasterRcnn/bbox_assign_sample_stage2.py View File

@@ -77,7 +77,7 @@ class BboxAssignSampleForRcnn(nn.Cell):
self.random_choice_with_mask_neg = P.RandomChoiceWithMask(self.num_expected_neg)
self.reshape = P.Reshape()
self.equal = P.Equal()
self.bounding_box_encode = P.BoundingBoxEncode(means=(0.0, 0.0, 0.0, 0.0), stds=(10.0, 10.0, 5.0, 5.0))
self.bounding_box_encode = P.BoundingBoxEncode(means=(0.0, 0.0, 0.0, 0.0), stds=(0.1, 0.1, 0.2, 0.2))
self.concat_axis1 = P.Concat(axis=1)
self.logicalnot = P.LogicalNot()
self.tile = P.Tile()


+ 1
- 1
model_zoo/official/cv/maskrcnn/eval.py View File

@@ -37,7 +37,7 @@ parser.add_argument("--checkpoint_path", type=str, required=True, help="Checkpoi
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
args_opt = parser.parse_args()

context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, device_id=args_opt.device_id)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)

def MaskRcnn_eval(dataset_path, ckpt_path, ann_file):
"""MaskRcnn evaluation."""


+ 1
- 1
model_zoo/official/cv/maskrcnn/src/maskrcnn/bbox_assign_sample_stage2.py View File

@@ -72,7 +72,7 @@ class BboxAssignSampleForRcnn(nn.Cell):
self.random_choice_with_mask_neg = P.RandomChoiceWithMask(self.num_expected_neg)
self.reshape = P.Reshape()
self.equal = P.Equal()
self.bounding_box_encode = P.BoundingBoxEncode(means=(0.0, 0.0, 0.0, 0.0), stds=(10.0, 10.0, 5.0, 5.0))
self.bounding_box_encode = P.BoundingBoxEncode(means=(0.0, 0.0, 0.0, 0.0), stds=(0.1, 0.1, 0.2, 0.2))
self.concat_axis1 = P.Concat(axis=1)
self.logicalnot = P.LogicalNot()
self.tile = P.Tile()


+ 1
- 1
model_zoo/official/cv/maskrcnn/train.py View File

@@ -51,7 +51,7 @@ parser.add_argument("--device_num", type=int, default=1, help="Use device nums,
parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default is 0.")
args_opt = parser.parse_args()

context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, device_id=args_opt.device_id)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)

if __name__ == '__main__':
print("Start train for maskrcnn!")


Loading…
Cancel
Save