GitOrigin-RevId: 5d47ed263f
tags/v1.0.0
| @@ -483,13 +483,13 @@ void init_graph_rt(py::module m) { | |||||
| }, | }, | ||||
| py::arg(), py::arg(), py::arg(), py::arg() = py::none(), py::arg() = py::tuple(), py::arg("graph") = py::none()); | py::arg(), py::arg(), py::arg(), py::arg() = py::none(), py::arg() = py::tuple(), py::arg("graph") = py::none()); | ||||
| auto output_callback = [](auto callback, const std::vector<cg::VarNode*>& inputs, bool borrow = false) { | |||||
| auto output_callback = [](auto callback, const std::vector<cg::VarNode*>& inputs, bool borrow = false, bool prefer_host_value = false) { | |||||
| SymbolVarArray sinputs; | SymbolVarArray sinputs; | ||||
| for (auto i : inputs) { | for (auto i : inputs) { | ||||
| sinputs.emplace_back(i); | sinputs.emplace_back(i); | ||||
| } | } | ||||
| static_assert(!std::is_reference<decltype(callback)>::value); | static_assert(!std::is_reference<decltype(callback)>::value); | ||||
| opr::OutputCallback::Param param{std::move(callback), borrow}; | |||||
| opr::OutputCallback::Param param{std::move(callback), borrow, prefer_host_value}; | |||||
| auto output = opr::OutputCallback::make(std::move(param), sinputs); | auto output = opr::OutputCallback::make(std::move(param), sinputs); | ||||
| return output.node(); | return output.node(); | ||||
| }; | }; | ||||
| @@ -519,7 +519,7 @@ void init_graph_rt(py::module m) { | |||||
| hv_with_event.second->record(); | hv_with_event.second->record(); | ||||
| p->set(std::move(hv_with_event)); | p->set(std::move(hv_with_event)); | ||||
| }; | }; | ||||
| return output_callback(std::move(f), std::move(inputs), true); | |||||
| return output_callback(std::move(f), std::move(inputs), true, true); | |||||
| }); | }); | ||||
| m.def("attr_output_callback", [output_callback](std::shared_ptr<Rendezvous<TensorAttr>> p, std::vector<cg::VarNode*> inputs) { | m.def("attr_output_callback", [output_callback](std::shared_ptr<Rendezvous<TensorAttr>> p, std::vector<cg::VarNode*> inputs) { | ||||
| @@ -144,13 +144,24 @@ cg::OperatorNodeBase::NodeProp* OutputCallback::do_make_node_prop() const { | |||||
| prop->add_flag(NodeProp::Flag::NO_AUTOMATIC_DUP); | prop->add_flag(NodeProp::Flag::NO_AUTOMATIC_DUP); | ||||
| SmallVector<NodeProp::DepType> dep_types(input().size(), | SmallVector<NodeProp::DepType> dep_types(input().size(), | ||||
| NodeProp::DepType::DEV_COMP_ORDER); | NodeProp::DepType::DEV_COMP_ORDER); | ||||
| dep_types[0] = NodeProp::DepType::DEV_VALUE; | |||||
| using IT = cg::static_infer::InferType; | |||||
| auto host_value_avail = [&]() -> bool { | |||||
| auto inp = input(0); | |||||
| auto it = owner_graph()->static_infer_manager().get_infer_type(inp).value; | |||||
| return it & (IT::CONST | IT::RT_STATIC | IT::MISSING_INP); | |||||
| }; | |||||
| m_use_host_value = m_param.prefer_host_value && host_value_avail(); | |||||
| dep_types[0] = m_use_host_value ? NodeProp::DepType::HOST_VALUE : NodeProp::DepType::DEV_VALUE; | |||||
| prop->reset_dep_type(input(), dep_types); | prop->reset_dep_type(input(), dep_types); | ||||
| return prop; | return prop; | ||||
| } | } | ||||
| void OutputCallback::scn_do_execute() { | void OutputCallback::scn_do_execute() { | ||||
| m_param.callback(input(0)->dev_tensor()); | |||||
| if (m_use_host_value) { | |||||
| m_param.callback(owner_graph()->static_infer_manager().infer_value(input(0))); | |||||
| } else { | |||||
| m_param.callback(input(0)->dev_tensor()); | |||||
| } | |||||
| } | } | ||||
| cg::OperatorNodeBase* OutputCallback::shallow_copy( | cg::OperatorNodeBase* OutputCallback::shallow_copy( | ||||
| @@ -60,7 +60,8 @@ public: | |||||
| using callback_t = thin_function<void(DeviceTensorND)>; | using callback_t = thin_function<void(DeviceTensorND)>; | ||||
| struct Param { | struct Param { | ||||
| callback_t callback; | callback_t callback; | ||||
| bool borrow = false; | |||||
| bool borrow = false; // do not obtain shared ownership on DeviceTensorND | |||||
| bool prefer_host_value = false; // use host value when possible | |||||
| }; | }; | ||||
| OutputCallback(Param param, | OutputCallback(Param param, | ||||
| const VarNodeArray& inputs, | const VarNodeArray& inputs, | ||||
| @@ -81,6 +82,7 @@ protected: | |||||
| NodeProp* do_make_node_prop() const override; | NodeProp* do_make_node_prop() const override; | ||||
| private: | private: | ||||
| Param m_param; | Param m_param; | ||||
| mutable bool m_use_host_value; | |||||
| }; | }; | ||||
| MGB_DEFINE_OPR_CLASS(NopCallback, cg::OperatorNodeBase) // { | MGB_DEFINE_OPR_CLASS(NopCallback, cg::OperatorNodeBase) // { | ||||
| @@ -13,6 +13,7 @@ | |||||
| #include "megbrain/opr/io.h" | #include "megbrain/opr/io.h" | ||||
| #include "megbrain/opr/basic_arith.h" | #include "megbrain/opr/basic_arith.h" | ||||
| #include "megbrain/opr/utility.h" | #include "megbrain/opr/utility.h" | ||||
| #include "megbrain/opr/tensor_manip.h" | |||||
| #include "megbrain/test/helper.h" | #include "megbrain/test/helper.h" | ||||
| using namespace mgb; | using namespace mgb; | ||||
| @@ -50,6 +51,27 @@ TEST(TestOprUtility, OutputCallback) { | |||||
| MGB_ASSERT_TENSOR_EQ(hy, *hx); | MGB_ASSERT_TENSOR_EQ(hy, *hx); | ||||
| } | } | ||||
| TEST(TestOprUtility, OutputCallbackPreferHost) { | |||||
| HostTensorGenerator<> gen; | |||||
| auto hx = gen({2, 3}); | |||||
| auto graph = ComputingGraph::make(); | |||||
| auto x = opr::Host2DeviceCopy::make(*graph, hx); | |||||
| x = opr::GetVarShape::make(x); | |||||
| HostTensorND hy; | |||||
| auto callback = [&hy](DeviceTensorND dv) {hy.copy_from(dv);}; | |||||
| opr::OutputCallback::Param param{callback}; | |||||
| param.prefer_host_value = true; | |||||
| auto dummy = opr::OutputCallback::make(param, x); | |||||
| auto y = opr::VirtualDep::make({x, dummy}); | |||||
| ComputingGraph::OutputSpec outspec{{y, [](DeviceTensorND&){}}}; | |||||
| auto func = graph->compile(outspec); | |||||
| func->execute(); | |||||
| ASSERT_TRUE(hy.comp_node() == CompNode::default_cpu()); | |||||
| ASSERT_EQ(hy.ptr<int>()[0], 2); | |||||
| ASSERT_EQ(hy.ptr<int>()[1], 3); | |||||
| } | |||||
| TEST(TestOprUtility, NopCallback) { | TEST(TestOprUtility, NopCallback) { | ||||
| HostTensorGenerator<> gen; | HostTensorGenerator<> gen; | ||||
| auto hx = gen({2, 3}); | auto hx = gen({2, 3}); | ||||