| @@ -309,7 +309,7 @@ inline auto apply(Subgraph graph, Tensor*const* args, size_t nargs) { | |||
| for (size_t i = 0; i < nargs; ++i) { | |||
| inputs.push_back(args[i]->shared_from_this()); | |||
| } | |||
| auto apply_functor = [](std::shared_ptr<OpDef> op, SmallVector<std::shared_ptr<Tensor>> inputs) { | |||
| auto apply_functor = [](std::shared_ptr<OpDef> op, SmallVector<std::shared_ptr<Tensor>> inputs, size_t) { | |||
| return apply(op, std::move(inputs)); | |||
| }; | |||
| return graph.apply(inputs, apply_functor, &make_const); | |||
| @@ -317,7 +317,7 @@ inline auto apply(Subgraph graph, Tensor*const* args, size_t nargs) { | |||
| template <typename T> | |||
| auto apply(Subgraph graph, T&& tensors) | |||
| -> std::enable_if_t<std::is_same_v<decltype(tensors[0]), Tensor*>, | |||
| -> std::enable_if_t<std::is_same_v<std::decay_t<decltype(tensors[0])>, Tensor*>, | |||
| apply_result_t> { | |||
| size_t nargs = tensors.size(); | |||
| Tensor* args[nargs]; | |||
| @@ -0,0 +1,105 @@ | |||
| /** | |||
| * \file imperative/src/impl/subgraph.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "megbrain/imperative/subgraph.h" | |||
| namespace mgb { | |||
| namespace imperative { | |||
| void Subgraph::remove_unused_exprs() { | |||
| std::unordered_set<size_t> required_vars = {outputs.begin(), outputs.end()}; | |||
| required_vars.erase(0); | |||
| for (auto iter = exprs.rbegin(); iter != exprs.rend(); ++iter) { | |||
| auto& expr = *iter; | |||
| bool required = false; | |||
| for (auto output : expr.outputs) { | |||
| if (required_vars.count(output)) { | |||
| required = true; | |||
| break; | |||
| } | |||
| } | |||
| if (required) { | |||
| required_vars.insert(expr.inputs.begin(), expr.inputs.end()); | |||
| } else { | |||
| expr.op = nullptr; | |||
| } | |||
| } | |||
| exprs.erase(std::remove_if(exprs.begin(), exprs.end(), | |||
| [](auto expr) { return expr.op == nullptr; }), | |||
| exprs.end()); | |||
| } | |||
| SmallVector<bool> Subgraph::gen_input_mask() { | |||
| std::unordered_set<size_t> unused_inputs = {inputs.begin(), inputs.end()}; | |||
| for (auto&& expr : exprs) { | |||
| for (auto&& input : expr.inputs) { | |||
| unused_inputs.erase(input); | |||
| } | |||
| } | |||
| for (auto&& output : outputs) { | |||
| unused_inputs.erase(output); | |||
| } | |||
| unused_inputs.insert(0); | |||
| SmallVector<bool> mask(inputs.size(), true); | |||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||
| if (unused_inputs.count(inputs[i])) { | |||
| mask[i] = false; | |||
| } | |||
| } | |||
| return mask; | |||
| } | |||
| SmallVector<bool> Subgraph::gen_output_mask() { | |||
| std::unordered_set<size_t> invalid_outputs = {outputs.begin(), | |||
| outputs.end()}; | |||
| for (auto&& input : inputs) { | |||
| invalid_outputs.erase(input); | |||
| } | |||
| for (auto&& expr : exprs) { | |||
| for (auto&& output : expr.outputs) { | |||
| invalid_outputs.erase(output); | |||
| } | |||
| } | |||
| for (auto&& constant: constants) { | |||
| invalid_outputs.erase(constant.first); | |||
| } | |||
| invalid_outputs.insert(0); | |||
| SmallVector<bool> mask(outputs.size(), true); | |||
| for (size_t i = 0; i < outputs.size(); ++i) { | |||
| if (invalid_outputs.count(outputs[i])) { | |||
| mask[i] = false; | |||
| } | |||
| } | |||
| return mask; | |||
| } | |||
| void Subgraph::replace_vars( | |||
| const std::unordered_map<size_t, size_t>& replace_map) { | |||
| // FIXME: preprocess replace_map | |||
| auto replace_var = [&](var_t& var) { | |||
| // TODO: detect infinite loop | |||
| while (replace_map.count(var)) { | |||
| var = replace_map.at(var); | |||
| } | |||
| }; | |||
| for (auto& expr : exprs) { | |||
| for (auto& input : expr.inputs) { | |||
| replace_var(input); | |||
| } | |||
| } | |||
| for (auto& output : outputs) { | |||
| replace_var(output); | |||
| } | |||
| } | |||
| } // namespace imperative | |||
| } // namespace mgb | |||
| @@ -14,6 +14,7 @@ | |||
| #include "megbrain/graph.h" | |||
| #include "megbrain/imperative/physical_tensor.h" | |||
| #include "megbrain/imperative/utils/to_string.h" | |||
| #include "megbrain/imperative/subgraph.h" | |||
| namespace mgb { | |||
| namespace imperative { | |||
| @@ -28,54 +29,6 @@ enum DispatchMode { | |||
| using SharedOp = std::shared_ptr<OpDef>; | |||
| template <typename T> | |||
| struct Expr { | |||
| std::shared_ptr<OpDef> op; | |||
| SmallVector<T> inputs; | |||
| SmallVector<T> outputs; | |||
| }; | |||
| struct Subgraph { | |||
| SmallVector<size_t> inputs; | |||
| SmallVector<std::pair<size_t, TensorPtr>> constants; | |||
| SmallVector<size_t> outputs; | |||
| SmallVector<Expr<size_t>> exprs; | |||
| template <typename T, typename F, typename C> | |||
| SmallVector<T> apply(SmallVector<T> input_vars, F&& f, C&& c) const { | |||
| std::unordered_map<size_t, T> idx2var; | |||
| mgb_assert(inputs.size() == input_vars.size(), "input size mismatch"); | |||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||
| idx2var[inputs[i]] = input_vars[i]; | |||
| } | |||
| for (auto&& [idx, val]: constants) { | |||
| idx2var[idx] = c(val); | |||
| } | |||
| for (auto& expr: exprs) { | |||
| SmallVector<T> expr_inputs; | |||
| for (auto idx: expr.inputs) { | |||
| expr_inputs.push_back(idx2var[idx]); | |||
| } | |||
| SmallVector<T> expr_outputs = f(expr.op, std::move(expr_inputs)); | |||
| mgb_assert(expr_outputs.size() == expr.outputs.size(), "output size mismatch"); | |||
| for (size_t i = 0; i < expr_outputs.size(); ++i) { | |||
| idx2var[expr.outputs[i]] = expr_outputs[i]; | |||
| } | |||
| } | |||
| SmallVector<T> output_vars; | |||
| for (auto idx: outputs) { | |||
| output_vars.push_back(idx2var[idx]); | |||
| } | |||
| return output_vars; | |||
| } | |||
| bool empty() const { | |||
| return outputs.size() == 0; | |||
| } | |||
| std::string repr() const; | |||
| }; | |||
| struct BackwardGraphResult { | |||
| Subgraph backward; | |||
| SmallVector<bool> save_for_backward; | |||
| @@ -0,0 +1,100 @@ | |||
| /** | |||
| * \file imperative/src/include/megbrain/imperative/subgraph.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include <list> | |||
| #include "megbrain/imperative/physical_tensor.h" | |||
| #include "megbrain/imperative/utils/to_string.h" | |||
| #include "megbrain/utils/small_vector.h" | |||
| namespace mgb { | |||
| namespace imperative { | |||
| class OpDef; | |||
| template <typename T> | |||
| struct Expr { | |||
| std::shared_ptr<OpDef> op; | |||
| SmallVector<T> inputs; | |||
| SmallVector<T> outputs; | |||
| }; | |||
| template <typename T> | |||
| struct ToStringTrait<Expr<T>> { | |||
| std::string operator()(const Expr<T>& expr) { | |||
| return ssprintf("%s = %s %s\n", to_string(expr.inputs).c_str(), to_string(expr.op.get()).c_str(), to_string(expr.outputs).c_str()); | |||
| } | |||
| }; | |||
| struct Subgraph { | |||
| template <typename TDesc> | |||
| class Builder; | |||
| using var_t = size_t; | |||
| using vars_t = SmallVector<size_t>; | |||
| using op_t = std::shared_ptr<OpDef>; | |||
| using expr_t = Expr<var_t>; | |||
| template <typename TDesc> | |||
| using builder_t = Builder<TDesc>; | |||
| SmallVector<var_t> inputs; | |||
| SmallVector<std::pair<var_t, TensorPtr>> constants; | |||
| SmallVector<var_t> outputs; | |||
| SmallVector<expr_t> exprs; | |||
| template <typename T, typename F, typename C> | |||
| SmallVector<T> apply(SmallVector<T> input_vars, F&& f, C&& c) const { | |||
| std::unordered_map<size_t, T> idx2var; | |||
| mgb_assert(inputs.size() == input_vars.size(), "input size mismatch"); | |||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||
| idx2var[inputs[i]] = input_vars[i]; | |||
| } | |||
| for (auto&& [idx, val] : constants) { | |||
| idx2var[idx] = c(val); | |||
| } | |||
| for (auto& expr : exprs) { | |||
| SmallVector<T> expr_inputs; | |||
| for (auto idx : expr.inputs) { | |||
| expr_inputs.push_back(idx2var[idx]); | |||
| } | |||
| SmallVector<T> expr_outputs = | |||
| f(expr.op, std::move(expr_inputs), expr.outputs.size()); | |||
| mgb_assert(expr_outputs.size() == expr.outputs.size(), | |||
| "output size mismatch"); | |||
| for (size_t i = 0; i < expr_outputs.size(); ++i) { | |||
| idx2var[expr.outputs[i]] = expr_outputs[i]; | |||
| } | |||
| } | |||
| SmallVector<T> output_vars; | |||
| for (auto idx : outputs) { | |||
| output_vars.push_back(idx2var[idx]); | |||
| } | |||
| return output_vars; | |||
| } | |||
| void remove_unused_exprs(); | |||
| SmallVector<bool> gen_input_mask(); | |||
| SmallVector<bool> gen_output_mask(); | |||
| bool empty() const { return outputs.size() == 0; } | |||
| void replace_vars(const std::unordered_map<size_t, size_t>& replace_map); | |||
| std::string repr() const; | |||
| bool is_single() const; | |||
| std::shared_ptr<OpDef> as_single() const; | |||
| bool operator==(const Subgraph& rhs) const; | |||
| }; | |||
| } // namespace imperative | |||
| } // namespace mgb | |||