|
|
|
@@ -13,95 +13,28 @@ |
|
|
|
# limitations under the License. |
|
|
|
# ============================================================================ |
|
|
|
from tests.st.ops.custom_ops_tbe.conv2d import conv2d |
|
|
|
from mindspore.ops.op_info_register import op_info_register |
|
|
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType |
|
|
|
|
|
|
|
@op_info_register("""{ |
|
|
|
"op_name": "Cus_Conv2D", |
|
|
|
"imply_type": "TBE", |
|
|
|
"fusion_type": "CONVLUTION", |
|
|
|
"async_flag": false, |
|
|
|
"binfile_name": "conv2d.so", |
|
|
|
"compute_cost": 10, |
|
|
|
"kernel_name": "Cus_Conv2D", |
|
|
|
"partial_flag": true, |
|
|
|
"attr": [ |
|
|
|
{ |
|
|
|
"name": "stride", |
|
|
|
"param_type": "required", |
|
|
|
"type": "listInt", |
|
|
|
"value": "all" |
|
|
|
}, |
|
|
|
{ |
|
|
|
"name": "pad_list", |
|
|
|
"param_type": "required", |
|
|
|
"type": "listInt", |
|
|
|
"value": "all" |
|
|
|
}, |
|
|
|
{ |
|
|
|
"name": "dilation", |
|
|
|
"param_type": "required", |
|
|
|
"type": "listInt", |
|
|
|
"value": "all" |
|
|
|
} |
|
|
|
], |
|
|
|
"inputs": [ |
|
|
|
{ |
|
|
|
"index": 0, |
|
|
|
"dtype": [ |
|
|
|
"float16" |
|
|
|
], |
|
|
|
"format": [ |
|
|
|
"NC1HWC0" |
|
|
|
], |
|
|
|
"name": "x", |
|
|
|
"need_compile": false, |
|
|
|
"param_type": "required", |
|
|
|
"shape": "all" |
|
|
|
}, |
|
|
|
{ |
|
|
|
"index": 1, |
|
|
|
"dtype": [ |
|
|
|
"float16" |
|
|
|
], |
|
|
|
"format": [ |
|
|
|
"FracZ" |
|
|
|
], |
|
|
|
"name": "filter", |
|
|
|
"need_compile": false, |
|
|
|
"param_type": "required", |
|
|
|
"shape": "all" |
|
|
|
}, |
|
|
|
{ |
|
|
|
"index": 2, |
|
|
|
"dtype": [ |
|
|
|
"float16" |
|
|
|
], |
|
|
|
"format": [ |
|
|
|
"DefaultFormat" |
|
|
|
], |
|
|
|
"name": "bias", |
|
|
|
"need_compile": false, |
|
|
|
"param_type": "optional", |
|
|
|
"shape": "all" |
|
|
|
} |
|
|
|
], |
|
|
|
"outputs": [ |
|
|
|
{ |
|
|
|
"index": 0, |
|
|
|
"dtype": [ |
|
|
|
"float16" |
|
|
|
], |
|
|
|
"format": [ |
|
|
|
"NC1HWC0" |
|
|
|
], |
|
|
|
"name": "y", |
|
|
|
"need_compile": true, |
|
|
|
"param_type": "required", |
|
|
|
"shape": "all" |
|
|
|
} |
|
|
|
] |
|
|
|
}""") |
|
|
|
cus_conv2D_op_info = TBERegOp("Cus_Conv2D") \ |
|
|
|
.fusion_type("CONVLUTION") \ |
|
|
|
.async_flag(False) \ |
|
|
|
.binfile_name("conv2d.so") \ |
|
|
|
.compute_cost(10) \ |
|
|
|
.kernel_name("Cus_Conv2D") \ |
|
|
|
.partial_flag(True) \ |
|
|
|
.attr("stride", "required", "listInt", "all") \ |
|
|
|
.attr("pad_list", "required", "listInt", "all") \ |
|
|
|
.attr("dilation", "required", "listInt", "all") \ |
|
|
|
.input(0, "x", False, "required", "all") \ |
|
|
|
.input(1, "filter", False, "required", "all") \ |
|
|
|
.input(2, "bias", False, "optional", "all") \ |
|
|
|
.output(0, "y", True, "required", "all") \ |
|
|
|
.dtype_format(DataType.F16_5HD, DataType.F16_FracZ, DataType.F32_Default, DataType.F16_5HD) \ |
|
|
|
.get_op_info() |
|
|
|
|
|
|
|
|
|
|
|
@op_info_register(cus_conv2D_op_info) |
|
|
|
def Cus_Conv2D(inputs, weights, bias, outputs, strides, pads, dilations, |
|
|
|
kernel_name="conv2d"): |
|
|
|
conv2d(inputs, weights, bias, outputs, strides, pads, dilations, |
|
|
|
kernel_name) |
|
|
|
kernel_name) |