Browse Source

modify gpu operator information registration

tags/v0.2.0-alpha
maoweiyong 5 years ago
parent
commit
a007e4812b
16 changed files with 225 additions and 544 deletions
  1. +2
    -2
      mindspore/ops/__init__.py
  2. +12
    -38
      mindspore/ops/_op_impl/akg/gpu/cast.py
  3. +12
    -43
      mindspore/ops/_op_impl/akg/gpu/equal.py
  4. +11
    -33
      mindspore/ops/_op_impl/akg/gpu/hsigmoid.py
  5. +12
    -43
      mindspore/ops/_op_impl/akg/gpu/hsigmoid_grad.py
  6. +11
    -33
      mindspore/ops/_op_impl/akg/gpu/hswish.py
  7. +12
    -43
      mindspore/ops/_op_impl/akg/gpu/hswish_grad.py
  8. +11
    -33
      mindspore/ops/_op_impl/akg/gpu/mean.py
  9. +12
    -38
      mindspore/ops/_op_impl/akg/gpu/mean_grad.py
  10. +12
    -43
      mindspore/ops/_op_impl/akg/gpu/mul.py
  11. +11
    -33
      mindspore/ops/_op_impl/akg/gpu/relu6.py
  12. +12
    -43
      mindspore/ops/_op_impl/akg/gpu/relu6_grad.py
  13. +12
    -38
      mindspore/ops/_op_impl/akg/gpu/squeeze.py
  14. +13
    -43
      mindspore/ops/_op_impl/akg/gpu/squeeze_grad.py
  15. +12
    -38
      mindspore/ops/_op_impl/akg/gpu/tile.py
  16. +58
    -0
      mindspore/ops/op_info_register.py

+ 2
- 2
mindspore/ops/__init__.py View File

@@ -30,7 +30,7 @@ Note:


from .primitive import Primitive, PrimitiveWithInfer, prim_attr_register from .primitive import Primitive, PrimitiveWithInfer, prim_attr_register
from .vm_impl_registry import get_vm_impl_fn, vm_impl_registry from .vm_impl_registry import get_vm_impl_fn, vm_impl_registry
from .op_info_register import op_info_register, AiCPURegOp, TBERegOp, DataType
from .op_info_register import op_info_register, AkgRegOp, AiCPURegOp, TBERegOp, DataType
from .primitive import constexpr from .primitive import constexpr
from .._c_expression import signature_rw, signature_kind from .._c_expression import signature_rw, signature_kind


@@ -40,6 +40,6 @@ __primitive__ = [
] ]


__all__ = ["get_vm_impl_fn", "vm_impl_registry", __all__ = ["get_vm_impl_fn", "vm_impl_registry",
"op_info_register", "AiCPURegOp", "TBERegOp", "DataType",
"op_info_register", "AkgRegOp", "AiCPURegOp", "TBERegOp", "DataType",
"constexpr"] "constexpr"]
__all__.extend(__primitive__) __all__.extend(__primitive__)

+ 12
- 38
mindspore/ops/_op_impl/akg/gpu/cast.py View File

@@ -13,45 +13,19 @@
# limitations under the License. # limitations under the License.


"""Cast op""" """Cast op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType


@op_info_register("""{
"op_name": "Cast",
"imply_type": "AutoDiff",
"fusion_type": "OPAQUE",
"processor": "cuda",
"attr": [
{
"name": "dst_type",
"param_type": "required",
"type": "str"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float32"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "x"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float32", "float16"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "output"
}
]
}""")
cast_op_info = AkgRegOp("Cast") \
.fusion_type("OPAQUE") \
.input(0, "x") \
.output(0, "output") \
.attr("dst_type", "required", "str") \
.dtype_format(DataType.F16_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_Default, DataType.F16_Default) \
.get_op_info()


@op_info_register(cast_op_info)
def _cast_akg(): def _cast_akg():
"""Cast AutoDiff register""" """Cast AutoDiff register"""
return return

+ 12
- 43
mindspore/ops/_op_impl/akg/gpu/equal.py View File

@@ -13,50 +13,19 @@
# limitations under the License. # limitations under the License.


"""Equal op""" """Equal op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType


@op_info_register("""{
"op_name": "Equal",
"imply_type": "AutoDiff",
"fusion_type": "OPAQUE",
"processor": "cuda",
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"float32", "float16"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "x"
},
{
"index": 1,
"dtype": [
"float32", "float16"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "y"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"bool", "bool"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "output"
}
]
}""")
equal_op_info = AkgRegOp("Equal") \
.fusion_type("OPAQUE") \
.input(0, "x") \
.input(1, "y") \
.output(0, "output") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \
.get_op_info()


