|
|
|
@@ -103,6 +103,19 @@ class Primitive(Primitive_): |
|
|
|
self.add_attr(name, value) |
|
|
|
return self |
|
|
|
|
|
|
|
def del_prim_attr(self, name): |
|
|
|
""" |
|
|
|
Del primitive attribute. |
|
|
|
|
|
|
|
Args: |
|
|
|
name (str): Attribute Name. |
|
|
|
""" |
|
|
|
if name in self.__dict__ and name in self.attrs: |
|
|
|
del self.__dict__[name] |
|
|
|
del self.attrs[name] |
|
|
|
self.del_attr(name) |
|
|
|
return self |
|
|
|
|
|
|
|
def set_stage(self, stage): |
|
|
|
""" |
|
|
|
Add stage id to primitive attribute. |
|
|
|
@@ -191,7 +204,7 @@ class Primitive(Primitive_): |
|
|
|
|
|
|
|
def init_prim_io_names(self, inputs, outputs): |
|
|
|
""" |
|
|
|
Initializes the name of inputs and outpus of Tensor or attributes. |
|
|
|
Initializes the name of inputs and outputs of Tensor or attributes. |
|
|
|
|
|
|
|
Args: |
|
|
|
inputs (list[str]): list of inputs names. |
|
|
|
@@ -222,9 +235,9 @@ class Primitive(Primitive_): |
|
|
|
class PrimitiveWithCheck(Primitive): |
|
|
|
""" |
|
|
|
PrimitiveWithCheck is the base class of primitives in python defines functions for checking operator input arguments |
|
|
|
but used the infer method registed in c++ source codes. |
|
|
|
but used the infer method registered in c++ source codes. |
|
|
|
|
|
|
|
There are three methods can be overide to define the check logic of the primitive: __check__(), check_shape(), |
|
|
|
There are three methods can be override to define the check logic of the primitive: __check__(), check_shape(), |
|
|
|
check_dtype(). If __check__() is defined in primitive, the __check__() has highest priority to be called. |
|
|
|
If __check__() is not defined, check_shape() and check_dtype() can be defined to describe the check logic of |
|
|
|
the shape and type. Method infer_value() can also be defined (such as PrimitiveWithInfer) for constant propagation. |
|
|
|
@@ -301,7 +314,7 @@ class PrimitiveWithInfer(Primitive): |
|
|
|
""" |
|
|
|
PrimitiveWithInfer is the base class of primitives in python and defines functions for tracking inference in python. |
|
|
|
|
|
|
|
There are four method can be overide to define the infer logic of the primitive: __infer__(), infer_shape(), |
|
|
|
There are four method can be override to define the infer logic of the primitive: __infer__(), infer_shape(), |
|
|
|
infer_dtype(), and infer_value(). If __infer__() is defined in primitive, the __infer__() has highest priority |
|
|
|
to be called. If __infer__() is not defined, infer_shape() and infer_dtype() can be defined to describe the infer |
|
|
|
logic of the shape and type. The infer_value() is used for constant propagation. |
|
|
|
|