| @@ -30,8 +30,8 @@ lrn_op_info = TBERegOp("LRN") \ | |||||
| .attr("norm_region", "optional", "str", "all", "ACROSS_CHANNELS") \ | .attr("norm_region", "optional", "str", "all", "ACROSS_CHANNELS") \ | ||||
| .input(0, "x", False, "required", "all") \ | .input(0, "x", False, "required", "all") \ | ||||
| .output(0, "y", False, "required", "all") \ | .output(0, "y", False, "required", "all") \ | ||||
| .dtype_format(DataType.F16_NCHW, DataType.F16_NCHW) \ | |||||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @@ -31,8 +31,8 @@ lrn_grad_op_info = TBERegOp("LRNGrad") \ | |||||
| .input(1, "x", False, "required", "all") \ | .input(1, "x", False, "required", "all") \ | ||||
| .input(2, "y", False, "required", "all") \ | .input(2, "y", False, "required", "all") \ | ||||
| .output(0, "z", False, "required", "all") \ | .output(0, "z", False, "required", "all") \ | ||||
| .dtype_format(DataType.F16_NCHW, DataType.F16_NCHW, DataType.F16_NCHW, DataType.F16_NCHW) \ | |||||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @@ -50,27 +50,17 @@ parallel_concat_op_info = TBERegOp("ParallelConcat") \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | ||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ | ||||
| .dtype_format(DataType.BOOL_NHWC, DataType.BOOL_NHWC) \ | .dtype_format(DataType.BOOL_NHWC, DataType.BOOL_NHWC) \ | ||||
| .dtype_format(DataType.BOOL_NCHW, DataType.BOOL_NCHW) \ | |||||
| .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ | |||||
| .dtype_format(DataType.I8_NHWC, DataType.I8_NHWC) \ | .dtype_format(DataType.I8_NHWC, DataType.I8_NHWC) \ | ||||
| .dtype_format(DataType.I8_NCHW, DataType.I8_NCHW) \ | |||||
| .dtype_format(DataType.U8_NHWC, DataType.U8_NHWC) \ | .dtype_format(DataType.U8_NHWC, DataType.U8_NHWC) \ | ||||
| .dtype_format(DataType.U8_NCHW, DataType.U8_NCHW) \ | |||||
| .dtype_format(DataType.I16_NHWC, DataType.I16_NHWC) \ | .dtype_format(DataType.I16_NHWC, DataType.I16_NHWC) \ | ||||
| .dtype_format(DataType.I16_NCHW, DataType.I16_NCHW) \ | |||||
| .dtype_format(DataType.U16_NHWC, DataType.U16_NHWC) \ | .dtype_format(DataType.U16_NHWC, DataType.U16_NHWC) \ | ||||
| .dtype_format(DataType.U16_NCHW, DataType.U16_NCHW) \ | |||||
| .dtype_format(DataType.I32_NHWC, DataType.I32_NHWC) \ | .dtype_format(DataType.I32_NHWC, DataType.I32_NHWC) \ | ||||
| .dtype_format(DataType.I32_NCHW, DataType.I32_NCHW) \ | |||||
| .dtype_format(DataType.U32_NHWC, DataType.U32_NHWC) \ | .dtype_format(DataType.U32_NHWC, DataType.U32_NHWC) \ | ||||
| .dtype_format(DataType.U32_NCHW, DataType.U32_NCHW) \ | |||||
| .dtype_format(DataType.I64_NHWC, DataType.I64_NHWC) \ | .dtype_format(DataType.I64_NHWC, DataType.I64_NHWC) \ | ||||
| .dtype_format(DataType.I64_NCHW, DataType.I64_NCHW) \ | |||||
| .dtype_format(DataType.U64_NHWC, DataType.U64_NHWC) \ | .dtype_format(DataType.U64_NHWC, DataType.U64_NHWC) \ | ||||
| .dtype_format(DataType.U64_NCHW, DataType.U64_NCHW) \ | |||||
| .dtype_format(DataType.F16_NHWC, DataType.F16_NHWC) \ | .dtype_format(DataType.F16_NHWC, DataType.F16_NHWC) \ | ||||
| .dtype_format(DataType.F16_NCHW, DataType.F16_NCHW) \ | |||||
| .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC) \ | .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC) \ | ||||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW) \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @@ -26,10 +26,8 @@ prelu_op_info = TBERegOp("PReLU") \ | |||||
| .input(0, "x", False, "required", "all") \ | .input(0, "x", False, "required", "all") \ | ||||
| .input(1, "weight", False, "required", "all") \ | .input(1, "weight", False, "required", "all") \ | ||||
| .output(0, "y", False, "required", "all") \ | .output(0, "y", False, "required", "all") \ | ||||
| .dtype_format(DataType.F16_NCHW, DataType.F16_Default, DataType.F16_NCHW) \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | ||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | ||||
| .dtype_format(DataType.F32_NCHW, DataType.F32_Default, DataType.F32_NCHW) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | ||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | ||||
| .get_op_info() | .get_op_info() | ||||
| @@ -28,8 +28,6 @@ prelu_grad_op_info = TBERegOp("PReLUGrad") \ | |||||
| .input(2, "weights", False, "required", "all") \ | .input(2, "weights", False, "required", "all") \ | ||||
| .output(0, "dx", False, "required", "all") \ | .output(0, "dx", False, "required", "all") \ | ||||
| .output(0, "da", False, "required", "all") \ | .output(0, "da", False, "required", "all") \ | ||||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_Default, | |||||
| DataType.F32_NCHW, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | ||||
| DataType.F32_5HD, DataType.F32_5HD) \ | DataType.F32_5HD, DataType.F32_5HD) \ | ||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | ||||
| @@ -32,8 +32,6 @@ sparse_apply_adagrad_d_op_info = TBERegOp("SparseApplyAdagrad") \ | |||||
| .input(3, "indices", False, "required", "all") \ | .input(3, "indices", False, "required", "all") \ | ||||
| .output(0, "var", False, "required", "all") \ | .output(0, "var", False, "required", "all") \ | ||||
| .output(1, "accum", False, "required", "all") \ | .output(1, "accum", False, "required", "all") \ | ||||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.I32_NCHW, | |||||
| DataType.F32_NCHW, DataType.F32_NCHW) \ | |||||
| .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, | .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, | ||||
| DataType.F32_NHWC, DataType.F32_NHWC) \ | DataType.F32_NHWC, DataType.F32_NHWC) \ | ||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, | .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, | ||||
| @@ -33,8 +33,6 @@ sparse_apply_adagrad_v2_d_op_info = TBERegOp("SparseApplyAdagradV2") \ | |||||
| .input(3, "indices", False, "required", "all") \ | .input(3, "indices", False, "required", "all") \ | ||||
| .output(0, "var", False, "required", "all") \ | .output(0, "var", False, "required", "all") \ | ||||
| .output(1, "accum", False, "required", "all") \ | .output(1, "accum", False, "required", "all") \ | ||||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.I32_NCHW, | |||||
| DataType.F32_NCHW, DataType.F32_NCHW) \ | |||||
| .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, | .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.I32_NHWC, | ||||
| DataType.F32_NHWC, DataType.F32_NHWC) \ | DataType.F32_NHWC, DataType.F32_NHWC) \ | ||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, | .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, | ||||
| @@ -36,14 +36,10 @@ sparse_apply_ftrl_d_op_info = TBERegOp("SparseApplyFtrl") \ | |||||
| .output(0, "var", False, "required", "all") \ | .output(0, "var", False, "required", "all") \ | ||||
| .output(1, "accum", False, "required", "all") \ | .output(1, "accum", False, "required", "all") \ | ||||
| .output(2, "linear", False, "required", "all") \ | .output(2, "linear", False, "required", "all") \ | ||||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, | |||||
| DataType.I32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \ | |||||
| .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, | .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, | ||||
| DataType.I32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC) \ | DataType.I32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC) \ | ||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | ||||
| DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | ||||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, | |||||
| DataType.I64_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \ | |||||
| .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, | .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, | ||||
| DataType.I64_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC) \ | DataType.I64_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC) \ | ||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | ||||
| @@ -37,8 +37,6 @@ sparse_apply_ftrl_d_op_info = TBERegOp("SparseApplyFtrl") \ | |||||
| .output(0, "var", False, "required", "all") \ | .output(0, "var", False, "required", "all") \ | ||||
| .output(1, "accum", False, "required", "all") \ | .output(1, "accum", False, "required", "all") \ | ||||
| .output(2, "linear", False, "required", "all") \ | .output(2, "linear", False, "required", "all") \ | ||||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, | |||||
| DataType.I32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \ | |||||
| .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, | .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, | ||||
| DataType.I32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC) \ | DataType.I32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC) \ | ||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | ||||
| @@ -37,8 +37,6 @@ sparse_apply_ftrl_v2_d_op_info = TBERegOp("SparseApplyFtrlV2") \ | |||||
| .output(0, "var", False, "required", "all") \ | .output(0, "var", False, "required", "all") \ | ||||
| .output(1, "accum", False, "required", "all") \ | .output(1, "accum", False, "required", "all") \ | ||||
| .output(2, "linear", False, "required", "all") \ | .output(2, "linear", False, "required", "all") \ | ||||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, | |||||
| DataType.I32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \ | |||||
| .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, | .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, | ||||
| DataType.I32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC) \ | DataType.I32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC) \ | ||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | ||||
| @@ -33,9 +33,6 @@ sparse_apply_proximal_adagrad_d_op_info = TBERegOp("SparseApplyProximalAdagrad") | |||||
| .input(6, "indices", False, "required", "all") \ | .input(6, "indices", False, "required", "all") \ | ||||
| .output(0, "var", False, "required", "all") \ | .output(0, "var", False, "required", "all") \ | ||||
| .output(1, "accum", False, "required", "all") \ | .output(1, "accum", False, "required", "all") \ | ||||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, | |||||
| DataType.F32_NCHW, DataType.F32_NCHW, DataType.I32_NCHW, DataType.F32_NCHW, | |||||
| DataType.F32_NCHW) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | ||||
| DataType.F32_5HD, DataType.F32_5HD, DataType.I32_5HD, DataType.F32_5HD, | DataType.F32_5HD, DataType.F32_5HD, DataType.I32_5HD, DataType.F32_5HD, | ||||
| DataType.F32_5HD) \ | DataType.F32_5HD) \ | ||||
| @@ -48,9 +45,6 @@ sparse_apply_proximal_adagrad_d_op_info = TBERegOp("SparseApplyProximalAdagrad") | |||||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, | .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, | ||||
| DataType.F32_FracZ, DataType.F32_FracZ, DataType.I32_FracZ, DataType.F32_FracZ, | DataType.F32_FracZ, DataType.F32_FracZ, DataType.I32_FracZ, DataType.F32_FracZ, | ||||
| DataType.F32_FracZ) \ | DataType.F32_FracZ) \ | ||||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, | |||||
| DataType.F32_NCHW, DataType.F32_NCHW, DataType.I64_NCHW, DataType.F32_NCHW, | |||||
| DataType.F32_NCHW) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | ||||
| DataType.F32_5HD, DataType.F32_5HD, DataType.I64_5HD, DataType.F32_5HD, | DataType.F32_5HD, DataType.F32_5HD, DataType.I64_5HD, DataType.F32_5HD, | ||||
| DataType.F32_5HD) \ | DataType.F32_5HD) \ | ||||
| @@ -34,9 +34,6 @@ sparse_apply_proximal_adagrad_d_ds_op_info = TBERegOp("SparseApplyProximalAdagra | |||||
| .input(6, "indices", False, "required", "all") \ | .input(6, "indices", False, "required", "all") \ | ||||
| .output(0, "var", False, "required", "all") \ | .output(0, "var", False, "required", "all") \ | ||||
| .output(1, "accum", False, "required", "all") \ | .output(1, "accum", False, "required", "all") \ | ||||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, | |||||
| DataType.F32_NCHW, DataType.F32_NCHW, DataType.I32_NCHW, DataType.F32_NCHW, | |||||
| DataType.F32_NCHW) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | ||||
| DataType.F32_5HD, DataType.F32_5HD, DataType.I32_5HD, DataType.F32_5HD, | DataType.F32_5HD, DataType.F32_5HD, DataType.I32_5HD, DataType.F32_5HD, | ||||
| DataType.F32_5HD) \ | DataType.F32_5HD) \ | ||||
| @@ -49,9 +46,6 @@ sparse_apply_proximal_adagrad_d_ds_op_info = TBERegOp("SparseApplyProximalAdagra | |||||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, | .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, | ||||
| DataType.F32_FracZ, DataType.F32_FracZ, DataType.I32_FracZ, DataType.F32_FracZ, | DataType.F32_FracZ, DataType.F32_FracZ, DataType.I32_FracZ, DataType.F32_FracZ, | ||||
| DataType.F32_FracZ) \ | DataType.F32_FracZ) \ | ||||
| .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, | |||||
| DataType.F32_NCHW, DataType.F32_NCHW, DataType.I64_NCHW, DataType.F32_NCHW, | |||||
| DataType.F32_NCHW) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | ||||
| DataType.F32_5HD, DataType.F32_5HD, DataType.I64_5HD, DataType.F32_5HD, | DataType.F32_5HD, DataType.F32_5HD, DataType.I64_5HD, DataType.F32_5HD, | ||||
| DataType.F32_5HD) \ | DataType.F32_5HD) \ | ||||