|
|
|
@@ -341,6 +341,15 @@ class _Executor: |
|
|
|
param.init_data(layout, set_sliced=True) |
|
|
|
obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode) |
|
|
|
|
|
|
|
def _set_dataset_mode(self, args_list): |
|
|
|
"""set dataset mode.""" |
|
|
|
# decide whether to sink based on whether the inputs is virtual or args_list is () |
|
|
|
if (args_list and isinstance(args_list[0], Tensor) and args_list[0].virtual_flag) or \ |
|
|
|
(args_list is not None and args_list == ()): |
|
|
|
_set_dataset_mode_config('sink') |
|
|
|
else: |
|
|
|
_set_dataset_mode_config('normal') |
|
|
|
|
|
|
|
def compile(self, obj, *args, phase='predict', params=None, do_convert=True, auto_parallel_mode=False): |
|
|
|
""" |
|
|
|
Compiles graph. |
|
|
|
@@ -371,6 +380,8 @@ class _Executor: |
|
|
|
|
|
|
|
use_vm = not enable_ge or (enable_debug_runtime and context.get_context("mode") == context.PYNATIVE_MODE) |
|
|
|
|
|
|
|
self._set_dataset_mode(args_list) |
|
|
|
|
|
|
|
if phase in self.compile_cache.keys(): |
|
|
|
logger.debug("%r graph has existed.", phase) |
|
|
|
return phase, False |
|
|
|
@@ -399,12 +410,6 @@ class _Executor: |
|
|
|
|
|
|
|
# the following GE init process is not needed when use vm or ms backend |
|
|
|
if enable_ge: |
|
|
|
# decide whether to sink based on whether the inputs is virtual or not |
|
|
|
if args_list and isinstance(args_list[0], Tensor) and args_list[0].virtual_flag: |
|
|
|
_set_dataset_mode_config('sink') |
|
|
|
else: |
|
|
|
_set_dataset_mode_config('normal') |
|
|
|
|
|
|
|
self._build_data_graph(obj, params, phase) |
|
|
|
|
|
|
|
if "export" not in phase: |
|
|
|
|