GitOrigin-RevId: bbe3ae3fa3
tags/v1.2.0
| @@ -569,3 +569,9 @@ class AttrOutputNode(OpNode): | |||||
| def reset(self): | def reset(self): | ||||
| self._rendezvous.reset() | self._rendezvous.reset() | ||||
| class VirtualDepNode(OpNode): | |||||
| def __init__(self, vars, device=""): | |||||
| out = _imperative_rt.virtual_dep(_unwrap(vars), device) | |||||
| super().__init__(out) | |||||
| @@ -25,7 +25,6 @@ from ..core._imperative_rt.ops import ( | |||||
| RemoteRecv, | RemoteRecv, | ||||
| RemoteSend, | RemoteSend, | ||||
| UniformRNG, | UniformRNG, | ||||
| VirtualDep, | |||||
| ) | ) | ||||
| from ..core._trace_option import set_symbolic_shape | from ..core._trace_option import set_symbolic_shape | ||||
| from ..core._wrap import device as as_device | from ..core._wrap import device as as_device | ||||
| @@ -548,9 +547,10 @@ class trace: | |||||
| need_reset_nodes.append(opnode) | need_reset_nodes.append(opnode) | ||||
| info.varnode, *in_out_links = opnode.outputs | info.varnode, *in_out_links = opnode.outputs | ||||
| if require_links and i == 0 and len(io_links) > 0: | if require_links and i == 0 and len(io_links) > 0: | ||||
| info.varnode = apply( | |||||
| VirtualDep(str(io_links[0].device)), info.varnode, *io_links | |||||
| )[0] | |||||
| opnode = G.VirtualDepNode( | |||||
| [info.varnode, *io_links], str(io_links[0].device) | |||||
| ) | |||||
| info.varnode = opnode.outputs[0] | |||||
| io_links = (info.varnode,) | io_links = (info.varnode,) | ||||
| ivars.append(info.varnode) | ivars.append(info.varnode) | ||||
| @@ -1112,11 +1112,11 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor): | |||||
| if require_links and active_trace._lazy_eval_links: | if require_links and active_trace._lazy_eval_links: | ||||
| assert len(ivars) > 0, "op should has at least one input" | assert len(ivars) > 0, "op should has at least one input" | ||||
| ivars[0] = apply( | |||||
| VirtualDep(str(active_trace._lazy_eval_links[0].device)), | |||||
| ivars[0], | |||||
| *active_trace._lazy_eval_links, | |||||
| )[0] | |||||
| opnode = G.VirtualDepNode( | |||||
| [ivars[0], *active_trace._lazy_eval_links], | |||||
| str(active_trace._lazy_eval_links[0].device), | |||||
| ) | |||||
| ivars[0] = opnode.outputs[0] | |||||
| active_trace._lazy_eval_links = (ivars[0],) | active_trace._lazy_eval_links = (ivars[0],) | ||||
| ovars = apply(op, *ivars) | ovars = apply(op, *ivars) | ||||
| @@ -15,6 +15,7 @@ | |||||
| #include "megbrain/serialization/serializer.h" | #include "megbrain/serialization/serializer.h" | ||||
| #include "megbrain/imperative/opr_utility.h" | #include "megbrain/imperative/opr_utility.h" | ||||
| #include "megbrain/opr/io.h" | #include "megbrain/opr/io.h" | ||||
| #include "megbrain/opr/utility.h" | |||||
| #include "megbrain/opr/basic_arith.h" | #include "megbrain/opr/basic_arith.h" | ||||
| #include "megbrain/imperative.h" | #include "megbrain/imperative.h" | ||||
| #include "./helper.h" | #include "./helper.h" | ||||
| @@ -562,4 +563,16 @@ void init_graph_rt(py::module m) { | |||||
| }; | }; | ||||
| return output_callback(std::move(f), std::move(inputs), p, true); | return output_callback(std::move(f), std::move(inputs), p, true); | ||||
| }); | }); | ||||
| m.def("virtual_dep", [](std::vector<cg::VarNode*> inputs, std::string device) { | |||||
| auto&& graph = inputs[0]->owner_graph(); | |||||
| VarNodeArray inps(inputs.begin(), inputs.end()); | |||||
| cg::OperatorNodeConfig config; | |||||
| if (device.length() > 0) { | |||||
| config.comp_node(CompNode::load(device)); | |||||
| } | |||||
| cg::OperatorNodeBase* opr = graph->insert_opr( | |||||
| std::make_unique<mgb::opr::VirtualDep>(inps, config)); | |||||
| return opr; | |||||
| }); | |||||
| } | } | ||||
| @@ -10,12 +10,10 @@ | |||||
| */ | */ | ||||
| #include "./ops.h" | #include "./ops.h" | ||||
| #include <string> | |||||
| #include "megbrain/imperative.h" | #include "megbrain/imperative.h" | ||||
| #include "megbrain/imperative/ops/backward_graph.h" | #include "megbrain/imperative/ops/backward_graph.h" | ||||
| #include "megbrain/imperative/ops/opr_attr.h" | #include "megbrain/imperative/ops/opr_attr.h" | ||||
| #include "megbrain/imperative/ops/utility.h" | |||||
| #include "megbrain/imperative/ops/autogen.h" | #include "megbrain/imperative/ops/autogen.h" | ||||
| namespace py = pybind11; | namespace py = pybind11; | ||||
| @@ -45,9 +43,5 @@ void init_ops(py::module m) { | |||||
| return self.graph().interpret<py::object>(f, c, inputs); | return self.graph().interpret<py::object>(f, c, inputs); | ||||
| }); | }); | ||||
| py::class_<VirtualDep, std::shared_ptr<VirtualDep>, OpDef>(m, "VirtualDep") | |||||
| .def(py::init<>()) | |||||
| .def(py::init<std::string>()); | |||||
| #include "opdef.py.inl" | #include "opdef.py.inl" | ||||
| } | } | ||||
| @@ -1,44 +0,0 @@ | |||||
| /** | |||||
| * \file imperative/src/impl/ops/utility.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include "megbrain/imperative/ops/utility.h" | |||||
| #include <string> | |||||
| #include "megbrain/comp_node.h" | |||||
| #include "megbrain/imperative/ops/opr_attr.h" | |||||
| #include "megbrain/opr/utility.h" | |||||
| #include "../op_trait.h" | |||||
| namespace mgb::imperative { | |||||
| namespace { | |||||
| cg::OperatorNodeBase* virtual_dep_apply_on_var_node( | |||||
| const OpDef& def, const VarNodeArray& inputs) { | |||||
| auto&& graph = inputs[0]->owner_graph(); | |||||
| auto&& op = def.cast_final_safe<VirtualDep>(); | |||||
| VarNodeArray inps(inputs.begin(), inputs.end()); | |||||
| cg::OperatorNodeConfig config; | |||||
| if (op.device.length() > 0) { | |||||
| config.comp_node(CompNode::load(op.device)); | |||||
| } | |||||
| cg::OperatorNodeBase* opr = | |||||
| graph->insert_opr(std::make_unique<mgb::opr::VirtualDep>( | |||||
| inps, config)); | |||||
| return opr; | |||||
| } | |||||
| OP_TRAIT_REG(VirtualDep, VirtualDep, mgb::opr::VirtualDep) | |||||
| .apply_on_var_node(virtual_dep_apply_on_var_node) | |||||
| .fallback(); | |||||
| } // namespace | |||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(VirtualDep); | |||||
| } // namespace mgb::imperative | |||||
| @@ -1,40 +0,0 @@ | |||||
| /** | |||||
| * \file imperative/src/include/megbrain/imperative/ops/utility.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include <string> | |||||
| #include "megbrain/graph/operator_node.h" | |||||
| #include "megbrain/imperative/op_def.h" | |||||
| #include "megbrain/utils/hash.h" | |||||
| namespace mgb::imperative { | |||||
| class VirtualDep : public OpDefImplBase<VirtualDep> { | |||||
| MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||||
| public: | |||||
| VirtualDep() = default; | |||||
| VirtualDep(std::string dev) : device(dev) {} | |||||
| std::string device; | |||||
| size_t hash() const override { | |||||
| return reinterpret_cast<size_t>(dyn_typeinfo()); | |||||
| } | |||||
| bool is_same_st(const Hashable& rhs) const override { | |||||
| return true; | |||||
| } | |||||
| }; | |||||
| } // namespace mgb::imperative | |||||