| @@ -20,6 +20,7 @@ | |||||
| #include "./proxy_graph_base.h" | #include "./proxy_graph_base.h" | ||||
| #include <optional> | #include <optional> | ||||
| #include "megbrain/opr/utility.h" | |||||
| #include "range/v3/all.hpp" | #include "range/v3/all.hpp" | ||||
| namespace mgb::imperative::proxy_graph { | namespace mgb::imperative::proxy_graph { | ||||
| @@ -83,7 +84,7 @@ TensorAdaptor(T&) -> TensorAdaptor<T, void>; | |||||
| template <typename T> | template <typename T> | ||||
| TensorAdaptor(T*) -> TensorAdaptor<T, void>; | TensorAdaptor(T*) -> TensorAdaptor<T, void>; | ||||
| SmallVector<Tensor*> to_raw_ptr_array( | |||||
| inline SmallVector<Tensor*> to_raw_ptr_array( | |||||
| const SmallVector<TensorPtr>& inputs, bool ensure_storage = true) { | const SmallVector<TensorPtr>& inputs, bool ensure_storage = true) { | ||||
| SmallVector<Tensor*> ret; | SmallVector<Tensor*> ret; | ||||
| for (auto&& i : inputs) { | for (auto&& i : inputs) { | ||||
| @@ -243,6 +244,13 @@ public: | |||||
| vinputs[i] = opr_ref_keeper.back()->output(0); | vinputs[i] = opr_ref_keeper.back()->output(0); | ||||
| } | } | ||||
| auto ovars = OpDef::apply_on_var_node(opdef, vinputs); | auto ovars = OpDef::apply_on_var_node(opdef, vinputs); | ||||
| if (!m_opr) { | |||||
| // identity | |||||
| mgb_assert(vinputs.size() == 1 && ovars.size() == 1); | |||||
| mgb_assert(ovars[0] == vinputs[0]); | |||||
| auto&& input = vinputs[0]; | |||||
| ovars[0] = opr::Identity::make(input).node(); | |||||
| } | |||||
| mgb_assert(m_opr); | mgb_assert(m_opr); | ||||
| output_data.resize(m_opr->output().size()); | output_data.resize(m_opr->output().size()); | ||||
| for (auto* v : ovars) { | for (auto* v : ovars) { | ||||
| @@ -343,7 +351,6 @@ public: | |||||
| } else { | } else { | ||||
| mgb_assert(j < outputs.size()); | mgb_assert(j < outputs.size()); | ||||
| auto&& tensor = outputs[j]; | auto&& tensor = outputs[j]; | ||||
| auto&& layout = tensor->layout(); | |||||
| if (var->m_mem_plan.chunk().owner_var != var) { | if (var->m_mem_plan.chunk().owner_var != var) { | ||||
| tensor->assign_from_dev_tensor( | tensor->assign_from_dev_tensor( | ||||
| var->m_dev_tensor); // memory forwarding | var->m_dev_tensor); // memory forwarding | ||||
| @@ -613,6 +620,7 @@ class ExecMiniGraph : public ProxyGraph::MiniGraph { | |||||
| busy_oprs.pop_front(); | busy_oprs.pop_front(); | ||||
| return m_opr; | return m_opr; | ||||
| } | } | ||||
| mgb_assert(false); | |||||
| } | } | ||||
| template <bool in_use> | template <bool in_use> | ||||