GitOrigin-RevId: ccc984acbd
tags/v1.3.0
| @@ -3,6 +3,7 @@ from collections import defaultdict | |||
| from contextlib import contextmanager | |||
| from typing import Callable | |||
| from ..core._imperative_rt.core2 import pop_scope, push_scope | |||
| from ..core.autodiff.grad import Grad | |||
| from ..logger import get_logger | |||
| from ..tensor import Tensor | |||
| @@ -239,6 +240,7 @@ class GradManager: | |||
| :param y: tensor or list of tensors | |||
| :param dy: tensor or list of tensors. Defaults to 1 if y is scalar | |||
| """ | |||
| push_scope("backward") | |||
| from ..functional import ones_like | |||
| global backwarding_grad_manager | |||
| @@ -280,6 +282,7 @@ class GradManager: | |||
| finally: | |||
| self.release() | |||
| backwarding_grad_manager = cache | |||
| pop_scope("backward") | |||
| def record(self): | |||
| r""" | |||
| @@ -8,5 +8,17 @@ | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import os | |||
| import sys | |||
| from contextlib import contextmanager | |||
| from ._imperative_rt.core2 import get_option, set_option | |||
| from .tensor.megbrain_graph import Graph | |||
| @contextmanager | |||
| def option(key, value): | |||
| value = int(value) | |||
| old = get_option(key) | |||
| set_option(key, value) | |||
| yield | |||
| assert get_option(key) == value | |||
| set_option(key, old) | |||
| @@ -12,6 +12,7 @@ from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union | |||
| import numpy as np | |||
| from ..core._imperative_rt.core2 import pop_scope, push_scope | |||
| from ..core.tensor.utils import make_shape_tuple | |||
| from ..logger import get_logger | |||
| from ..tensor import Parameter, Tensor | |||
| @@ -78,6 +79,7 @@ class Module(metaclass=ABCMeta): | |||
| self._forward_hooks = OrderedDict() | |||
| self._modules = [] | |||
| self._name = "{anonymous}" | |||
| @abstractmethod | |||
| def forward(self, inputs): | |||
| @@ -103,6 +105,7 @@ class Module(metaclass=ABCMeta): | |||
| return HookHandler(self._forward_hooks, hook) | |||
| def __call__(self, *inputs, **kwargs): | |||
| push_scope(self._name) | |||
| for hook in self._forward_pre_hooks.values(): | |||
| modified_inputs = hook(self, inputs) | |||
| if modified_inputs is not None: | |||
| @@ -116,6 +119,7 @@ class Module(metaclass=ABCMeta): | |||
| modified_outputs = hook(self, inputs, outputs) | |||
| if modified_outputs is not None: | |||
| outputs = modified_outputs | |||
| pop_scope(self._name) | |||
| return outputs | |||
| def _flatten( | |||
| @@ -571,6 +575,14 @@ class Module(metaclass=ABCMeta): | |||
| return set(loaded), set(skipped) | |||
| def __getattribute__(self, name: str): | |||
| value = super().__getattribute__(name) | |||
| if name == "_name": | |||
| return value | |||
| if _is_module(value): | |||
| value._name = name | |||
| return value | |||
| def __setattr__(self, name: str, value): | |||
| if _is_module(value): | |||
| modules = self.__dict__.get("_modules") | |||
| @@ -15,6 +15,7 @@ from typing import Union | |||
| import numpy as np | |||
| from ..core._imperative_rt.core2 import pop_scope, push_scope | |||
| from ..core.tensor.utils import set_convert_inputs | |||
| from ..tensor import Parameter, Tensor | |||
| from ..utils.deprecation import deprecated | |||
| @@ -155,7 +156,9 @@ class Optimizer(metaclass=ABCMeta): | |||
| "but the ordering of parameters in sets will change between runs. " | |||
| "Please use a list instead." | |||
| ) | |||
| push_scope("step") | |||
| self._updates(group) | |||
| pop_scope("step") | |||
| # restore the globle state `_enable_convert_inputs` | |||
| set_convert_inputs(backup) | |||
| return self | |||
| @@ -172,8 +175,10 @@ class Optimizer(metaclass=ABCMeta): | |||
| Set the grad attribute to None for all parameters. | |||
| """ | |||
| for param_group in self.param_groups: | |||
| push_scope("clear_grad") | |||
| for param in param_group["params"]: | |||
| param.grad = None | |||
| pop_scope("clear_grad") | |||
| def state_dict(self) -> Dict: | |||
| r""" | |||
| @@ -6,159 +6,17 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import base64 | |||
| import json | |||
| import os | |||
| import re | |||
| from typing import Iterable, List, Optional | |||
| from contextlib import contextmanager | |||
| from typing import List | |||
| from ..core._imperative_rt import OperatorNodeConfig, ProfileEntry | |||
| from ..core._imperative_rt import ProfilerImpl as _Profiler | |||
| from ..core._imperative_rt.core2 import sync | |||
| from ..core._imperative_rt.ops import CollectiveComm | |||
| def _make_dict(**kwargs): | |||
| unused_keys = [] | |||
| for k, v in kwargs.items(): | |||
| if v is None: | |||
| unused_keys.append(k) | |||
| for k in unused_keys: | |||
| del kwargs[k] | |||
| return kwargs | |||
| def _print_opnode_config(config): | |||
| return _make_dict( | |||
| name=config.name, dtype=config.dtype, comp_node_arr=config.comp_node_arr, | |||
| ) | |||
| def _dump_chrome_timeline(entries: List[ProfileEntry], path: str): | |||
| pid = os.getpid() | |||
| trace_events = [] | |||
| def append_event(**kwargs): | |||
| trace_events.append(_make_dict(**kwargs)) | |||
| for id, entry in enumerate(entries): | |||
| op = entry.op | |||
| name = type(op).__name__ | |||
| host_begin, host_end = entry.host | |||
| device_list = entry.device_list | |||
| args = Profiler.fetch_attrs(op) | |||
| args["__id__"] = "[{}]".format(id) | |||
| cat = name | |||
| for ts, ph in [(host_begin, "B"), (host_end, "E")]: | |||
| append_event( | |||
| name=name, ph=ph, ts=ts * 1000, pid=pid, tid="host", args=args, cat=cat, | |||
| ) | |||
| for device, device_begin, device_end in device_list: | |||
| for ts, ph in [(device_begin(), "B"), (device_end(), "E")]: | |||
| append_event( | |||
| name=name, ph=ph, ts=ts * 1000, pid=pid, tid=str(device), args=args, | |||
| ) | |||
| with open("{}.chrome_timeline.json".format(path), "w") as f: | |||
| json.dump(trace_events, f, indent=2) | |||
| def _dump_compatible(entries: List[ProfileEntry], path: str): | |||
| obj = { | |||
| "graph_exec": {"var": [], "operator": {}}, | |||
| "profiler": {"device": {}, "host": {}, "opr_footprint": {}}, | |||
| } | |||
| var_list = obj["graph_exec"]["var"] | |||
| operator_dict = obj["graph_exec"]["operator"] | |||
| device_dict = obj["profiler"]["device"] | |||
| host_dict = obj["profiler"]["host"] | |||
| opr_foot_print_dict = obj["profiler"]["opr_footprint"] | |||
| def add_var(var) -> int: | |||
| var_id = len(var_list) | |||
| var_list.append( | |||
| {"comp_node": str(var[2]),} | |||
| ) | |||
| return var_id | |||
| for op_id, entry in enumerate(entries): | |||
| operator_dict[op_id] = { | |||
| "input": [add_var(var) for var in entry.inputs], | |||
| "output": [add_var(var) for var in entry.outputs], | |||
| "name": str(entry.op.ctype()), | |||
| "type": "imperative", | |||
| "id": entry.id, | |||
| } | |||
| op_device_dict = {} | |||
| for device, device_begin, device_end in entry.device_list: | |||
| op_device_dict[str(device)] = { | |||
| "start": device_begin(), | |||
| "kern": device_begin(), | |||
| "end": device_end(), | |||
| } | |||
| device_dict[op_id] = op_device_dict | |||
| host_begin, host_end = entry.host | |||
| host_dict[op_id] = { | |||
| "host": {"start": host_begin, "kern": host_begin, "end": host_end} | |||
| } | |||
| opr_footprint = { | |||
| "out_shapes": [oup[1] for oup in entry.outputs], | |||
| "in_shapes": [inp[1] for inp in entry.inputs], | |||
| "params": {}, | |||
| } | |||
| if entry.memory > 0: | |||
| opr_footprint["memory"] = entry.memory | |||
| if entry.computation > 0: | |||
| opr_footprint["computation"] = entry.computation | |||
| opr_foot_print_dict[op_id] = opr_footprint | |||
| with open("{}.compatible.json".format(path), "w") as f: | |||
| json.dump(obj, f, indent=2) | |||
| def _dump_graphviz(entries: List[ProfileEntry], path: str): | |||
| import json | |||
| import graphviz | |||
| graph = graphviz.Digraph() | |||
| graph.graph_attr["ordering"] = "out" | |||
| var_cache = {} | |||
| def cache_var(var_id, var_shape): | |||
| if var_id not in var_cache: | |||
| var_name = "var({})".format(var_id) | |||
| var_label = "{}\nshape:{}\n".format(var_name, shape) | |||
| graph.node(var_name, var_label) | |||
| var_cache[var_id] = var_name | |||
| return var_cache[var_id] | |||
| for op_id, entry in enumerate(entries): | |||
| op = entry.op | |||
| op_name = "op({})".format(op_id) | |||
| op_type = type(op).__name__ | |||
| op_attrs = Profiler.fetch_attrs(op) | |||
| label_lines = [] | |||
| if "param" in op_attrs: | |||
| del op_attrs["param"] | |||
| label_lines.append("{}:{}".format(op_name, op_type)) | |||
| for k, v in op_attrs.items(): | |||
| label_lines.append("attr[{}]: {}".format(k, v)) | |||
| op_param_str = entry.param | |||
| if len(op_param_str) > 0: | |||
| op_param = json.loads(op_param_str) | |||
| for k, v in op_param.items(): | |||
| label_lines.append("param[{}]:{}".format(k, v)) | |||
| host_begin, host_end = entry.host | |||
| label_lines.append("time[host]: {:f}ms".format(host_end - host_begin)) | |||
| for device, device_begin, device_end in entry.device_list: | |||
| device_time = device_end() - device_begin() | |||
| label_lines.append("time[{}]: {:f}ms".format(device, device_time)) | |||
| op_label = "\n".join(label_lines) | |||
| graph.node(op_name, op_label, shape="rectangle") | |||
| for var_id, shape, device in entry.inputs: | |||
| graph.edge(cache_var(var_id, shape), op_name) | |||
| for var_id, shape, device in entry.outputs: | |||
| graph.edge(op_name, cache_var(var_id, shape)) | |||
| graph.save("{}.graphviz.dot".format(path)) | |||
| from ..core._imperative_rt.core2 import ( | |||
| pop_scope, | |||
| push_scope, | |||
| start_profile, | |||
| stop_profile, | |||
| sync, | |||
| ) | |||
| class Profiler: | |||
| @@ -181,85 +39,45 @@ class Profiler: | |||
| # Only profile record of last iter would be saved | |||
| with Profiler("profile"): | |||
| # your code here | |||
| # Then open the profile file in chrome timeline window | |||
| """ | |||
| CHROME_TIMELINE = "chrome_timeline" | |||
| COMPATIBLE = "compatible" | |||
| GRAPHVIZ = "graphviz" | |||
| WITH_FOOTPRINT = 1 | |||
| CHROME_TIMELINE = "chrome_timeline.json" | |||
| _type_map = { | |||
| OperatorNodeConfig: lambda x: _print_opnode_config(x), | |||
| bytes: lambda x: base64.encodebytes(x).decode("ascii"), | |||
| CollectiveComm.Mode: lambda x: str(x), | |||
| } | |||
| _dumper_map = { | |||
| CHROME_TIMELINE: _dump_chrome_timeline, | |||
| COMPATIBLE: _dump_compatible, | |||
| GRAPHVIZ: _dump_graphviz, | |||
| } | |||
| COMMAND = 1 << 0 | |||
| OPERATOR = 1 << 1 | |||
| TENSOR_LIFETIME = 1 << 2 | |||
| TENSOR_PROP = 1 << 3 | |||
| SYNC = 1 << 4 | |||
| SCOPE = 1 << 5 | |||
| ALL = (1 << 6) - 1 | |||
| def __init__( | |||
| self, | |||
| path: str = "profile", | |||
| format: str = CHROME_TIMELINE, | |||
| *, | |||
| formats: Iterable[str] = (CHROME_TIMELINE,), | |||
| type_filter: str = ".*", | |||
| exit_dump: bool = True | |||
| topic=OPERATOR | SCOPE, | |||
| align_time=True, | |||
| show_operator_name=True | |||
| ) -> None: | |||
| self._impl = _Profiler() | |||
| self._path = path | |||
| if isinstance(formats, str): | |||
| formats = (formats,) | |||
| self._filter = type_filter | |||
| self._dumpers = [Profiler._dumper_map[fmt] for fmt in formats] | |||
| self._exit_dump = exit_dump | |||
| self._format = format | |||
| self._options = { | |||
| "topic": int(topic), | |||
| "align_time": int(align_time), | |||
| "show_operator_name": int(show_operator_name), | |||
| } | |||
| def __enter__(self): | |||
| sync() | |||
| self._impl.start(Profiler.WITH_FOOTPRINT) | |||
| start_profile(self._options) | |||
| return self | |||
| def __exit__(self, val, tp, trace): | |||
| if self._exit_dump: | |||
| self.dump() | |||
| sync() | |||
| self._impl.stop() | |||
| self._impl.clear() | |||
| @classmethod | |||
| def fetch_attrs(cls, op): | |||
| attrs = dir(op) | |||
| results = {} | |||
| for attr in attrs: | |||
| if attr.startswith("_"): | |||
| continue | |||
| value = op.__getattribute__(attr) | |||
| if callable(value): | |||
| continue | |||
| value_type = type(value) | |||
| if value_type in cls._type_map: | |||
| value = cls._type_map[value_type](value) | |||
| results[attr] = str(value) | |||
| return results | |||
| def dump(self, path: Optional[str] = None): | |||
| stop_profile(self._path, self._format) | |||
| # dump is async, so it's necessary to sync interpreter | |||
| sync() | |||
| raw = [ | |||
| entry | |||
| for entry in self._impl.dump() | |||
| if re.match(self._filter, type(entry.op).__name__) | |||
| ] | |||
| if path is None: | |||
| path = self._path | |||
| for dumper in self._dumpers: | |||
| dumper(raw, path) | |||
| def __call__(self, func): | |||
| def wrapper(*args, **kwargs): | |||
| @@ -269,4 +87,23 @@ class Profiler: | |||
| return wrapper | |||
| @contextmanager | |||
| def scope(name): | |||
| push_scope(name) | |||
| yield | |||
| pop_scope(name) | |||
| profile = Profiler | |||
| def merge_trace_events(sources: List[str], target: str): | |||
| names = list(map(lambda x: x + ".chrome_timeline.json", sources)) | |||
| result = [] | |||
| for name in names: | |||
| with open(name, "r", encoding="utf-8") as f: | |||
| content = json.load(f) | |||
| for entry in content: | |||
| result.append(entry) | |||
| with open(target + ".chrome_timeline.json", "w") as f: | |||
| json.dump(result, f, ensure_ascii=False, indent=4) | |||
| @@ -807,16 +807,34 @@ void init_tensor(py::module m) { | |||
| } | |||
| } | |||
| m.def("set_option", | |||
| [](std::string name, int value){ interpreter_for_py->set_option(name, value); }); | |||
| m.def("get_option", | |||
| [](std::string name){ return interpreter_for_py->get_option(name); }); | |||
| m.def("_set_swap_flag", | |||
| [](bool flag) { interpreter_for_py->set_swap_flag(flag); }); | |||
| [](bool flag) { interpreter_for_py->set_option("enable_swap", flag); }); | |||
| m.def("_set_drop_flag", | |||
| [](bool flag) { interpreter_for_py->set_drop_flag(flag); }); | |||
| [](bool flag) { interpreter_for_py->set_option("enable_drop", flag); }); | |||
| m.def("config_async_level", | |||
| [](int level) { interpreter_for_py->config_async_level(level); }); | |||
| [](int level) { | |||
| mgb_assert(level >= 0 and level <= 2, "async_level should be 0, 1 or 2"); | |||
| interpreter_for_py->set_option("async_level", level); | |||
| }); | |||
| m.def("get_async_level", | |||
| []() { return interpreter_for_py->get_async_level(); }); | |||
| []() { return interpreter_for_py->get_option("async_level"); }); | |||
| m.def("set_buffer_length", | |||
| [](int length) { interpreter_for_py->set_buffer_length(length); }); | |||
| [](int length) { | |||
| mgb_assert(length >= 0 and length < 100, "buffer_length should be in [0, 100)"); | |||
| interpreter_for_py->set_option("buffer_length", length); | |||
| }); | |||
| m.def("push_scope", | |||
| [](std::string name) { interpreter_for_py->push_scope(name); }); | |||
| m.def("pop_scope", | |||
| [](std::string name) { interpreter_for_py->pop_scope(name); }); | |||
| m.def("start_profile", | |||
| [](std::unordered_map<std::string, int> option) { return interpreter_for_py->start_profile(option); }); | |||
| m.def("stop_profile", | |||
| [](std::string basename, std::string format) { interpreter_for_py->stop_profile(basename, format); }); | |||
| m.def("sync", | |||
| []() { | |||
| interpreter_for_py->sync(); | |||
| @@ -200,33 +200,6 @@ void init_utils(py::module m) { | |||
| m.def("_get_device_count", &mgb::CompNode::get_device_count, | |||
| "Get total number of specific devices on this system"); | |||
| using mgb::imperative::ProfileEntry; | |||
| py::class_<ProfileEntry>(m, "ProfileEntry") | |||
| .def_readwrite("op", &ProfileEntry::op) | |||
| .def_readwrite("host", &ProfileEntry::host) | |||
| .def_readwrite("device_list", &ProfileEntry::device_list) | |||
| .def_readwrite("inputs", &ProfileEntry::inputs) | |||
| .def_readwrite("outputs", &ProfileEntry::outputs) | |||
| .def_readwrite("id", &ProfileEntry::id) | |||
| .def_readwrite("parent", &ProfileEntry::parent) | |||
| .def_readwrite("memory", &ProfileEntry::memory) | |||
| .def_readwrite("computation", &ProfileEntry::computation) | |||
| .def_property_readonly("param", [](ProfileEntry& self)->std::string{ | |||
| if(self.param){ | |||
| return self.param->to_string(); | |||
| } else { | |||
| return {}; | |||
| } | |||
| }); | |||
| py::class_<mgb::imperative::Profiler>(m, "ProfilerImpl") | |||
| .def(py::init<>()) | |||
| .def("start", &mgb::imperative::Profiler::start) | |||
| .def("stop", &mgb::imperative::Profiler::stop) | |||
| .def("clear", &mgb::imperative::Profiler::clear) | |||
| .def("dump", &mgb::imperative::Profiler::get_profile); | |||
| using mgb::imperative::TensorSanityCheck; | |||
| py::class_<TensorSanityCheck>(m, "TensorSanityCheckImpl") | |||
| .def(py::init<>()) | |||
| @@ -0,0 +1,54 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied | |||
| import json | |||
| import os | |||
| import pytest | |||
| from megengine import Parameter, tensor | |||
| from megengine.core import option | |||
| from megengine.module import Module | |||
| from megengine.utils.profiler import Profiler, scope | |||
| class Simple(Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.a = Parameter([1.23], dtype="float32") | |||
| def forward(self, x): | |||
| x = x * self.a | |||
| return x | |||
| def test_profiler(): | |||
| profile_prefix = "pytest_profile" | |||
| profile_format = "chrome_timeline.json" | |||
| profile_path = "{}.{}".format(profile_prefix, profile_format) | |||
| with Profiler(profile_prefix, format=profile_format): | |||
| with scope("my_scope"): | |||
| oup = Simple()(tensor([1.23], dtype="float32")) | |||
| with open(profile_path, "r") as f: | |||
| events = json.load(f) | |||
| os.remove(profile_path) | |||
| prev_ts = {} | |||
| scope_count = 0 | |||
| for event in events: | |||
| if "dur" in event: | |||
| assert event["dur"] >= 0 | |||
| elif "ts" in event and "tid" in event: | |||
| ts = event["ts"] | |||
| tid = event["tid"] | |||
| if ts == 0: | |||
| continue | |||
| assert (tid not in prev_ts) or prev_ts[tid] <= ts | |||
| prev_ts[tid] = ts | |||
| if "name" in event and event["name"] == "my_scope": | |||
| scope_count += 1 | |||
| assert scope_count > 0 and scope_count % 2 == 0 | |||
| @@ -17,52 +17,37 @@ namespace mgb { | |||
| namespace imperative { | |||
| template <typename TFunction> | |||
| class FunctionHooker; | |||
| class FunctionHook; | |||
| template <typename TRet, typename... TArgs> | |||
| class FunctionHooker<TRet(TArgs...)> { | |||
| template <template <typename> class TFunction, typename TRet, typename... TArgs> | |||
| class FunctionHook<TFunction<TRet(TArgs...)>> { | |||
| public: | |||
| using FunctionType = thin_function<TRet(TArgs...)>; | |||
| //Type of hooks. Hook should accept a real function as argument | |||
| //and invoke it on an appropriate time | |||
| using HookType = thin_function<TRet(FunctionType, TArgs...)>; | |||
| explicit FunctionHooker(FunctionType* fptr) : m_fptr{fptr} { | |||
| m_backup = {nullptr, [](FunctionType*){}}; | |||
| using FunctionType = TFunction<TRet(TArgs...)>; | |||
| explicit FunctionHook(FunctionType* fptr) : m_fptr{fptr} { | |||
| m_backup = *fptr; | |||
| } | |||
| public: | |||
| FunctionHooker& apply_hook(HookType&& hook) { | |||
| if (!m_backup) { | |||
| FunctionType* backup = new FunctionType(*m_fptr); | |||
| //Restore hooked function, would be invoked when destructed | |||
| std::function<void(FunctionType*)> restorer = | |||
| [fptr = m_fptr](FunctionType* bkp) -> void { | |||
| *fptr = *bkp; | |||
| delete bkp; | |||
| }; | |||
| m_backup = decltype(m_backup)(backup, restorer); | |||
| } | |||
| template <typename THook, typename=std::enable_if_t<std::is_invocable_r_v<TRet, THook, FunctionType, TArgs...>, void>> | |||
| FunctionHook& apply_hook(THook&& hook) { | |||
| //Replace with hooked version | |||
| *m_fptr = [func = *m_fptr, hook](TArgs... args) -> TRet { | |||
| *m_fptr = [func = *m_fptr, hook=std::forward<THook>(hook)](TArgs... args) -> TRet { | |||
| return hook(func, std::forward<TArgs>(args)...); | |||
| }; | |||
| //Convinent for chain call | |||
| return *this; | |||
| } | |||
| private: | |||
| FunctionType* m_fptr; | |||
| std::unique_ptr<FunctionType, std::function<void(FunctionType*)>> m_backup; | |||
| FunctionType m_backup; | |||
| public: | |||
| ~FunctionHook() { | |||
| *m_fptr = std::move(m_backup); | |||
| } | |||
| }; | |||
| //Helps to deduce template args | |||
| template <typename TRet, typename... TArgs> | |||
| FunctionHooker(thin_function<TRet(TArgs...)>* f) | |||
| -> FunctionHooker<TRet(TArgs...)>; | |||
| template<typename TSignature> | |||
| auto make_shared_hook(thin_function<TSignature>* fptr){ | |||
| return std::make_shared<FunctionHooker<TSignature>>(fptr); | |||
| template<typename TFunction> | |||
| auto make_shared_hook(TFunction* fptr){ | |||
| return std::make_shared<FunctionHook<TFunction>>(fptr); | |||
| } | |||
| } // namespace imperative | |||
| @@ -0,0 +1,231 @@ | |||
| /** | |||
| * \file imperative/src/impl/interpreter/commands.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include <string> | |||
| #include <variant> | |||
| #include "megbrain/tensor.h" | |||
| #include "megbrain/imperative/op_def.h" | |||
| #include "megbrain/imperative/utils/to_string.h" | |||
| namespace mgb::imperative { | |||
| namespace interpreter::intl { | |||
| struct TensorInfo; | |||
| class InterpreterProfiler; | |||
| struct Put { | |||
| TensorInfo* dest; | |||
| HostTensorND value; | |||
| bool no_cache = false; | |||
| template <typename TFunctor> | |||
| void get_props(TFunctor&& functor) const { | |||
| functor("dest", dest); | |||
| functor("no_cache", no_cache); | |||
| //functor("value", value); | |||
| } | |||
| const char* get_name() const { | |||
| return "Put"; | |||
| } | |||
| }; | |||
| struct ApplyOp { | |||
| std::shared_ptr<OpDef> op; | |||
| SmallVector<TensorInfo*> inputs; | |||
| SmallVector<TensorInfo*> outputs; | |||
| SmallVector<TensorInfo*> dels; | |||
| template <typename TFunctor> | |||
| void get_props(TFunctor&& functor) const { | |||
| functor("op", op); | |||
| functor("inputs", inputs); | |||
| functor("outputs", outputs); | |||
| functor("dels", dels); | |||
| } | |||
| const char* get_name() const { | |||
| return "ApplyOp"; | |||
| } | |||
| }; | |||
| struct Del { | |||
| TensorInfo* dest; | |||
| template <typename TFunctor> | |||
| void get_props(TFunctor&& functor) const { | |||
| functor("dest", dest); | |||
| } | |||
| const char* get_name() const { | |||
| return "Del"; | |||
| } | |||
| }; | |||
| struct GetValue { | |||
| TensorInfo* dest; | |||
| template <typename TFunctor> | |||
| void get_props(TFunctor&& functor) const { | |||
| functor("dest", dest); | |||
| } | |||
| const char* get_name() const { | |||
| return "GetValue"; | |||
| } | |||
| }; | |||
| struct SwapIn { | |||
| TensorInfo* dest; | |||
| template <typename TFunctor> | |||
| void get_props(TFunctor&& functor) const { | |||
| functor("dest", dest); | |||
| } | |||
| const char* get_name() const { | |||
| return "SwapIn"; | |||
| } | |||
| }; | |||
| struct SwapOut { | |||
| TensorInfo* dest; | |||
| template <typename TFunctor> | |||
| void get_props(TFunctor&& functor) const { | |||
| functor("dest", dest); | |||
| } | |||
| const char* get_name() const { | |||
| return "SwapOut"; | |||
| } | |||
| }; | |||
| struct Drop { | |||
| TensorInfo* dest; | |||
| template <typename TFunctor> | |||
| void get_props(TFunctor&& functor) const { | |||
| functor("dest", dest); | |||
| } | |||
| const char* get_name() const { | |||
| return "Drop"; | |||
| } | |||
| }; | |||
| struct SetOption { | |||
| std::string key; | |||
| int value; | |||
| template <typename TFunctor> | |||
| void get_props(TFunctor&& functor) const { | |||
| functor("key", key); | |||
| functor("value", value); | |||
| } | |||
| const char* get_name() const { | |||
| return "SetOption"; | |||
| } | |||
| }; | |||
| struct StartProfile { | |||
| InterpreterProfiler* profiler; | |||
| template <typename TFunctor> | |||
| void get_props(TFunctor&& functor) const {} | |||
| const char* get_name() const { | |||
| return "StartProfile"; | |||
| } | |||
| }; | |||
| struct StopProfile { | |||
| std::string basename; | |||
| std::string format; | |||
| template <typename TFunctor> | |||
| void get_props(TFunctor&& functor) const { | |||
| functor("basename", basename); | |||
| functor("format", format); | |||
| } | |||
| const char* get_name() const { | |||
| return "StopProfile"; | |||
| } | |||
| }; | |||
| struct PushScope { | |||
| std::string scope_name; | |||
| template <typename TFunctor> | |||
| void get_props(TFunctor&& functor) const { | |||
| functor("scope_name", scope_name); | |||
| } | |||
| const char* get_name() const { | |||
| return "PushScope"; | |||
| } | |||
| }; | |||
| struct PopScope { | |||
| std::string scope_name; | |||
| template <typename TFunctor> | |||
| void get_props(TFunctor&& functor) const { | |||
| functor("scope_name", scope_name); | |||
| } | |||
| const char* get_name() const { | |||
| return "PopScope"; | |||
| } | |||
| }; | |||
| using Command = std::variant<Put, | |||
| ApplyOp, | |||
| Del, | |||
| GetValue, | |||
| SwapIn, | |||
| SwapOut, | |||
| Drop, | |||
| SetOption, | |||
| StartProfile, | |||
| StopProfile, | |||
| PushScope, | |||
| PopScope>; | |||
| using IdentifiedCommand = std::pair<uint64_t, Command>; | |||
| } | |||
| template <> | |||
| struct ToStringTrait<interpreter::intl::Command>{ | |||
| std::string operator()(const interpreter::intl::Command& cmd) const { | |||
| return std::visit([](auto& cmd){ | |||
| std::string result = cmd.get_name(); | |||
| result += "{"; | |||
| cmd.get_props([&](const char* key, auto&& value) { | |||
| result += key; | |||
| result += ": "; | |||
| result += to_string(value); | |||
| result += ","; | |||
| }); | |||
| result += "}"; | |||
| return result; | |||
| }, cmd); | |||
| } | |||
| }; | |||
| } | |||
| @@ -0,0 +1,92 @@ | |||
| /** | |||
| * \file imperative/src/impl/interpreter/events.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "./commands.h" | |||
| #include "./tensor_info.h" | |||
| namespace mgb::imperative::interpreter::intl { | |||
| struct CommandEvent { | |||
| IdentifiedCommand icmd; | |||
| }; | |||
| struct CommandEnqueueEvent: CommandEvent {}; | |||
| struct CommandExecuteEvent: CommandEvent {}; | |||
| struct CommandFinishEvent: CommandEvent {}; | |||
| struct OpEvent { | |||
| uint64_t id; | |||
| std::shared_ptr<OpDef> op; | |||
| SmallVector<uint64_t> inputs; | |||
| SmallVector<uint64_t> outputs; | |||
| }; | |||
| struct HostOpExecuteEvent: OpEvent {}; | |||
| struct DeviceOpExecuteEvent: OpEvent {}; | |||
| struct HostOpFinishEvent: OpEvent {}; | |||
| struct DeviceOpFinishEvent: OpEvent {}; | |||
| struct TensorDeclareEvent { | |||
| uint64_t tensor_id; | |||
| }; | |||
| struct TensorProduceEvent { | |||
| uint64_t tensor_id; | |||
| TensorLayout layout; | |||
| CompNode device; | |||
| }; | |||
| struct TensorEraseEvent { | |||
| uint64_t tensor_id; | |||
| }; | |||
| struct TensorPropEvent { | |||
| uint64_t tensor_id; | |||
| TensorInfo::Prop prop; | |||
| std::string prop_desc; | |||
| }; | |||
| struct TensorGetPropEvent: TensorPropEvent{}; | |||
| struct TensorWaitPropEvent: TensorPropEvent{}; | |||
| struct TensorNotifyPropEvent: TensorPropEvent{}; | |||
| struct TensorWaitPropFinishEvent: TensorPropEvent{}; | |||
| struct SyncStartEvent {}; | |||
| struct SyncFinishEvent {}; | |||
| struct ScopeEvent { | |||
| std::string name; | |||
| }; | |||
| struct ChannelBeginScope: ScopeEvent {}; | |||
| struct ChannelEndScope: ScopeEvent {}; | |||
| struct WorkerBeginScope: ScopeEvent {}; | |||
| struct WorkerEndScope: ScopeEvent {}; | |||
| struct DeviceBeginScope: ScopeEvent {}; | |||
| struct DeviceEndScope: ScopeEvent {}; | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * \file imperative/src/impl/interpreter_impl.cpp | |||
| * \file imperative/src/impl/interpreter/interpreter_impl.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| @@ -10,10 +10,14 @@ | |||
| */ | |||
| #include "./interpreter_impl.h" | |||
| #include "megbrain/common.h" | |||
| #include "megbrain/imperative/opr_utility.h" | |||
| #include "megbrain/imperative/ops/backward_graph.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| #include "megbrain/imperative/utils/to_string.h" | |||
| #include "../op_trait.h" | |||
| using namespace mgb; | |||
| using namespace imperative; | |||
| @@ -48,6 +52,7 @@ Handle ChannelImpl::put(const DeviceTensorND& data) { | |||
| info->desc.layout = data.layout(); | |||
| info->desc.comp_node = data.comp_node(); | |||
| info->ptr = Tensor::make(data); | |||
| m_channel_state.profiler->record_host<TensorProduceEvent>(info->id, info->desc.layout, info->desc.comp_node); | |||
| return info; | |||
| } | |||
| @@ -61,7 +66,7 @@ void ChannelImpl::del(Handle handle) { | |||
| } | |||
| void ChannelImpl::swap_in(Handle handle) { | |||
| if (m_enable_evict & SWAP) { | |||
| if (m_worker_state.options.enable_swap) { | |||
| mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | |||
| "invalid handle: %p", handle); | |||
| auto* info = reinterpret_cast<TensorInfo*>(handle); | |||
| @@ -71,7 +76,7 @@ void ChannelImpl::swap_in(Handle handle) { | |||
| } | |||
| void ChannelImpl::swap_out(Handle handle) { | |||
| if (m_enable_evict & SWAP) { | |||
| if (m_worker_state.options.enable_swap) { | |||
| mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | |||
| "invalid handle: %p", handle); | |||
| auto* info = reinterpret_cast<TensorInfo*>(handle); | |||
| @@ -81,7 +86,7 @@ void ChannelImpl::swap_out(Handle handle) { | |||
| } | |||
| void ChannelImpl::drop(Handle handle) { | |||
| if (m_enable_evict & DROP) { | |||
| if (m_worker_state.options.enable_drop) { | |||
| mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | |||
| "invalid handle: %p", handle); | |||
| auto* info = reinterpret_cast<TensorInfo*>(handle); | |||
| @@ -100,6 +105,7 @@ void ChannelImpl::dispatch_default_cpu( | |||
| const SmallVector<LogicalTensorDesc>& input_descs, | |||
| SmallVector<Handle>* outputs) { | |||
| auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs); | |||
| MGB_MARK_USED_VAR(validated); | |||
| SmallVector<DeviceTensorND> input_tensornds; | |||
| input_tensornds.reserve(input_descs.size()); | |||
| @@ -133,6 +139,17 @@ void ChannelImpl::dispatch_default_cpu( | |||
| output_tensornds.emplace_back(HostTensorND(output_cn, desc.layout).proxy_to_default_cpu()); | |||
| } | |||
| auto tinfo_to_tid = [&](SmallVector<TensorInfo*> tinfo) { | |||
| SmallVector<uint64_t> tid; | |||
| for (auto* ptinfo: tinfo) { | |||
| tid.push_back(ptinfo->id); | |||
| } | |||
| return tid; | |||
| }; | |||
| OpEvent event_data = {++m_last_id, op, tinfo_to_tid(input_infos), {}}; | |||
| m_channel_state.profiler->record_host<HostOpExecuteEvent>(event_data); | |||
| OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds); | |||
| SmallVector<TensorInfo*> output_infos; | |||
| @@ -146,9 +163,14 @@ void ChannelImpl::dispatch_default_cpu( | |||
| output_infos.push_back(info); | |||
| outputs->push_back(info); | |||
| } | |||
| if (m_enable_evict & DROP) { | |||
| if (m_channel_state.options.enable_drop) { | |||
| TensorInfo::ComputePath::make(op, input_infos, output_infos); | |||
| } | |||
| event_data.outputs = tinfo_to_tid(output_infos); | |||
| m_channel_state.profiler->record_host<HostOpFinishEvent>(event_data); | |||
| } | |||
| void ChannelImpl::dispatch_kernel( | |||
| @@ -173,13 +195,13 @@ void ChannelImpl::dispatch_kernel( | |||
| cmd.outputs.push_back(info); | |||
| outputs->push_back(info); | |||
| } | |||
| if (m_enable_evict & DROP) { | |||
| if (m_channel_state.options.enable_drop) { | |||
| TensorInfo::ComputePath::make(cmd.op, cmd.inputs, cmd.outputs); | |||
| } | |||
| m_buffer.enqueue(std::move(cmd)); | |||
| if (!validated && m_async_level == 1) { | |||
| if (!validated && m_channel_state.options.async_level == 1) { | |||
| sync(); | |||
| } else if (m_async_level == 0) { | |||
| } else if (m_channel_state.options.async_level == 0) { | |||
| sync(); | |||
| // check device error | |||
| for (auto&& oup : *outputs) { | |||
| @@ -212,7 +234,10 @@ SmallVector<Handle> ChannelImpl::apply_op( | |||
| } | |||
| SmallVector<Handle> outputs; | |||
| switch (OpDef::decide_dispatch_mode(*op, input_descs)) { | |||
| DispatchMode dispatch_mode = m_channel_state.options.enable_host_compute | |||
| ? OpDef::decide_dispatch_mode(*op, input_descs) | |||
| : DispatchMode::KERNEL; | |||
| switch (dispatch_mode) { | |||
| case DEFAULT_CPU: { | |||
| dispatch_default_cpu(op, input_infos, input_descs, &outputs); | |||
| break; | |||
| @@ -242,11 +267,13 @@ HostTensorND ChannelImpl::get_value(Handle handle) { | |||
| m_waitee = info; | |||
| regenerate(info); | |||
| m_buffer.enqueue(GetValue{info}); | |||
| m_channel_state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::HostValue); | |||
| m_cv.wait(lock, [&]() { | |||
| check_worker_exc_unsafe(); | |||
| tensor_ptr = info->ptr; | |||
| return value_fetched(); | |||
| }); | |||
| m_channel_state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::HostValue); | |||
| m_waitee = nullptr; | |||
| } | |||
| return tensor_ptr->get_value(); | |||
| @@ -262,11 +289,13 @@ TensorShape ChannelImpl::get_shape(Handle handle) { | |||
| std::unique_lock<decltype(m_mutex)> lock(m_mutex); | |||
| mgb_assert(!m_waitee); | |||
| m_waitee = info; | |||
| m_buffer.enqueue(Flush{info}); | |||
| m_buffer.flush(); | |||
| m_channel_state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::Shape); | |||
| m_cv.wait(lock, [&]() { | |||
| check_worker_exc_unsafe(); | |||
| return static_cast<bool>(info->ptr); | |||
| }); | |||
| m_channel_state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::Shape); | |||
| m_waitee = nullptr; | |||
| TensorShape ret = info->ptr->layout(); | |||
| mgb_assert(ret.ndim != 0); | |||
| @@ -277,6 +306,7 @@ DType ChannelImpl::get_dtype(Handle handle) { | |||
| mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | |||
| "invalid handle: %p", handle); | |||
| auto info = reinterpret_cast<TensorInfo*>(handle); | |||
| m_channel_state.profiler->record_host<TensorGetPropEvent>(info->id, TensorInfo::DType); | |||
| auto ret = info->desc.layout.dtype; | |||
| mgb_assert(ret.valid()); | |||
| return ret; | |||
| @@ -286,6 +316,7 @@ CompNode ChannelImpl::get_device(Handle handle) { | |||
| mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | |||
| "invalid handle: %p", handle); | |||
| auto info = reinterpret_cast<TensorInfo*>(handle); | |||
| m_channel_state.profiler->record_host<TensorGetPropEvent>(info->id, TensorInfo::Device); | |||
| auto ret = info->desc.comp_node; | |||
| mgb_assert(ret.valid()); | |||
| return ret; | |||
| @@ -299,20 +330,23 @@ DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) { | |||
| mgb_assert(!m_waitee); | |||
| m_waitee = info; | |||
| regenerate(info); | |||
| m_buffer.enqueue(Flush{info}); | |||
| m_buffer.flush(); | |||
| m_channel_state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::DevValue); | |||
| m_cv.wait(lock, [&]() { | |||
| check_worker_exc_unsafe(); | |||
| return static_cast<bool>(info->ptr); | |||
| }); | |||
| m_channel_state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::DevValue); | |||
| m_waitee = nullptr; | |||
| return info->ptr->dev_tensor(); | |||
| } | |||
| void ChannelImpl::sync() { | |||
| if (!m_buffer.empty()) { | |||
| m_buffer.enqueue(Flush{}); | |||
| } | |||
| m_buffer.flush(); | |||
| m_channel_state.profiler->record_host<SyncStartEvent>(); | |||
| m_worker.wait_all_task_finish(); | |||
| CompNode::sync_all(); | |||
| m_channel_state.profiler->record_host<SyncFinishEvent>(); | |||
| MGB_LOCK_GUARD(m_mutex); | |||
| check_worker_exc_unsafe(); | |||
| } | |||
| @@ -321,33 +355,41 @@ void ChannelImpl::close() { | |||
| sync(); | |||
| } | |||
| void ChannelImpl::config_async_level(int level) { | |||
| mgb_assert(level <= 2 && level >= 0, "async_level should be 0, 1 or 2"); | |||
| m_async_level = level; | |||
| int ChannelImpl::get_option(std::string name) { | |||
| return m_channel_state.options.get_option(name); | |||
| } | |||
| int ChannelImpl::get_async_level() { | |||
| return m_async_level; | |||
| void ChannelImpl::set_option(std::string name, int value) { | |||
| m_channel_state.options.set_option(name, value); | |||
| m_buffer.enqueue(SetOption{name, value}); | |||
| } | |||
| TensorInfo* ChannelImpl::alloc() { | |||
| MGB_LOCK_GUARD(m_mutex); | |||
| auto info = m_pool.alloc(); | |||
| m_valid_handle.insert(info); | |||
| info->id = m_last_id++; | |||
| m_channel_state.profiler->record_host<TensorDeclareEvent>(info->id); | |||
| return info; | |||
| } | |||
| void ChannelImpl::free(TensorInfo* ptr) { | |||
| MGB_LOCK_GUARD(m_mutex); | |||
| m_channel_state.profiler->record_host<TensorEraseEvent>(ptr->id); | |||
| m_pool.free(ptr); | |||
| } | |||
| ChannelImpl::ChannelImpl() : m_worker(this), m_buffer(this){ | |||
| m_channel_state.tid = std::this_thread::get_id(); | |||
| } | |||
| ChannelImpl::~ChannelImpl() { | |||
| close(); | |||
| } | |||
| void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { | |||
| MGB_LOCK_GUARD(m_mutex); | |||
| m_worker_state.profiler->record_host<TensorProduceEvent>(dest->id, ptr->layout(), ptr->comp_node()); | |||
| dest->value_fetched = ptr->value_fetched(); | |||
| // update tensor desc for static infer | |||
| dest->desc.layout = ptr->layout(); | |||
| @@ -397,55 +439,57 @@ void ChannelImpl::detach_users(TensorInfo* dest) { | |||
| output->detach_producer(); | |||
| } | |||
| } | |||
| dest->users.clear(); | |||
| mgb_assert(dest->users.size() == 0); | |||
| //dest->users.clear(); | |||
| } | |||
| void ChannelImpl::set_swap_flag(bool flag) { | |||
| if ((!flag) && (m_enable_evict & SWAP)) { | |||
| for (auto handle: m_valid_handle) { | |||
| auto* info = reinterpret_cast<TensorInfo*>(handle); | |||
| if (info->evict_type == SWAP) { | |||
| swap_in(info); | |||
| } | |||
| void ChannelImpl::sync_device_scope(CompNode device) { | |||
| auto& prev = m_worker_state.device_scope_map[device]; | |||
| auto& current = m_worker_state.scopes; | |||
| auto push_scope = [&](std::string name) { | |||
| m_worker_state.profiler->record_device<DeviceBeginScope>(device, name); | |||
| }; | |||
| auto pop_scope = [&](std::string name) { | |||
| m_worker_state.profiler->record_device<DeviceEndScope>(device, name); | |||
| }; | |||
| size_t similarity = 0; | |||
| for (size_t i = 0; i < prev.size() && i < current.size(); i++) { | |||
| if (prev[i] == current[i]) { | |||
| similarity++; | |||
| } else { | |||
| break; | |||
| } | |||
| } | |||
| if (flag) { | |||
| m_enable_evict |= SWAP; | |||
| } else { | |||
| m_enable_evict &= ~SWAP; | |||
| while (prev.size() > similarity) { | |||
| pop_scope(prev.back()); | |||
| prev.pop_back(); | |||
| } | |||
| } | |||
| void ChannelImpl::set_drop_flag(bool flag) { | |||
| if ((!flag) && (m_enable_evict & DROP)) { | |||
| for (auto handle: m_valid_handle) { | |||
| auto* info = reinterpret_cast<TensorInfo*>(handle); | |||
| if (info->evict_type == DROP) { | |||
| recompute(info->producer); | |||
| } | |||
| } | |||
| } | |||
| if (flag) { | |||
| m_enable_evict |= DROP; | |||
| } else { | |||
| m_enable_evict &= ~DROP; | |||
| while (prev.size() < current.size()) { | |||
| prev.push_back(current[prev.size()]); | |||
| push_scope(prev.back()); | |||
| } | |||
| } | |||
| void ChannelImpl::set_buffer_length(int length) { | |||
| m_buffer.set_capacity(length); | |||
| } | |||
| void ChannelImpl::process_one_task(Command& cmd) { | |||
| void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { | |||
| m_worker_state.profiler->record_host<CommandExecuteEvent>(icmd); | |||
| bool finished = false; | |||
| auto do_finish_command = [&]{ | |||
| if (finished) { | |||
| return; | |||
| } | |||
| m_worker_state.profiler->record_host<CommandFinishEvent>(icmd); | |||
| finished = true; | |||
| }; | |||
| //TODO: remove std::visit for support osx 10.12 | |||
| std::visit([this](auto& cmd) { | |||
| using T = std::remove_reference_t<decltype(cmd)>; | |||
| try { | |||
| auto cmd_visitor = [&](auto& cmd) { | |||
| using T = std::remove_reference_t<decltype(cmd)>; | |||
| if constexpr (std::is_same_v<T, Put>) { | |||
| auto value = cmd.no_cache ? std::make_shared<Tensor>(cmd.value) : Tensor::make(cmd.value); | |||
| produce_tensor(cmd.dest, std::move(value)); | |||
| } else if constexpr (std::is_same_v<T, ApplyOp>) { | |||
| uint64_t apply_id = ++m_last_id; | |||
| SmallVector<TensorPtr> tensor_inputs; | |||
| SmallVector<CompNode> devices; | |||
| tensor_inputs.reserve(cmd.inputs.size()); | |||
| // refcnt == 1, owners: [TensorInfo::ptr] | |||
| for (auto i : cmd.inputs) { | |||
| @@ -453,6 +497,23 @@ void ChannelImpl::process_one_task(Command& cmd) { | |||
| // refcnt ++, owners: [i->ptr, tensor_inputs] | |||
| tensor_inputs.push_back(i->ptr); | |||
| } | |||
| // Begin profiling operator | |||
| auto tinfo_to_tid = [&](SmallVector<TensorInfo*> tinfo) { | |||
| SmallVector<uint64_t> tid; | |||
| for (auto* ptinfo: tinfo) { | |||
| tid.push_back(ptinfo->id); | |||
| } | |||
| return tid; | |||
| }; | |||
| OpEvent event_data = {apply_id, cmd.op, tinfo_to_tid(cmd.inputs), tinfo_to_tid(cmd.outputs)}; | |||
| // Collecting devices | |||
| for (auto i : cmd.inputs) { | |||
| devices.push_back(i->desc.comp_node); | |||
| } | |||
| for (auto i : cmd.outputs) { | |||
| devices.push_back(i->desc.comp_node); | |||
| } | |||
| devices.erase(std::unique(devices.begin(), devices.end()), devices.end()); | |||
| // Fused by command buffer. @see: CommandBuffer::fuse_del | |||
| // Now if dest is inplacable, it's refcnt would be decreased to 1 and owned by tensor_inputs after Del. | |||
| // Note for exprs like 'y = x op x', inplace is unsupported yet but Del would be also fused. | |||
| @@ -461,9 +522,24 @@ void ChannelImpl::process_one_task(Command& cmd) { | |||
| // if it's decreased to 1, would be detected at @see: proxy_graph_detail::apply_on_physical_tensor | |||
| free(del); | |||
| } | |||
| // Before wait | |||
| //TODO: split operator wait and execute so that OpWait could be corrected recorded. | |||
| // Before execute | |||
| m_worker_state.profiler->record_host<HostOpExecuteEvent>(event_data); | |||
| for (auto&& device: devices) { | |||
| sync_device_scope(device); | |||
| m_worker_state.profiler->record_device<DeviceOpExecuteEvent>(device, event_data); | |||
| } | |||
| // Apply op | |||
| // Here std::move is REQUIRED for removing duplicated references. | |||
| auto tensor_outputs = OpDef::apply_on_physical_tensor( | |||
| *cmd.op, std::move(tensor_inputs)); | |||
| // After execute | |||
| m_worker_state.profiler->record_host<HostOpFinishEvent>(event_data); | |||
| for (auto&& device: devices) { | |||
| m_worker_state.profiler->record_device<DeviceOpFinishEvent>(device, event_data); | |||
| } | |||
| // End profiling operator | |||
| mgb_assert(tensor_outputs.size() == cmd.outputs.size()); | |||
| for (size_t i = 0; i < tensor_outputs.size(); ++i) { | |||
| if (cmd.outputs[i] == nullptr) { | |||
| @@ -488,13 +564,51 @@ void ChannelImpl::process_one_task(Command& cmd) { | |||
| release_tensor(cmd.dest); | |||
| } else if constexpr (std::is_same_v<T, Drop>) { | |||
| release_tensor(cmd.dest); | |||
| } else if constexpr (std::is_same_v<T, Move>) { | |||
| produce_tensor(cmd.dest, cmd.src->ptr); | |||
| free(cmd.src); | |||
| } else if constexpr (std::is_same_v<T, SetOption>) { | |||
| m_worker_state.options.set_option(cmd.key, cmd.value); | |||
| } else if constexpr (std::is_same_v<T, StartProfile>) { | |||
| CompNode::sync_all(); | |||
| m_worker_state.profiler.reset(cmd.profiler); | |||
| } else if constexpr (std::is_same_v<T, StopProfile>) { | |||
| for (auto&& [device, scopes]: m_worker_state.device_scope_map) { | |||
| MGB_MARK_USED_VAR(scopes); | |||
| sync_device_scope(device); | |||
| } | |||
| do_finish_command(); | |||
| auto profiler = std::make_unique<InterpreterProfiler>(); | |||
| std::swap(profiler, m_worker_state.profiler); | |||
| auto records = profiler->stop(); | |||
| auto host_map = [this](std::thread::id tid) { | |||
| if (tid == m_channel_state.tid) { | |||
| return "channel"; | |||
| } else if (tid == m_worker_state.tid) { | |||
| return "worker"; | |||
| } else { | |||
| return "unknown"; | |||
| } | |||
| }; | |||
| InterpreterProfiler::dump_data(cmd.basename, cmd.format, records, profiler->get_option(), host_map); | |||
| } else if constexpr (std::is_same_v<T, PushScope>) { | |||
| m_worker_state.scopes.push_back(cmd.scope_name); | |||
| do_finish_command(); | |||
| m_worker_state.profiler->record_host<WorkerBeginScope>(cmd.scope_name); | |||
| } else if constexpr (std::is_same_v<T, PopScope>) { | |||
| mgb_assert(m_worker_state.scopes.back() == cmd.scope_name, "scope name mismatch"); | |||
| m_worker_state.scopes.pop_back(); | |||
| do_finish_command(); | |||
| m_worker_state.profiler->record_host<WorkerEndScope>(cmd.scope_name); | |||
| } else { | |||
| static_assert(std::is_same_v<T, Flush> || | |||
| std::is_same_v<T, Nop>); | |||
| static_assert(std::is_same_v<T, T>); | |||
| } | |||
| }; | |||
| std::visit([&](auto& cmd){ | |||
| using T = std::decay_t<decltype(cmd)>; | |||
| if (!m_worker_state.options.catch_worker_execption) { | |||
| cmd_visitor(cmd); | |||
| return; | |||
| } | |||
| try { | |||
| cmd_visitor(cmd); | |||
| } catch (...) { | |||
| MGB_LOCK_GUARD(m_mutex); | |||
| if constexpr (std::is_same_v<T, ApplyOp>) { | |||
| @@ -507,7 +621,8 @@ void ChannelImpl::process_one_task(Command& cmd) { | |||
| m_worker_exc = std::current_exception(); | |||
| m_cv.notify_all(); | |||
| } | |||
| }, cmd); | |||
| }, icmd.second); | |||
| do_finish_command(); | |||
| } | |||
| void ChannelImpl::check_worker_exc_unsafe() { | |||
| @@ -524,18 +639,22 @@ void ChannelImpl::CommandBuffer::enqueue(Command cmd) { | |||
| if (std::get_if<Del>(&cmd) && fuse_del(std::get<Del>(cmd))) { | |||
| return; | |||
| } | |||
| auto command_repr = std::visit([](auto& cmd){ return cmd.to_string(); }, cmd); | |||
| mgb_log_debug("%s Enqueued", command_repr.c_str()); | |||
| mgb_log_debug("%s Enqueued", to_string(cmd).c_str()); | |||
| m_commands.push_back(std::move(cmd)); | |||
| auto flush_pos = flush_pos_for(m_commands.back()); | |||
| flush(flush_pos); | |||
| } | |||
| void ChannelImpl::CommandBuffer::flush() { | |||
| flush(m_commands.end()); | |||
| } | |||
| void ChannelImpl::CommandBuffer::flush(Handle pos) { | |||
| for (auto iter = m_commands.begin(); iter != pos; ++iter) { | |||
| auto command_repr = std::visit([](auto& cmd){ return cmd.to_string(); }, *iter); | |||
| mgb_log_debug("%s Flushed", command_repr.c_str()); | |||
| m_owner->m_worker.add_task(std::move(*iter)); | |||
| mgb_log_debug("%s Flushed", to_string(*iter).c_str()); | |||
| IdentifiedCommand icmd{++m_owner->m_last_id, std::move(*iter)}; | |||
| m_owner->m_channel_state.profiler->record_host<CommandEnqueueEvent>(icmd); | |||
| m_owner->m_worker.add_task(std::move(icmd)); | |||
| } | |||
| m_commands.erase(m_commands.begin(), pos); | |||
| } | |||
| @@ -555,17 +674,10 @@ auto ChannelImpl::CommandBuffer::flush_pos_for(const Command& cmd) -> Handle { | |||
| } | |||
| } else if constexpr (std::is_same_v<T, GetValue>) { | |||
| return m_commands.end(); | |||
| } else if constexpr (std::is_same_v<T, Flush>) { | |||
| if (cmd.dest == nullptr) { | |||
| return m_commands.end(); | |||
| } | |||
| auto produce_iter = find_produce(cmd.dest, {m_commands.begin(), m_commands.end()}); | |||
| if (produce_iter != m_commands.end()) { | |||
| return produce_iter + 1; | |||
| } | |||
| } | |||
| if (m_commands.size() > m_capacity) { | |||
| return m_commands.begin() + (m_commands.size() - m_capacity); | |||
| size_t buffer_length = m_owner->m_channel_state.options.buffer_length; | |||
| if (m_commands.size() > buffer_length) { | |||
| return m_commands.begin() + (m_commands.size() - buffer_length); | |||
| } | |||
| return m_commands.begin(); | |||
| }, cmd); | |||
| @@ -589,7 +701,7 @@ bool ChannelImpl::CommandBuffer::fuse_del(const Del& cmd) { | |||
| if (apply_iter == end || find_last_usage(dest, {apply_iter+1, end}) != end) { | |||
| return false; | |||
| } | |||
| mgb_log_debug("%s Fused", cmd.to_string().c_str()); | |||
| mgb_log_debug("%s Fused", to_string(Command{cmd}).c_str()); | |||
| std::get<ApplyOp>(*apply_iter).dels.push_back(dest); | |||
| return true; | |||
| } | |||
| @@ -636,3 +748,41 @@ auto ChannelImpl::CommandBuffer::find_produce(TensorInfo* dest, Range range) | |||
| }, cmd); | |||
| }); | |||
| } | |||
| void ChannelImpl::start_profile(std::unordered_map<std::string, int> option) { | |||
| auto profiler_option = InterpreterProfiler::Option::from_dict(option); | |||
| auto profiler = std::make_unique<InterpreterProfiler>(); | |||
| profiler->set_option(profiler_option); | |||
| profiler->start(InterpreterProfiler::topic_to_mask(profiler_option.topic)); | |||
| std::swap(profiler, m_channel_state.profiler); | |||
| m_buffer.enqueue(StartProfile{m_channel_state.profiler.get()}); | |||
| } | |||
| void ChannelImpl::stop_profile(std::string basename, std::string format) { | |||
| m_buffer.flush(); | |||
| auto profiler = std::make_unique<InterpreterProfiler>(); | |||
| std::swap(profiler, m_channel_state.profiler); | |||
| profiler.release(); | |||
| m_buffer.enqueue(StopProfile{basename, format}); | |||
| } | |||
| void ChannelImpl::push_scope(std::string name) { | |||
| m_channel_state.profiler->record_host<ChannelBeginScope>(name); | |||
| m_channel_state.scopes.push_back(name); | |||
| m_buffer.enqueue(PushScope{name}); | |||
| } | |||
| void ChannelImpl::pop_scope(std::string name) { | |||
| mgb_assert((!m_channel_state.scopes.empty()) && m_channel_state.scopes.back() == name, "scope name mismatch"); | |||
| m_channel_state.scopes.pop_back(); | |||
| m_channel_state.profiler->record_host<ChannelEndScope>(name); | |||
| m_buffer.enqueue(PopScope{name}); | |||
| } | |||
| void ChannelImpl::assert_in_channel() { | |||
| mgb_assert(m_channel_state.tid != std::this_thread::get_id()); | |||
| } | |||
| void ChannelImpl::assert_in_worker() { | |||
| mgb_assert(m_worker_state.tid == std::this_thread::get_id()); | |||
| } | |||
| @@ -0,0 +1,205 @@ | |||
| /** | |||
| * \file imperative/src/impl/interpreter/interpreter_impl.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include <deque> | |||
| #include <future> | |||
| #include <list> | |||
| #include <thread> | |||
| #include <unordered_set> | |||
| #include <variant> | |||
| #include "megbrain/utils/mempool.h" | |||
| #include "megbrain/imperative/interpreter.h" | |||
| #include "megbrain/imperative/profiler.h" | |||
| #include "./commands.h" | |||
| #include "./events.h" | |||
| #include "./tensor_info.h" | |||
| #include "./option_manager.h" | |||
| #include "./profiler.h" | |||
| namespace mgb::imperative::interpreter::intl { | |||
| using Handle = Interpreter::Handle; | |||
| struct InterpreterImpl : Interpreter { | |||
| std::unique_ptr<Channel> create_channel() override; | |||
| }; | |||
| struct ChannelImpl : Interpreter::Channel { | |||
| ChannelImpl(); | |||
| ~ChannelImpl() override; | |||
| Handle put(const HostTensorND& value, bool no_cache) override; | |||
| Handle put(const DeviceTensorND& value) override; | |||
| void del(Handle) override; | |||
| void swap_in(Handle) override; | |||
| void swap_out(Handle) override; | |||
| void drop(Handle) override; | |||
| SmallVector<Handle> apply_op( | |||
| std::shared_ptr<OpDef> op, | |||
| const SmallVector<Handle>& inputs) override; | |||
| HostTensorND get_value(Handle) override; | |||
| TensorShape get_shape(Handle) override; | |||
| DType get_dtype(Handle) override; | |||
| CompNode get_device(Handle) override; | |||
| DeviceTensorND get_dev_tensor(Handle) override; | |||
| void sync() override; | |||
| void close() override; | |||
| int get_option(std::string name) override; | |||
| void set_option(std::string name, int value) override; | |||
| void start_profile(std::unordered_map<std::string, int> option) override; | |||
| void stop_profile(std::string basename, std::string format) override; | |||
| void push_scope(std::string) override; | |||
| void pop_scope(std::string) override; | |||
| private: | |||
| TensorInfo* alloc(); | |||
| void free(TensorInfo*); | |||
| void detach_users(TensorInfo*); | |||
| void process_one_task(IdentifiedCommand&); | |||
| void check_worker_exc_unsafe(); | |||
| void produce_tensor(TensorInfo* dest, TensorPtr ptr); | |||
| void release_tensor(TensorInfo* dest); | |||
| void regenerate(TensorInfo* dest); | |||
| void recompute(TensorInfo::ComputePath* path); | |||
| void dispatch_default_cpu( | |||
| std::shared_ptr<OpDef> op, | |||
| const SmallVector<TensorInfo*>& input_infos, | |||
| const SmallVector<LogicalTensorDesc>& input_descs, | |||
| SmallVector<Handle>* outputs); | |||
| void dispatch_kernel( | |||
| std::shared_ptr<OpDef> op, | |||
| const SmallVector<TensorInfo*>& input_infos, | |||
| const SmallVector<LogicalTensorDesc>& input_descs, | |||
| SmallVector<Handle>* outputs); | |||
| void assert_in_channel(); | |||
| void assert_in_worker(); | |||
| void sync_device_scope(CompNode device); | |||
| template <typename TCommand> | |||
| void enqueue_command(TCommand&& cmd) { | |||
| m_buffer.enqueue(Command{std::forward<TCommand>(cmd)}); | |||
| } | |||
| std::mutex m_mutex; | |||
| std::condition_variable m_cv; | |||
| MemPool<TensorInfo> m_pool; | |||
| std::unordered_set<Handle> m_valid_handle; | |||
| TensorInfo* m_waitee = nullptr; | |||
| std::exception_ptr m_worker_exc; | |||
| std::atomic_uint64_t m_last_id = 0; | |||
| struct WorkQueue : AsyncQueueSC<IdentifiedCommand, WorkQueue> { | |||
| // set max_spin=0 to prevent Queue fetch task in busy wait manner. | |||
| // this won't affect throughput when python interpreter is sending enough task, | |||
| // but will significantly save CPU time when waiting for task, e.g. wait for data input | |||
| WorkQueue(ChannelImpl* owner) | |||
| : AsyncQueueSC<IdentifiedCommand, WorkQueue>(0), m_owner(owner) { | |||
| sys::set_thread_name("interpreter"); | |||
| } | |||
| void process_one_task(IdentifiedCommand& icmd) { | |||
| m_owner->process_one_task(icmd); | |||
| } | |||
| void on_async_queue_worker_thread_start() override { | |||
| sys::set_thread_name("worker"); | |||
| m_owner->m_worker_state.tid = std::this_thread::get_id(); | |||
| } | |||
| private: | |||
| ChannelImpl* m_owner; | |||
| } m_worker; | |||
| /** | |||
| * Buf a command window for following fuse | |||
| * example: | |||
| * --------------------------------------------------------------------- | |||
| * | ..., Apply{in: (i0, i1), out: (o0, o1)}, ... + Del{i0} + Del{i1} | | |||
| * --------------------------------------------------------------------- | |||
| * | ..., Apply{in: (i0, i1), out: (o0, o1), del: (i0)}, ... + Del{i1} | | |||
| * --------------------------------------------------------------------- | |||
| * | ..., Apply{in: (i0, i1), out: (o0, o1), del: (i0, i1)}, ... | | |||
| * --------------------------------------------------------------------- | |||
| * Then the fused Apply may be invoked inplace. see: ChannelImpl::process_one_task | |||
| */ | |||
| struct CommandBuffer { | |||
| CommandBuffer(ChannelImpl* owner) : m_owner(owner) {} | |||
| void enqueue(Command cmd); | |||
| bool empty() const { | |||
| return m_commands.empty(); | |||
| } | |||
| void flush(); | |||
| private: | |||
| ChannelImpl* m_owner; | |||
| std::deque<Command> m_commands; | |||
| using Handle = decltype(m_commands)::iterator; | |||
| // [begin, end) | |||
| using Range = std::array<Handle, 2>; | |||
| // Launch commands in range [m_commands.begin(), pos) | |||
| void flush(Handle pos); | |||
| // Select flush position for incoming cmd | |||
| Handle flush_pos_for(const Command& cmd); | |||
| // Fuse del command into suitable ApplyOp | |||
| bool fuse_del(const Del& cmd); | |||
| // Returns the last handle that dest is used within range. If dest is not used, returns range[1] | |||
| Handle find_last_usage(TensorInfo* dest, Range range); | |||
| // Returns the produce position of dest. If not found, returns range[1] | |||
| Handle find_produce(TensorInfo* dest, Range range); | |||
| } m_buffer; | |||
| //! config whether raise error exactly when invoking op. | |||
| //! level 2: both device and user side errors are async; | |||
| //! level 1: user side errors are sync; | |||
| //! level 0: both sync. | |||
| int m_async_level = 2; | |||
| int m_max_recompute_time = 1; | |||
| struct State { | |||
| std::thread::id tid; | |||
| OptionManager options; | |||
| std::vector<std::string> scopes; | |||
| std::unique_ptr<InterpreterProfiler> profiler; | |||
| State() { | |||
| profiler = std::make_unique<InterpreterProfiler>(); | |||
| } | |||
| }; | |||
| struct ChannelState: State {}; | |||
| struct WorkerState: State { | |||
| CompNode::UnorderedMap<std::vector<std::string>> device_scope_map; | |||
| }; | |||
| ChannelState m_channel_state; | |||
| WorkerState m_worker_state; | |||
| }; | |||
| } // namespace mgb::imperative::interpreter::intl | |||
| @@ -0,0 +1,61 @@ | |||
| /** | |||
| * \file imperative/src/impl/interpreter/option_manager.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include "megbrain/common.h" | |||
| namespace mgb::imperative::interpreter::intl { | |||
| struct OptionManager { | |||
| private: | |||
| std::unordered_map<std::string, int*> m_option_map = {}; | |||
| public: | |||
| #define DEF_OPTION(name, env_key, default_value, desc) \ | |||
| int name = (m_option_map[#name]=&name, get_option_from_env(env_key, default_value)); | |||
| DEF_OPTION(async_level, "MEGENGINE_INTERP_ASYNC_LEVEL", 2, | |||
| "config whether raise error exactly when invoking op.\n" | |||
| "level 2: both device and user side errors are async;\n" | |||
| "level 1: user side errors are sync;\n" | |||
| "level 0: both sync."); | |||
| DEF_OPTION(enable_swap, "MEGENGINE_ENABLE_SWAP", 0, ""); | |||
| DEF_OPTION(enable_drop, "MEGENGINE_ENABLE_DROP", 0, ""); | |||
| DEF_OPTION(max_recompute_time, "MEGENGINE_MAX_RECOMP_TIME", 1, ""); | |||
| DEF_OPTION(catch_worker_execption, "MEGENGINE_CATCH_WORKER_EXEC", 1, | |||
| "catch worker exception if enabled, close it when debugging"); | |||
| DEF_OPTION(buffer_length, "MEGENGINE_COMMAND_BUFFER_LENGTH", 3, | |||
| "set command buffer length."); | |||
| DEF_OPTION(enable_host_compute, "MEGENGINE_HOST_COMPUTE", 1, | |||
| "enable host compute, thus computation may be done in host event if it's device is gpu."); | |||
| #undef DEF_OPTION | |||
| void set_option(const std::string& name, int value) { | |||
| *m_option_map[name] = value; | |||
| } | |||
| int get_option(const std::string& name) const { | |||
| return *m_option_map.at(name); | |||
| } | |||
| static int get_option_from_env(const std::string& name, int default_value) { | |||
| if (const char* env_val = MGB_GETENV(name.c_str())) { | |||
| default_value = std::atoi(env_val); | |||
| } | |||
| return default_value; | |||
| } | |||
| }; | |||
| } | |||
| @@ -0,0 +1,280 @@ | |||
| /** | |||
| * \file imperative/src/impl/interpreter/profiler.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "./profiler.h" | |||
| #include <sstream> | |||
| #include <cinttypes> | |||
| #if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) | |||
| #include <unistd.h> | |||
| #elif defined(_WIN32) | |||
| #include <process.h> | |||
| #else | |||
| #error Unsupported platform | |||
| #endif | |||
| #include "../op_trait.h" | |||
| namespace mgb::imperative::interpreter::intl { | |||
| namespace { | |||
| struct InterpreterProfilerDumpChromeTimelineContext { | |||
| // either host_thread(std::thread::id) or device_thread(CompNode) | |||
| using Thread = std::variant<std::thread::id, CompNode>; | |||
| // input params | |||
| std::string base_name; | |||
| std::string format; | |||
| InterpreterProfiler::Data profile_data; | |||
| InterpreterProfiler::Option option; | |||
| std::function<std::string(std::thread::id)> host_map; | |||
| // internal states | |||
| decltype(getpid()) pid; | |||
| CompNode::UnorderedMap<std::map<double, CompNode::Event*>> device_sync_map; | |||
| SmallVector<Thread> thread_list; | |||
| double time_start; | |||
| // options | |||
| bool show_operator_name; | |||
| // results | |||
| ChromeTraceEventList event_list; | |||
| InterpreterProfilerDumpChromeTimelineContext( | |||
| std::string base_name, | |||
| std::string format, | |||
| InterpreterProfiler::Data profile_data, | |||
| InterpreterProfiler::Option option, | |||
| std::function<std::string(std::thread::id)> host_map) | |||
| : base_name{base_name}, format{format}, profile_data{profile_data}, option{option}, host_map{host_map} { | |||
| pid = getpid(); | |||
| time_start = option.align_time ? time_start : 0; | |||
| show_operator_name = option.show_operator_name; | |||
| } | |||
| // get device time from event | |||
| double get_device_time(CompNode::Event* device_event, double host_time) { | |||
| device_event->host_wait(); | |||
| auto& sync_map = device_sync_map[device_event->comp_node()]; | |||
| // find sync point | |||
| auto iter = sync_map.begin(); | |||
| auto sync_current = [&] { | |||
| iter = sync_map.insert(iter, {host_time, device_event}); | |||
| return host_time; | |||
| }; | |||
| if (iter == sync_map.end()) { | |||
| // not found, insert sync | |||
| return sync_current(); | |||
| } | |||
| auto& [base_time, base] = *iter; | |||
| // calculate elapsed time | |||
| double delta_time = base->elapsed_time_until(*device_event) * 1e3; | |||
| return base_time + delta_time; | |||
| }; | |||
| template <typename T> | |||
| size_t get_tid(T t) { | |||
| for (size_t i = 0; i < thread_list.size(); i++) { | |||
| if (thread_list[i] == Thread{t}) { | |||
| return i; | |||
| } | |||
| } | |||
| thread_list.push_back(t); | |||
| return thread_list.size() - 1; | |||
| }; | |||
| ChromeTraceEvent& new_event(std::string name, char ph, uint64_t tid, double ts) { | |||
| return event_list.new_event().name(name).ph(ph).tid(tid).ts(ts).pid(pid); | |||
| }; | |||
| // convert Command to json object. Has to be an callable object | |||
| static auto constexpr cmd_to_args = [](auto&& cmd) { | |||
| auto args = json::Object::make(); | |||
| cmd.get_props([&](const char* key, auto&& value){ | |||
| (*args)[key] = json::String::make(to_string(value)); | |||
| }); | |||
| (*args)["__type__"] = json::String::make(typeid(cmd).name()); | |||
| return args; | |||
| }; | |||
| void process() { | |||
| // enumerate and process each record | |||
| for (auto&& record: profile_data.records) { | |||
| std::visit([this](auto& record){ | |||
| using TEvent = std::decay_t<decltype(record.data)>; | |||
| Session<TEvent>(*this, record).process(); | |||
| }, record); | |||
| } | |||
| for (size_t tid = 0; tid < thread_list.size(); ++tid) { | |||
| auto tname = std::visit([&](auto& host_or_device) -> std::string{ | |||
| using T = std::decay_t<decltype(host_or_device)>; | |||
| if constexpr (std::is_same_v<T, std::thread::id>) { | |||
| // take name from host_map | |||
| return host_map(host_or_device); | |||
| } else { | |||
| // use CompNode::to_string | |||
| return host_or_device.to_string(); | |||
| } | |||
| }, thread_list[tid]); | |||
| // assign thread name | |||
| new_event("thread_name", 'M', tid, 0) | |||
| .arg("name", tname); | |||
| } | |||
| // wraite output to file | |||
| std::string out_buf; | |||
| event_list.to_json()->writeto(out_buf, 4); | |||
| std::ofstream output_stream; | |||
| output_stream.open(base_name + "." + format); | |||
| output_stream << out_buf; | |||
| output_stream.flush(); | |||
| output_stream.close(); | |||
| } | |||
| template <typename TEvent> | |||
| struct Session { | |||
| InterpreterProfilerDumpChromeTimelineContext& ctx; | |||
| ProfilerBase::EventRecord<TEvent>& record; | |||
| TEvent& data; | |||
| Session(InterpreterProfilerDumpChromeTimelineContext& ctx, | |||
| ProfilerBase::EventRecord<TEvent>& record) | |||
| : ctx{ctx}, record{record}, data{record.data} {} | |||
| uint64_t get_host_tid() { | |||
| return ctx.get_tid(record.host().tid); | |||
| }; | |||
| double get_host_ts() { | |||
| return (ctx.time_start + record.host().time) * 1e3; | |||
| }; | |||
| uint64_t get_device_tid() { | |||
| return ctx.get_tid(record.device().event->comp_node()); | |||
| }; | |||
| double get_device_ts() { | |||
| return (ctx.time_start + ctx.get_device_time(record.device().event.get(), record.device().after)) * 1e3; | |||
| }; | |||
| ChromeTraceEvent& new_host_event(std::string name, char ph) { | |||
| return ctx.new_event(std::move(name), ph, get_host_tid(), get_host_ts()); | |||
| }; | |||
| ChromeTraceEvent& new_device_event(std::string name, char ph) { | |||
| return ctx.new_event(std::move(name), ph, get_device_tid(), get_device_ts()); | |||
| }; | |||
| void process() { | |||
| // dispatch event by type | |||
| if constexpr (std::is_same_v<TEvent, CommandEnqueueEvent>) { | |||
| auto args = std::visit(cmd_to_args, data.icmd.second); | |||
| new_host_event("CommandEnqueue", 'X').dur(0).args(args); | |||
| } else if constexpr (std::is_same_v<TEvent, CommandExecuteEvent>) { | |||
| auto args = std::visit(cmd_to_args, data.icmd.second); | |||
| new_host_event("CommandExecute", 'B').args(args); | |||
| } else if constexpr (std::is_same_v<TEvent, CommandFinishEvent>) { | |||
| new_host_event("CommandExecute", 'E'); | |||
| } else if constexpr (std::is_same_v<TEvent, HostOpExecuteEvent>) { | |||
| auto args = json::Object::make(); | |||
| auto props = OpDef::props(*data.op); | |||
| auto name = data.op->trait()->name; | |||
| for (auto&& [prop_name, prop_val]: props) { | |||
| (*args)[std::string("op.") + prop_name] = json::String::make(prop_val); | |||
| } | |||
| (*args)["name"] = json::String::make(name); | |||
| (*args)["id"] = json::Number::make(data.id); | |||
| (*args)["inputs"] = json::String::make(to_string(data.inputs)); | |||
| (*args)["outputs"] = json::String::make(to_string(data.outputs)); | |||
| new_host_event(ctx.show_operator_name ? name : "OpExecute", 'B').args(args); | |||
| } else if constexpr (std::is_same_v<TEvent, DeviceOpExecuteEvent>) { | |||
| auto args = json::Object::make(); | |||
| auto props = OpDef::props(*data.op); | |||
| auto name = data.op->trait()->name; | |||
| for (auto&& [prop_name, prop_val]: props) { | |||
| (*args)[std::string("op.") + prop_name] = json::String::make(prop_val); | |||
| } | |||
| (*args)["name"] = json::String::make(name); | |||
| (*args)["id"] = json::Number::make(data.id); | |||
| (*args)["inputs"] = json::String::make(to_string(data.inputs)); | |||
| (*args)["outputs"] = json::String::make(to_string(data.outputs)); | |||
| new_device_event(ctx.show_operator_name ? name : "OpExecute", 'B').args(args); | |||
| } else if constexpr (std::is_same_v<TEvent, HostOpFinishEvent>) { | |||
| auto name = data.op->trait()->name; | |||
| new_host_event(ctx.show_operator_name ? name : "OpExecute", 'E'); | |||
| } else if constexpr (std::is_same_v<TEvent, DeviceOpFinishEvent>) { | |||
| auto name = data.op->trait()->name; | |||
| new_device_event(ctx.show_operator_name ? name : "OpExecute", 'E'); | |||
| } else if constexpr (std::is_same_v<TEvent, TensorDeclareEvent>) { | |||
| json::Number::make(data.tensor_id); | |||
| new_host_event("TensorLifetime", 'N').id(data.tensor_id); | |||
| } else if constexpr (std::is_same_v<TEvent, TensorProduceEvent>) { | |||
| auto snapshot = json::Object::make(); | |||
| (*snapshot)["shape"] = json::String::make(to_string((TensorShape)data.layout)); | |||
| (*snapshot)["dtype"] = json::String::make(to_string(data.layout.dtype)); | |||
| (*snapshot)["device"] = json::String::make(to_string(data.device)); | |||
| json::Number::make(data.tensor_id); | |||
| new_host_event("TensorLifetime", 'O').id(data.tensor_id).arg("snapshot", snapshot); | |||
| } else if constexpr (std::is_same_v<TEvent, TensorEraseEvent>) { | |||
| json::Number::make(data.tensor_id); | |||
| new_host_event("TensorLifetime", 'D').id(data.tensor_id); | |||
| } else if constexpr (std::is_same_v<TEvent, TensorGetPropEvent>) { | |||
| auto args = json::Object::make(); | |||
| (*args)["id"] = json::Number::make(data.tensor_id); | |||
| (*args)["prop"] = json::String::make(to_string(data.prop)); | |||
| (*args)["prop_desc"] = json::String::make(data.prop_desc); | |||
| new_host_event("TensorGetProp", 'X').dur(0).args(args); | |||
| } else if constexpr (std::is_same_v<TEvent, TensorNotifyPropEvent>) { | |||
| // TODO | |||
| } else if constexpr (std::is_same_v<TEvent, TensorWaitPropEvent>) { | |||
| auto args = json::Object::make(); | |||
| (*args)["id"] = json::Number::make(data.tensor_id); | |||
| (*args)["prop"] = json::String::make(to_string(data.prop)); | |||
| (*args)["prop_desc"] = json::String::make(data.prop_desc); | |||
| new_host_event("TensorWaitProp", 'B').args(args); | |||
| } else if constexpr (std::is_same_v<TEvent, TensorWaitPropFinishEvent>) { | |||
| auto args = json::Object::make(); | |||
| (*args)["id"] = json::Number::make(data.tensor_id); | |||
| (*args)["prop"] = json::String::make(to_string(data.prop)); | |||
| (*args)["prop_desc"] = json::String::make(data.prop_desc); | |||
| new_host_event("TensorWaitProp", 'E').args(args); | |||
| } else if constexpr (std::is_same_v<TEvent, SyncStartEvent>) { | |||
| new_host_event("SyncEvent", 'B'); | |||
| } else if constexpr (std::is_same_v<TEvent, SyncFinishEvent>) { | |||
| new_host_event("SyncEvent", 'E'); | |||
| } else if constexpr (std::is_same_v<TEvent, ChannelBeginScope>) { | |||
| new_host_event(data.name, 'B'); | |||
| } else if constexpr (std::is_same_v<TEvent, ChannelEndScope>) { | |||
| new_host_event(data.name, 'E'); | |||
| } else if constexpr (std::is_same_v<TEvent, WorkerBeginScope>) { | |||
| new_host_event(data.name, 'B'); | |||
| } else if constexpr (std::is_same_v<TEvent, WorkerEndScope>) { | |||
| new_host_event(data.name, 'E'); | |||
| } else if constexpr (std::is_same_v<TEvent, DeviceBeginScope>) { | |||
| new_device_event(data.name, 'B'); | |||
| } else if constexpr (std::is_same_v<TEvent, DeviceEndScope>) { | |||
| new_device_event(data.name, 'E'); | |||
| } else { | |||
| static_assert(!std::is_same_v<TEvent, TEvent>); | |||
| } | |||
| } | |||
| }; | |||
| }; | |||
| } | |||
| void InterpreterProfiler::dump_data( | |||
| std::string basename, | |||
| std::string format, | |||
| InterpreterProfiler::Data profile_data, | |||
| const InterpreterProfiler::Option& option, | |||
| std::function<std::string(std::thread::id)> host_map) { | |||
| InterpreterProfilerDumpChromeTimelineContext{ | |||
| basename, format, profile_data, option, host_map | |||
| }.process(); | |||
| } | |||
| } | |||
| @@ -0,0 +1,97 @@ | |||
| /** | |||
| * \file imperative/src/impl/interpreter/profiler.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "megbrain/imperative/profiler.h" | |||
| #include "./commands.h" | |||
| #include "./events.h" | |||
| #include "./option_manager.h" | |||
| namespace mgb::imperative::interpreter::intl { | |||
| class InterpreterProfiler: public Profiler< | |||
| CommandEnqueueEvent, CommandExecuteEvent, CommandFinishEvent, | |||
| HostOpExecuteEvent, HostOpFinishEvent, | |||
| DeviceOpExecuteEvent, DeviceOpFinishEvent, | |||
| TensorDeclareEvent, TensorProduceEvent, TensorEraseEvent, | |||
| TensorGetPropEvent, TensorWaitPropEvent, TensorNotifyPropEvent, TensorWaitPropFinishEvent, | |||
| SyncStartEvent, SyncFinishEvent, | |||
| ChannelBeginScope, ChannelEndScope, | |||
| WorkerBeginScope, WorkerEndScope, | |||
| DeviceBeginScope, DeviceEndScope> { | |||
| /*22 events now. Enum code may be a better solution*/ | |||
| public: | |||
| enum Topic { | |||
| Command = 0b000001, | |||
| Operator = 0b000010, | |||
| TensorLifetime = 0b000100, | |||
| TensorProp = 0b001000, | |||
| Sync = 0b010000, | |||
| Scope = 0b100000, | |||
| }; | |||
| struct Option { | |||
| Topic topic; | |||
| bool align_time; | |||
| bool show_operator_name; | |||
| static Option from_dict(std::unordered_map<std::string, int> dict) { | |||
| Option option; | |||
| option.topic = Topic(dict.at("topic")); | |||
| option.align_time = bool(dict.at("align_time")); | |||
| option.show_operator_name = bool(dict.at("show_operator_name")); | |||
| return option; | |||
| } | |||
| }; | |||
| Option get_option() const { | |||
| return m_option; | |||
| } | |||
| void set_option(const Option& option) { | |||
| m_option = option; | |||
| } | |||
| static void dump_data(std::string basename, std::string format, InterpreterProfiler::Data profile_data, const Option& option, std::function<std::string(std::thread::id)> host_map); | |||
| static Mask topic_to_mask(Topic topic) { | |||
| Mask result; | |||
| if (topic & Command) { | |||
| result |= mask_of<CommandEnqueueEvent, CommandExecuteEvent, CommandFinishEvent>(); | |||
| } | |||
| if (topic & Operator) { | |||
| result |= mask_of<HostOpExecuteEvent, HostOpFinishEvent>(); | |||
| result |= mask_of<DeviceOpExecuteEvent, DeviceOpFinishEvent>(); | |||
| } | |||
| if (topic & TensorLifetime) { | |||
| result |= mask_of<TensorDeclareEvent, TensorProduceEvent, TensorEraseEvent>(); | |||
| } | |||
| if (topic & TensorProp) { | |||
| result |= mask_of<TensorGetPropEvent, TensorWaitPropEvent, TensorNotifyPropEvent, TensorWaitPropFinishEvent>(); | |||
| } | |||
| if (topic & Sync) { | |||
| result |= mask_of<SyncStartEvent, SyncFinishEvent>(); | |||
| } | |||
| if (topic & Scope) { | |||
| result |= mask_of<ChannelBeginScope, ChannelEndScope, WorkerBeginScope, WorkerEndScope>(); | |||
| result |= mask_of<DeviceBeginScope, DeviceEndScope>(); | |||
| } | |||
| return result; | |||
| } | |||
| private: | |||
| Option m_option; | |||
| }; | |||
| } | |||
| @@ -0,0 +1,135 @@ | |||
| /** | |||
| * \file imperative/src/impl/interpreter/tensor_info.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "megbrain/imperative/physical_tensor.h" | |||
| #include "megbrain/imperative/op_def.h" | |||
| #include "megbrain/imperative/utils/to_string.h" | |||
| namespace mgb::imperative { | |||
| namespace interpreter::intl { | |||
| enum EvictType { | |||
| NONE = 0, | |||
| SWAP = 1, | |||
| DROP = 2, | |||
| }; | |||
| struct TensorInfo; | |||
| using TensorInfoPtr = std::shared_ptr<TensorInfo>; | |||
| struct TensorInfo { | |||
| enum Prop { | |||
| Device, Shape, DType, DevValue, HostValue | |||
| }; | |||
| uint64_t id; | |||
| TensorPtr ptr; | |||
| LogicalTensorDesc desc; | |||
| // FIXME: broken by drop | |||
| bool value_fetched = false; | |||
| bool invalid = false; | |||
| bool allow_delete = false; | |||
| EvictType evict_type = NONE; | |||
| HostTensorND h_value; | |||
| // reserved for auto drop | |||
| size_t pinned = 0; | |||
| size_t recompute_times = 0; | |||
| struct ComputePath { | |||
| std::shared_ptr<OpDef> op; | |||
| SmallVector<TensorInfo*> inputs; | |||
| SmallVector<TensorInfo*> unique_inputs; | |||
| SmallVector<TensorInfo*> outputs; | |||
| size_t ref_cnt() { | |||
| return outputs.size() - std::count(outputs.begin(), outputs.end(), nullptr); | |||
| } | |||
| static ComputePath* make(std::shared_ptr<OpDef> op, SmallVector<TensorInfo*> inputs, SmallVector<TensorInfo*> outputs) { | |||
| auto* path = new TensorInfo::ComputePath(); | |||
| path->op = op; | |||
| path->inputs = inputs; | |||
| path->outputs = outputs; | |||
| // dedup | |||
| SmallVector<TensorInfo*> unique_inputs = inputs; | |||
| std::sort(unique_inputs.begin(), unique_inputs.end()); | |||
| unique_inputs.erase(std::unique(unique_inputs.begin(), unique_inputs.end()), unique_inputs.end()); | |||
| path->unique_inputs = unique_inputs; | |||
| // attach users | |||
| for (auto input: unique_inputs) { | |||
| input->users.push_back(path); | |||
| } | |||
| // attach producer | |||
| for (auto output: outputs) { | |||
| output->producer = path; | |||
| } | |||
| return path; | |||
| } | |||
| }* producer = nullptr; | |||
| void pin() { | |||
| ++pinned; | |||
| } | |||
| void unpin() { | |||
| --pinned; | |||
| } | |||
| void detach_producer() { | |||
| if (!producer) { | |||
| return; | |||
| } | |||
| auto output = std::find(producer->outputs.begin(), producer->outputs.end(), this); | |||
| mgb_assert(output != producer->outputs.end()); | |||
| *output = nullptr; | |||
| if (producer->ref_cnt() == 0) { | |||
| for (auto* input: producer->unique_inputs) { | |||
| input->users.erase(std::find(input->users.begin(), input->users.end(), producer)); | |||
| } | |||
| delete producer; | |||
| } | |||
| producer = nullptr; | |||
| } | |||
| SmallVector<ComputePath*> users; | |||
| }; | |||
| } | |||
| template <> | |||
| struct ToStringTrait<interpreter::intl::TensorInfo::Prop>{ | |||
| using TensorInfo = interpreter::intl::TensorInfo; | |||
| std::string operator()(TensorInfo::Prop prop) const { | |||
| switch(prop) { | |||
| case TensorInfo::DType: | |||
| return "dtype"; | |||
| case TensorInfo::DevValue: | |||
| return "dev_value"; | |||
| case TensorInfo::Device: | |||
| return "device"; | |||
| case TensorInfo::HostValue: | |||
| return "host_value"; | |||
| case TensorInfo::Shape: | |||
| return "shape"; | |||
| default: | |||
| return "unknown"; | |||
| } | |||
| } | |||
| }; | |||
| } | |||
| @@ -1,351 +0,0 @@ | |||
| /** | |||
| * \file imperative/src/impl/interpreter_impl.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include <deque> | |||
| #include <future> | |||
| #include <list> | |||
| #include <unordered_set> | |||
| #include <variant> | |||
| #include "megbrain/utils/mempool.h" | |||
| #include "megbrain/imperative/interpreter.h" | |||
| namespace mgb::imperative::interpreter::intl { | |||
| using Handle = Interpreter::Handle; | |||
| struct InterpreterImpl : Interpreter { | |||
| std::unique_ptr<Channel> create_channel() override; | |||
| }; | |||
| enum EvictType { | |||
| NONE = 0, | |||
| SWAP = 1, | |||
| DROP = 2, | |||
| }; | |||
| struct TensorInfo; | |||
| using TensorInfoPtr = std::shared_ptr<TensorInfo>; | |||
| struct TensorInfo { | |||
| TensorPtr ptr; | |||
| LogicalTensorDesc desc; | |||
| // FIXME: broken by drop | |||
| bool value_fetched = false; | |||
| bool invalid = false; | |||
| EvictType evict_type = NONE; | |||
| HostTensorND h_value; | |||
| // reserved for auto drop | |||
| size_t pinned = 0; | |||
| size_t recompute_times = 0; | |||
| struct ComputePath { | |||
| std::shared_ptr<OpDef> op; | |||
| SmallVector<TensorInfo*> inputs; | |||
| SmallVector<TensorInfo*> unique_inputs; | |||
| SmallVector<TensorInfo*> outputs; | |||
| size_t ref_cnt() { | |||
| return outputs.size() - std::count(outputs.begin(), outputs.end(), nullptr); | |||
| } | |||
| static ComputePath* make(std::shared_ptr<OpDef> op, SmallVector<TensorInfo*> inputs, SmallVector<TensorInfo*> outputs) { | |||
| auto* path = new TensorInfo::ComputePath(); | |||
| path->op = op; | |||
| path->inputs = inputs; | |||
| path->outputs = outputs; | |||
| // dedup | |||
| SmallVector<TensorInfo*> unique_inputs = inputs; | |||
| std::sort(unique_inputs.begin(), unique_inputs.end()); | |||
| unique_inputs.erase(std::unique(unique_inputs.begin(), unique_inputs.end()), unique_inputs.end()); | |||
| path->unique_inputs = unique_inputs; | |||
| // attach users | |||
| for (auto input: unique_inputs) { | |||
| input->users.push_back(path); | |||
| } | |||
| // attach producer | |||
| for (auto output: outputs) { | |||
| output->producer = path; | |||
| } | |||
| return path; | |||
| } | |||
| }* producer = nullptr; | |||
| void pin() { | |||
| ++pinned; | |||
| } | |||
| void unpin() { | |||
| --pinned; | |||
| } | |||
| void detach_producer() { | |||
| if (!producer) { | |||
| return; | |||
| } | |||
| auto output = std::find(producer->outputs.begin(), producer->outputs.end(), this); | |||
| mgb_assert(output != producer->outputs.end()); | |||
| *output = nullptr; | |||
| if (producer->ref_cnt() == 0) { | |||
| for (auto* input: producer->unique_inputs) { | |||
| input->users.erase(std::find(input->users.begin(), input->users.end(), producer)); | |||
| } | |||
| delete producer; | |||
| } | |||
| producer = nullptr; | |||
| } | |||
| SmallVector<ComputePath*> users; | |||
| }; | |||
| struct Put { | |||
| TensorInfo* dest; | |||
| HostTensorND value; | |||
| bool no_cache = false; | |||
| std::string to_string() const { return ssprintf("Command: Put %p", dest); } | |||
| }; | |||
| struct ApplyOp { | |||
| std::shared_ptr<OpDef> op; | |||
| SmallVector<TensorInfo*> inputs; | |||
| SmallVector<TensorInfo*> outputs; | |||
| SmallVector<TensorInfo*> dels; | |||
| std::string to_string() const { | |||
| std::string builder{"Command: ApplyOp {"}; | |||
| builder += "inputs ["; | |||
| for (auto* input : inputs) { | |||
| builder += ssprintf("%p, ", input); | |||
| } | |||
| builder += "], outputs ["; | |||
| for (auto* output : outputs) { | |||
| builder += ssprintf("%p, ", output); | |||
| } | |||
| builder += "], dels ["; | |||
| for (auto* del : dels) { | |||
| builder += ssprintf("%p, ", del); | |||
| } | |||
| builder += "]"; | |||
| return builder; | |||
| } | |||
| }; | |||
| struct Del { | |||
| TensorInfo* dest; | |||
| std::string to_string() const { return ssprintf("Command: Del %p", dest); } | |||
| }; | |||
| struct GetValue { | |||
| TensorInfo* dest; | |||
| std::string to_string() const { | |||
| return ssprintf("Command: GetValue %p", dest); | |||
| } | |||
| }; | |||
| struct SwapIn { | |||
| TensorInfo* dest; | |||
| std::string to_string() const { | |||
| return ssprintf("Command: SwapIn %p", dest); | |||
| } | |||
| }; | |||
| struct SwapOut { | |||
| TensorInfo* dest; | |||
| std::string to_string() const { | |||
| return ssprintf("Command: SwapOut %p", dest); | |||
| } | |||
| }; | |||
| struct Drop { | |||
| TensorInfo* dest; | |||
| std::string to_string() const { | |||
| return ssprintf("Command: Drop %p", dest); | |||
| } | |||
| }; | |||
| struct Move { | |||
| TensorInfo* src; | |||
| TensorInfo* dest; | |||
| std::string to_string() const { | |||
| return ssprintf("Command: Move %s to %s", | |||
| src->desc.layout.to_string().c_str(), | |||
| dest->desc.layout.to_string().c_str()); | |||
| } | |||
| }; | |||
| struct Flush { | |||
| TensorInfo* dest = nullptr; | |||
| std::string to_string() const { | |||
| return ssprintf("Command: Flush %p", dest); | |||
| } | |||
| }; | |||
| struct Nop { | |||
| std::string to_string() const { return "Command: Nop"; } | |||
| }; | |||
| using Command = std::variant<Put, | |||
| ApplyOp, | |||
| Del, | |||
| GetValue, | |||
| SwapIn, | |||
| SwapOut, | |||
| Drop, | |||
| Move, | |||
| Flush, | |||
| Nop>; | |||
| struct ChannelImpl : Interpreter::Channel { | |||
| ChannelImpl() : m_worker(this), m_buffer(this) {} | |||
| ~ChannelImpl() override; | |||
| Handle put(const HostTensorND& value, bool no_cache) override; | |||
| Handle put(const DeviceTensorND& value) override; | |||
| void del(Handle) override; | |||
| void swap_in(Handle) override; | |||
| void swap_out(Handle) override; | |||
| void drop(Handle) override; | |||
| SmallVector<Handle> apply_op( | |||
| std::shared_ptr<OpDef> op, | |||
| const SmallVector<Handle>& inputs) override; | |||
| HostTensorND get_value(Handle) override; | |||
| TensorShape get_shape(Handle) override; | |||
| DType get_dtype(Handle) override; | |||
| CompNode get_device(Handle) override; | |||
| DeviceTensorND get_dev_tensor(Handle) override; | |||
| void sync() override; | |||
| void close() override; | |||
| void set_swap_flag(bool) override; | |||
| void set_drop_flag(bool) override; | |||
| void set_buffer_length(int) override; | |||
| void config_async_level(int level) override; | |||
| int get_async_level() override; | |||
| private: | |||
| TensorInfo* alloc(); | |||
| void free(TensorInfo*); | |||
| void detach_users(TensorInfo*); | |||
| void process_one_task(Command&); | |||
| void check_worker_exc_unsafe(); | |||
| void produce_tensor(TensorInfo* dest, TensorPtr ptr); | |||
| void release_tensor(TensorInfo* dest); | |||
| void regenerate(TensorInfo* dest); | |||
| void recompute(TensorInfo::ComputePath* path); | |||
| void dispatch_default_cpu( | |||
| std::shared_ptr<OpDef> op, | |||
| const SmallVector<TensorInfo*>& input_infos, | |||
| const SmallVector<LogicalTensorDesc>& input_descs, | |||
| SmallVector<Handle>* outputs); | |||
| void dispatch_kernel( | |||
| std::shared_ptr<OpDef> op, | |||
| const SmallVector<TensorInfo*>& input_infos, | |||
| const SmallVector<LogicalTensorDesc>& input_descs, | |||
| SmallVector<Handle>* outputs); | |||
| std::mutex m_mutex; | |||
| std::condition_variable m_cv; | |||
| MemPool<TensorInfo> m_pool; | |||
| std::unordered_set<Handle> m_valid_handle; | |||
| TensorInfo* m_waitee = nullptr; | |||
| std::exception_ptr m_worker_exc; | |||
| size_t m_enable_evict = 0; | |||
| struct WorkQueue : AsyncQueueSC<Command, WorkQueue> { | |||
| // set max_spin=0 to prevent Queue fetch task in busy wait manner. | |||
| // this won't affect throughput when python interpreter is sending enough task, | |||
| // but will significantly save CPU time when waiting for task, e.g. wait for data input | |||
| WorkQueue(ChannelImpl* owner) | |||
| : AsyncQueueSC<Command, WorkQueue>(0), m_owner(owner) { | |||
| sys::set_thread_name("interpreter"); | |||
| } | |||
| void process_one_task(Command& cmd) { | |||
| m_owner->process_one_task(cmd); | |||
| } | |||
| void on_async_queue_worker_thread_start() override { | |||
| sys::set_thread_name("worker"); | |||
| } | |||
| private: | |||
| ChannelImpl* m_owner; | |||
| } m_worker; | |||
| /** | |||
| * Buf a command window for following fuse | |||
| * example: | |||
| * --------------------------------------------------------------------- | |||
| * | ..., Apply{in: (i0, i1), out: (o0, o1)}, ... + Del{i0} + Del{i1} | | |||
| * --------------------------------------------------------------------- | |||
| * | ..., Apply{in: (i0, i1), out: (o0, o1), del: (i0)}, ... + Del{i1} | | |||
| * --------------------------------------------------------------------- | |||
| * | ..., Apply{in: (i0, i1), out: (o0, o1), del: (i0, i1)}, ... | | |||
| * --------------------------------------------------------------------- | |||
| * Then the fused Apply may be invoked inplace. see: ChannelImpl::process_one_task | |||
| */ | |||
| struct CommandBuffer { | |||
| CommandBuffer(ChannelImpl* owner) : m_owner(owner) { | |||
| int capacity = 3; | |||
| if(const char* capacity_str = MGB_GETENV("MEGENGINE_COMMAND_BUFFER_LENGTH")) { | |||
| capacity = atoi(capacity_str); | |||
| } | |||
| set_capacity(capacity); | |||
| } | |||
| void enqueue(Command cmd); | |||
| bool empty() const { | |||
| return m_commands.empty(); | |||
| } | |||
| void set_capacity(int capacity) { | |||
| mgb_assert(capacity >= 0 && capacity < 100, "invalid command buffer length"); | |||
| m_capacity = capacity; | |||
| } | |||
| private: | |||
| ChannelImpl* m_owner; | |||
| size_t m_capacity; | |||
| std::deque<Command> m_commands; | |||
| using Handle = decltype(m_commands)::iterator; | |||
| // [begin, end) | |||
| using Range = std::array<Handle, 2>; | |||
| // Launch commands in range [m_commands.begin(), pos) | |||
| void flush(Handle pos); | |||
| // Select flush position for incoming cmd | |||
| Handle flush_pos_for(const Command& cmd); | |||
| // Fuse del command into suitable ApplyOp | |||
| bool fuse_del(const Del& cmd); | |||
| // Returns the last handle that dest is used within range. If dest is not used, returns range[1] | |||
| Handle find_last_usage(TensorInfo* dest, Range range); | |||
| // Returns the produce position of dest. If not found, returns range[1] | |||
| Handle find_produce(TensorInfo* dest, Range range); | |||
| } m_buffer; | |||
| //! config whether raise error exactly when invoking op. | |||
| //! level 2: both device and user side errors are async; | |||
| //! level 1: user side errors are sync; | |||
| //! level 0: both sync. | |||
| int m_async_level = 2; | |||
| int m_max_recompute_time = 1; | |||
| }; | |||
| } // namespace mgb::imperative::interpreter::intl | |||
| @@ -70,6 +70,26 @@ BackwardGraphResult OpDef::make_backward_graph( | |||
| return def.trait()->make_backward_graph(def, inputs, input_requires_grad, output_has_grad); | |||
| } | |||
| std::vector<std::pair<const char*, std::string>> OpDef::props( | |||
| const OpDef& def) { | |||
| return def.trait()->props(def); | |||
| } | |||
| const char* OpDef::name() const { | |||
| return trait()->name; | |||
| } | |||
| std::string OpDef::to_string() const { | |||
| std::string builder = "{"; | |||
| for (auto&& [name, value]: props(*this)) { | |||
| builder += name; | |||
| builder += ": "; | |||
| builder += value; | |||
| builder += ","; | |||
| } | |||
| return builder + "}"; | |||
| } | |||
| size_t OpDef::hash() const { | |||
| return trait()->hash(*this); | |||
| } | |||
| @@ -72,6 +72,7 @@ using InferOutputAttrsFallible = detail::OpMeth< | |||
| decltype(OpDef::infer_output_attrs_fallible)>; | |||
| using GradMaker = detail::OpMeth< | |||
| decltype(OpDef::make_backward_graph)>; | |||
| using Props = detail::OpMeth<decltype(OpDef::props)>; | |||
| using HashFunc = detail::OpMeth<size_t(const OpDef&)>; | |||
| using IsSame = detail::OpMeth<bool(const OpDef&, const OpDef&)>; | |||
| @@ -84,6 +85,7 @@ struct OpTrait { | |||
| ApplyOnVarNode apply_on_var_node; | |||
| InferOutputAttrsFallible infer_output_attrs_fallible; | |||
| GradMaker make_backward_graph; | |||
| Props props; | |||
| HashFunc hash; | |||
| IsSame is_same_st; | |||
| OpTrait(const char* name); | |||
| @@ -100,6 +102,7 @@ struct OpTrait { | |||
| cb(apply_on_var_node) \ | |||
| cb(infer_output_attrs_fallible) \ | |||
| cb(make_backward_graph) \ | |||
| cb(props) \ | |||
| cb(hash) \ | |||
| cb(is_same_st) | |||
| @@ -148,9 +148,15 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_tensor_attrs( | |||
| .graph().infer_attrs(inputs); | |||
| } | |||
| std::vector<std::pair<const char*, std::string>> props( | |||
| const OpDef& backward_graph) { | |||
| return {}; | |||
| } | |||
| OP_TRAIT_REG(BackwardGraph, BackwardGraph) | |||
| .apply_on_physical_tensor(backward_impl) | |||
| .infer_output_attrs_fallible(infer_tensor_attrs) | |||
| .props(props) | |||
| .fallback(); | |||
| } // anonymous namespace | |||
| @@ -95,9 +95,14 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* opr) { | |||
| return OprAttr::make(registry->name, std::move(ctx.m_param), opr->config()); | |||
| } | |||
| std::vector<std::pair<const char*, std::string>> props(const OpDef& def) { | |||
| return {}; | |||
| } | |||
| OP_TRAIT_REG(OprAttr, OprAttr) | |||
| .make_from_op_node(make_from_op_node) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .props(props) | |||
| .fallback(); | |||
| } // anonymous namespace | |||
| @@ -11,12 +11,14 @@ | |||
| #include "megbrain/imperative/profiler.h" | |||
| #include "./function_hook.h" | |||
| #include <chrono> | |||
| #include "megbrain/imperative/ops/opr_attr.h" | |||
| #include "megbrain/imperative/physical_tensor.h" | |||
| #include "megbrain/plugin/opr_footprint.h" | |||
| #include "./function_hook.h" | |||
| #include "./event_pool.h" | |||
| #include "./op_trait.h" | |||
| @@ -25,200 +27,42 @@ namespace imperative { | |||
| namespace { | |||
| CompNode::UnorderedSet collect_comp_nodes( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
| CompNode::UnorderedSet comp_nodes; | |||
| SmallVector<LogicalTensorDesc> inp_descs; | |||
| for (auto&& i : inputs) { | |||
| comp_nodes.insert(i->comp_node()); | |||
| inp_descs.push_back({i->layout(), i->comp_node(), {}}); | |||
| } | |||
| SmallVector<LogicalTensorDesc> oup_descs = std::get<0>(def.infer_output_attrs_fallible(def, inp_descs)); | |||
| for (auto&& output_attr : oup_descs) { | |||
| comp_nodes.insert(output_attr.comp_node); | |||
| } | |||
| return comp_nodes; | |||
| } | |||
| DeviceTimer::SharedEvent alloc_recorded_event(CompNode device) { | |||
| auto event = EventPool::with_timer().alloc_shared(device); | |||
| event->record(); | |||
| return event; | |||
| } | |||
| OprFootprint footprint{}; | |||
| } // namespace | |||
| void DeviceTimer::reset(thin_function<double()> host_timer) { | |||
| CompNode::foreach ([this, host_timer](CompNode device) { | |||
| m_base_event_table[device] = {alloc_recorded_event(device), host_timer()}; | |||
| }); | |||
| m_host_timer = host_timer; | |||
| DeviceTimer::SharedEvent DeviceTimer::get_device_time(CompNode device) { | |||
| return alloc_recorded_event(device); | |||
| } | |||
| thin_function<double()> DeviceTimer::get_device_time(CompNode device) { | |||
| auto event = EventPool::with_timer().alloc_shared(device); | |||
| event->record(); | |||
| if(m_base_event_table.count(device) == 0) { | |||
| m_base_event_table[device] = {alloc_recorded_event(device), m_host_timer()}; | |||
| SmallVector<DeviceTimer::SharedEvent> DeviceTimer::get_all(SmallVector<CompNode> device_list) { | |||
| SmallVector<DeviceTimer::SharedEvent> results; | |||
| for (auto&& device: device_list) { | |||
| results.push_back(alloc_recorded_event(device)); | |||
| } | |||
| auto base = m_base_event_table[device]; | |||
| return [base, event] { | |||
| auto [base_event, host_time] = base; | |||
| // TODO: sync once for each compnode | |||
| event->host_wait(); | |||
| return base_event->elapsed_time_until(*event) * 1000 + host_time; | |||
| }; | |||
| return results; | |||
| } | |||
| void DeviceTimer::clear() { | |||
| m_base_event_table.clear(); | |||
| double HostTimer::get_msecs() { | |||
| using namespace std::chrono; | |||
| auto finish = steady_clock::now(); | |||
| auto duration = duration_cast<microseconds>(finish - m_start); | |||
| return (double)duration.count() / 1e3; | |||
| } | |||
| size_t TensorRecorder::record_tensor(const TensorPtr& tensor) { | |||
| if (m_tensor_map.count(tensor.get()) > 0) { | |||
| auto& [prev, id] = m_tensor_map[tensor.get()]; | |||
| if (prev.lock() != tensor) { | |||
| prev = tensor; | |||
| id = m_next_id++; | |||
| } | |||
| return id; | |||
| } else { | |||
| auto id = m_next_id++; | |||
| m_tensor_map.insert( | |||
| {tensor.get(), {std::weak_ptr<Tensor>{tensor}, id}}); | |||
| return id; | |||
| } | |||
| } | |||
| void TensorRecorder::clear() { | |||
| m_next_id = 0; | |||
| m_tensor_map.clear(); | |||
| } | |||
| Profile& Profiler::get_profile() { | |||
| for (auto& entry : m_profile) { | |||
| for (auto& [device, device_begin, device_end] : entry.device_list) { | |||
| MGB_MARK_USED_VAR(device); | |||
| device_begin = [value = device_begin()] { return value; }; | |||
| device_end = [value = device_end()] { return value; }; | |||
| } | |||
| } | |||
| return m_profile; | |||
| } | |||
| void Profiler::start(uint32_t flags) { | |||
| m_host_timer.reset(); | |||
| m_device_timer.reset([&] { return m_host_timer.get_msecs(); }); | |||
| OpTrait::for_each_trait([this, flags](OpTrait& trait) { | |||
| auto hook_apply_on_physical_tensor = | |||
| make_shared_hook(&trait.apply_on_physical_tensor); | |||
| auto hook_apply_on_var_node = | |||
| make_shared_hook(&trait.apply_on_var_node); | |||
| hook_apply_on_physical_tensor->apply_hook([this, flags] | |||
| (auto&& apply, const OpDef& def, SmallVector<TensorPtr> inputs) { | |||
| auto shape2vector = [](const TensorShape& shape) { | |||
| std::vector<size_t> vector_shape; | |||
| for (size_t i = 0; i < shape.ndim; i++) { | |||
| vector_shape.push_back(shape[i]); | |||
| } | |||
| return vector_shape; | |||
| }; | |||
| ProfileEntry entry; | |||
| entry.id = m_entry_count++; | |||
| // TODO: assign parent | |||
| entry.parent = 0; | |||
| // Record apply context and save to m_profile | |||
| entry.op = const_cast<OpDef&>(def).shared_from_this(); | |||
| for (auto&& input : inputs) { | |||
| entry.inputs.push_back({m_tensor_recorder.record_tensor(input), | |||
| shape2vector(input->layout()), | |||
| input->comp_node()}); | |||
| } | |||
| double host_begin = m_host_timer.get_msecs(); | |||
| auto&& comp_nodes = collect_comp_nodes(def, inputs); | |||
| for (auto&& comp_node : comp_nodes) { | |||
| entry.device_list.push_back( | |||
| {comp_node, | |||
| m_device_timer.get_device_time(comp_node), | |||
| {}}); | |||
| } | |||
| if (flags & PROFILE_FOOTPRINT) { | |||
| MGB_LOCK_GUARD(m_lock); | |||
| m_entry_stack.push({&def, &entry, std::this_thread::get_id()}); | |||
| } | |||
| // Do real apply | |||
| auto outputs = apply(def, inputs); | |||
| for (auto& [cn, dev_begin, dev_end] : entry.device_list) { | |||
| MGB_MARK_USED_VAR(cn); | |||
| MGB_MARK_USED_VAR(dev_begin); | |||
| dev_end = m_device_timer.get_device_time(cn); | |||
| } | |||
| entry.host = {host_begin, m_host_timer.get_msecs()}; | |||
| for (auto&& output : outputs) { | |||
| entry.outputs.push_back( | |||
| {m_tensor_recorder.record_tensor(output), | |||
| shape2vector(output->layout()), output->comp_node()}); | |||
| } | |||
| if (flags & PROFILE_FOOTPRINT) { | |||
| mgb_assert(std::get<1>(m_entry_stack.top()) == &entry); | |||
| MGB_LOCK_GUARD(m_lock); | |||
| m_entry_stack.pop(); | |||
| } | |||
| m_profile.push_back(std::move(entry)); | |||
| return outputs; | |||
| }); | |||
| if (flags & PROFILE_FOOTPRINT) { | |||
| hook_apply_on_var_node->apply_hook( | |||
| [this](auto&& apply, const OpDef& def, | |||
| VarNodeArray inputs) -> VarNodeArray { | |||
| auto vars = apply(def, std::move(inputs)); | |||
| std::remove_reference_t<decltype(m_entry_stack.top())> | |||
| top; | |||
| { | |||
| MGB_LOCK_GUARD(m_lock); | |||
| if (m_entry_stack.empty()) { | |||
| return vars; | |||
| } | |||
| top = m_entry_stack.top(); | |||
| } | |||
| auto [current_op, current_entry, thread_id] = top; | |||
| if (current_op != &def || | |||
| thread_id != std::this_thread::get_id()) { | |||
| return vars; | |||
| } | |||
| auto&& footprint_result = | |||
| footprint.calc_footprint(vars[0]->owner_opr()); | |||
| current_entry->memory = footprint_result.memory; | |||
| current_entry->computation = | |||
| footprint_result.computation; | |||
| #if MGB_ENABLE_JSON | |||
| current_entry->param = footprint_result.param; | |||
| #endif | |||
| return vars; | |||
| }); | |||
| } | |||
| m_hooker_list.push_back(std::move(hook_apply_on_physical_tensor)); | |||
| m_hooker_list.push_back(std::move(hook_apply_on_var_node)); | |||
| }); | |||
| } | |||
| void Profiler::stop() { | |||
| m_hooker_list.clear(); | |||
| for (auto& entry : m_profile) { | |||
| entry.wait_device(); | |||
| } | |||
| double HostTimer::get_started_at() { | |||
| return m_started_at; | |||
| } | |||
| void Profiler::clear() { | |||
| mgb_assert(m_entry_stack.empty(), | |||
| "entry_stack should be empty after profile"); | |||
| mgb_assert(m_hooker_list.empty(), "hooks should be released"); | |||
| m_profile.clear(); | |||
| m_entry_count = 0; | |||
| m_device_timer.clear(); | |||
| m_tensor_recorder.clear(); | |||
| void HostTimer::reset() { | |||
| using namespace std::chrono; | |||
| m_start = steady_clock::now(); | |||
| auto now_us = duration_cast<microseconds>(std::chrono::system_clock::now().time_since_epoch()); | |||
| m_started_at = (double)(now_us.count()) / 1e3; | |||
| } | |||
| } // namespace imperative | |||
| @@ -471,6 +471,7 @@ class ExecMiniGraph : public ProxyGraph::MiniGraph { | |||
| } | |||
| if (can_pop) { | |||
| for (auto _ : comp_node_trackers) { | |||
| MGB_MARK_USED_VAR(_); | |||
| busy_oprs.pop_front(); | |||
| } | |||
| m_opr = busy_oprs.front().opr; | |||
| @@ -10,6 +10,7 @@ | |||
| */ | |||
| #include <atomic> | |||
| #include <any> | |||
| #include "megbrain/imperative/op_def.h" | |||
| @@ -42,12 +43,15 @@ struct Interpreter { | |||
| virtual void sync() = 0; | |||
| virtual void close() = 0; | |||
| virtual void set_swap_flag(bool) = 0; | |||
| virtual void set_drop_flag(bool) = 0; | |||
| virtual void set_buffer_length(int) = 0; | |||
| virtual void config_async_level(int level) = 0; | |||
| virtual int get_async_level() = 0; | |||
| virtual int get_option(std::string name) = 0; | |||
| virtual void set_option(std::string name, int value) = 0; | |||
| virtual void start_profile(std::unordered_map<std::string, int> option) = 0; | |||
| virtual void stop_profile(std::string basename, std::string format) = 0; | |||
| virtual void push_scope(std::string name) = 0; | |||
| virtual void pop_scope(std::string name) = 0; | |||
| }; | |||
| virtual std::unique_ptr<Channel> create_channel() = 0; | |||
| @@ -13,6 +13,7 @@ | |||
| #include "megbrain/graph.h" | |||
| #include "megbrain/imperative/physical_tensor.h" | |||
| #include "megbrain/imperative/utils/to_string.h" | |||
| namespace mgb { | |||
| namespace imperative { | |||
| @@ -80,8 +81,15 @@ public: | |||
| const SmallVector<bool>& input_requires_grad, | |||
| const SmallVector<bool>& output_has_grad); | |||
| static std::vector<std::pair<const char*, std::string>> props( | |||
| const OpDef& def); | |||
| const OpTrait* trait() const; | |||
| const char* name() const; | |||
| std::string to_string() const; | |||
| virtual size_t hash() const; | |||
| virtual bool is_same_st(const Hashable&) const; | |||
| @@ -96,6 +104,16 @@ public: | |||
| } | |||
| }; | |||
| template <> | |||
| struct ToStringTrait<OpDef*>{ | |||
| std::string operator()(OpDef* op) const { | |||
| if (op == nullptr) { | |||
| return "nullptr"; | |||
| } | |||
| return op->to_string(); | |||
| } | |||
| }; | |||
| } // namespace imperative | |||
| } // namespace mgb | |||
| @@ -11,10 +11,12 @@ | |||
| #pragma once | |||
| #include <any> | |||
| #include <optional> | |||
| #include <stack> | |||
| #include <list> | |||
| #include <map> | |||
| #include <variant> | |||
| #include <fstream> | |||
| #include <chrono> | |||
| #include <bitset> | |||
| #include "megbrain/comp_node.h" | |||
| #include "megbrain/graph/event.h" | |||
| @@ -27,89 +29,298 @@ | |||
| namespace mgb { | |||
| namespace imperative { | |||
| using ProfileTensor = std::tuple<size_t, std::vector<size_t>, CompNode>; | |||
| struct ProfileEntry { | |||
| using TimeClosure = std::function<double()>; | |||
| size_t id; | |||
| size_t parent; | |||
| std::shared_ptr<OpDef> op; | |||
| //(host_begin, host_end) | |||
| std::tuple<double, double> host; | |||
| //[(device, device_begin, device_end)] | |||
| std::vector<std::tuple<CompNode, TimeClosure, TimeClosure>> device_list; | |||
| std::vector<ProfileTensor> inputs; | |||
| std::vector<ProfileTensor> outputs; | |||
| long long memory = 0; | |||
| long long computation = 0; | |||
| #if MGB_ENABLE_JSON | |||
| std::shared_ptr<json::Value> param; | |||
| #endif | |||
| void wait_device() { | |||
| for (auto& [cn, begin, end] : device_list) { | |||
| MGB_MARK_USED_VAR(cn); | |||
| begin = [begin = begin()] { return begin; }; | |||
| end = [end = end()] { return end; }; | |||
| } | |||
| } | |||
| }; | |||
| using Profile = std::list<ProfileEntry>; | |||
| class DeviceTimer { | |||
| public: | |||
| using SharedEvent = std::shared_ptr<CompNode::Event>; | |||
| DeviceTimer() = default; | |||
| void reset(thin_function<double()> host_timer); | |||
| thin_function<double()> get_device_time(CompNode device); | |||
| void clear(); | |||
| SharedEvent get_device_time(CompNode device); | |||
| SmallVector<SharedEvent> get_all(SmallVector<CompNode> device_list); | |||
| }; | |||
| class HostTimer { | |||
| public: | |||
| void reset(); | |||
| double get_msecs(); | |||
| double get_started_at(); | |||
| private: | |||
| CompNode::UnorderedMap<std::tuple<SharedEvent, double>> m_base_event_table; | |||
| thin_function<double()> m_host_timer; | |||
| decltype(std::chrono::steady_clock::now()) m_start; | |||
| double m_started_at; | |||
| }; | |||
| class TensorRecorder { | |||
| private: | |||
| // active tensors | |||
| std::unordered_map<Tensor*, std::tuple<std::weak_ptr<Tensor>, size_t>> | |||
| m_tensor_map; | |||
| size_t m_next_id; | |||
| class ProfilerBase { | |||
| public: | |||
| size_t record_tensor(const TensorPtr& tensor); | |||
| void clear(); | |||
| using Host = std::thread::id; | |||
| using Device = CompNode; | |||
| struct HostInstant { | |||
| Host tid; | |||
| double time; | |||
| void wait() {} | |||
| }; | |||
| struct DeviceInstant { | |||
| double before; | |||
| std::shared_ptr<CompNode::Event> event; | |||
| double after; | |||
| void wait() { | |||
| event->host_wait(); | |||
| } | |||
| }; | |||
| using Instant = std::variant<HostInstant, DeviceInstant>; | |||
| template <typename TEvent> | |||
| struct EventRecord { | |||
| Instant instant; | |||
| TEvent data; | |||
| HostInstant& host() { | |||
| return std::get<HostInstant>(instant); | |||
| } | |||
| DeviceInstant device() { | |||
| return std::get<DeviceInstant>(instant); | |||
| } | |||
| void wait() { | |||
| std::visit([&](auto& instant){ instant.wait(); }, instant); | |||
| } | |||
| }; | |||
| protected: | |||
| HostInstant record_host() { | |||
| return {std::this_thread::get_id(), m_host_timer.get_msecs()}; | |||
| } | |||
| DeviceInstant record_device(Device device) { | |||
| auto before = m_host_timer.get_msecs(); | |||
| auto event = m_device_timer.get_device_time(device); | |||
| auto after = m_host_timer.get_msecs(); | |||
| return {before, event, after}; | |||
| } | |||
| protected: | |||
| std::atomic_int64_t m_last_id = 0; | |||
| HostTimer m_host_timer; | |||
| DeviceTimer m_device_timer; | |||
| Spinlock m_lock; | |||
| }; | |||
| class Profiler { | |||
| template <typename... TEvents> | |||
| class Profiler: public ProfilerBase { | |||
| public: | |||
| enum Flags { | |||
| PROFILE_FOOTPRINT = 1, | |||
| using Record = std::variant<EventRecord<TEvents>...>; | |||
| using Mask = std::bitset<sizeof...(TEvents)>; | |||
| struct Data { | |||
| std::vector<Record> records; | |||
| double started_at; | |||
| }; | |||
| template <typename TEvent, size_t index = 0> | |||
| static constexpr size_t index_of() { | |||
| if constexpr (index == std::variant_size_v<Record>) { | |||
| return index; | |||
| } else if constexpr (std::is_same_v<EventRecord<TEvent>, std::variant_alternative_t<index, Record>>) { | |||
| return index; | |||
| } else { | |||
| return index_of<TEvent, index+1>(); | |||
| } | |||
| }; | |||
| template <typename... TEvents2> | |||
| static Mask mask_of() { | |||
| return Mask{} | (Mask{}.set(index_of<TEvents2>()) |...); | |||
| } | |||
| enum Status { | |||
| NotStarted, Profiling, Stopped | |||
| }; | |||
| public: | |||
| Profiler() = default; | |||
| // Start profiler by hook OpTrait | |||
| void start(uint32_t flags); | |||
| // Stop profiler and clean environment | |||
| void stop(); | |||
| void clear(); | |||
| Profile& get_profile(); | |||
| template <typename TEvent, typename... TArgs> | |||
| void record_host(TArgs&&... args) { | |||
| auto instant = HostInstant{std::this_thread::get_id(), m_host_timer.get_msecs()}; | |||
| MGB_LOCK_GUARD(m_lock); | |||
| if (!m_event_mask.test(index_of<TEvent>())) { | |||
| return; | |||
| } | |||
| mgb_assert(m_status != Stopped, "record after stop"); | |||
| m_record_list.emplace_back(EventRecord<TEvent>{std::move(instant), {std::forward<TArgs>(args)...}}); | |||
| } | |||
| template <typename TEvent, typename... TArgs> | |||
| void record_device(Device device, TArgs&&... args) { | |||
| auto before = m_host_timer.get_msecs(); | |||
| auto event = m_device_timer.get_device_time(device); | |||
| auto after = m_host_timer.get_msecs(); | |||
| auto instant = DeviceInstant{before, event, after}; | |||
| MGB_LOCK_GUARD(m_lock); | |||
| if (!m_event_mask.test(index_of<TEvent>())) { | |||
| return; | |||
| } | |||
| mgb_assert(m_status != Stopped, "record after stop"); | |||
| m_record_list.emplace_back(EventRecord<TEvent>{std::move(instant), {std::forward<TArgs>(args)...}}); | |||
| } | |||
| void start(Mask mask) { | |||
| MGB_LOCK_GUARD(m_lock); | |||
| mgb_assert(m_status == NotStarted, "profiler already started"); | |||
| m_status = Profiling; | |||
| m_event_mask = mask; | |||
| m_host_timer.reset(); | |||
| } | |||
| Data stop() { | |||
| MGB_LOCK_GUARD(m_lock); | |||
| mgb_assert(m_status == Profiling, "profiler not active"); | |||
| m_status = Stopped; | |||
| for (auto&& record: m_record_list) { | |||
| std::visit([&](auto& record){ | |||
| record.wait(); | |||
| }, record); | |||
| } | |||
| auto records = std::move(m_record_list); | |||
| return { records, m_host_timer.get_started_at() }; | |||
| } | |||
| protected: | |||
| std::vector<Record> m_record_list; | |||
| Mask m_event_mask; | |||
| Status m_status = NotStarted; | |||
| }; | |||
| class ChromeTraceEvent { | |||
| public: | |||
| ChromeTraceEvent& name(std::string name) { | |||
| m_name = std::move(name); | |||
| return *this; | |||
| } | |||
| ChromeTraceEvent& tid(uint64_t tid) { | |||
| m_tid = std::move(tid); | |||
| return *this; | |||
| } | |||
| ChromeTraceEvent& cat(std::string cat) { | |||
| m_cat = std::move(cat); | |||
| return *this; | |||
| } | |||
| ChromeTraceEvent& pid(uint64_t pid) { | |||
| m_pid = pid; | |||
| return *this; | |||
| } | |||
| ChromeTraceEvent& id(uint64_t id) { | |||
| m_id = id; | |||
| return *this; | |||
| } | |||
| ChromeTraceEvent& idx(uint64_t idx) { | |||
| m_idx = idx; | |||
| return *this; | |||
| } | |||
| ChromeTraceEvent& ts(double ts) { | |||
| m_ts = ts; | |||
| return *this; | |||
| } | |||
| ChromeTraceEvent& dur(double dur) { | |||
| m_dur = dur; | |||
| return *this; | |||
| } | |||
| ChromeTraceEvent& ph(char ph) { | |||
| m_ph = ph; | |||
| return *this; | |||
| } | |||
| ChromeTraceEvent& bp(char bp) { | |||
| m_bp = bp; | |||
| return *this; | |||
| } | |||
| ChromeTraceEvent& args(std::shared_ptr<json::Object> args) { | |||
| m_args = std::move(args); | |||
| return *this; | |||
| } | |||
| ChromeTraceEvent& arg(std::string key, std::string value) { | |||
| if (!m_args) { | |||
| m_args = json::Object::make(); | |||
| } | |||
| (*m_args)[key] = json::String::make(value); | |||
| return *this; | |||
| } | |||
| ChromeTraceEvent& arg(std::string key, double value) { | |||
| if (!m_args) { | |||
| m_args = json::Object::make(); | |||
| } | |||
| (*m_args)[key] = json::Number::make(value); | |||
| return *this; | |||
| } | |||
| ChromeTraceEvent& arg(std::string key, std::shared_ptr<json::Value> value) { | |||
| if (!m_args) { | |||
| m_args = json::Object::make(); | |||
| } | |||
| (*m_args)[key] = value; | |||
| return *this; | |||
| } | |||
| std::shared_ptr<json::Object> to_json() const { | |||
| auto result = json::Object::make(); | |||
| auto prop_str = [&](auto key, auto value) { | |||
| if (value.empty()) { | |||
| return; | |||
| } | |||
| (*result)[key] = json::String::make(value); | |||
| }; | |||
| auto prop_num = [&](auto key, auto value) { | |||
| if (!value) { | |||
| return; | |||
| } | |||
| (*result)[key] = json::Number::make(value.value()); | |||
| }; | |||
| auto prop_char = [&](auto key, auto value) { | |||
| if (!value) { | |||
| return; | |||
| } | |||
| (*result)[key] = json::String::make(std::string{} + value.value()); | |||
| }; | |||
| prop_str("name", m_name); | |||
| prop_num("tid", m_tid); | |||
| prop_str("cat", m_cat); | |||
| prop_num("pid", m_pid); | |||
| prop_num("id", m_id); | |||
| prop_num("idx", m_idx); | |||
| prop_num("ts", m_ts); | |||
| prop_num("dur", m_dur); | |||
| prop_char("ph", m_ph); | |||
| prop_char("bp", m_bp); | |||
| if (m_args) { | |||
| (*result)["args"] = m_args; | |||
| } | |||
| return result; | |||
| } | |||
| private: | |||
| DeviceTimer m_device_timer; | |||
| RealTimer m_host_timer; | |||
| Profile m_profile; | |||
| TensorRecorder m_tensor_recorder; | |||
| std::stack<std::tuple<const OpDef*, ProfileEntry*, std::thread::id>> | |||
| m_entry_stack; | |||
| // Hold profile owned by this Profiler | |||
| std::unique_ptr<Profile> m_owned_profile; | |||
| // Hold hooks, cleared when stop | |||
| std::vector<std::any> m_hooker_list; | |||
| size_t m_entry_count = 0; | |||
| Spinlock m_lock; | |||
| std::unordered_map<Tensor*, std::weak_ptr<Tensor>> m_recorded_tensors; | |||
| std::string m_name; | |||
| std::string m_cat; | |||
| std::optional<uint64_t> m_tid; | |||
| std::optional<uint64_t> m_pid; | |||
| std::optional<uint64_t> m_id; | |||
| std::optional<uint64_t> m_idx; | |||
| std::optional<double> m_ts; | |||
| std::optional<double> m_dur; | |||
| std::optional<char> m_ph; | |||
| std::optional<char> m_bp; | |||
| std::shared_ptr<json::Object> m_args; | |||
| }; | |||
| class ChromeTraceEventList { | |||
| public: | |||
| ChromeTraceEvent& new_event() { | |||
| m_content.emplace_back(); | |||
| return m_content.back(); | |||
| } | |||
| std::shared_ptr<json::Array> to_json() { | |||
| auto result = json::Array::make(); | |||
| for (auto&& event: m_content) { | |||
| result->add(event.to_json()); | |||
| } | |||
| return result; | |||
| } | |||
| private: | |||
| std::vector<ChromeTraceEvent> m_content; | |||
| }; | |||
| } // namespace imperative | |||
| @@ -0,0 +1,125 @@ | |||
| /** | |||
| * \file imperative/src/include/megbrain/imperative/utils/to_string.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include <string> | |||
| #include <type_traits> | |||
| #include <memory> | |||
| #include <tuple> | |||
| #include "megbrain/utils/small_vector.h" | |||
| #include "megbrain/tensor.h" | |||
| namespace mgb::imperative { | |||
| template <typename T> | |||
| struct ToStringTrait; | |||
| template <typename T> | |||
| std::string to_string(const T& value) { | |||
| return ToStringTrait<T>{}(value); | |||
| } | |||
| template <typename T> | |||
| struct ToStringTrait{ | |||
| std::string operator()(const T& value) const { | |||
| return std::to_string(value); | |||
| } | |||
| }; | |||
| template <> | |||
| struct ToStringTrait<std::string>{ | |||
| std::string operator()(const std::string& value) const { | |||
| return value; | |||
| } | |||
| }; | |||
| template <typename T, unsigned N> | |||
| struct ToStringTrait<SmallVector<T, N>>{ | |||
| std::string operator()(const SmallVector<T, N>& sv) const { | |||
| if (sv.empty()) { | |||
| return "[]"; | |||
| } | |||
| std::string result = "["; | |||
| result += to_string(sv[0]); | |||
| for (size_t i = 1; i < sv.size(); ++i) { | |||
| result += ", "; | |||
| result += to_string(sv[i]); | |||
| } | |||
| return result + "]"; | |||
| } | |||
| }; | |||
| template <typename T> | |||
| struct ToStringTrait<std::shared_ptr<T>>{ | |||
| std::string operator()(const std::shared_ptr<T>& sp) const { | |||
| return to_string(sp.get()); | |||
| } | |||
| }; | |||
| template <typename TKey, typename TValue> | |||
| struct ToStringTrait<std::pair<TKey, TValue>>{ | |||
| std::string operator()(const std::pair<TKey, TValue>& pr) const { | |||
| return "(" + to_string(pr.first) + ", " + to_string(pr.second) + ")"; | |||
| } | |||
| }; | |||
| template <typename TItem, typename... TItems> | |||
| struct ToStringTrait<std::tuple<TItem, TItems...>>{ | |||
| std::string operator()(const std::tuple<TItem, TItems...>& tp) const { | |||
| auto folder = [&](auto... item){ return ( ...+ ("," + to_string(item))); }; | |||
| return "(" + std::apply(folder, tp) + ")"; | |||
| } | |||
| }; | |||
| template <typename T> | |||
| struct ToStringTrait<T*>{ | |||
| std::string operator()(T* p) const { | |||
| return ssprintf("%p", p); | |||
| } | |||
| }; | |||
| template <> | |||
| struct ToStringTrait<TensorShape>{ | |||
| std::string operator()(TensorShape shape) const { | |||
| if (shape.ndim > TensorShape::MAX_NDIM) { | |||
| printf("ndim: %d\n", (int)shape.ndim); | |||
| return "[]"; | |||
| } | |||
| mgb_assert(shape.ndim <= TensorShape::MAX_NDIM); | |||
| if (shape.ndim == 0) { | |||
| return "[ ]"; | |||
| } | |||
| std::string result = "[ " + std::to_string(shape[0]); | |||
| for (size_t i = 1; i < shape.ndim; i++) { | |||
| result += ", "; | |||
| result += std::to_string(shape[i]); | |||
| } | |||
| return result + " ]"; | |||
| } | |||
| }; | |||
| template <> | |||
| struct ToStringTrait<DType>{ | |||
| std::string operator()(DType dtype) const { | |||
| return dtype.name(); | |||
| } | |||
| }; | |||
| template <> | |||
| struct ToStringTrait<CompNode>{ | |||
| std::string operator()(CompNode device) const { | |||
| return device.to_string(); | |||
| } | |||
| }; | |||
| } | |||
| @@ -222,10 +222,25 @@ static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) { | |||
| os << mlir::tblgen::tgfmt(hashable->getCmpFunctionTemplate(), &ctx, "a_", "b_"); | |||
| os << "}\n"; | |||
| // generate props() | |||
| os << formatv( | |||
| "std::vector<std::pair<const char*, std::string>> {0}(const OpDef& def_) {{\n", | |||
| formatMethImpl("props") | |||
| ); | |||
| os << formatv( | |||
| " auto&& op_ = def_.cast_final_safe<{0}>();\n" | |||
| " static_cast<void>(op_);\n", | |||
| className | |||
| ); | |||
| ctx.withSelf("op_"); | |||
| os << mlir::tblgen::tgfmt(hashable->getPropsFunctionTemplate(), &ctx); | |||
| os << "}\n"; | |||
| os << "} // anonymous namespace\n"; | |||
| methods.push_back("hash"); | |||
| methods.push_back("is_same_st"); | |||
| methods.push_back("props"); | |||
| } | |||
| if (!methods.empty()) { | |||
| os << formatv( | |||
| @@ -423,7 +438,7 @@ EnumWrapper<{0}::{1}>::type2str = {{ | |||
| std::vector<std::string> getsetters; | |||
| for (auto &&i : op.getMgbAttributes()) { | |||
| getsetters.push_back(formatv( | |||
| "{{\"{1}\", py_get_generic({0}, {1}), py_set_generic({0}, {1}), \"{1}\", NULL},", | |||
| "{{const_cast<char*>(\"{1}\"), py_get_generic({0}, {1}), py_set_generic({0}, {1}), const_cast<char*>(\"{1}\"), NULL},", | |||
| className, i.name)); | |||
| } | |||
| @@ -66,7 +66,7 @@ struct MgbEnumAttrMixin : public MgbAttrWrapperBase { | |||
| } | |||
| llvm::StringRef getParentNamespace() const { | |||
| return getBaseRecord()->getValueAsString("parentNamespce"); | |||
| return getBaseRecord()->getValueAsString("parentNamespace"); | |||
| } | |||
| llvm::StringRef getEnumName() const { | |||
| return getBaseRecord()->getValueAsString("enumName"); | |||
| @@ -87,6 +87,9 @@ struct MgbHashableAttrMixin : public MgbAttrWrapperBase { | |||
| llvm::StringRef getCmpFunctionTemplate() const { | |||
| return getBaseRecord()->getValueAsString("cmpFunction"); | |||
| } | |||
| llvm::StringRef getReprFunctionTemplate() const { | |||
| return getBaseRecord()->getValueAsString("reprFunction"); | |||
| } | |||
| }; | |||
| struct MgbAliasAttrMixin : public MgbAttrWrapperBase { | |||
| @@ -205,6 +208,39 @@ private: | |||
| body += " return true;\n"; | |||
| return body; | |||
| } | |||
| std::string getDefaultPropsFunction() const { | |||
| std::string body = " std::vector<std::pair<const char*, std::string>> props_;\n"; | |||
| if (!getMgbAttributes().empty()) { | |||
| mlir::tblgen::FmtContext ctx; | |||
| for (auto&& it : getMgbAttributes()) { | |||
| if (auto* enumAttr = llvm::dyn_cast<MgbEnumAttrMixin>(&it.attr)) { | |||
| body += formatv(" switch ({0}){{\n", "$_self." + it.name); | |||
| for (auto&& enumMember: enumAttr->getEnumMembers()) { | |||
| body += formatv( | |||
| " case {0}::{1}::{2}:\n", | |||
| getCppClassName(), enumAttr->getEnumName(), enumMember | |||
| ); | |||
| body += formatv( | |||
| " props_.emplace_back(\"{0}\", \"{1}\");\n", | |||
| it.name, enumMember | |||
| ); | |||
| body += " break;\n"; | |||
| } | |||
| body += " default: break;\n"; | |||
| body += " }\n"; | |||
| } else { | |||
| auto&& attr = llvm::cast<MgbHashableAttrMixin>(it.attr); | |||
| body += formatv( | |||
| " props_.emplace_back(\"{0}\", {1});\n", it.name, | |||
| mlir::tblgen::tgfmt(attr.getReprFunctionTemplate(), | |||
| &ctx, "$_self." + it.name) | |||
| ); | |||
| } | |||
| } | |||
| } | |||
| body += " return props_;\n"; | |||
| return body; | |||
| } | |||
| public: | |||
| static bool classof(const Operator* op) { | |||
| return op->getDef().isSubClassOf("MgbHashableOpMixin"); | |||
| @@ -222,7 +258,13 @@ public: | |||
| } | |||
| return getDefaultCmpFunction(); | |||
| } | |||
| std::string getPropsFunctionTemplate() const { | |||
| if (auto f = getDef().getValueAsOptionalString("propsFunction")) { | |||
| return f.getValue().str(); | |||
| } | |||
| return getDefaultPropsFunction(); | |||
| } | |||
| }; | |||
| } // namespace tblgen | |||
| } // namespace mlir | |||
| } // namespace mlir | |||
| @@ -30,6 +30,7 @@ class MgbHashableAttrMixin { | |||
| string hashFunction = "mgb::hash($0)"; | |||
| // return 0 for eq, else for ne | |||
| string cmpFunction = "$0 != $1"; | |||
| string reprFunction = "std::to_string($0)"; | |||
| } | |||
| class MgbEnumAttrMixin<string namespace, string name, list<string> members> { | |||
| @@ -98,6 +99,7 @@ def MgbStringAttr : HashableAttr<"std::string"> { | |||
| let storageType = "::mlir::StringAttr"; | |||
| let convertFromStorage = "$_self.getValue().str()"; | |||
| let constBuilderCall = "$_builder.getStringAttr($0)"; // llvm::StringRef implicit ctor | |||
| string reprFunction = "$0"; | |||
| } | |||
| class MgbArrayAttr<MgbAttrWrapper elem>: | |||
| @@ -123,6 +125,7 @@ class MgbArrayAttr<MgbAttrWrapper elem>: | |||
| " });\n" | |||
| " return $_builder.getArrayAttr(ret" # recursionDepth # ");" | |||
| "}()"; | |||
| let reprFunction = "\"{std::vector}\""; | |||
| } | |||
| defvar EmptyStrList = !listsplat("", 0); | |||
| @@ -168,6 +171,7 @@ class MgbEnumAttr<string namespace, string enumName, list<string> members>: | |||
| let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())"; | |||
| let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast<int32_t>($0))"; | |||
| let hashFunction = "mgb::enumhash()($0)"; | |||
| string reprFunction = "std::to_string((int)$0)"; | |||
| } | |||
| class MgbEnumAliasAttr<string namespace, string enumName, MgbEnumAttr base>: | |||
| @@ -179,12 +183,14 @@ def MgbDTypeAttr: HashableAttr<"::megdnn::DType"> { | |||
| let convertFromStorage = underlyingType # "::from_enum(static_cast<::megdnn::DTypeEnum>($_self.getInt()))"; | |||
| let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast<int32_t>($0.enumv()))"; | |||
| let hashFunction = "mgb::hash($0.handle())"; | |||
| let reprFunction = "$0.name()"; | |||
| } | |||
| def MgbCompNodeAttr: HashableAttr<"::mgb::CompNode"> { | |||
| let storageType = "::mlir::StringAttr"; | |||
| let convertFromStorage = underlyingType # "::load($_self.getValue().str())"; | |||
| let constBuilderCall = "$_builder.getStringAttr($0.to_string_logical())"; | |||
| string reprFunction = "$0.to_string()"; | |||
| } | |||
| def MgbTensorShapeAttr: HashableAttr<"::megdnn::TensorShape"> { | |||
| @@ -209,6 +215,7 @@ def MgbTensorShapeAttr: HashableAttr<"::megdnn::TensorShape"> { | |||
| " }\n" | |||
| " return $_builder.getArrayAttr(ret);" | |||
| "}()"; | |||
| let reprFunction = "$0.to_string()"; | |||
| } | |||
| class MgbDefaultValuedAttr<MgbAttrWrapper attr, string value>: | |||