iou NMSWithMask larsupdate testcase sgd testcasetags/v0.3.0-alpha
| @@ -114,6 +114,9 @@ def build_op(build_type, json_str): | |||||
| return get_op_pattern() | return get_op_pattern() | ||||
| # call function | # call function | ||||
| if kernel_name[0:19] == "bounding_box_encode": | |||||
| return op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name_val=kernel_name) | |||||
| return op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name) | return op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name) | ||||
| except Exception as e: | except Exception as e: | ||||
| @@ -84,7 +84,11 @@ static std::map<string, string> tbe_func_adapter_map = { | |||||
| {"resize_bilinear_grad", "resize_bilinear_v2_grad"}, | {"resize_bilinear_grad", "resize_bilinear_v2_grad"}, | ||||
| {"adam", "apply_adam_d"}, | {"adam", "apply_adam_d"}, | ||||
| {"r_oi_align", "roi_align"}, | {"r_oi_align", "roi_align"}, | ||||
| {"r_oi_align_grad", "roi_align_grad"}}; | |||||
| {"r_oi_align_grad", "roi_align_grad"}, | |||||
| {"i_ou", "iou"}, | |||||
| {"s_gd", "sgd"}, | |||||
| {"l_ars_update", "lars_v2_update"}, | |||||
| {"n_ms_with_mask", "nms_with_mask"}}; | |||||
| void TbeAdapter::NormalizeFuncName(std::string *func_name) { | void TbeAdapter::NormalizeFuncName(std::string *func_name) { | ||||
| if (func_name == nullptr) { | if (func_name == nullptr) { | ||||
| @@ -430,6 +430,18 @@ void TbeKernelJsonCreator::ParseAttrValue(const std::string &type, const mindspo | |||||
| attr_value = GetValue<std::vector<int>>(value); | attr_value = GetValue<std::vector<int>>(value); | ||||
| } | } | ||||
| (*attr_obj)["value"] = attr_value; | (*attr_obj)["value"] = attr_value; | ||||
| } else if (type == "listFloat") { | |||||
| std::vector<float> attr_value; | |||||
| auto value_type = value->type(); | |||||
| MS_EXCEPTION_IF_NULL(value_type); | |||||
| auto value_type_str = value_type->ToString(); | |||||
| if (value_type_str == "float") { | |||||
| float data = GetValue<float>(value); | |||||
| attr_value.push_back(data); | |||||
| } else { | |||||
| attr_value = GetValue<std::vector<float>>(value); | |||||
| } | |||||
| (*attr_obj)["value"] = attr_value; | |||||
| } else if (type == "listListInt") { | } else if (type == "listListInt") { | ||||
| auto attr_value = GetValue<std::vector<std::vector<int>>>(value); | auto attr_value = GetValue<std::vector<std::vector<int>>>(value); | ||||
| (*attr_obj)["value"] = attr_value; | (*attr_obj)["value"] = attr_value; | ||||
| @@ -171,3 +171,11 @@ from .resize_bilinear_grad import _resize_bilinear_grad_tbe | |||||
| from .flatten import _flatten_tbe | from .flatten import _flatten_tbe | ||||
| from .roi_align import _roi_align_tbe | from .roi_align import _roi_align_tbe | ||||
| from .roi_align_grad import _roi_align_grad_tbe | from .roi_align_grad import _roi_align_grad_tbe | ||||
| from .bounding_box_decode import _bounding_box_decode_tbe | |||||
| from .bounding_box_encode import _bounding_box_encode_tbe | |||||
| from .check_valid import _check_valid_tbe | |||||
| from .iou import _iou_tbe | |||||
| from .nms_with_mask import nms_with_mask_op_info | |||||
| from .random_choice_with_mask import random_choice_with_mask_op_info | |||||
| from .sgd import sgd_op_info | |||||
| from .lars_update import lars_update_op_info | |||||
| @@ -0,0 +1,41 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """BoundingBoxDecode op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| bounding_box_decode_op_info = TBERegOp("BoundingBoxDecode") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("bounding_box_decode.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("bounding_box_decode") \ | |||||
| .partial_flag(True) \ | |||||
| .attr("means", "optional", "listFloat", "all") \ | |||||
| .attr("stds", "optional", "listFloat", "all") \ | |||||
| .attr("max_shape", "optional", "listInt", "all") \ | |||||
| .attr("wh_ratio_clip", "optional", "float", "all") \ | |||||
| .input(0, "rois", False, "required", "all") \ | |||||
| .input(1, "deltas", False, "required", "all") \ | |||||
| .output(0, "bboxes", False, "required", "all") \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(bounding_box_decode_op_info) | |||||
| def _bounding_box_decode_tbe(): | |||||
| """BoundingBoxDecode TBE register""" | |||||
| return | |||||
| @@ -0,0 +1,38 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """BoundingBoxEncode op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| bounding_box_encode_op_info = TBERegOp("BoundingBoxEncode") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("bounding_box_encode.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("bounding_box_encode") \ | |||||
| .partial_flag(True) \ | |||||
| .attr("means", "optional", "listFloat", "all") \ | |||||
| .attr("stds", "optional", "listFloat", "all") \ | |||||
| .input(0, "anchor_box", False, "required", "all") \ | |||||
| .input(1, "ground_truth_box", False, "required", "all") \ | |||||
| .output(0, "delats", False, "required", "all") \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(bounding_box_encode_op_info) | |||||
| def _bounding_box_encode_tbe(): | |||||
| """BoundingBoxEncode TBE register""" | |||||
| return | |||||
| @@ -0,0 +1,37 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """CheckValid op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| check_valid_op_info = TBERegOp("CheckValid") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("check_valid.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("check_valid") \ | |||||
| .partial_flag(True) \ | |||||
| .input(0, "bbox_tensor", False, "required", "all") \ | |||||
| .input(1, "img_tas", False, "required", "all") \ | |||||
| .output(0, "valid_tensor", False, "required", "all") \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I8_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(check_valid_op_info) | |||||
| def _check_valid_tbe(): | |||||
| """CheckValid TBE register""" | |||||
| return | |||||
| @@ -0,0 +1,37 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Iou op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| iou_op_info = TBERegOp("IOU") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("iou.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("iou") \ | |||||
| .partial_flag(True) \ | |||||
| .attr("mode", "required", "str", "all") \ | |||||
| .input(0, "bboxes", False, "required", "all") \ | |||||
| .input(1, "gtboxes", False, "required", "all") \ | |||||
| .output(0, "overlap", False, "required", "all") \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(iou_op_info) | |||||
| def _iou_tbe(): | |||||
| """Iou TBE register""" | |||||
| return | |||||
| @@ -0,0 +1,50 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """LarsUpdate op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| lars_update_op_info = TBERegOp("LARSUpdate") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("lars_v2_update.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("lars_v2_update") \ | |||||
| .partial_flag(True) \ | |||||
| .attr("hyperpara", "optional", "float", "all") \ | |||||
| .attr("epsilon", "optional", "float", "all") \ | |||||
| .attr("use_clip", "optional", "bool", "all") \ | |||||
| .input(0, "w", False, "required", "all") \ | |||||
| .input(1, "g", False, "required", "all") \ | |||||
| .input(2, "w_square_sum", False, "required", "all") \ | |||||
| .input(3, "g_square_sum", False, "required", "all") \ | |||||
| .input(4, "weight_decay", False, "required", "all") \ | |||||
| .input(5, "learning_rate", False, "required", "all") \ | |||||
| .output(0, "g_new", False, "required", "all") \ | |||||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_Default, | |||||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_FracZ) \ | |||||
| .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_Default, | |||||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_C1HWNCoC0) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_Default, | |||||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_5HD) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(lars_update_op_info) | |||||
| def _lars_update_tbe(): | |||||
| """LarsUpdate TBE register""" | |||||
| return | |||||
| @@ -0,0 +1,39 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """NMSWithMask op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| nms_with_mask_op_info = TBERegOp("NMSWithMask") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("nms_with_mask.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("nms_with_mask") \ | |||||
| .partial_flag(True) \ | |||||
| .attr("iou_threshold", "optional", "float", "all") \ | |||||
| .input(0, "box_scores", False, "required", "all") \ | |||||
| .output(0, "selected_boxes", False, "required", "all") \ | |||||
| .output(0, "selected_idx", False, "required", "all") \ | |||||
| .output(0, "selected_mask", False, "required", "all") \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I32_Default, DataType.U8_Default) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I32_Default, DataType.BOOL_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(nms_with_mask_op_info) | |||||
| def _nms_with_mask_tbe(): | |||||
| """NMSWithMask TBE register""" | |||||
| return | |||||
| @@ -0,0 +1,41 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """RandomChoiceWithMask op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| random_choice_with_mask_op_info = TBERegOp("RandomChoiceWithMask") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("random_choice_with_mask.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("random_choice_with_mask") \ | |||||
| .partial_flag(True) \ | |||||
| .attr("max_shape", "optional", "listInt", "all") \ | |||||
| .attr("means", "optional", "listFloat", "all") \ | |||||
| .attr("stds", "optional", "listFloat", "all") \ | |||||
| .attr("wh_ratio_clip", "optional", "float", "all") \ | |||||
| .input(0, "rois", False, "required", "all") \ | |||||
| .input(1, "deltas", False, "required", "all") \ | |||||
| .output(0, "bboxes", False, "required", "all") \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||||
| .get_op_info() | |||||
| @op_info_register(random_choice_with_mask_op_info) | |||||
| def _random_choice_with_mask_tbe(): | |||||
| """RandomChoiceWithMask TBE register""" | |||||
| return | |||||
| @@ -0,0 +1,54 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """SGD op""" | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| sgd_op_info = TBERegOp("SGD") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("sgd.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("sgd") \ | |||||
| .partial_flag(True) \ | |||||
| .attr("dampening", "optional", "float", "all") \ | |||||
| .attr("weight_decay", "optional", "float", "all") \ | |||||
| .attr("nesterov", "optional", "bool", "all") \ | |||||
| .input(0, "parameters", False, "required", "all") \ | |||||
| .input(1, "gradient", False, "required", "all") \ | |||||
| .input(2, "learning_rate", False, "required", "all") \ | |||||
| .input(3, "accum", False, "required", "all") \ | |||||
| .input(4, "momentum", False, "required", "all") \ | |||||
| .input(5, "stat", False, "required", "all") \ | |||||
| .output(0, "parameters", False, "required", "all") \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_5HD, | |||||
| DataType.F16_Default, DataType.F16_5HD, DataType.F16_5HD) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_FracZ, | |||||
| DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_5HD, | |||||
| DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_FracZ, | |||||
| DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ) \ | |||||
| .get_op_info() | |||||
| @op_info_register(sgd_op_info) | |||||
| def _sgd_tbe(): | |||||
| """SGD TBE register""" | |||||
| return | |||||
| @@ -16,6 +16,7 @@ | |||||
| """Operators for math.""" | """Operators for math.""" | ||||
| import numpy as np | import numpy as np | ||||
| from ... import context | |||||
| from ..._c_expression import signature_rw as sig_rw | from ..._c_expression import signature_rw as sig_rw | ||||
| from ..._c_expression import signature_kind as sig_kind | from ..._c_expression import signature_kind as sig_kind | ||||
| from ..._c_expression import signature_dtype as sig_dtype | from ..._c_expression import signature_dtype as sig_dtype | ||||
| @@ -1950,12 +1951,16 @@ class NMSWithMask(PrimitiveWithInfer): | |||||
| """Init NMSWithMask""" | """Init NMSWithMask""" | ||||
| validator.check_value_type("iou_threshold", iou_threshold, [float], self.name) | validator.check_value_type("iou_threshold", iou_threshold, [float], self.name) | ||||
| self.init_prim_io_names(inputs=['bboxes'], outputs=['selected_boxes', 'selected_idx', 'selected_mask']) | self.init_prim_io_names(inputs=['bboxes'], outputs=['selected_boxes', 'selected_idx', 'selected_mask']) | ||||
| self.is_ge = context.get_context("enable_ge") | |||||
| def infer_shape(self, bboxes_shape): | def infer_shape(self, bboxes_shape): | ||||
| cls_name = self.name | cls_name = self.name | ||||
| validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ, cls_name) | validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ, cls_name) | ||||
| validator.check_integer("bboxes.shape()[0]", bboxes_shape[0], 0, Rel.GT, cls_name) | validator.check_integer("bboxes.shape()[0]", bboxes_shape[0], 0, Rel.GT, cls_name) | ||||
| validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 5, Rel.EQ, cls_name) | |||||
| if not self.is_ge: | |||||
| validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 8, Rel.EQ, cls_name) | |||||
| else: | |||||
| validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 5, Rel.EQ, cls_name) | |||||
| num = bboxes_shape[0] | num = bboxes_shape[0] | ||||
| return (bboxes_shape, (num,), (num,)) | return (bboxes_shape, (num,), (num,)) | ||||
| @@ -175,10 +175,10 @@ class CheckValid(PrimitiveWithInfer): | |||||
| self.init_prim_io_names(inputs=['bboxes', 'img_metas'], outputs=['output']) | self.init_prim_io_names(inputs=['bboxes', 'img_metas'], outputs=['output']) | ||||
| def infer_shape(self, bboxes_shape, metas_shape): | def infer_shape(self, bboxes_shape, metas_shape): | ||||
| validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ, self.name) | |||||
| validator.check_integer("bboxes_shape[-1]", bboxes_shape[-1], 4, Rel.EQ, self.name) | |||||
| validator.check_integer("img_metas rank", len(metas_shape), 1, Rel.EQ, self.name) | |||||
| validator.check_integer("img_metas shape[0]", metas_shape[0], 3, Rel.EQ, self.name) | |||||
| validator.check("bboxes rank", len(bboxes_shape), "", 2, Rel.EQ, self.name) | |||||
| validator.check("bboxes_shape[-1]", bboxes_shape[-1], "", 4, Rel.EQ, self.name) | |||||
| validator.check("img_metas rank", len(metas_shape), "", 1, Rel.EQ, self.name) | |||||
| validator.check("img_metas shape[0]", metas_shape[0], "", 3, Rel.EQ, self.name) | |||||
| return bboxes_shape[:-1] | return bboxes_shape[:-1] | ||||
| def infer_dtype(self, bboxes_type, metas_type): | def infer_dtype(self, bboxes_type, metas_type): | ||||
| @@ -188,6 +188,10 @@ class InputOpNet(nn.Cell): | |||||
| x = self.op(x1, x2, x3, x4, self.c1) | x = self.op(x1, x2, x3, x4, self.c1) | ||||
| return x | return x | ||||
| def construct4_c2(self, x1, x2, x3, x4): | |||||
| x = self.op(x1, x2, x3, x4, self.c1, self.c2) | |||||
| return x | |||||
| def construct4_c4(self, x1, x2, x3, x4): | def construct4_c4(self, x1, x2, x3, x4): | ||||
| x = self.op(x1, x2, x3, x4, self.c1, self.c2, self.c3, self.c4) | x = self.op(x1, x2, x3, x4, self.c1, self.c2, self.c3, self.c4) | ||||
| return x | return x | ||||
| @@ -951,6 +951,17 @@ test_case_nn_ops = [ | |||||
| 'desc_inputs': [[1, 1, 2, 2], [1, 5]], | 'desc_inputs': [[1, 1, 2, 2], [1, 5]], | ||||
| 'desc_bprop': [[1, 1, 2, 2]], | 'desc_bprop': [[1, 1, 2, 2]], | ||||
| 'skip': ['backward']}), | 'skip': ['backward']}), | ||||
| ('LARSUpdate', { | |||||
| 'block': P.LARSUpdate(1e-05, 0.001, False), | |||||
| 'desc_const': [0.0, 0.001], | |||||
| 'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3]], | |||||
| 'desc_bprop': [3, 3], | |||||
| 'skip': ['backward']}), | |||||
| ('SGD', { | |||||
| 'block': P.SGD(0.0, 0.0, False), | |||||
| 'desc_inputs': [[3, 3], [3, 3], Tensor(0.001, mstype.float32), [3, 3], Tensor(0.1, mstype.float32), [3, 3]], | |||||
| 'desc_bprop': [3, 3], | |||||
| 'skip': ['backward']}), | |||||
| ] | ] | ||||
| test_case_array_ops = [ | test_case_array_ops = [ | ||||