|
|
|
@@ -421,6 +421,7 @@ class _Executor: |
|
|
|
Bool, if the graph has been compiled before, return False, else return True. |
|
|
|
""" |
|
|
|
from mindspore import nn |
|
|
|
from mindspore.ops.composite import GradOperation |
|
|
|
|
|
|
|
class InputsToAttrCell(nn.Cell): |
|
|
|
"""The cell that converts non-tensor inputs to attr.""" |
|
|
|
@@ -457,17 +458,21 @@ class _Executor: |
|
|
|
logger.debug("%r graph has existed.", phase) |
|
|
|
return phase, False |
|
|
|
|
|
|
|
if getattr(obj, "support_non_tensor_inputs", None): |
|
|
|
attrs = {} |
|
|
|
inputs = [] |
|
|
|
for key, value in dic.items(): |
|
|
|
if not isinstance(value, (Tensor, MetaTensor)): |
|
|
|
attrs[key] = value |
|
|
|
else: |
|
|
|
inputs.append(value) |
|
|
|
if attrs: |
|
|
|
inputs_to_attr_cell = InputsToAttrCell(obj, args_names, attrs) |
|
|
|
return self.compile(inputs_to_attr_cell, *inputs, phase=phase) |
|
|
|
if getattr(obj, "support_non_tensor_inputs", None): |
|
|
|
for i in obj.__dict__.values(): |
|
|
|
if isinstance(i, GradOperation): |
|
|
|
raise ValueError("Not support set 'support_non_tensor_inputs' to the 'True' for grad net, " |
|
|
|
"only support forward net.") |
|
|
|
attrs = {} |
|
|
|
inputs = [] |
|
|
|
for key, value in dic.items(): |
|
|
|
if not isinstance(value, (Tensor, MetaTensor)): |
|
|
|
attrs[key] = value |
|
|
|
else: |
|
|
|
inputs.append(value) |
|
|
|
if attrs: |
|
|
|
inputs_to_attr_cell = InputsToAttrCell(obj, args_names, attrs) |
|
|
|
return self.compile(inputs_to_attr_cell, *inputs, phase=phase) |
|
|
|
|
|
|
|
obj.check_names() |
|
|
|
_check_full_batch() |
|
|
|
|