@op_info_register(equal_op_info)
def _equal_akg(): def _equal_akg():
"""Equal AutoDiff register""" """Equal AutoDiff register"""
return return

+ 11
- 33
mindspore/ops/_op_impl/akg/gpu/hsigmoid.py View File

@@ -13,40 +13,18 @@
# limitations under the License. # limitations under the License.


"""HSigmoid op""" """HSigmoid op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType


@op_info_register("""{
"op_name": "HSigmoid",
"imply_type": "AutoDiff",
"fusion_type": "OPAQUE",
"processor": "cuda",
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"float32", "float16"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "x"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float32", "float16"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "output"
}
]
}""")
hsigmoid_op_info = AkgRegOp("HSigmoid") \
.fusion_type("OPAQUE") \
.input(0, "x") \
.output(0, "output") \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.get_op_info()


@op_info_register(hsigmoidgrad_op_info)
def _hsigmoid_akg(): def _hsigmoid_akg():
"""HSigmoid AutoDiff register""" """HSigmoid AutoDiff register"""
return return

+ 12
- 43
mindspore/ops/_op_impl/akg/gpu/hsigmoid_grad.py View File

@@ -13,50 +13,19 @@
# limitations under the License. # limitations under the License.


"""HSigmoidGrad op""" """HSigmoidGrad op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType


@op_info_register("""{
"op_name": "HSigmoidGrad",
"imply_type": "AutoDiff",
"fusion_type": "OPAQUE",
"processor": "cuda",
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"float32", "float16"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "y_grad"
},
{
"index": 1,
"dtype": [
"float32", "float16"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "x"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float32", "float16"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "output"
}
]
}""")
hsigmoidgrad_op_info = AkgRegOp("HSigmoidGrad") \
.fusion_type("OPAQUE") \
.input(0, "y_grad") \
.input(1, "x") \
.output(0, "output") \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.get_op_info()


@op_info_register(hsigmoidgrad_op_info)
def _hsigmoid_grad_akg(): def _hsigmoid_grad_akg():
"""HSigmoidGrad AutoDiff register""" """HSigmoidGrad AutoDiff register"""
return return

+ 11
- 33
mindspore/ops/_op_impl/akg/gpu/hswish.py View File

@@ -13,40 +13,18 @@
# limitations under the License. # limitations under the License.


"""HSwish op""" """HSwish op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType


@op_info_register("""{
"op_name": "HSwish",
"imply_type": "AutoDiff",
"fusion_type": "OPAQUE",
"processor": "cuda",
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"float32", "float16"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "x"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float32", "float16"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "output"
}
]
}""")
hswish_op_info = AkgRegOp("HSwish") \
.fusion_type("OPAQUE") \
.input(0, "x") \
.output(0, "output") \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.get_op_info()


@op_info_register(hsigmoidgrad_op_info)
def _hswish_akg(): def _hswish_akg():
"""HSwish AutoDiff register""" """HSwish AutoDiff register"""
return return

+ 12
- 43
mindspore/ops/_op_impl/akg/gpu/hswish_grad.py View File

@@ -13,50 +13,19 @@
# limitations under the License. # limitations under the License.


"""HSwishGrad op""" """HSwishGrad op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType


@op_info_register("""{
"op_name": "HSwishGrad",
"imply_type": "AutoDiff",
"fusion_type": "OPAQUE",
"processor": "cuda",
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"float32", "float16"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "y_grad"
},
{
"index": 1,
"dtype": [
"float32", "float16"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "x"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float32", "float16"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "output"
}
]
}""")
hswishgrad_op_info = AkgRegOp("HSwishGrad") \
.fusion_type("OPAQUE") \
.input(0, "y_grad") \
.input(1, "x") \
.output(0, "output") \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.get_op_info()


@op_info_register(hsigmoidgrad_op_info)
def _hswish_grad_akg(): def _hswish_grad_akg():
"""HSwishGrad AutoDiff register""" """HSwishGrad AutoDiff register"""
return return

+ 11
- 33
mindspore/ops/_op_impl/akg/gpu/mean.py View File

@@ -13,40 +13,18 @@
# limitations under the License. # limitations under the License.


