Browse Source

compile profiling

tags/v1.1.0
wilfChen 5 years ago
parent
commit
eb48c8d647
1 changed files with 7 additions and 8 deletions
  1. +7
    -8
      mindspore/common/api.py

+ 7
- 8
mindspore/common/api.py View File

@@ -383,8 +383,6 @@ class _Executor:
Str, the full phase of the cell. Str, the full phase of the cell.
Bool, if the graph has been compiled before, return False, else return True. Bool, if the graph has been compiled before, return False, else return True.
""" """
obj.check_names()
_check_full_batch()
args_names, args_list = _generate_pip_args(obj, *args) args_names, args_list = _generate_pip_args(obj, *args)
dic = dict(zip(args_names, args_list)) dic = dict(zip(args_names, args_list))
key = generate_key(phase, dic) key = generate_key(phase, dic)
@@ -393,22 +391,23 @@ class _Executor:
phase = phase + '.' + self.phase_prefix + '.' + str(obj.create_time) phase = phase + '.' + self.phase_prefix + '.' + str(obj.create_time)
else: else:
phase = self.phase_prefix + phase + '.' + str(obj.create_time) phase = self.phase_prefix + phase + '.' + str(obj.create_time)
enable_debug_runtime = context.get_context("enable_debug_runtime")
enable_ge = context.get_context("enable_ge")

use_vm = not enable_ge or (enable_debug_runtime and context.get_context("mode") == context.PYNATIVE_MODE)

self._set_dataset_mode(args_list)


if phase in self.compile_cache.keys(): if phase in self.compile_cache.keys():
logger.debug("%r graph has existed.", phase) logger.debug("%r graph has existed.", phase)
return phase, False return phase, False


obj.check_names()
_check_full_batch()
self._set_dataset_mode(args_list)

is_sink_mode = args and isinstance(args[0], Tensor) and args[0].virtual_flag is_sink_mode = args and isinstance(args[0], Tensor) and args[0].virtual_flag
if auto_parallel_mode and _need_to_full() and not is_sink_mode and obj.auto_parallel_compile_and_run(): if auto_parallel_mode and _need_to_full() and not is_sink_mode and obj.auto_parallel_compile_and_run():
args_full = _to_full_tensor(args, _get_device_num(), _get_global_rank()) args_full = _to_full_tensor(args, _get_device_num(), _get_global_rank())
_, args_list = _generate_pip_args(obj, *args_full) _, args_list = _generate_pip_args(obj, *args_full)


enable_debug_runtime = context.get_context("enable_debug_runtime")
enable_ge = context.get_context("enable_ge")
use_vm = not enable_ge or (enable_debug_runtime and context.get_context("mode") == context.PYNATIVE_MODE)
result = self._executor.compile(obj, args_list, phase, use_vm) result = self._executor.compile(obj, args_list, phase, use_vm)
self.compile_cache[phase] = phase self.compile_cache[phase] = phase
if not result: if not result:


Loading…
Cancel
Save