| @@ -0,0 +1,134 @@ | |||
| /** | |||
| * \file imperative/src/include/megbrain/imperative/graph_builder.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "megbrain/imperative/subgraph.h" | |||
| namespace mgb { | |||
| namespace imperative { | |||
| template <typename TDesc> | |||
| class Subgraph::Builder { | |||
| using graph_t = Subgraph; | |||
| using var_t = graph_t::var_t; | |||
| using vars_t = graph_t::vars_t; | |||
| using op_t = graph_t::op_t; | |||
| using expr_t = graph_t::expr_t; | |||
| using exprs_t = std::list<expr_t>; | |||
| using expr_iter_t = std::list<expr_t>::iterator; | |||
| using desc_t = TDesc; | |||
| using descs_t = SmallVector<TDesc>; | |||
| using infer_fn_t = std::function<descs_t(op_t, descs_t, size_t)>; | |||
| using encoded_graph_t = EncodedSubraph; | |||
| using var_map_t = std::unordered_map<var_t, var_t>; | |||
| vars_t m_inputs; | |||
| SmallVector<std::pair<var_t, TensorPtr>> m_constants; | |||
| vars_t m_outputs; | |||
| exprs_t m_exprs; | |||
| var_t m_last_var = 0; | |||
| std::unordered_map<var_t, TDesc> m_var2desc; | |||
| infer_fn_t m_infer_fn; | |||
| var_map_t m_var_replace_map; | |||
| private: | |||
| var_t next_var() { return ++m_last_var; } | |||
| public: | |||
| explicit Builder(std::function<descs_t(op_t, descs_t, size_t)> infer_function) | |||
| : m_infer_fn{infer_function} {} | |||
| vars_t write_expr(op_t op, vars_t inputs, size_t nr_outputs) { | |||
| return write_expr_before(m_exprs.end(), std::move(op), | |||
| std::move(inputs), std::move(nr_outputs)); | |||
| } | |||
| vars_t write_expr_before(expr_iter_t iter, op_t op, vars_t inputs, | |||
| size_t nr_outputs) { | |||
| vars_t outputs; | |||
| for (size_t i = 0; i < nr_outputs; ++i) { | |||
| outputs.push_back(next_var()); | |||
| } | |||
| m_exprs.insert(iter, {op, inputs, outputs}); | |||
| descs_t input_descs = get_descs(inputs); | |||
| descs_t output_descs = m_infer_fn(op, input_descs, nr_outputs); | |||
| mgb_assert(output_descs.size() == nr_outputs, | |||
| "bad infer_function: output descs size mismatch"); | |||
| for (size_t i = 0; i < nr_outputs; ++i) { | |||
| m_var2desc[outputs[i]] = output_descs[i]; | |||
| } | |||
| return outputs; | |||
| } | |||
| var_t write_constant(TensorPtr constant, desc_t desc) { | |||
| var_t constant_var = next_var(); | |||
| m_constants.emplace_back(constant_var, constant); | |||
| m_var2desc[constant_var] = std::move(desc); | |||
| return constant_var; | |||
| } | |||
| var_t write_input(desc_t input_desc) { | |||
| var_t input = next_var(); | |||
| m_var2desc[input] = input_desc; | |||
| m_inputs.push_back(input); | |||
| return input; | |||
| } | |||
| vars_t write_inputs(descs_t input_descs) { | |||
| vars_t inputs; | |||
| for (auto&& input_desc: input_descs) { | |||
| inputs.push_back(write_input(input_desc)); | |||
| } | |||
| return inputs; | |||
| } | |||
| void add_output(var_t var) { m_outputs.push_back(var); } | |||
| void add_outputs(vars_t vars) { | |||
| m_outputs.insert(m_outputs.begin(), vars.begin(), vars.end()); | |||
| } | |||
| desc_t get_desc(var_t var) { return m_var2desc.at(var); } | |||
| descs_t get_descs(vars_t vars) { | |||
| descs_t descs; | |||
| for (auto&& var : vars) { | |||
| descs.push_back(get_desc(var)); | |||
| } | |||
| return descs; | |||
| } | |||
| encoded_graph_t encode() const { | |||
| graph_t graph{m_inputs, | |||
| m_constants, | |||
| m_outputs, | |||
| {m_exprs.begin(), m_exprs.end()}}; | |||
| graph.replace_vars(m_var_replace_map); | |||
| graph.remove_unused_exprs(); | |||
| return encoded_graph_t::make(std::move(graph)); | |||
| } | |||
| void replace_var(var_t old_var, var_t new_var) { | |||
| mgb_assert(!m_var_replace_map.count(old_var), | |||
| "var cannot be replaced twice"); | |||
| m_var_replace_map[old_var] = new_var; | |||
| } | |||
| template <typename TFunctor> | |||
| void iterate(TFunctor&& functor) { | |||
| for (expr_iter_t iter = m_exprs.begin(); iter != m_exprs.end(); | |||
| ++iter) { | |||
| functor(iter); | |||
| } | |||
| } | |||
| template <typename TFunctor> | |||
| void reverse_iterate(TFunctor&& functor) { | |||
| for (expr_iter_t iter = --m_exprs.end();; --iter) { | |||
| functor(iter); | |||
| if (iter == m_exprs.begin()) { | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| expr_iter_t begin() { return m_exprs.begin(); } | |||
| expr_iter_t end() { return m_exprs.end(); } | |||
| }; | |||
| } // namespace imperative | |||
| } // namespace mgb | |||