"""SimpleMean op""" """SimpleMean op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType


@op_info_register("""{
"op_name": "SimpleMean",
"imply_type": "AutoDiff",
"fusion_type": "OPAQUE",
"processor": "cuda",
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"float32", "float16"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "x"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float32", "float16"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "output"
}
]
}""")
mean_op_info = AkgRegOp("SimpleMean") \
.fusion_type("OPAQUE") \
.input(0, "x") \
.output(0, "output") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info()


@op_info_register(mean_op_info)
def _simple_mean_akg(): def _simple_mean_akg():
"""SimpleMean AutoDiff register""" """SimpleMean AutoDiff register"""
return return

+ 12
- 38
mindspore/ops/_op_impl/akg/gpu/mean_grad.py View File

@@ -13,45 +13,19 @@
# limitations under the License. # limitations under the License.


"""SimpleMeanGrad op""" """SimpleMeanGrad op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType


@op_info_register("""{
"op_name": "SimpleMeanGrad",
"imply_type": "AutoDiff",
"fusion_type": "OPAQUE",
"processor": "cuda",
"attr": [
{
"name": "input_shape",
"param_type": "required",
"type": "listInt"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float32", "float16"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "HEAD"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float32", "float16"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "output"
}
]
}""")
mean_grad_op_info = AkgRegOp("SimpleMeanGrad") \
.fusion_type("OPAQUE") \
.input(0, "HEAD") \
.output(0, "output") \
.attr("input_shape", "required", "listInt") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info()


@op_info_register(mean_grad_op_info)
def _simple_mean_grad_akg(): def _simple_mean_grad_akg():
"""SimpleMeanGrad AutoDiff register""" """SimpleMeanGrad AutoDiff register"""
return return

+ 12
- 43
mindspore/ops/_op_impl/akg/gpu/mul.py View File

@@ -13,50 +13,19 @@
# limitations under the License. # limitations under the License.


"""Mul op""" """Mul op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType


@op_info_register("""{
"op_name": "Mul",
"imply_type": "AutoDiff",
"fusion_type": "OPAQUE",
"processor": "cuda",
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"float32", "float16"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "x"
},
{
"index": 1,
"dtype": [
"float32", "float16"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "y"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float32", "float16"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "output"
}
]
}""")
mul_op_info = AkgRegOp("Mul") \
.fusion_type("OPAQUE") \
.input(0, "x") \
.input(1, "y") \
.output(0, "output") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()


@op_info_register(mul_op_info)
def _mul_akg(): def _mul_akg():
"""Mul AutoDiff register""" """Mul AutoDiff register"""
return return

+ 11
- 33
mindspore/ops/_op_impl/akg/gpu/relu6.py View File

@@ -13,40 +13,18 @@
# limitations under the License. # limitations under the License.


"""ReLU6 op""" """ReLU6 op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType


@op_info_register("""{
"op_name": "ReLU6",
"imply_type": "AutoDiff",
"fusion_type": "OPAQUE",
"processor": "cuda",
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"float32", "float16"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "x"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float32", "float16"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "output"
}
]
}""")
relu_op_info = AkgRegOp("ReLU6") \
.fusion_type("OPAQUE") \
.input(0, "x") \
.output(0, "output") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info()


@op_info_register(relu_op_info)
def _relu6_akg(): def _relu6_akg():
"""ReLU6 AutoDiff register""" """ReLU6 AutoDiff register"""
return return

+ 12
- 43
mindspore/ops/_op_impl/akg/gpu/relu6_grad.py View File

@@ -13,50 +13,19 @@
# limitations under the License. # limitations under the License.


"""ReLU6Grad op""" """ReLU6Grad op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType


@op_info_register("""{
"op_name": "ReLU6Grad",
"imply_type": "AutoDiff",
"fusion_type": "OPAQUE",
"processor": "cuda",
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"float32", "float16"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "y_grad"
},
{
"index": 1,
"dtype": [
"float32", "float16"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "x"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float32", "float16"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "output"
}
]
}""")
relu_grad_op_info = AkgRegOp("ReLU6Grad") \
.fusion_type("OPAQUE") \
.input(0, "y_grad") \
.input(1, "x") \
.output(0, "output") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()


@op_info_register(relu_grad_op_info)
def _relu6_grad_akg(): def _relu6_grad_akg():
"""ReLU6Grad AutoDiff register""" """ReLU6Grad AutoDiff register"""
return return

+ 12
- 38
mindspore/ops/_op_impl/akg/gpu/squeeze.py View File

@@ -13,45 +13,19 @@
# limitations under the License. # limitations under the License.


"""Squeeze op""" """Squeeze op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType


@op_info_register("""{
"op_name": "Squeeze",
"imply_type": "AutoDiff",
"fusion_type": "OPAQUE",
"processor": "cuda",
"attr": [
{
"name": "axis",
"param_type": "optional",
"type": "listInt"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float32", "float16"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "x"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float32", "float16"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "output"
}
]
}""")
squeeze_op_info = AkgRegOp("SqueezeGrad") \
.fusion_type("OPAQUE") \
.input(0, "x") \
.output(0, "output") \
.attr("axis", "optional", "listInt") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info()


