GitOrigin-RevId: dd82b53faf
tags/v1.5.0
| @@ -1,6 +1,4 @@ | |||||
| import weakref | import weakref | ||||
| from collections import defaultdict | |||||
| from contextlib import contextmanager | |||||
| from typing import Callable, Iterable | from typing import Callable, Iterable | ||||
| from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option | from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option | ||||
| @@ -1125,10 +1125,6 @@ def apply_compiled_mode(op: OpDef, *args: RawTensor): | |||||
| def apply_const_compiled_mode(value, dtype, device, is_const, no_cache, name): | def apply_const_compiled_mode(value, dtype, device, is_const, no_cache, name): | ||||
| if skip_tracing: | if skip_tracing: | ||||
| args = [ | |||||
| RawTensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x | |||||
| for x in args | |||||
| ] | |||||
| unset_tracing() | unset_tracing() | ||||
| ret = RawTensor(value, dtype, device, False, name) | ret = RawTensor(value, dtype, device, False, name) | ||||
| set_tracing() | set_tracing() | ||||
| @@ -50,29 +50,20 @@ REGISTE_APPLY_FUNC(cpp_apply_backward_varnode) | |||||
| #undef REGISTE_APPLY_FUNC | #undef REGISTE_APPLY_FUNC | ||||
| bool is_tracing = false; | |||||
| #define SET_UNSET_PROP(mode) \ | |||||
| void set_##mode() { \ | |||||
| is_##mode = true; \ | |||||
| } \ | |||||
| void unset_##mode() { \ | |||||
| is_##mode = false; \ | |||||
| } \ | |||||
| SET_UNSET_PROP(tracing) | |||||
| Tensor::flags_t ApplyContext::global_disable = 0; | |||||
| Tensor::flags_t ApplyContext::global_enable = 0; | |||||
| #undef SET_UNSET_PROP | |||||
| void set_tracing() { ApplyContext::global_enable |= Tensor::Flags::TRACE; } | |||||
| void unset_tracing() { ApplyContext::global_enable &= ~Tensor::Flags::TRACE; } | |||||
| bool skip_tracing = false; | bool skip_tracing = false; | ||||
| Tensor::flags_t ApplyContext::global_disable = 0; | |||||
| apply_result_t apply(ApplyContext& ctx) { | apply_result_t apply(ApplyContext& ctx) { | ||||
| // emulating scalar should be put to specific op's apply, e.g., | // emulating scalar should be put to specific op's apply, e.g., | ||||
| // elementwise, reduce, typecvt. Currently it's still handled at python | // elementwise, reduce, typecvt. Currently it's still handled at python | ||||
| // side. It could be move to C++ side if it has an impact on performance | // side. It could be move to C++ side if it has an impact on performance | ||||
| auto flags = ctx.flags & ~ApplyContext::global_disable; | auto flags = ctx.flags & ~ApplyContext::global_disable; | ||||
| flags = flags | ApplyContext::global_enable; | |||||
| if (flags & Tensor::Flags::SCALAR) { | if (flags & Tensor::Flags::SCALAR) { | ||||
| // TODO: emulate scalar | // TODO: emulate scalar | ||||
| @@ -190,10 +181,6 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje | |||||
| } | } | ||||
| } | } | ||||
| if (is_tracing) { | |||||
| ctx.flags |= Tensor::Flags::TRACE; | |||||
| } | |||||
| auto outputs = apply(ctx); | auto outputs = apply(ctx); | ||||
| size_t nout = outputs.size(); | size_t nout = outputs.size(); | ||||
| auto ret = py::tuple(nout); | auto ret = py::tuple(nout); | ||||
| @@ -255,7 +242,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { | |||||
| if (tup[nargs - 1].ptr() != Py_None) name = tup[nargs - 1].cast<std::string>(); | if (tup[nargs - 1].ptr() != Py_None) name = tup[nargs - 1].cast<std::string>(); | ||||
| // const op | // const op | ||||
| if (is_const && is_tracing) { | |||||
| if (is_const && (ApplyContext::global_enable == Tensor::Flags::TRACE)) { | |||||
| auto py_ret = PyObject_Call(cpp_apply_const_with_tracing, tup.ptr(), nullptr); | auto py_ret = PyObject_Call(cpp_apply_const_with_tracing, tup.ptr(), nullptr); | ||||
| if (!py_ret) throw py::error_already_set(); | if (!py_ret) throw py::error_already_set(); | ||||
| auto py_list = py::reinterpret_steal<py::list>(py_ret); | auto py_list = py::reinterpret_steal<py::list>(py_ret); | ||||
| @@ -193,8 +193,9 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje | |||||
| struct ApplyContext { | struct ApplyContext { | ||||
| static Tensor::flags_t global_disable; | static Tensor::flags_t global_disable; | ||||
| static Tensor::flags_t global_enable; | |||||
| Tensor::flags_t flags; | |||||
| Tensor::flags_t flags = 0; | |||||
| std::shared_ptr<OpDef> op; | std::shared_ptr<OpDef> op; | ||||
| Tensor*const* args; | Tensor*const* args; | ||||
| size_t nargs; | size_t nargs; | ||||
| @@ -236,14 +237,11 @@ decltype(auto) resolve_arrow(T&& p) { | |||||
| template <typename... Args> | template <typename... Args> | ||||
| constexpr bool is_all_tensor_ptr = (... && std::is_same_v<decltype(resolve_arrow(std::declval<Args>())), Tensor*>); | constexpr bool is_all_tensor_ptr = (... && std::is_same_v<decltype(resolve_arrow(std::declval<Args>())), Tensor*>); | ||||
| extern bool is_tracing; // FIXME: should use ApplyContext::global_enable | |||||
| template <typename... Args, std::enable_if_t<is_all_tensor_ptr<Args...>, int> = 0> | template <typename... Args, std::enable_if_t<is_all_tensor_ptr<Args...>, int> = 0> | ||||
| apply_result_t apply(std::shared_ptr<OpDef> op, Args&&... args) { | apply_result_t apply(std::shared_ptr<OpDef> op, Args&&... args) { | ||||
| ApplyContext ctx; | ApplyContext ctx; | ||||
| Tensor* arg_arr[] = {resolve_arrow(args)...}; | Tensor* arg_arr[] = {resolve_arrow(args)...}; | ||||
| ctx.flags = (0 | ... | args->m_flags); | ctx.flags = (0 | ... | args->m_flags); | ||||
| ctx.flags |= is_tracing ? Tensor::Flags::TRACE : 0; | |||||
| ctx.args = arg_arr; | ctx.args = arg_arr; | ||||
| ctx.nargs = sizeof...(args); | ctx.nargs = sizeof...(args); | ||||
| ctx.op = std::move(op); | ctx.op = std::move(op); | ||||
| @@ -256,7 +254,6 @@ auto apply(std::shared_ptr<OpDef> op, T&& tensors) | |||||
| apply_result_t> { | apply_result_t> { | ||||
| ApplyContext ctx; | ApplyContext ctx; | ||||
| ctx.op = std::move(op); | ctx.op = std::move(op); | ||||
| ctx.flags = is_tracing ? Tensor::Flags::TRACE : 0; | |||||
| ctx.nargs = tensors.size(); | ctx.nargs = tensors.size(); | ||||
| Tensor* args[ctx.nargs]; | Tensor* args[ctx.nargs]; | ||||
| ctx.args = args; | ctx.args = args; | ||||
| @@ -270,7 +267,6 @@ auto apply(std::shared_ptr<OpDef> op, T&& tensors) | |||||
| inline auto apply(std::shared_ptr<OpDef> op, Tensor*const* args, size_t nargs) { | inline auto apply(std::shared_ptr<OpDef> op, Tensor*const* args, size_t nargs) { | ||||
| ApplyContext ctx; | ApplyContext ctx; | ||||
| ctx.op = std::move(op); | ctx.op = std::move(op); | ||||
| ctx.flags = is_tracing ? Tensor::Flags::TRACE : 0; | |||||
| ctx.nargs = nargs; | ctx.nargs = nargs; | ||||
| ctx.args = args; | ctx.args = args; | ||||
| for (size_t i = 0; i < nargs; ++i) { | for (size_t i = 0; i < nargs; ++i) { | ||||
| @@ -28,12 +28,12 @@ apply_result_t apply_trace(ApplyContext& ctx) { | |||||
| for (size_t i = 0; i < ctx.nargs; i++) { | for (size_t i = 0; i < ctx.nargs; i++) { | ||||
| args[i + 1] = py::cast(ctx.args[i]->m_var); | args[i + 1] = py::cast(ctx.args[i]->m_var); | ||||
| } | } | ||||
| py::object ret = py::reinterpret_steal<py::object>( | |||||
| py::object pyout = py::reinterpret_steal<py::object>( | |||||
| PyObject_Call(cpp_apply_backward_varnode, args.ptr(), nullptr)); | PyObject_Call(cpp_apply_backward_varnode, args.ptr(), nullptr)); | ||||
| if (!ret) throw py::error_already_set(); | |||||
| if (!pyout) throw py::error_already_set(); | |||||
| // assumption: python function always returns PyList | // assumption: python function always returns PyList | ||||
| auto tup = py::reinterpret_borrow<py::list>(ret); | |||||
| auto tup = py::reinterpret_borrow<py::list>(pyout); | |||||
| for (size_t i = 0; i < tup.size(); i++) { | for (size_t i = 0; i < tup.size(); i++) { | ||||
| auto pitem = tup[i].cast<cg::VarNode*>(); | auto pitem = tup[i].cast<cg::VarNode*>(); | ||||
| outputs.emplace_back(std::make_shared<Tensor>(pitem)); | outputs.emplace_back(std::make_shared<Tensor>(pitem)); | ||||
| @@ -48,6 +48,7 @@ apply_result_t apply_trace(ApplyContext& ctx) { | |||||
| } | } | ||||
| auto pyout = PyObject_Call(cpp_apply_with_tracing, args.ptr(), nullptr); | auto pyout = PyObject_Call(cpp_apply_with_tracing, args.ptr(), nullptr); | ||||
| if (!pyout) throw py::error_already_set(); | if (!pyout) throw py::error_already_set(); | ||||
| // assumption: python function always returns PyList | // assumption: python function always returns PyList | ||||
| auto tup = py::reinterpret_steal<py::list>(pyout); | auto tup = py::reinterpret_steal<py::list>(pyout); | ||||
| for (size_t i = 0; i < tup.size(); i++) { | for (size_t i = 0; i < tup.size(); i++) { | ||||