| @@ -77,7 +77,7 @@ class BboxAssignSampleForRcnn(nn.Cell): | |||||
| self.random_choice_with_mask_neg = P.RandomChoiceWithMask(self.num_expected_neg) | self.random_choice_with_mask_neg = P.RandomChoiceWithMask(self.num_expected_neg) | ||||
| self.reshape = P.Reshape() | self.reshape = P.Reshape() | ||||
| self.equal = P.Equal() | self.equal = P.Equal() | ||||
| self.bounding_box_encode = P.BoundingBoxEncode(means=(0.0, 0.0, 0.0, 0.0), stds=(10.0, 10.0, 5.0, 5.0)) | |||||
| self.bounding_box_encode = P.BoundingBoxEncode(means=(0.0, 0.0, 0.0, 0.0), stds=(0.1, 0.1, 0.2, 0.2)) | |||||
| self.concat_axis1 = P.Concat(axis=1) | self.concat_axis1 = P.Concat(axis=1) | ||||
| self.logicalnot = P.LogicalNot() | self.logicalnot = P.LogicalNot() | ||||
| self.tile = P.Tile() | self.tile = P.Tile() | ||||
| @@ -37,7 +37,7 @@ parser.add_argument("--checkpoint_path", type=str, required=True, help="Checkpoi | |||||
| parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") | parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") | ||||
| args_opt = parser.parse_args() | args_opt = parser.parse_args() | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, device_id=args_opt.device_id) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) | |||||
| def MaskRcnn_eval(dataset_path, ckpt_path, ann_file): | def MaskRcnn_eval(dataset_path, ckpt_path, ann_file): | ||||
| """MaskRcnn evaluation.""" | """MaskRcnn evaluation.""" | ||||
| @@ -72,7 +72,7 @@ class BboxAssignSampleForRcnn(nn.Cell): | |||||
| self.random_choice_with_mask_neg = P.RandomChoiceWithMask(self.num_expected_neg) | self.random_choice_with_mask_neg = P.RandomChoiceWithMask(self.num_expected_neg) | ||||
| self.reshape = P.Reshape() | self.reshape = P.Reshape() | ||||
| self.equal = P.Equal() | self.equal = P.Equal() | ||||
| self.bounding_box_encode = P.BoundingBoxEncode(means=(0.0, 0.0, 0.0, 0.0), stds=(10.0, 10.0, 5.0, 5.0)) | |||||
| self.bounding_box_encode = P.BoundingBoxEncode(means=(0.0, 0.0, 0.0, 0.0), stds=(0.1, 0.1, 0.2, 0.2)) | |||||
| self.concat_axis1 = P.Concat(axis=1) | self.concat_axis1 = P.Concat(axis=1) | ||||
| self.logicalnot = P.LogicalNot() | self.logicalnot = P.LogicalNot() | ||||
| self.tile = P.Tile() | self.tile = P.Tile() | ||||
| @@ -51,7 +51,7 @@ parser.add_argument("--device_num", type=int, default=1, help="Use device nums, | |||||
| parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default is 0.") | parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default is 0.") | ||||
| args_opt = parser.parse_args() | args_opt = parser.parse_args() | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, device_id=args_opt.device_id) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| print("Start train for maskrcnn!") | print("Start train for maskrcnn!") | ||||