|
|
|
@@ -18,6 +18,8 @@ |
|
|
|
import inspect |
|
|
|
import copy |
|
|
|
from mindspore.common.api import _wrap_func |
|
|
|
from mindspore.common import Parameter |
|
|
|
from mindspore.common._register_for_tensor import tensor_operator_registry |
|
|
|
from .._c_expression import Primitive_, real_run_op, prim_type |
|
|
|
from .._c_expression import signature_rw as sig_rw |
|
|
|
from .._c_expression import signature_kind as sig_kind |
|
|
|
@@ -49,6 +51,7 @@ class Primitive(Primitive_): |
|
|
|
self.name = name |
|
|
|
self.attrs = {} |
|
|
|
self.init_attrs = {"name": name} |
|
|
|
self._update_parameter = False |
|
|
|
Primitive_.__init__(self, name, self) |
|
|
|
if hasattr(self.__class__, '__mindspore_signature__'): |
|
|
|
sig = self._fill_signature(self.__class__.__mindspore_signature__) |
|
|
|
@@ -189,6 +192,11 @@ class Primitive(Primitive_): |
|
|
|
# for checking output number with kernel implementation |
|
|
|
self.add_prim_attr("output_names", outputs) |
|
|
|
|
|
|
|
@property |
|
|
|
def update_parameter(self): |
|
|
|
""" Whether the primitive will update the value of parameter.""" |
|
|
|
return self._update_parameter |
|
|
|
|
|
|
|
|
|
|
|
class PrimitiveWithInfer(Primitive): |
|
|
|
""" |
|
|
|
@@ -359,7 +367,20 @@ def constexpr(fn=None, get_instance=True, name=None): |
|
|
|
@_wrap_func |
|
|
|
def _run_op(obj, op_name, args): |
|
|
|
"""Single op execution function supported by ge in PyNative mode.""" |
|
|
|
output = real_run_op(obj, op_name, args) |
|
|
|
cast = tensor_operator_registry.get("cast") |
|
|
|
if op_name == "Cast" or obj.update_parameter: |
|
|
|
cast_args = args |
|
|
|
else: |
|
|
|
cast_args = list() |
|
|
|
for arg in args: |
|
|
|
if isinstance(arg, Parameter): |
|
|
|
if arg.cast_type: |
|
|
|
cast_args.append(cast(arg, arg.cast_type)) |
|
|
|
else: |
|
|
|
cast_args.append(arg) |
|
|
|
else: |
|
|
|
cast_args.append(arg) |
|
|
|
output = real_run_op(obj, op_name, tuple(cast_args)) |
|
|
|
if not output: |
|
|
|
raise RuntimeError("Pynative run op %s failed!" % op_name) |
|
|
|
if len(output) == 1: |
|
|
|
|