@op_info_register(squeeze_op_info)
def _squeeze_akg(): def _squeeze_akg():
"""Squeeze AutoDiff register""" """Squeeze AutoDiff register"""
return return

+ 13
- 43
mindspore/ops/_op_impl/akg/gpu/squeeze_grad.py View File

@@ -13,50 +13,20 @@
# limitations under the License. # limitations under the License.


"""SqueezeGrad op""" """SqueezeGrad op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType


@op_info_register("""{
"op_name": "SqueezeGrad",
"imply_type": "AutoDiff",
"fusion_type": "OPAQUE",
"processor": "cuda",
"attr": [
{
"name": "x_shape",
"param_type": "required",
"type": "listInt"
},
{
"name": "axis",
"param_type": "optional",
"type": "listInt"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float32", "float16"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "y_grad"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float32", "float16"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "output"
}
]
}""")
squeeze_grad_op_info = AkgRegOp("SqueezeGrad") \
.fusion_type("OPAQUE") \
.input(0, "y_grad") \
.output(0, "output") \
.attr("x_shape", "required", "listInt") \
.attr("axis", "optional", "listInt") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info()


@op_info_register(squeeze_grad_op_info)
def _squeeze_grad_akg(): def _squeeze_grad_akg():
"""SqueezeGrad AutoDiff register""" """SqueezeGrad AutoDiff register"""
return return

+ 12
- 38
mindspore/ops/_op_impl/akg/gpu/tile.py View File

@@ -13,45 +13,19 @@
# limitations under the License. # limitations under the License.


"""Tile op""" """Tile op"""
from mindspore.ops.op_info_register import op_info_register
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType


@op_info_register("""{
"op_name": "Tile",
"imply_type": "AutoDiff",
"fusion_type": "OPAQUE",
"processor": "cuda",
"attr": [
{
"name": "multiples",
"param_type": "required",
"type": "listInt"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float32", "float16"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "x"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float32", "float16"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "output"
}
]
}""")
tile_op_info = AkgRegOp("Tile") \
.fusion_type("OPAQUE") \
.input(0, "x") \
.output(0, "output") \
.attr("multiples", "required", "listInt") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.get_op_info()


@op_info_register(tile_op_info)
def _tile_akg(): def _tile_akg():
"""Tile AutoDiff register""" """Tile AutoDiff register"""
return return

+ 58
- 0
mindspore/ops/op_info_register.py View File

@@ -205,6 +205,64 @@ class RegOp():
return op_info return op_info




class AkgRegOp(RegOp):
"""Class for Akg op info register"""

def __init__(self, op_name):
super(AkgRegOp, self).__init__(op_name)
self.imply_type = "AutoDiff"
self.processor = "cuda"

def input(self, index=None, name=None, **kwargs):
"""
Register Akg op input information.

Args:
index (int): Order of the input. Default: None.
name (str): Name of the input. Default: None.
kwargs (dict): Other information for the input.
"""
param_list = [index, name]
key_list = ["index", "name"]
fn_list = [self._is_int, self._is_string]
input_dict = self._check_param(param_list, key_list, fn_list, kwargs)
self.inputs.append(input_dict)
return self

def output(self, index=None, name=None, **kwargs):
"""
Register Akg op output information.

Args:
index (int): Order of the output. Default: None.
name (str): Name of the output. Default: None.
kwargs (dict): Other information for the output.
"""
param_list = [index, name]
key_list = ["index", "name"]
fn_list = [self._is_int, self._is_string]
output_dict = self._check_param(param_list, key_list, fn_list, kwargs)
self.outputs.append(output_dict)
return self

def attr(self, name=None, param_type=None, value_type=None, **kwargs):
"""
Register Akg op attribute information.

Args:
name (str): Name of the attribute. Default: None.
param_type (str): Param type of the attribute. Default: None.
value_type (str): Value type of the attribute. Default: None.
kwargs (dict): Other information for the attribute.
"""
param_list = [name, param_type, value_type]
key_list = ["name", "param_type", "type"]
fn_list = [self._is_string]
attr_dict = self._check_param(param_list, key_list, fn_list, kwargs)
self.attr_.append(attr_dict)
return self


class AiCPURegOp(RegOp): class AiCPURegOp(RegOp):
"""Class for AiCPU op info register""" """Class for AiCPU op info register"""




Loading…
Cancel
Save