|
|
|
@@ -22,6 +22,7 @@ from mindspore import context |
|
|
|
from .._c_expression import Primitive_, real_run_op, prim_type |
|
|
|
from . import signature as sig |
|
|
|
|
|
|
|
|
|
|
|
class Primitive(Primitive_): |
|
|
|
""" |
|
|
|
Primitive is the base class of primitives in python. |
|
|
|
@@ -168,7 +169,7 @@ class Primitive(Primitive_): |
|
|
|
return type(self)(**self.init_attrs) |
|
|
|
|
|
|
|
def __repr__(self): |
|
|
|
attr = ', '.join([f'{k}={self.attrs[k]}'for k in self.attrs if not k in Primitive._repr_ignore_list]) |
|
|
|
attr = ', '.join([f'{k}={self.attrs[k]}' for k in self.attrs if not k in Primitive._repr_ignore_list]) |
|
|
|
info_str = f'Prim[{self.name}]' |
|
|
|
if attr: |
|
|
|
info_str += f'<{attr}>' |
|
|
|
@@ -425,6 +426,7 @@ def prim_attr_register(fn): |
|
|
|
Returns: |
|
|
|
function, original function. |
|
|
|
""" |
|
|
|
|
|
|
|
def deco(self, *args, **kwargs): |
|
|
|
if isinstance(self, PrimitiveWithInfer): |
|
|
|
PrimitiveWithInfer.__init__(self, self.__class__.__name__) |
|
|
|
@@ -442,6 +444,7 @@ def prim_attr_register(fn): |
|
|
|
self.add_prim_attr(name, value) |
|
|
|
self.init_attrs[name] = value |
|
|
|
fn(self, *args, **kwargs) |
|
|
|
|
|
|
|
deco.decorated_func = fn |
|
|
|
return deco |
|
|
|
|
|
|
|
@@ -470,6 +473,7 @@ def constexpr(fn=None, get_instance=True, name=None): |
|
|
|
>>> return len(x) |
|
|
|
>>> assert tuple_len_class()(a) == 2 |
|
|
|
""" |
|
|
|
|
|
|
|
def deco(fn): |
|
|
|
class CompileOp(PrimitiveWithInfer): |
|
|
|
def __init__(self): |
|
|
|
@@ -479,9 +483,11 @@ def constexpr(fn=None, get_instance=True, name=None): |
|
|
|
|
|
|
|
def infer_value(self, *args): |
|
|
|
return fn(*args) |
|
|
|
|
|
|
|
if get_instance: |
|
|
|
return CompileOp() |
|
|
|
return CompileOp |
|
|
|
|
|
|
|
if fn is not None: |
|
|
|
return deco(fn) |
|
|
|
return deco |
|
|
|
|