GitOrigin-RevId: bb3a59380e
tags/v1.0.0-rc1
| @@ -12,6 +12,7 @@ import weakref | |||
| from concurrent.futures import Future, ThreadPoolExecutor | |||
| from .. import _imperative_rt | |||
| from .._imperative_rt.ops import BackwardGraph | |||
| from .._wrap import device as as_device | |||
| from ..ops.builtin import OpDef | |||
| from .core import OpBase, TensorBase, apply | |||
| @@ -131,6 +132,13 @@ def _(op: OpDef, *args: VarNode): | |||
| return _wrap(outputs) | |||
| @apply.register() | |||
| def _(op: BackwardGraph, *args: VarNode): | |||
| assert args | |||
| graph = args[0].graph | |||
| return op.interpret(lambda op, args: apply(op, *args), graph.make_const, args) | |||
| def input_callback(callback, *args, device=None, dtype=None, shape=None, graph=None): | |||
| outputs = _imperative_rt.input_callback( | |||
| callback, as_device(device).to_c(), dtype, shape, _unwrap(args), graph=graph | |||
| @@ -40,6 +40,18 @@ void init_ops(py::module m) { | |||
| attr.param.insert(attr.param.end(), s.begin(), s.end()); | |||
| }); | |||
| py::class_<BackwardGraph, std::shared_ptr<BackwardGraph>, OpDef>(m, "BackwardGraph") | |||
| .def("interpret", [](BackwardGraph& self, py::object pyf, py::object pyc, | |||
| const mgb::SmallVector<py::object>& inputs) { | |||
| auto f = [pyf](OpDef& op, const mgb::SmallVector<py::object>& inputs) { | |||
| return py::cast<mgb::SmallVector<py::object>>(pyf(op.copy(), inputs)); | |||
| }; | |||
| auto c = [pyc](const TensorPtr& tensor) { | |||
| return pyc(tensor->dev_tensor()); | |||
| }; | |||
| return self.graph().interpret<py::object>(f, c, inputs); | |||
| }); | |||
| py::class_<GetVarShape, std::shared_ptr<GetVarShape>, OpDef>(m, "GetVarShape") | |||
| .def(py::init()); | |||
| @@ -98,7 +110,6 @@ void init_ops(py::module m) { | |||
| .def(py::init<>()) | |||
| .def_readwrite("offsets", &ParamPackConcat::offsets); | |||
| py::class_<BackwardGraph, std::shared_ptr<BackwardGraph>, OpDef>(m, "BackwardGraph"); | |||
| py::class_<CondTake, std::shared_ptr<CondTake>, OpDef>(m, "CondTake") | |||
| .def(py::init<>()); | |||
| @@ -18,34 +18,10 @@ namespace imperative { | |||
| SmallVector<TensorPtr> | |||
| BackwardGraph::InternalGraph::apply( | |||
| const SmallVector<TensorPtr>& inputs) const { | |||
| ThinHashMap<size_t, TensorPtr> node2tensor; | |||
| auto&& input_nodes = this->inputs; | |||
| mgb_assert(inputs.size() == input_nodes.size()); | |||
| for (size_t i = 0; i < inputs.size(); ++ i) { | |||
| node2tensor[input_nodes[i]] = inputs[i]; | |||
| } | |||
| for (auto &&i : constants) { | |||
| node2tensor[i.first] = i.second; | |||
| } | |||
| for (size_t i = 0; i < exprs.size(); ++ i) { | |||
| auto&& expr = exprs[i]; | |||
| SmallVector<TensorPtr> inputs; | |||
| for (auto &&in : std::get<1>(expr)) { | |||
| inputs.push_back(node2tensor.at(in)); | |||
| } | |||
| auto outputs = OpDef::apply_on_physical_tensor( | |||
| *std::get<0>(expr), inputs); | |||
| auto output_nodes = std::get<2>(expr); | |||
| mgb_assert(outputs.size() == output_nodes.size()); | |||
| for (size_t i = 0; i < outputs.size(); ++ i) { | |||
| node2tensor[output_nodes[i]] = outputs[i]; | |||
| } | |||
| } | |||
| SmallVector<TensorPtr> ret; | |||
| for (auto &&i : outputs) { | |||
| ret.push_back(node2tensor.at(i)); | |||
| } | |||
| return ret; | |||
| return interpret<TensorPtr>( | |||
| &OpDef::apply_on_physical_tensor, | |||
| [](const TensorPtr& x) {return x;}, | |||
| inputs); | |||
| } | |||
| SmallVector<LogicalTensorDesc> | |||
| @@ -40,6 +40,37 @@ public: | |||
| SmallVector<LogicalTensorDesc> | |||
| infer_attrs(const SmallVector<LogicalTensorDesc>& inputs) const; | |||
| template <typename T, typename F, typename C> | |||
| SmallVector<T> interpret(F&& f, C&& c, const SmallVector<T>& inputs) const { | |||
| ThinHashMap<size_t, T> node2tensor; | |||
| auto&& input_nodes = this->inputs; | |||
| mgb_assert(inputs.size() == input_nodes.size()); | |||
| for (size_t i = 0; i < inputs.size(); ++ i) { | |||
| node2tensor[input_nodes[i]] = inputs[i]; | |||
| } | |||
| for (auto &&i : constants) { | |||
| node2tensor[i.first] = c(i.second); | |||
| } | |||
| for (size_t i = 0; i < exprs.size(); ++ i) { | |||
| auto&& expr = exprs[i]; | |||
| SmallVector<T> inputs; | |||
| for (auto &&in : std::get<1>(expr)) { | |||
| inputs.push_back(node2tensor.at(in)); | |||
| } | |||
| auto&& outputs = f(*std::get<0>(expr), std::move(inputs)); | |||
| auto&& output_nodes = std::get<2>(expr); | |||
| mgb_assert(outputs.size() == output_nodes.size()); | |||
| for (size_t i = 0; i < outputs.size(); ++ i) { | |||
| node2tensor[output_nodes[i]] = std::move(outputs[i]); | |||
| } | |||
| } | |||
| SmallVector<T> ret; | |||
| for (auto &&i : outputs) { | |||
| ret.push_back(node2tensor.at(i)); | |||
| } | |||
| return ret; | |||
| } | |||
| }; | |||
| const InternalGraph& graph() const { | |||