| @@ -649,6 +649,7 @@ GraphOptimizer& GraphOptimizer::add_preset_passes( | |||||
| add_pass<ReorderArithChainPass>(cv_type); | add_pass<ReorderArithChainPass>(cv_type); | ||||
| add_pass<FinalArithTransformPass>(); | add_pass<FinalArithTransformPass>(); | ||||
| add_pass<RemoveRedundantTypeCvtPass>(); | add_pass<RemoveRedundantTypeCvtPass>(); | ||||
| add_pass<RemoveRedundantCopyPass>(); | |||||
| #if MGB_JIT | #if MGB_JIT | ||||
| bool need_jit = false; | bool need_jit = false; | ||||
| @@ -682,6 +682,69 @@ void RemoveRedundantTypeCvtPass::apply(OptState& opt) const { | |||||
| MIDOUT_E | MIDOUT_E | ||||
| } | } | ||||
| /* ======================= RemoveRedundantCopyPass ====================== */ | |||||
| const char* RemoveRedundantCopyPass::name() const { | |||||
| return "remove_redundant_copy"; | |||||
| } | |||||
| bool RemoveRedundantCopyPass::should_remove(const CompNode& A, | |||||
| const CompNode& B) { | |||||
| //! if A and B has the same memnode and cpu <-> atlas/cpu <-> cuda, as only | |||||
| //! these two compnode support crosscncopy | |||||
| if (A.mem_node() == B.mem_node() || | |||||
| ((A.device_type() == CompNode::DeviceType::CPU || | |||||
| A.device_type() == CompNode::DeviceType::MULTITHREAD) && | |||||
| (B.device_type() == CompNode::DeviceType::ATLAS || | |||||
| B.device_type() == CompNode::DeviceType::CUDA)) || | |||||
| ((B.device_type() == CompNode::DeviceType::CPU || | |||||
| B.device_type() == CompNode::DeviceType::MULTITHREAD) && | |||||
| (A.device_type() == CompNode::DeviceType::ATLAS || | |||||
| A.device_type() == CompNode::DeviceType::CUDA))) { | |||||
| return true; | |||||
| } else { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| void RemoveRedundantCopyPass::apply(OptState& opt) const { | |||||
| MIDOUT_B("RemoveRedundantCopyPass::apply") | |||||
| auto rewriter = opt.graph().make_rewriter(); | |||||
| auto on_opr = [&](OperatorNodeBase* opr) { | |||||
| if (auto copy0 = try_cast_as_op<opr::Copy>(opr)) { | |||||
| auto inp0 = rewriter.get_var(copy0->input(0)); | |||||
| if (auto copy1= try_cast_as_op<opr::Copy>(inp0)) { | |||||
| auto inp1 = copy1->input(0); | |||||
| if (should_remove(inp1->comp_node(), | |||||
| copy0->output(0)->comp_node())) { | |||||
| mgb_assert(!rewriter.has_manual_replace(inp1)); | |||||
| if (inp1->comp_node() == copy0->output(0)->comp_node()) { | |||||
| rewriter.replace_var( | |||||
| copy0->output(0), inp1, | |||||
| mgb_cstr_log("copy(copy(a0, a1), a0) -> " | |||||
| "a0")); | |||||
| return; | |||||
| } else { | |||||
| auto fold = opr::Copy::make( | |||||
| inp1, copy0->output(0)->comp_node()); | |||||
| rewriter.replace_var( | |||||
| copy0->output(0), fold.node(), | |||||
| mgb_cstr_log("copy(copy(a0, a1), a2) -> " | |||||
| "copy(a0, a2)")); | |||||
| return; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| rewriter.auto_replace_outputs(opr); | |||||
| }; | |||||
| opt.graph().iter(on_opr); | |||||
| rewriter.apply_inplace(); | |||||
| MIDOUT_E | |||||
| } | |||||
| #if MGB_ENABLE_OPR_MM | #if MGB_ENABLE_OPR_MM | ||||
| #include "megbrain/opr/collective_comm.h" | #include "megbrain/opr/collective_comm.h" | ||||
| @@ -85,6 +85,16 @@ namespace gopt { | |||||
| void apply(OptState &opt) const override; | void apply(OptState &opt) const override; | ||||
| }; | }; | ||||
| class RemoveRedundantCopyPass final : public Pass { | |||||
| private: | |||||
| //! Remove the copy chain of form cpu -> cpu -> cpu, | |||||
| //! cpu -> gpu -> cpu | |||||
| static bool should_remove(const CompNode& A, const CompNode& B); | |||||
| public: | |||||
| const char * name() const override; | |||||
| void apply(OptState &opt) const override; | |||||
| }; | |||||
| //! remove execution mask for const PPVs in conditional execution | //! remove execution mask for const PPVs in conditional execution | ||||
| class CondExecConstPredicateFolding final : public Pass { | class CondExecConstPredicateFolding final : public Pass { | ||||
| public: | public: | ||||
| @@ -26,14 +26,16 @@ namespace mgb { | |||||
| HostTensorGenerator<> gen; | HostTensorGenerator<> gen; | ||||
| std::shared_ptr<ComputingGraph> graph = ComputingGraph::make(); | std::shared_ptr<ComputingGraph> graph = ComputingGraph::make(); | ||||
| SymbolVar mkvar(const char *name, const TensorShape &shp = {1}) { | |||||
| return opr::Host2DeviceCopy::make( | |||||
| *graph, gen(shp)).rename(name); | |||||
| SymbolVar mkvar(const char* name, const TensorShape& shp = {1}, | |||||
| CompNode cn = CompNode::load("xpu0")) { | |||||
| return opr::Host2DeviceCopy::make(*graph, gen(shp), cn) | |||||
| .rename(name); | |||||
| } | } | ||||
| SymbolVar mkcvar(const char *name, const TensorShape &shp = {1}) { | |||||
| SymbolVar mkcvar(const char* name, const TensorShape& shp = {1}, | |||||
| CompNode cn = CompNode::load("xpu0")) { | |||||
| return opr::SharedDeviceTensor::make( | return opr::SharedDeviceTensor::make( | ||||
| *graph, *gen(shp)).rename(name); | |||||
| *graph, *gen(shp), cn).rename(name); | |||||
| } | } | ||||
| template<typename ...Args> | template<typename ...Args> | ||||
| @@ -73,4 +75,3 @@ namespace mgb { | |||||
| TEST_F(TestGopt##pass, name) | TEST_F(TestGopt##pass, name) | ||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -16,6 +16,7 @@ | |||||
| #include "megbrain/opr/basic_arith_wrapper.h" | #include "megbrain/opr/basic_arith_wrapper.h" | ||||
| #include "megbrain/opr/blas.h" | #include "megbrain/opr/blas.h" | ||||
| #include "megbrain/opr/cond.h" | #include "megbrain/opr/cond.h" | ||||
| #include "megbrain/opr/io.h" | |||||
| #include "megbrain/opr/tensor_manip.h" | #include "megbrain/opr/tensor_manip.h" | ||||
| #include "megbrain/opr/utility.h" | #include "megbrain/opr/utility.h" | ||||
| @@ -411,6 +412,71 @@ TEST_PASS(RemoveRedundantTypeCvtPass, Basic) { | |||||
| check(x_q8_q8, x_q8_fp32_q8_); | check(x_q8_q8, x_q8_fp32_q8_); | ||||
| } | } | ||||
| TEST_PASS(RemoveRedundantCopyPass, Basic) { | |||||
| auto x = mkvar("x", {2, 3, 3}, CompNode::load("cpu0")); | |||||
| { | |||||
| auto x_cpu1 = opr::Copy::make(x, CompNode::load("cpu1")); | |||||
| auto x_cpu0 = opr::Copy::make(x_cpu1, CompNode::load("cpu0")); | |||||
| auto x_cpu2 = opr::Copy::make(x_cpu0, CompNode::load("cpu2")); | |||||
| auto x_expected = opr::Copy::make(x, CompNode::load("cpu2")); | |||||
| check(x, x_cpu0); | |||||
| check(x_expected, x_cpu2); | |||||
| } | |||||
| { | |||||
| auto x_cpu1 = opr::Copy::make(x, CompNode::load("cpu1")); | |||||
| auto x_cpu2 = opr::Copy::make(x_cpu1, CompNode::load("cpu2")); | |||||
| auto x_cpu3 = opr::Copy::make(x_cpu2, CompNode::load("cpu3")); | |||||
| auto x_expected = opr::Copy::make(x, CompNode::load("cpu3")); | |||||
| check(x_expected, x_cpu3); | |||||
| } | |||||
| { | |||||
| auto x_cpu1 = opr::Copy::make(x, CompNode::load("cpu0:1")); | |||||
| auto x_cpu2 = opr::Copy::make(x_cpu1, CompNode::load("cpu0:2")); | |||||
| auto x_cpu3 = opr::Copy::make(x_cpu2, CompNode::load("cpu0:3")); | |||||
| auto x_expected = opr::Copy::make(x, CompNode::load("cpu0:3")); | |||||
| check(x_expected, x_cpu3); | |||||
| } | |||||
| { | |||||
| auto x_cpu1 = opr::Copy::make(x, CompNode::load("cpu0:1")); | |||||
| auto x_mt = opr::Copy::make(x_cpu1, CompNode::load("multithread8:0")); | |||||
| auto x_cpu3 = opr::Copy::make(x_mt, CompNode::load("cpu0:3")); | |||||
| auto x_expected = opr::Copy::make(x, CompNode::load("cpu0:3")); | |||||
| check(x_expected, x_cpu3); | |||||
| } | |||||
| #if MGB_ATLAS | |||||
| { | |||||
| auto x_atlas0 = opr::Copy::make(x, CompNode::load("atlas0")); | |||||
| auto x_cpu2 = opr::Copy::make(x_atlas0, CompNode::load("cpu0:2")); | |||||
| auto x_cpu3 = opr::Copy::make(x_cpu2, CompNode::load("cpu0:3")); | |||||
| auto x_expected = opr::Copy::make(x, CompNode::load("cpu0:3")); | |||||
| check(x_expected, x_cpu3); | |||||
| } | |||||
| #endif | |||||
| #if MGB_CUDA | |||||
| { | |||||
| auto x_cuda0 = opr::Copy::make(x, CompNode::load("gpu0")); | |||||
| auto x_cpu2 = opr::Copy::make(x_cuda0, CompNode::load("cpu0:2")); | |||||
| auto x_cpu3 = opr::Copy::make(x_cpu2, CompNode::load("cpu0:3")); | |||||
| auto x_expected = opr::Copy::make(x, CompNode::load("cpu0:3")); | |||||
| check(x_expected, x_cpu3); | |||||
| } | |||||
| { | |||||
| auto x_mt = opr::Copy::make(x, CompNode::load("multithread8:0")); | |||||
| auto x_cpu2 = opr::Copy::make(x_mt , CompNode::load("gpu0:1")); | |||||
| auto x_cpu3 = opr::Copy::make(x_cpu2, CompNode::load("multithread8:0")); | |||||
| auto x_expected = opr::Copy::make(x, CompNode::load("multithread8:0")); | |||||
| check(x_expected, x_cpu3); | |||||
| } | |||||
| #endif | |||||
| } | |||||
| #if MGB_ENABLE_OPR_MM | #if MGB_ENABLE_OPR_MM | ||||
| #include "megbrain/opr/collective_comm.h" | #include "megbrain/opr/collective_comm.h" | ||||
| #include "../../opr-mm/test/mock_client.h" | #include "../../opr-mm/test/mock_client.h" | ||||