|
|
|
@@ -537,6 +537,9 @@ def constexpr(fn=None, get_instance=True, name=None): |
|
|
|
|
|
|
|
def deco(fn): |
|
|
|
class CompileOp(PrimitiveWithInfer): |
|
|
|
""" |
|
|
|
CompileOp is a temporary operator used to execute the constexpr function. |
|
|
|
""" |
|
|
|
def __init__(self): |
|
|
|
op_name = name if name else fn.__name__ |
|
|
|
PrimitiveWithInfer.__init__(self, op_name) |
|
|
|
@@ -545,6 +548,9 @@ def constexpr(fn=None, get_instance=True, name=None): |
|
|
|
def infer_value(self, *args): |
|
|
|
return fn(*args) |
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs): |
|
|
|
return fn(*args) |
|
|
|
|
|
|
|
if get_instance: |
|
|
|
return CompileOp() |
|
|
|
return CompileOp |
|
|
|
|