|
|
|
@@ -440,59 +440,19 @@ class _Executor: |
|
|
|
Str, the full phase of the cell. |
|
|
|
Bool, if the graph has been compiled before, return False, else return True. |
|
|
|
""" |
|
|
|
from mindspore import nn |
|
|
|
from mindspore.ops.composite import GradOperation |
|
|
|
|
|
|
|
class InputsToAttrCell(nn.Cell): |
|
|
|
"""The cell that converts non-tensor inputs to attr.""" |
|
|
|
|
|
|
|
def __init__(self, net, args_names, non_tensor_inputs): |
|
|
|
super(InputsToAttrCell, self).__init__() |
|
|
|
self.net = net |
|
|
|
self.args_names = args_names |
|
|
|
self.non_tensor_inputs = non_tensor_inputs |
|
|
|
self.inputs_to_attr = True |
|
|
|
|
|
|
|
def construct(self, *tensor_inputs): |
|
|
|
real_inputs = () |
|
|
|
index = 0 |
|
|
|
for i in args_names: |
|
|
|
if i in self.non_tensor_inputs.keys(): |
|
|
|
real_inputs += (self.non_tensor_inputs[i],) |
|
|
|
else: |
|
|
|
real_inputs += (tensor_inputs[index],) |
|
|
|
index += 1 |
|
|
|
return self.net(*real_inputs) |
|
|
|
|
|
|
|
args_names, args_list = _generate_pip_args(obj, *args) |
|
|
|
if not hasattr(obj, "inputs_to_attr"): |
|
|
|
dic = dict(zip(args_names, args_list)) |
|
|
|
key = generate_key(phase, dic) |
|
|
|
obj.phase_prefix = str(key[1]) |
|
|
|
if 'export' in phase: |
|
|
|
phase = phase + '.' + obj.phase_prefix + '.' + str(obj.create_time) |
|
|
|
else: |
|
|
|
phase = obj.phase_prefix + phase + '.' + str(obj.create_time) |
|
|
|
|
|
|
|
if phase in self.compile_cache.keys(): |
|
|
|
logger.debug("%r graph has existed.", phase) |
|
|
|
return phase, False |
|
|
|
|
|
|
|
if getattr(obj, "support_non_tensor_inputs", None): |
|
|
|
for i in obj.__dict__.values(): |
|
|
|
if isinstance(i, GradOperation): |
|
|
|
raise ValueError("Not support set 'support_non_tensor_inputs' to the 'True' for grad net, " |
|
|
|
"only support forward net.") |
|
|
|
attrs = {} |
|
|
|
inputs = [] |
|
|
|
for key, value in dic.items(): |
|
|
|
if not isinstance(value, (Tensor, MetaTensor)): |
|
|
|
attrs[key] = value |
|
|
|
else: |
|
|
|
inputs.append(value) |
|
|
|
if attrs: |
|
|
|
inputs_to_attr_cell = InputsToAttrCell(obj, args_names, attrs) |
|
|
|
return self.compile(inputs_to_attr_cell, *inputs, phase=phase) |
|
|
|
dic = dict(zip(args_names, args_list)) |
|
|
|
key = generate_key(phase, dic) |
|
|
|
obj.phase_prefix = str(key[1]) |
|
|
|
if 'export' in phase: |
|
|
|
phase = phase + '.' + obj.phase_prefix + '.' + str(obj.create_time) |
|
|
|
else: |
|
|
|
phase = obj.phase_prefix + phase + '.' + str(obj.create_time) |
|
|
|
|
|
|
|
if phase in self.compile_cache.keys(): |
|
|
|
logger.debug("%r graph has existed.", phase) |
|
|
|
return phase, False |
|
|
|
|
|
|
|
obj.check_names() |
|
|
|
_check_full_batch() |
|
|
|
|