GitOrigin-RevId: 6dbcb67009
tags/v1.6.0
| @@ -11,17 +11,214 @@ | |||||
| */ | */ | ||||
| #include "megbrain/gopt/subgraph_extractor.h" | #include "megbrain/gopt/subgraph_extractor.h" | ||||
| #include <atomic> | |||||
| #include "megbrain/serialization/opr_shallow_copy.h" | |||||
| using namespace mgb; | using namespace mgb; | ||||
| using namespace cg; | using namespace cg; | ||||
| using namespace gopt; | using namespace gopt; | ||||
| /* ================== GraphPartition::InputPlaceholder =================*/ | |||||
| // clang-format off | |||||
| MGB_DEFINE_OPR_CLASS(GraphPartition::InputPlaceholder, | |||||
| cg::SingleCNOperatorNodeBase) // { | |||||
| public: | |||||
| InputPlaceholder(VarNode* src_var, const TensorShape& infer_shp, | |||||
| std::unique_ptr<HostTensorND> infer_val = nullptr); | |||||
| static SymbolVar make(VarNode* src_var, const TensorShape& infer_shp, | |||||
| std::unique_ptr<HostTensorND> infer_val = nullptr); | |||||
| size_t input_id() const { return m_id; } | |||||
| private: | |||||
| void init_output_static_infer_desc() override; | |||||
| void scn_do_execute() override; | |||||
| void init_output_comp_node() override; | |||||
| const size_t m_id; | |||||
| TensorShape m_infer_shp; | |||||
| std::unique_ptr<HostTensorND> m_infer_val; | |||||
| static std::atomic_size_t sm_id; | |||||
| }; | |||||
| // clang-format on | |||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(GraphPartition::InputPlaceholder); | |||||
| std::atomic_size_t GraphPartition::InputPlaceholder::sm_id{0}; | |||||
| GraphPartition::InputPlaceholder::InputPlaceholder( | |||||
| VarNode* src_var, const TensorShape& infer_shp, | |||||
| std::unique_ptr<HostTensorND> infer_val) | |||||
| : Super(src_var->owner_graph(), {}, {}, {}), | |||||
| m_id{sm_id.fetch_add(1, std::memory_order_relaxed)}, | |||||
| m_infer_shp{infer_shp}, | |||||
| m_infer_val{std::move(infer_val)} { | |||||
| name(ssprintf("InputPlaceholder@%zu", m_id)); | |||||
| add_equivalence_component<ScalarHash<DTypeEnum>>(src_var->dtype().enumv()); | |||||
| add_equivalence_component<ScalarHash<size_t>>(m_id); | |||||
| add_output(None)->dtype(src_var->dtype()); | |||||
| } | |||||
| void GraphPartition::InputPlaceholder::init_output_comp_node() { | |||||
| output(0)->comp_node(CompNode::default_cpu()); | |||||
| } | |||||
| void GraphPartition::InputPlaceholder::scn_do_execute() { | |||||
| mgb_throw(InternalError, "InputPlaceholder opr can not be executed"); | |||||
| } | |||||
| void GraphPartition::InputPlaceholder::init_output_static_infer_desc() { | |||||
| using namespace cg::static_infer; | |||||
| auto&& mgr = owner_graph()->static_infer_manager(); | |||||
| if (m_infer_shp.ndim == 0) { | |||||
| auto infer_shape = [](TensorShape&, const InpVal&) { return false; }; | |||||
| mgr.register_shape_infer(output(0), | |||||
| {SourceType::MUTABLE, {}, infer_shape}); | |||||
| } else { | |||||
| mgr.register_shape_infer(output(0), | |||||
| ShapeInferDesc::make_const(m_infer_shp)); | |||||
| } | |||||
| if (m_infer_val == nullptr) { | |||||
| auto infer_value = [](DeviceTensorND&, const InpVal&) { return false; }; | |||||
| mgr.register_value_infer(output(0), | |||||
| {SourceType::MUTABLE, {}, infer_value}); | |||||
| } else { | |||||
| auto infer_value = [this](DeviceTensorND& dest, const InpVal&) { | |||||
| dest.copy_from(*m_infer_val).sync(); | |||||
| return true; | |||||
| }; | |||||
| mgr.register_value_infer(output(0), | |||||
| {SourceType::CONSTANT, {}, infer_value}); | |||||
| } | |||||
| } | |||||
| SymbolVar GraphPartition::InputPlaceholder::make( | |||||
| VarNode* src_var, const TensorShape& infer_shp, | |||||
| std::unique_ptr<HostTensorND> infer_val) { | |||||
| return src_var->owner_graph() | |||||
| ->insert_opr(std::make_unique<InputPlaceholder>( | |||||
| src_var, infer_shp, std::move(infer_val))) | |||||
| ->output(0); | |||||
| } | |||||
| /* ================== GraphPartition =================*/ | |||||
| #if MGB_ENABLE_JSON | |||||
| std::shared_ptr<json::Value> GraphPartition::to_json() const { | |||||
| auto replaced_outputs = std::get<1>(replace_graph_by_placeholder()); | |||||
| ThinHashSet<VarNode*> all_var_node; | |||||
| ThinHashSet<OperatorNodeBase*> all_opr_node; | |||||
| auto comp_seq = json::Array::make(); | |||||
| auto cb = [&](OperatorNodeBase* opr) { | |||||
| comp_seq->add(json::String::make(opr->id_str())); | |||||
| for (const auto& i : opr->input()) { | |||||
| if (all_var_node.count(i) == 0) { | |||||
| all_var_node.insert(i); | |||||
| } | |||||
| } | |||||
| all_opr_node.insert(opr); | |||||
| for (const auto& o : opr->output()) { | |||||
| all_var_node.insert(o); | |||||
| } | |||||
| }; | |||||
| cg::DepOprIter iter{cb}; | |||||
| for (const auto& o : replaced_outputs) | |||||
| iter.add(o->owner_opr()); | |||||
| auto dump_node_coll = [](auto&& collection) { | |||||
| auto objptr = json::Object::make(); | |||||
| auto&& obj = *objptr; | |||||
| for (auto&& i : collection) | |||||
| obj[i->id_str()] = i->to_json(); | |||||
| return objptr; | |||||
| }; | |||||
| return json::Object::make({{"operator", dump_node_coll(all_opr_node)}, | |||||
| {"var", dump_node_coll(all_var_node)}, | |||||
| {"comp_seq", comp_seq}}); | |||||
| } | |||||
| #endif | |||||
| std::pair<VarNodeArray, VarNodeArray> | |||||
| GraphPartition::replace_graph_by_placeholder() const { | |||||
| ThinHashMap<VarNode*, VarNode*> old2new; | |||||
| auto graph_partition_copy_opr_shallow = [](OperatorNodeBase* opr, | |||||
| const VarNodeArray& inps) { | |||||
| OperatorNodeConfig config = opr->config(); | |||||
| return serialization::copy_opr_shallow(*opr, inps, config)->output(0); | |||||
| }; | |||||
| OperatorNodeSet input_opr_set; | |||||
| for (const auto& i : m_inputs) | |||||
| input_opr_set.insert(i->owner_opr()); | |||||
| VarNodeArray placeholders; | |||||
| VarNodeArray replaced_outputs; | |||||
| VarNodeArray new_i; | |||||
| auto cb = [&](OperatorNodeBase* opr) { | |||||
| for (const auto& o : opr->output()) { | |||||
| if (o->contain_flag(VarNode::Flag::VOLATILE_CONTENT) || | |||||
| (input_opr_set.count(opr) && !m_inputs.count(o))) { | |||||
| continue; | |||||
| } | |||||
| VarNode* new_o; | |||||
| if (m_inputs.count(o)) { | |||||
| auto&& mgr = opr->owner_graph()->static_infer_manager(); | |||||
| const TensorShape* shp_ptr = nullptr; | |||||
| if (cg::is_static_var_shape(o)) { | |||||
| shp_ptr = mgr.infer_shape_fallible(o); | |||||
| } | |||||
| TensorShape infer_shp; | |||||
| if (shp_ptr) | |||||
| infer_shp = *shp_ptr; | |||||
| std::unique_ptr<HostTensorND> hval = nullptr; | |||||
| const DeviceTensorND* dval_ptr = nullptr; | |||||
| if (cg::is_static_var_value(o)) { | |||||
| dval_ptr = mgr.infer_value_fallible(o); | |||||
| } | |||||
| if (dval_ptr) { | |||||
| hval.reset(new HostTensorND(CompNode::default_cpu(), | |||||
| dval_ptr->dtype())); | |||||
| hval->resize(dval_ptr->shape()).copy_from(*dval_ptr).sync(); | |||||
| } | |||||
| new_o = InputPlaceholder::make(o, infer_shp, std::move(hval)) | |||||
| .node(); | |||||
| placeholders.push_back(new_o); | |||||
| } else { | |||||
| new_i.clear(); | |||||
| for (const auto& i : opr->input()) { | |||||
| new_i.push_back(old2new.at(i)); | |||||
| } | |||||
| new_o = graph_partition_copy_opr_shallow(o->owner_opr(), new_i); | |||||
| } | |||||
| old2new[o] = new_o; | |||||
| } | |||||
| }; | |||||
| cg::DepOprIter iter{cb}; | |||||
| for (auto&& i : m_inputs) { | |||||
| for (auto&& j : i->owner_opr()->input()) { | |||||
| if (!input_opr_set.count(j->owner_opr()) && | |||||
| !m_opr_set.count(j->owner_opr())) { | |||||
| iter.set_visited(j->owner_opr()); | |||||
| } | |||||
| } | |||||
| } | |||||
| for (auto&& o : m_outputs) | |||||
| iter.add(o->owner_opr()); | |||||
| for (auto&& o : m_outputs) { | |||||
| replaced_outputs.push_back(old2new.at(o)); | |||||
| } | |||||
| return std::make_pair(placeholders, replaced_outputs); | |||||
| } | |||||
| /* ================== SubGraphExtractor =================*/ | /* ================== SubGraphExtractor =================*/ | ||||
| std::vector<InternalGraph> SubGraphExtractor::extract( | |||||
| std::vector<GraphPartition> SubGraphExtractor::extract( | |||||
| const SymbolVarArray& endpoint_vars) const { | const SymbolVarArray& endpoint_vars) const { | ||||
| ThinHashMap<OperatorNodeBase*, std::pair<OperatorNodeBase*, int>> parent; | ThinHashMap<OperatorNodeBase*, std::pair<OperatorNodeBase*, int>> parent; | ||||
| thin_function<OperatorNodeBase*(OperatorNodeBase*)> union_find; | thin_function<OperatorNodeBase*(OperatorNodeBase*)> union_find; | ||||
| auto union_find = [&parent, &union_find](OperatorNodeBase* o) { | |||||
| union_find = [&parent, &union_find](OperatorNodeBase* o) { | |||||
| if (parent[o].first == o) | if (parent[o].first == o) | ||||
| return o; | return o; | ||||
| else { | else { | ||||
| @@ -34,7 +231,7 @@ std::vector<InternalGraph> SubGraphExtractor::extract( | |||||
| OperatorNodeBase* y) { | OperatorNodeBase* y) { | ||||
| auto root_x = union_find(x), root_y = union_find(y); | auto root_x = union_find(x), root_y = union_find(y); | ||||
| if (root_x != root_y) { | if (root_x != root_y) { | ||||
| OperatorNodeBase *large, small; | |||||
| OperatorNodeBase *large, *small; | |||||
| if (parent[root_x].second < parent[root_y].second) { | if (parent[root_x].second < parent[root_y].second) { | ||||
| small = root_x, large = root_y; | small = root_x, large = root_y; | ||||
| } else { | } else { | ||||
| @@ -42,25 +239,23 @@ std::vector<InternalGraph> SubGraphExtractor::extract( | |||||
| } | } | ||||
| parent[small].first = large; | parent[small].first = large; | ||||
| if (parent[large].second == parent[small].second) { | if (parent[large].second == parent[small].second) { | ||||
| parend[large].second += 1; | |||||
| parent[large].second += 1; | |||||
| } | } | ||||
| } | } | ||||
| }; | }; | ||||
| std::vector<OperatorNodeBase*> topo; | std::vector<OperatorNodeBase*> topo; | ||||
| auto cb = [&topo](OperatorNodeBase* opr) { | |||||
| auto cb = [this, &parent, &union_merge, &topo](OperatorNodeBase* opr) { | |||||
| topo.push_back(opr); | topo.push_back(opr); | ||||
| if (opr_list.count(opr->dyn_typeinfo()) == 0) | |||||
| if (m_opr_list.count(opr->dyn_typeinfo()) == 0) | |||||
| return; | return; | ||||
| auto find = parent.find(opr); | auto find = parent.find(opr); | ||||
| if (find == parent.end()) { | if (find == parent.end()) { | ||||
| auto insert = | |||||
| parent.insert(std::make_pair(opr, std::make_pair(opr, 0))); | |||||
| find = insert.first; | |||||
| parent.insert(std::make_pair(opr, std::make_pair(opr, 0))); | |||||
| } | } | ||||
| for (auto&& i : opr->input()) { | for (auto&& i : opr->input()) { | ||||
| auto&& o = i->owner_opr(); | auto&& o = i->owner_opr(); | ||||
| if (opr_list.count(o->dyn_typeinfo()) == 0) | |||||
| if (m_opr_list.count(o->dyn_typeinfo()) == 0) | |||||
| continue; | continue; | ||||
| union_merge(opr, o); | union_merge(opr, o); | ||||
| } | } | ||||
| @@ -69,33 +264,51 @@ std::vector<InternalGraph> SubGraphExtractor::extract( | |||||
| for (const auto& v : endpoint_vars) | for (const auto& v : endpoint_vars) | ||||
| iter.add(v.node()->owner_opr()); | iter.add(v.node()->owner_opr()); | ||||
| std::vector<InternalGraph> partitions; | |||||
| ThinHashMap<OperatorNodeBase*, InternalGraph*> roots; | |||||
| std::vector<GraphPartition> partitions; | |||||
| partitions.reserve(topo.size()); | |||||
| ThinHashMap<OperatorNodeBase*, GraphPartition*> roots; | |||||
| for (const auto& opr : reverse_adaptor(topo)) { | for (const auto& opr : reverse_adaptor(topo)) { | ||||
| auto root = union_find(opr); | |||||
| auto find = roots.find(root); | |||||
| InternalGraph* internal_graph = nullptr; | |||||
| if (find == roots.end()) { | |||||
| partitions.emplace_back(InternalGraph{}); | |||||
| auto insert = | |||||
| roots.insert(std::make_pair(root, &partitions.back())); | |||||
| internal_graph = insert.first->second; | |||||
| internal_graph->m_outputs.insert(opr->output(0)); | |||||
| if (m_opr_list.count(opr->dyn_typeinfo()) == 0) { | |||||
| for (const auto& i : opr->input()) { | |||||
| if (m_opr_list.count(i->owner_opr()->dyn_typeinfo())) { | |||||
| auto root = union_find(i->owner_opr()); | |||||
| GraphPartition* partition; | |||||
| auto find = roots.find(root); | |||||
| if (find != roots.end()) { | |||||
| partition = find->second; | |||||
| partition->output().insert(i); | |||||
| } | |||||
| } | |||||
| } | |||||
| } else { | } else { | ||||
| internal_graph = find->second; | |||||
| auto erase = internal_graph->m_inputs.erase(opr->output(0)); | |||||
| if (erase > 0) { | |||||
| internal_graph->m_internals.insert(opr->output(0)); | |||||
| auto root = union_find(opr); | |||||
| auto find = roots.find(root); | |||||
| GraphPartition* partition = nullptr; | |||||
| if (find == roots.end()) { | |||||
| partitions.emplace_back(GraphPartition{}); | |||||
| auto insert = | |||||
| roots.insert(std::make_pair(root, &partitions.back())); | |||||
| partition = insert.first->second; | |||||
| for (auto&& o : opr->output()) { | |||||
| if (!o->contain_flag(cg::VarNode::Flag::VOLATILE_CONTENT)) | |||||
| partition->output().insert(o); | |||||
| } | |||||
| } else { | } else { | ||||
| internal_graph->m_outputs.insert(opr->output(0)); | |||||
| partition = find->second; | |||||
| for (auto&& o : opr->output()) { | |||||
| if (!o->contain_flag(cg::VarNode::Flag::VOLATILE_CONTENT)) { | |||||
| auto erase = partition->input().erase(o); | |||||
| if (erase == 0) | |||||
| partition->output().insert(o); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| partition->opr_set().insert(opr); | |||||
| for (const auto& i : opr->input()) | |||||
| partition->input().insert(i); | |||||
| } | } | ||||
| for (const auto& i : opr->input()) | |||||
| internal_graph->m_inputs.insert(i); | |||||
| } | } | ||||
| return partitions; | return partitions; | ||||
| } | } | ||||
| /* ============= SubGraphExtractor =================*/ | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -16,17 +16,37 @@ | |||||
| namespace mgb { | namespace mgb { | ||||
| namespace gopt { | namespace gopt { | ||||
| struct InternalGraph { | |||||
| ThinHashSet<VarNode*> m_internals; | |||||
| ThinHashSet<VarNode*> m_inputs; | |||||
| ThinHashSet<VarNode*> m_outputs; | |||||
| class GraphPartition { | |||||
| public: | |||||
| using VarNodeSet = ThinHashSet<VarNode*>; | |||||
| using OperatorNodeSet = ThinHashSet<cg::OperatorNodeBase*>; | |||||
| class InputPlaceholder; | |||||
| GraphPartition() = default; | |||||
| #if MGB_ENABLE_JSON | |||||
| std::shared_ptr<json::Value> to_json() const; | |||||
| #endif | |||||
| const OperatorNodeSet& opr_set() const { return m_opr_set; } | |||||
| const VarNodeSet& input() const { return m_inputs; } | |||||
| const VarNodeSet& output() const { return m_outputs; } | |||||
| OperatorNodeSet& opr_set() { return m_opr_set; } | |||||
| VarNodeSet& input() { return m_inputs; } | |||||
| VarNodeSet& output() { return m_outputs; } | |||||
| private: | |||||
| OperatorNodeSet m_opr_set; | |||||
| VarNodeSet m_inputs; | |||||
| VarNodeSet m_outputs; | |||||
| std::pair<VarNodeArray, VarNodeArray> replace_graph_by_placeholder() const; | |||||
| }; | }; | ||||
| class SubGraphExtractor { | class SubGraphExtractor { | ||||
| public: | public: | ||||
| using OprList = ThinHashSet<Typeinfo*>; | using OprList = ThinHashSet<Typeinfo*>; | ||||
| SubGraphExtractor(OprList opr_list) : m_opr_list{opr_list} {}; | SubGraphExtractor(OprList opr_list) : m_opr_list{opr_list} {}; | ||||
| std::vector<InternalGraph> extract( | |||||
| std::vector<GraphPartition> extract( | |||||
| const SymbolVarArray& endpoint_vars) const; | const SymbolVarArray& endpoint_vars) const; | ||||
| private: | private: | ||||
| @@ -0,0 +1,275 @@ | |||||
| /** | |||||
| * \file src/gopt/test/subgraph_extractor.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "./helper.h" | |||||
| #include "megbrain/gopt/subgraph_extractor.h" | |||||
| #include "megbrain/opr/basic_arith.h" | |||||
| #include "megbrain/opr/blas.h" | |||||
| #include "megbrain/opr/dnn/convolution.h" | |||||
| #include "megbrain/opr/dnn/pooling.h" | |||||
| #include "megbrain/opr/imgproc.h" | |||||
| #include "megbrain/opr/internal/identical_fwd.h" | |||||
| #include "megbrain/opr/nn_int.h" | |||||
| #include "megbrain/opr/tensor_manip.h" | |||||
| #include "megbrain/serialization/serializer.h" | |||||
| using namespace mgb; | |||||
| using namespace gopt; | |||||
| using namespace serialization; | |||||
| namespace { | |||||
| // clang-format off | |||||
| MGB_DEFINE_OPR_CLASS(MultipleInputOutput, | |||||
| cg::SingleCNOperatorNodeBase) // { | |||||
| public: | |||||
| MultipleInputOutput(const VarNodeArray& inputs, const OperatorNodeConfig& config); | |||||
| static SymbolVarArray make(const SymbolVarArray& inputs, const OperatorNodeConfig& config = {}); | |||||
| private: | |||||
| void scn_do_execute() override { } | |||||
| void init_output_static_infer_desc() override { } | |||||
| }; | |||||
| // clang-format on | |||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(MultipleInputOutput); | |||||
| MultipleInputOutput::MultipleInputOutput(const VarNodeArray& inputs, | |||||
| const OperatorNodeConfig& config) | |||||
| : Super(inputs[0]->owner_graph(), config, "multiple_input_output", | |||||
| inputs) { | |||||
| for (auto&& i : inputs) | |||||
| add_input({i}); | |||||
| if (inputs.size() == 1) { | |||||
| add_output(None); | |||||
| } else { | |||||
| for (size_t i = 0; i < inputs.size(); ++i) | |||||
| add_output(ssprintf("o%zu", i)); | |||||
| } | |||||
| cg::add_workspace_output(this); | |||||
| } | |||||
| SymbolVarArray MultipleInputOutput::make(const SymbolVarArray& inputs, | |||||
| const OperatorNodeConfig& config) { | |||||
| auto src = cg::to_var_node_array(inputs); | |||||
| auto multiple_io = std::make_unique<MultipleInputOutput>(src, config); | |||||
| auto ret = | |||||
| cg::to_symbol_var_array(src[0]->owner_graph() | |||||
| ->insert_opr(std::move(multiple_io)) | |||||
| ->output()); | |||||
| ret.pop_back(); | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| TEST(TestSubGraphExtractor, MultipleOutputs) { | |||||
| HostTensorGenerator<> gen; | |||||
| auto graph = ComputingGraph::make(); | |||||
| auto mkvar = [&](const char* name, const TensorShape& shp) { | |||||
| return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); | |||||
| }; | |||||
| auto mkcvar = [&](const char* name, const TensorShape& shp) { | |||||
| return opr::SharedDeviceTensor::make(*graph, *gen(shp)).rename(name); | |||||
| }; | |||||
| graph->options().graph_opt_level = 0; | |||||
| auto x = mkvar("x", {8, 8, 8, 8}), w1 = mkcvar("w1", {4, 8, 3, 3}); | |||||
| auto y = mkvar("y", {1, 8, 1, 1}); | |||||
| auto add = x + y; | |||||
| opr::Convolution::Param param; | |||||
| param.pad_h = param.pad_w = 1; | |||||
| auto c1 = opr::Convolution::make(add, w1, param); | |||||
| auto w2 = mkcvar("w2", {8, 4, 3, 3}); | |||||
| auto c2 = opr::ConvolutionBackwardData::make(w2, add, param, {}, {}); | |||||
| auto sym_var_arr = MultipleInputOutput::make({c1, c2}); | |||||
| auto z = sym_var_arr[1]; | |||||
| z = z + (-128); | |||||
| using OprList = SubGraphExtractor::OprList; | |||||
| static const OprList opr_list = { | |||||
| opr::ConvolutionForward::typeinfo(), | |||||
| opr::Elemwise::typeinfo(), | |||||
| opr::TypeCvt::typeinfo(), | |||||
| MultipleInputOutput::typeinfo(), | |||||
| }; | |||||
| SubGraphExtractor extractor(opr_list); | |||||
| auto partitions = extractor.extract({z}); | |||||
| ASSERT_EQ(partitions.size(), 1u); | |||||
| // outputs: sym_var_arr[0], z, add | |||||
| ASSERT_EQ(partitions[0].output().size(), 3u); | |||||
| ASSERT_TRUE(partitions[0].output().count(add.node()) > 0); | |||||
| ASSERT_TRUE(partitions[0].output().count(z.node()) > 0); | |||||
| ASSERT_TRUE(partitions[0].output().count(sym_var_arr[0].node()) > 0); | |||||
| ASSERT_TRUE(partitions[0].output().count(sym_var_arr[1].node()) == 0); | |||||
| // inputs: x, y, w1, c2, (-128) | |||||
| ASSERT_EQ(partitions[0].input().size(), 5u); | |||||
| ASSERT_TRUE(partitions[0].input().count(x.node()) > 0); | |||||
| ASSERT_TRUE(partitions[0].input().count(c2.node()) > 0); | |||||
| // opr: (x + y) conv1 multi_io, (z - 128) | |||||
| ASSERT_EQ(partitions[0].opr_set().size(), 4u); | |||||
| ASSERT_TRUE(partitions[0].opr_set().count(add.node()->owner_opr()) > 0); | |||||
| ASSERT_TRUE(partitions[0].opr_set().count(c1.node()->owner_opr()) > 0); | |||||
| ASSERT_TRUE(partitions[0].opr_set().count( | |||||
| sym_var_arr[0].node()->owner_opr()) > 0); | |||||
| ASSERT_TRUE(partitions[0].opr_set().count(z.node()->owner_opr()) > 0); | |||||
| } | |||||
| TEST(TestSubGraphExtractor, MultipleReaders) { | |||||
| HostTensorGenerator<> gen; | |||||
| auto graph = ComputingGraph::make(); | |||||
| auto mkvar = [&](const char* name, const TensorShape& shp) { | |||||
| return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); | |||||
| }; | |||||
| auto mkcvar = [&](const char* name, const TensorShape& shp) { | |||||
| return opr::SharedDeviceTensor::make(*graph, *gen(shp)).rename(name); | |||||
| }; | |||||
| graph->options().graph_opt_level = 0; | |||||
| auto x = mkvar("x", {8, 8, 8, 8}), w1 = mkcvar("w1", {4, 8, 3, 3}); | |||||
| auto y = mkvar("y", {1, 8, 1, 1}); | |||||
| auto add = x + y; | |||||
| opr::Convolution::Param param; | |||||
| param.pad_h = param.pad_w = 1; | |||||
| auto c1 = opr::Convolution::make(add, w1, param); | |||||
| auto w2 = mkcvar("w2", {8, 4, 3, 3}); | |||||
| auto c2 = opr::ConvolutionBackwardData::make(w2, add, param, {}, {}); | |||||
| auto z = c1 + c2; | |||||
| using OprList = SubGraphExtractor::OprList; | |||||
| static const OprList opr_list = { | |||||
| opr::ConvolutionForward::typeinfo(), | |||||
| opr::Elemwise::typeinfo(), | |||||
| opr::TypeCvt::typeinfo(), | |||||
| }; | |||||
| SubGraphExtractor extractor(opr_list); | |||||
| auto partitions = extractor.extract({z}); | |||||
| ASSERT_EQ(partitions.size(), 1u); | |||||
| ASSERT_EQ(partitions[0].output().size(), 2u); | |||||
| ASSERT_TRUE(partitions[0].output().count(add.node()) > 0); | |||||
| ASSERT_TRUE(partitions[0].output().count(z.node()) > 0); | |||||
| ASSERT_EQ(partitions[0].input().size(), 4u); | |||||
| ASSERT_TRUE(partitions[0].input().count(x.node()) > 0); | |||||
| partitions[0].to_json()->writeto_fpath( | |||||
| output_file("TestSubGraphExtractor.MultipleReaders.json")); | |||||
| } | |||||
| TEST(TestSubGraphExtractor, Complicated) { | |||||
| const size_t N = 16, C = 3, H = 768, W = 1280; | |||||
| HostTensorGenerator<dtype::Uint8> gen; | |||||
| auto graph = ComputingGraph::make(); | |||||
| /* h2d | |||||
| | | |||||
| v | |||||
| astype(f32) | |||||
| | | |||||
| add(-128) | |||||
| | | |||||
| v | |||||
| astype(q8) | |||||
| | | |||||
| v | |||||
| conv1 | |||||
| | | |||||
| v | |||||
| astype(u4) | |||||
| | | |||||
| / \ | |||||
| conv2 conv3 -> astype(q32) -> output | |||||
| \ / | |||||
| qadd | |||||
| | | |||||
| v | |||||
| astype(q8) | |||||
| / \ | |||||
| deconv conv4 | |||||
| \ / | |||||
| concat -> output */ | |||||
| auto h2d = opr::Host2DeviceCopy::make(*graph, gen({N, C, H, W})); | |||||
| auto data = opr::TypeCvt::make(h2d, dtype::Float32()); | |||||
| auto sub_128 = data + (-128); | |||||
| auto x = opr::TypeCvt::make(sub_128, dtype::QuantizedS8(1.f)); | |||||
| auto mkcvar = [&](const char* name, const TensorShape& shp, | |||||
| const DType& dtype) { | |||||
| return opr::TypeCvt::make( | |||||
| opr::SharedDeviceTensor::make(*graph, *gen(shp)).rename(name), | |||||
| dtype); | |||||
| }; | |||||
| auto w1 = mkcvar("w1", {16, 3, 3, 3}, dtype::QuantizedS8(1.f)); | |||||
| auto b1 = mkcvar("b1", {1, 16, 1, 1}, dtype::QuantizedS32(1.f)); | |||||
| opr::ConvBias::Param param; | |||||
| param.stride_h = param.stride_w = 2; | |||||
| param.pad_h = param.pad_w = 1; | |||||
| auto conv1 = opr::ConvBias::make( | |||||
| x, w1, b1, param, {}, OperatorNodeConfig(dtype::QuantizedS8(1.f))); | |||||
| conv1 = opr::TypeCvt::make( | |||||
| conv1, dtype::Quantized4Asymm(1.f, static_cast<uint8_t>(8))); | |||||
| auto w2 = mkcvar("w2", {16, 16, 3, 3}, dtype::QuantizedS4(1.f)); | |||||
| auto b2 = mkcvar("b2", {1, 16, 1, 1}, dtype::QuantizedS32(1.f)); | |||||
| auto conv2 = opr::ConvBias::make(conv1, w2, b2, param, {}, | |||||
| OperatorNodeConfig(dtype::Quantized4Asymm( | |||||
| 1.f, static_cast<uint8_t>(8)))); | |||||
| param.pad_h = param.pad_w = 0; | |||||
| auto w3 = mkcvar("w3", {16, 16, 1, 1}, dtype::QuantizedS4(1.f)); | |||||
| auto b3 = mkcvar("b3", {1, 16, 1, 1}, dtype::QuantizedS32(1.f)); | |||||
| auto conv3 = opr::ConvBias::make(conv1, w3, b3, param, {}, | |||||
| OperatorNodeConfig(dtype::Quantized4Asymm( | |||||
| 1.f, static_cast<uint8_t>(8)))); | |||||
| auto conv3f = opr::TypeCvt::make(conv3, dtype::Float32()); | |||||
| auto qadd = opr::ElemwiseMultiType::make( | |||||
| {conv2, conv3}, {opr::ElemwiseMultiType::Mode::QADD}, | |||||
| OperatorNodeConfig( | |||||
| dtype::Quantized4Asymm(1.f, static_cast<uint8_t>(8)))); | |||||
| auto q8 = opr::TypeCvt::make(qadd, dtype::QuantizedS8(1.f)); | |||||
| auto w4 = mkcvar("w4", {16, 16, 3, 3}, dtype::QuantizedS8(1.f)); | |||||
| param.stride_h = param.stride_w = 1; | |||||
| param.pad_h = param.pad_w = 1; | |||||
| auto conv4 = opr::ConvBiasForward::make( | |||||
| q8, w4, param, {}, OperatorNodeConfig(dtype::QuantizedS8(1.f))); | |||||
| conv4 = opr::TypeCvt::make(conv4, dtype::Float32()); | |||||
| opr::Convolution::Param conv_param; | |||||
| conv_param.stride_h = param.stride_w = 1; | |||||
| conv_param.pad_h = param.pad_w = 0; | |||||
| auto w5 = mkcvar("w4", {16, 16, 1, 1}, dtype::QuantizedS8(1.f)); | |||||
| auto deconv = opr::ConvolutionBackwardData::make( | |||||
| w5, q8, conv_param, {}, | |||||
| OperatorNodeConfig(dtype::QuantizedS8(1.f))); | |||||
| deconv = opr::TypeCvt::make(deconv, dtype::Float32()); | |||||
| auto z = opr::Concat::make({conv4, deconv}, 1); | |||||
| using OprList = SubGraphExtractor::OprList; | |||||
| static const OprList opr_list = { | |||||
| opr::ConvBiasForward::typeinfo(), | |||||
| opr::ConvolutionForward::typeinfo(), | |||||
| opr::ConvolutionBackwardData::typeinfo(), | |||||
| opr::ElemwiseMultiType::typeinfo(), | |||||
| opr::Elemwise::typeinfo(), | |||||
| opr::TypeCvt::typeinfo(), | |||||
| opr::PoolingForward::typeinfo(), | |||||
| opr::WarpPerspectiveForward::typeinfo(), | |||||
| }; | |||||
| SubGraphExtractor extractor(opr_list); | |||||
| auto partitions = extractor.extract({conv3f.node(), z.node()}); | |||||
| ASSERT_EQ(partitions.size(), 1u); | |||||
| const char* prefix = "TestSubGraphExtractor.Complicated"; | |||||
| partitions[0].to_json()->writeto_fpath( | |||||
| output_file(ssprintf("%s.json", prefix).c_str())); | |||||
| } | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||