iou NMSWithMask larsupdate testcase sgd testcasetags/v0.3.0-alpha
| @@ -114,6 +114,9 @@ def build_op(build_type, json_str): | |||
| return get_op_pattern() | |||
| # call function | |||
| if kernel_name[0:19] == "bounding_box_encode": | |||
| return op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name_val=kernel_name) | |||
| return op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name) | |||
| except Exception as e: | |||
| @@ -84,7 +84,11 @@ static std::map<string, string> tbe_func_adapter_map = { | |||
| {"resize_bilinear_grad", "resize_bilinear_v2_grad"}, | |||
| {"adam", "apply_adam_d"}, | |||
| {"r_oi_align", "roi_align"}, | |||
| {"r_oi_align_grad", "roi_align_grad"}}; | |||
| {"r_oi_align_grad", "roi_align_grad"}, | |||
| {"i_ou", "iou"}, | |||
| {"s_gd", "sgd"}, | |||
| {"l_ars_update", "lars_v2_update"}, | |||
| {"n_ms_with_mask", "nms_with_mask"}}; | |||
| void TbeAdapter::NormalizeFuncName(std::string *func_name) { | |||
| if (func_name == nullptr) { | |||
| @@ -430,6 +430,18 @@ void TbeKernelJsonCreator::ParseAttrValue(const std::string &type, const mindspo | |||
| attr_value = GetValue<std::vector<int>>(value); | |||
| } | |||
| (*attr_obj)["value"] = attr_value; | |||
| } else if (type == "listFloat") { | |||
| std::vector<float> attr_value; | |||
| auto value_type = value->type(); | |||
| MS_EXCEPTION_IF_NULL(value_type); | |||
| auto value_type_str = value_type->ToString(); | |||
| if (value_type_str == "float") { | |||
| float data = GetValue<float>(value); | |||
| attr_value.push_back(data); | |||
| } else { | |||
| attr_value = GetValue<std::vector<float>>(value); | |||
| } | |||
| (*attr_obj)["value"] = attr_value; | |||
| } else if (type == "listListInt") { | |||
| auto attr_value = GetValue<std::vector<std::vector<int>>>(value); | |||
| (*attr_obj)["value"] = attr_value; | |||
| @@ -171,3 +171,11 @@ from .resize_bilinear_grad import _resize_bilinear_grad_tbe | |||
| from .flatten import _flatten_tbe | |||
| from .roi_align import _roi_align_tbe | |||
| from .roi_align_grad import _roi_align_grad_tbe | |||
| from .bounding_box_decode import _bounding_box_decode_tbe | |||
| from .bounding_box_encode import _bounding_box_encode_tbe | |||
| from .check_valid import _check_valid_tbe | |||
| from .iou import _iou_tbe | |||
| from .nms_with_mask import nms_with_mask_op_info | |||
| from .random_choice_with_mask import random_choice_with_mask_op_info | |||
| from .sgd import sgd_op_info | |||
| from .lars_update import lars_update_op_info | |||
| @@ -0,0 +1,41 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """BoundingBoxDecode op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| bounding_box_decode_op_info = TBERegOp("BoundingBoxDecode") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("bounding_box_decode.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("bounding_box_decode") \ | |||
| .partial_flag(True) \ | |||
| .attr("means", "optional", "listFloat", "all") \ | |||
| .attr("stds", "optional", "listFloat", "all") \ | |||
| .attr("max_shape", "optional", "listInt", "all") \ | |||
| .attr("wh_ratio_clip", "optional", "float", "all") \ | |||
| .input(0, "rois", False, "required", "all") \ | |||
| .input(1, "deltas", False, "required", "all") \ | |||
| .output(0, "bboxes", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(bounding_box_decode_op_info) | |||
| def _bounding_box_decode_tbe(): | |||
| """BoundingBoxDecode TBE register""" | |||
| return | |||
| @@ -0,0 +1,38 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """BoundingBoxEncode op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| bounding_box_encode_op_info = TBERegOp("BoundingBoxEncode") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("bounding_box_encode.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("bounding_box_encode") \ | |||
| .partial_flag(True) \ | |||
| .attr("means", "optional", "listFloat", "all") \ | |||
| .attr("stds", "optional", "listFloat", "all") \ | |||
| .input(0, "anchor_box", False, "required", "all") \ | |||
| .input(1, "ground_truth_box", False, "required", "all") \ | |||
| .output(0, "delats", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(bounding_box_encode_op_info) | |||
| def _bounding_box_encode_tbe(): | |||
| """BoundingBoxEncode TBE register""" | |||
| return | |||
| @@ -0,0 +1,37 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """CheckValid op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| check_valid_op_info = TBERegOp("CheckValid") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("check_valid.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("check_valid") \ | |||
| .partial_flag(True) \ | |||
| .input(0, "bbox_tensor", False, "required", "all") \ | |||
| .input(1, "img_tas", False, "required", "all") \ | |||
| .output(0, "valid_tensor", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(check_valid_op_info) | |||
| def _check_valid_tbe(): | |||
| """CheckValid TBE register""" | |||
| return | |||
| @@ -0,0 +1,37 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Iou op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| iou_op_info = TBERegOp("IOU") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("iou.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("iou") \ | |||
| .partial_flag(True) \ | |||
| .attr("mode", "required", "str", "all") \ | |||
| .input(0, "bboxes", False, "required", "all") \ | |||
| .input(1, "gtboxes", False, "required", "all") \ | |||
| .output(0, "overlap", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(iou_op_info) | |||
| def _iou_tbe(): | |||
| """Iou TBE register""" | |||
| return | |||
| @@ -0,0 +1,50 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """LarsUpdate op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| lars_update_op_info = TBERegOp("LARSUpdate") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("lars_v2_update.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("lars_v2_update") \ | |||
| .partial_flag(True) \ | |||
| .attr("hyperpara", "optional", "float", "all") \ | |||
| .attr("epsilon", "optional", "float", "all") \ | |||
| .attr("use_clip", "optional", "bool", "all") \ | |||
| .input(0, "w", False, "required", "all") \ | |||
| .input(1, "g", False, "required", "all") \ | |||
| .input(2, "w_square_sum", False, "required", "all") \ | |||
| .input(3, "g_square_sum", False, "required", "all") \ | |||
| .input(4, "weight_decay", False, "required", "all") \ | |||
| .input(5, "learning_rate", False, "required", "all") \ | |||
| .output(0, "g_new", False, "required", "all") \ | |||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_FracZ) \ | |||
| .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_C1HWNCoC0) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(lars_update_op_info) | |||
| def _lars_update_tbe(): | |||
| """LarsUpdate TBE register""" | |||
| return | |||
| @@ -0,0 +1,39 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """NMSWithMask op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| nms_with_mask_op_info = TBERegOp("NMSWithMask") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("nms_with_mask.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("nms_with_mask") \ | |||
| .partial_flag(True) \ | |||
| .attr("iou_threshold", "optional", "float", "all") \ | |||
| .input(0, "box_scores", False, "required", "all") \ | |||
| .output(0, "selected_boxes", False, "required", "all") \ | |||
| .output(0, "selected_idx", False, "required", "all") \ | |||
| .output(0, "selected_mask", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I32_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I32_Default, DataType.BOOL_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(nms_with_mask_op_info) | |||
| def _nms_with_mask_tbe(): | |||
| """NMSWithMask TBE register""" | |||
| return | |||
| @@ -0,0 +1,41 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """RandomChoiceWithMask op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| random_choice_with_mask_op_info = TBERegOp("RandomChoiceWithMask") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("random_choice_with_mask.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("random_choice_with_mask") \ | |||
| .partial_flag(True) \ | |||
| .attr("max_shape", "optional", "listInt", "all") \ | |||
| .attr("means", "optional", "listFloat", "all") \ | |||
| .attr("stds", "optional", "listFloat", "all") \ | |||
| .attr("wh_ratio_clip", "optional", "float", "all") \ | |||
| .input(0, "rois", False, "required", "all") \ | |||
| .input(1, "deltas", False, "required", "all") \ | |||
| .output(0, "bboxes", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .get_op_info() | |||
| @op_info_register(random_choice_with_mask_op_info) | |||
| def _random_choice_with_mask_tbe(): | |||
| """RandomChoiceWithMask TBE register""" | |||
| return | |||
| @@ -0,0 +1,54 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """SGD op""" | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| sgd_op_info = TBERegOp("SGD") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("sgd.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("sgd") \ | |||
| .partial_flag(True) \ | |||
| .attr("dampening", "optional", "float", "all") \ | |||
| .attr("weight_decay", "optional", "float", "all") \ | |||
| .attr("nesterov", "optional", "bool", "all") \ | |||
| .input(0, "parameters", False, "required", "all") \ | |||
| .input(1, "gradient", False, "required", "all") \ | |||
| .input(2, "learning_rate", False, "required", "all") \ | |||
| .input(3, "accum", False, "required", "all") \ | |||
| .input(4, "momentum", False, "required", "all") \ | |||
| .input(5, "stat", False, "required", "all") \ | |||
| .output(0, "parameters", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_Default, DataType.F16_5HD, | |||
| DataType.F16_Default, DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_FracZ, | |||
| DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_5HD, | |||
| DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_FracZ, | |||
| DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ) \ | |||
| .get_op_info() | |||
| @op_info_register(sgd_op_info) | |||
| def _sgd_tbe(): | |||
| """SGD TBE register""" | |||
| return | |||
| @@ -16,6 +16,7 @@ | |||
| """Operators for math.""" | |||
| import numpy as np | |||
| from ... import context | |||
| from ..._c_expression import signature_rw as sig_rw | |||
| from ..._c_expression import signature_kind as sig_kind | |||
| from ..._c_expression import signature_dtype as sig_dtype | |||
| @@ -1950,12 +1951,16 @@ class NMSWithMask(PrimitiveWithInfer): | |||
| """Init NMSWithMask""" | |||
| validator.check_value_type("iou_threshold", iou_threshold, [float], self.name) | |||
| self.init_prim_io_names(inputs=['bboxes'], outputs=['selected_boxes', 'selected_idx', 'selected_mask']) | |||
| self.is_ge = context.get_context("enable_ge") | |||
| def infer_shape(self, bboxes_shape): | |||
| cls_name = self.name | |||
| validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ, cls_name) | |||
| validator.check_integer("bboxes.shape()[0]", bboxes_shape[0], 0, Rel.GT, cls_name) | |||
| validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 5, Rel.EQ, cls_name) | |||
| if not self.is_ge: | |||
| validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 8, Rel.EQ, cls_name) | |||
| else: | |||
| validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 5, Rel.EQ, cls_name) | |||
| num = bboxes_shape[0] | |||
| return (bboxes_shape, (num,), (num,)) | |||
| @@ -175,10 +175,10 @@ class CheckValid(PrimitiveWithInfer): | |||
| self.init_prim_io_names(inputs=['bboxes', 'img_metas'], outputs=['output']) | |||
| def infer_shape(self, bboxes_shape, metas_shape): | |||
| validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ, self.name) | |||
| validator.check_integer("bboxes_shape[-1]", bboxes_shape[-1], 4, Rel.EQ, self.name) | |||
| validator.check_integer("img_metas rank", len(metas_shape), 1, Rel.EQ, self.name) | |||
| validator.check_integer("img_metas shape[0]", metas_shape[0], 3, Rel.EQ, self.name) | |||
| validator.check("bboxes rank", len(bboxes_shape), "", 2, Rel.EQ, self.name) | |||
| validator.check("bboxes_shape[-1]", bboxes_shape[-1], "", 4, Rel.EQ, self.name) | |||
| validator.check("img_metas rank", len(metas_shape), "", 1, Rel.EQ, self.name) | |||
| validator.check("img_metas shape[0]", metas_shape[0], "", 3, Rel.EQ, self.name) | |||
| return bboxes_shape[:-1] | |||
| def infer_dtype(self, bboxes_type, metas_type): | |||
| @@ -188,6 +188,10 @@ class InputOpNet(nn.Cell): | |||
| x = self.op(x1, x2, x3, x4, self.c1) | |||
| return x | |||
| def construct4_c2(self, x1, x2, x3, x4): | |||
| x = self.op(x1, x2, x3, x4, self.c1, self.c2) | |||
| return x | |||
| def construct4_c4(self, x1, x2, x3, x4): | |||
| x = self.op(x1, x2, x3, x4, self.c1, self.c2, self.c3, self.c4) | |||
| return x | |||
| @@ -951,6 +951,17 @@ test_case_nn_ops = [ | |||
| 'desc_inputs': [[1, 1, 2, 2], [1, 5]], | |||
| 'desc_bprop': [[1, 1, 2, 2]], | |||
| 'skip': ['backward']}), | |||
| ('LARSUpdate', { | |||
| 'block': P.LARSUpdate(1e-05, 0.001, False), | |||
| 'desc_const': [0.0, 0.001], | |||
| 'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3]], | |||
| 'desc_bprop': [3, 3], | |||
| 'skip': ['backward']}), | |||
| ('SGD', { | |||
| 'block': P.SGD(0.0, 0.0, False), | |||
| 'desc_inputs': [[3, 3], [3, 3], Tensor(0.001, mstype.float32), [3, 3], Tensor(0.1, mstype.float32), [3, 3]], | |||
| 'desc_bprop': [3, 3], | |||
| 'skip': ['backward']}), | |||
| ] | |||
| test_case_array_ops = [ | |||