GitOrigin-RevId: 844b7e8d39
tags/v1.6.0-rc1
| @@ -19,6 +19,7 @@ using namespace gopt; | |||||
| using Dimension = megdnn::Dimension; | using Dimension = megdnn::Dimension; | ||||
| using NamedTensorShape = megdnn::NamedTensorShape; | using NamedTensorShape = megdnn::NamedTensorShape; | ||||
| // =================== ModifyShapeMixin ====================*/ | |||||
| ModifyShapeMixin::Pattern ModifyShapeMixin::mixin_analyze() const { | ModifyShapeMixin::Pattern ModifyShapeMixin::mixin_analyze() const { | ||||
| static constexpr uint32_t UNDETERMINED_EXTENT = | static constexpr uint32_t UNDETERMINED_EXTENT = | ||||
| Dimension::UNDETERMINED_EXTENT; | Dimension::UNDETERMINED_EXTENT; | ||||
| @@ -50,7 +51,9 @@ ModifyShapeMixin::Pattern ModifyShapeMixin::mixin_analyze() const { | |||||
| ModifyShapeMixin::Checker ModifyShapeMixin::mixin_emit_checker( | ModifyShapeMixin::Checker ModifyShapeMixin::mixin_emit_checker( | ||||
| const Pattern& pattern) const { | const Pattern& pattern) const { | ||||
| auto src = m_src; | auto src = m_src; | ||||
| auto checker = [src, pattern](VarNode* var) { | |||||
| auto checker = [src, pattern](const VarNodeArray& input) { | |||||
| mgb_assert(input.size() >= 1); | |||||
| const auto& var = input.front(); | |||||
| const auto& shp = var->shape(); | const auto& shp = var->shape(); | ||||
| if (shp.ndim != src.ndim) | if (shp.ndim != src.ndim) | ||||
| return false; | return false; | ||||
| @@ -73,10 +76,14 @@ ModifyShapeMixin::Checker ModifyShapeMixin::mixin_emit_checker( | |||||
| return checker; | return checker; | ||||
| } | } | ||||
| ReshapeEmitter::EmitResult ReshapeEmitter::emit() const { | |||||
| // =================== MakeShapeEmitter ====================*/ | |||||
| MakeShapeEmitter::EmitResult MakeShapeEmitter::emit() const { | |||||
| auto pattern = mixin_analyze(); | auto pattern = mixin_analyze(); | ||||
| auto builder = [pattern](VarNode* var) { | |||||
| auto sym_var = SymbolVar(var); | |||||
| auto builder = [pattern](const VarNodeArray& input) { | |||||
| mgb_assert(input.size() == 1, | |||||
| "number of input of MakeShapeBuilder should be 1(got:%zu)", | |||||
| input.size()); | |||||
| auto sym_var = SymbolVar(input.front()); | |||||
| auto shp = opr::GetVarShape::make(sym_var); | auto shp = opr::GetVarShape::make(sym_var); | ||||
| auto cv = [&sym_var](int c) { return sym_var.make_scalar(c); }; | auto cv = [&sym_var](int c) { return sym_var.make_scalar(c); }; | ||||
| auto sub = [&shp, &cv](int ax) { | auto sub = [&shp, &cv](int ax) { | ||||
| @@ -97,31 +104,59 @@ ReshapeEmitter::EmitResult ReshapeEmitter::emit() const { | |||||
| } | } | ||||
| } | } | ||||
| auto tshp = opr::Concat::make(axs, 0); | auto tshp = opr::Concat::make(axs, 0); | ||||
| auto ovar = opr::Reshape::make(sym_var, tshp); | |||||
| return tshp.node(); | |||||
| }; | |||||
| auto checker = mixin_emit_checker(pattern); | |||||
| return std::make_tuple(builder, checker); | |||||
| } | |||||
| // =================== ReshapeEmitter ====================*/ | |||||
| ReshapeEmitter::EmitResult ReshapeEmitter::emit() const { | |||||
| auto pattern = mixin_analyze(); | |||||
| auto builder = [pattern](const VarNodeArray& input) { | |||||
| mgb_assert(input.size() == 2, | |||||
| "number of input of Reshape should be 2(got:%zu)", | |||||
| input.size()); | |||||
| auto ovar = opr::Reshape::make(input[0], input[1]); | |||||
| return ovar.node(); | return ovar.node(); | ||||
| }; | }; | ||||
| auto checker = mixin_emit_checker(pattern); | auto checker = mixin_emit_checker(pattern); | ||||
| return std::make_tuple(builder, checker); | return std::make_tuple(builder, checker); | ||||
| } | } | ||||
| // =================== DimshuffleEmitter ====================*/ | |||||
| DimshuffleEmitter::EmitResult DimshuffleEmitter::emit() const { | DimshuffleEmitter::EmitResult DimshuffleEmitter::emit() const { | ||||
| auto&& pattern = m_pattern; | auto&& pattern = m_pattern; | ||||
| auto builder = [pattern](VarNode* var) { | |||||
| auto sym_var = SymbolVar(var); | |||||
| auto builder = [pattern](const VarNodeArray& input) { | |||||
| mgb_assert(input.size() == 1, | |||||
| "number of input of Dimshuffle should be 1(got:%zu)", | |||||
| input.size()); | |||||
| auto sym_var = SymbolVar(input.front()); | |||||
| return opr::Dimshuffle::make(sym_var, pattern).node(); | return opr::Dimshuffle::make(sym_var, pattern).node(); | ||||
| }; | }; | ||||
| auto checker = [pattern](VarNode* var) { | |||||
| return var->shape().ndim == pattern.size(); | |||||
| auto checker = [pattern](const VarNodeArray& input) { | |||||
| mgb_assert(input.size() == 1, | |||||
| "number of input of Dimshuffle should be 1(got:%zu)", | |||||
| input.size()); | |||||
| return input.front()->shape().ndim == pattern.size(); | |||||
| }; | }; | ||||
| return std::make_tuple(builder, checker); | return std::make_tuple(builder, checker); | ||||
| } | } | ||||
| // =================== ReformatEmitter ====================*/ | |||||
| ReformatEmitter::EmitResult ReformatEmitter::emit() const { | ReformatEmitter::EmitResult ReformatEmitter::emit() const { | ||||
| auto ops = analyze(); | |||||
| auto builder = [ops](VarNode* var) { | |||||
| VarNode* ovar = var; | |||||
| for (const auto& i : ops) { | |||||
| ovar = i(ovar); | |||||
| auto builders = analyze(); | |||||
| auto builder = [builders](const VarNodeArray& input) { | |||||
| VarNode *var, *ovar; | |||||
| var = ovar = input.front(); | |||||
| if (builders.make_shape1) { | |||||
| auto shp1 = builders.make_shape1({var}); | |||||
| ovar = builders.reshape1({ovar, shp1}); | |||||
| } | |||||
| ovar = builders.dimshuffle({ovar}); | |||||
| if (builders.make_shape2) { | |||||
| auto shp2 = builders.make_shape2({var}); | |||||
| ovar = builders.reshape2({ovar, shp2}); | |||||
| } | } | ||||
| return ovar; | return ovar; | ||||
| }; | }; | ||||
| @@ -130,7 +165,7 @@ ReformatEmitter::EmitResult ReformatEmitter::emit() const { | |||||
| return std::make_tuple(builder, checker); | return std::make_tuple(builder, checker); | ||||
| } | } | ||||
| SmallVector<ReformatEmitter::Builder> ReformatEmitter::analyze() const { | |||||
| ReformatEmitter::UnderlyingBuilders ReformatEmitter::analyze() const { | |||||
| struct Dim { | struct Dim { | ||||
| Dimension dim; | Dimension dim; | ||||
| int index; | int index; | ||||
| @@ -196,12 +231,21 @@ SmallVector<ReformatEmitter::Builder> ReformatEmitter::analyze() const { | |||||
| i1[i] = src_dims[src_perm[i]].dim; | i1[i] = src_dims[src_perm[i]].dim; | ||||
| i2[i] = src_dims[src_perm[permute[i]]].dim; | i2[i] = src_dims[src_perm[permute[i]]].dim; | ||||
| } | } | ||||
| SmallVector<Builder> ops; | |||||
| if (!m_src.eq_shape(i1)) | |||||
| ops.emplace_back(std::get<0>(ReshapeEmitter(m_src, i1).emit())); | |||||
| ops.emplace_back(std::get<0>(DimshuffleEmitter(permute).emit())); | |||||
| if (!m_dest.eq_shape(i2)) | |||||
| ops.emplace_back(std::get<0>(ReshapeEmitter(i2, m_dest).emit())); | |||||
| return ops; | |||||
| UnderlyingBuilders builders; | |||||
| if (!m_src.eq_shape(i1)) { | |||||
| builders.make_shape1 = | |||||
| std::move(std::get<0>(MakeShapeEmitter(m_src, i1).emit())); | |||||
| builders.reshape1 = | |||||
| std::move(std::get<0>(ReshapeEmitter(m_src, i1).emit())); | |||||
| } | |||||
| builders.dimshuffle = | |||||
| std::move(std::get<0>(DimshuffleEmitter(permute).emit())); | |||||
| if (!m_dest.eq_shape(i2)) { | |||||
| builders.make_shape2 = | |||||
| std::move(std::get<0>(MakeShapeEmitter(m_src, m_dest).emit())); | |||||
| builders.reshape2 = | |||||
| std::move(std::get<0>(ReshapeEmitter(i2, m_dest).emit())); | |||||
| } | |||||
| return builders; | |||||
| } | } | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -20,8 +20,8 @@ namespace gopt { | |||||
| class Emitter { | class Emitter { | ||||
| public: | public: | ||||
| using Builder = thin_function<VarNode*(VarNode*)>; | |||||
| using Checker = thin_function<bool(VarNode*)>; | |||||
| using Builder = thin_function<VarNode*(const VarNodeArray&)>; | |||||
| using Checker = thin_function<bool(const VarNodeArray&)>; | |||||
| using EmitResult = std::tuple<Builder, Checker>; | using EmitResult = std::tuple<Builder, Checker>; | ||||
| virtual ~Emitter() = default; | virtual ~Emitter() = default; | ||||
| virtual EmitResult emit() const = 0; | virtual EmitResult emit() const = 0; | ||||
| @@ -39,6 +39,14 @@ protected: | |||||
| megdnn::NamedTensorShape m_src, m_dest; | megdnn::NamedTensorShape m_src, m_dest; | ||||
| }; | }; | ||||
| class MakeShapeEmitter final : public Emitter, ModifyShapeMixin { | |||||
| public: | |||||
| MakeShapeEmitter(const megdnn::NamedTensorShape& src, | |||||
| const megdnn::NamedTensorShape& dest) | |||||
| : ModifyShapeMixin(src, dest) {} | |||||
| EmitResult emit() const override; | |||||
| }; | |||||
| class ReshapeEmitter final : public Emitter, ModifyShapeMixin { | class ReshapeEmitter final : public Emitter, ModifyShapeMixin { | ||||
| public: | public: | ||||
| ReshapeEmitter(const megdnn::NamedTensorShape& src, | ReshapeEmitter(const megdnn::NamedTensorShape& src, | ||||
| @@ -64,7 +72,10 @@ public: | |||||
| EmitResult emit() const override; | EmitResult emit() const override; | ||||
| private: | private: | ||||
| SmallVector<Builder> analyze() const; | |||||
| struct UnderlyingBuilders { | |||||
| Builder make_shape1, make_shape2, reshape1, reshape2, dimshuffle; | |||||
| }; | |||||
| UnderlyingBuilders analyze() const; | |||||
| }; | }; | ||||
| } // namespace gopt | } // namespace gopt | ||||
| } // namespace mgb | } // namespace mgb | ||||
| @@ -21,10 +21,10 @@ TEST(TestReformatEmitter, Basic) { | |||||
| constexpr size_t N = 12, C = 64, H = 7, W = 7; | constexpr size_t N = 12, C = 64, H = 7, W = 7; | ||||
| HostTensorGenerator<> gen; | HostTensorGenerator<> gen; | ||||
| using NamedTensorShape = megdnn::NamedTensorShape; | using NamedTensorShape = megdnn::NamedTensorShape; | ||||
| auto dest = NamedTensorShape::make_named_tensor_shape( | |||||
| NamedTensorShape::Format::NCHW4); | |||||
| auto src = NamedTensorShape::make_named_tensor_shape( | auto src = NamedTensorShape::make_named_tensor_shape( | ||||
| NamedTensorShape::Format::NCHW32); | NamedTensorShape::Format::NCHW32); | ||||
| auto dest = NamedTensorShape::make_named_tensor_shape( | |||||
| NamedTensorShape::Format::NCHW4); | |||||
| auto&& tuple = gopt::ReformatEmitter(src, dest).emit(); | auto&& tuple = gopt::ReformatEmitter(src, dest).emit(); | ||||
| auto reformat = std::get<0>(tuple); | auto reformat = std::get<0>(tuple); | ||||
| auto checker = std::get<1>(tuple); | auto checker = std::get<1>(tuple); | ||||
| @@ -53,10 +53,21 @@ TEST(TestReformatEmitter, Basic) { | |||||
| return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); | return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); | ||||
| }; | }; | ||||
| auto x = mkvar("x", {N, C / 32, H, W, 32}); | auto x = mkvar("x", {N, C / 32, H, W, 32}); | ||||
| EXPECT_TRUE(checker(x.node())); | |||||
| EXPECT_TRUE(checker({x.node()})); | |||||
| auto x_ = mkvar("x", {N, H, W, C}); | auto x_ = mkvar("x", {N, H, W, C}); | ||||
| EXPECT_FALSE(checker(x_.node())); | |||||
| auto y1 = SymbolVar(reformat(x.node())); | |||||
| EXPECT_FALSE(checker({x_.node()})); | |||||
| auto y1 = SymbolVar(reformat({x.node()})); | |||||
| size_t nr_shapeof = 0; | |||||
| size_t nr_reshape = 0; | |||||
| cg::DepOprIter{[&nr_shapeof, &nr_reshape](cg::OperatorNodeBase* o) { | |||||
| if (o->same_type<opr::GetVarShape>()) | |||||
| nr_shapeof++; | |||||
| if (o->same_type<opr::Reshape>()) | |||||
| nr_reshape++; | |||||
| }} | |||||
| .add(y1.node()->owner_opr()); | |||||
| ASSERT_EQ(nr_shapeof, 1); | |||||
| ASSERT_EQ(nr_reshape, 2); | |||||
| auto y2 = SymbolVar(nchw32_to_nchw4(x.node())); | auto y2 = SymbolVar(nchw32_to_nchw4(x.node())); | ||||
| HostTensorND t1, t2; | HostTensorND t1, t2; | ||||
| auto func1 = graph->compile({make_callback_copy(y1, t1)}); | auto func1 = graph->compile({make_callback_copy(y1, t1)}); | ||||
| @@ -84,12 +95,116 @@ TEST(TestReformatEmitter, MoreComplicated) { | |||||
| return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); | return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); | ||||
| }; | }; | ||||
| auto x = mkvar("x", {N, C / 64, H, W, 64}); | auto x = mkvar("x", {N, C / 64, H, W, 64}); | ||||
| EXPECT_TRUE(checker(x.node())); | |||||
| EXPECT_TRUE(checker({x.node()})); | |||||
| auto x_ = mkvar("x", {N, H, W, C}); | auto x_ = mkvar("x", {N, H, W, C}); | ||||
| EXPECT_FALSE(checker(x_.node())); | |||||
| auto y = SymbolVar(reformat(x.node())); | |||||
| EXPECT_FALSE(checker({x_.node()})); | |||||
| auto y = SymbolVar(reformat({x.node()})); | |||||
| HostTensorND t; | HostTensorND t; | ||||
| auto func = graph->compile({make_callback_copy(y, t)}); | auto func = graph->compile({make_callback_copy(y, t)}); | ||||
| func->execute(); | func->execute(); | ||||
| } | } | ||||
| TEST(TestReformatEmitter, EliminateRedudantReshape) { | |||||
| constexpr size_t N = 16, C = 64, H = 7, W = 7; | |||||
| HostTensorGenerator<> gen; | |||||
| using NamedTensorShape = megdnn::NamedTensorShape; | |||||
| auto src = NamedTensorShape::make_named_tensor_shape( | |||||
| NamedTensorShape::Format::NCHW); | |||||
| auto dest = NamedTensorShape::make_named_tensor_shape( | |||||
| NamedTensorShape::Format::NHWC); | |||||
| auto&& tuple = gopt::ReformatEmitter(src, dest).emit(); | |||||
| auto reformat = std::get<0>(tuple); | |||||
| auto checker = std::get<1>(tuple); | |||||
| auto graph = ComputingGraph::make(); | |||||
| graph->options().graph_opt_level = 0; | |||||
| auto nchw_to_nhwc = [](VarNode* in) { | |||||
| auto x = SymbolVar(in); | |||||
| auto y = opr::Dimshuffle::make(x, {0, 2, 3, 1}); | |||||
| return y.node(); | |||||
| }; | |||||
| auto mkvar = [&](const char* name, const TensorShape& shp) { | |||||
| return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); | |||||
| }; | |||||
| auto x = mkvar("x", {N, C, H, W}); | |||||
| EXPECT_TRUE(checker({x.node()})); | |||||
| auto y1 = SymbolVar(reformat({x.node()})); | |||||
| size_t nr_reshape = 0; | |||||
| cg::DepOprIter{[&nr_reshape](cg::OperatorNodeBase* o) { | |||||
| if (o->same_type<opr::Reshape>()) | |||||
| nr_reshape++; | |||||
| }} | |||||
| .add(y1.node()->owner_opr()); | |||||
| ASSERT_EQ(nr_reshape, 0); | |||||
| HostTensorND t1, t2; | |||||
| auto func1 = graph->compile({make_callback_copy(y1, t1)}); | |||||
| func1->execute(); | |||||
| auto y2 = SymbolVar(nchw_to_nhwc(x.node())); | |||||
| auto func2 = graph->compile({make_callback_copy(y2, t2)}); | |||||
| func2->execute(); | |||||
| MGB_ASSERT_TENSOR_EQ(t1, t2); | |||||
| } | |||||
| TEST(TestReformatEmitter, Nchw4ToNchw) { | |||||
| constexpr size_t N = 12, C = 64, H = 7, W = 7; | |||||
| HostTensorGenerator<> gen; | |||||
| using NamedTensorShape = megdnn::NamedTensorShape; | |||||
| auto src = NamedTensorShape::make_named_tensor_shape( | |||||
| NamedTensorShape::Format::NCHW4); | |||||
| auto dest = NamedTensorShape::make_named_tensor_shape( | |||||
| NamedTensorShape::Format::NCHW); | |||||
| auto&& tuple = gopt::ReformatEmitter(src, dest).emit(); | |||||
| auto reformat = std::get<0>(tuple); | |||||
| auto checker = std::get<1>(tuple); | |||||
| auto graph = ComputingGraph::make(); | |||||
| graph->options().graph_opt_level = 0; | |||||
| auto nchw4_to_nchw = [](VarNode* in) { | |||||
| auto x = SymbolVar(in); | |||||
| auto xshp = opr::GetVarShape::make(x); | |||||
| auto cv = [&x](int v) { return x.make_scalar(v); }; | |||||
| auto sub = [&xshp, &cv](int idx) { | |||||
| return opr::IndexAt::make(xshp, {{0, cv(idx)}}); | |||||
| }; | |||||
| auto tshp = opr::Concat::make({sub(0), sub(1) * 4, sub(2), sub(3)}, 0); | |||||
| auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3}); | |||||
| auto y1 = opr::Reshape::make(y0, tshp); | |||||
| return y1.node(); | |||||
| }; | |||||
| auto mkvar = [&](const char* name, const TensorShape& shp) { | |||||
| return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); | |||||
| }; | |||||
| auto x = mkvar("x", {N, C / 4, H, W, 4}); | |||||
| EXPECT_TRUE(checker({x.node()})); | |||||
| auto y1 = SymbolVar(reformat({x.node()})); | |||||
| SmallVector<VarNode*> reshapes; | |||||
| VarNode* dimshuffle; | |||||
| cg::DepOprIter{[&dimshuffle, &reshapes](cg::OperatorNodeBase* o) { | |||||
| if (o->same_type<opr::Reshape>()) { | |||||
| reshapes.push_back(o->output(0)); | |||||
| } | |||||
| if (o->same_type<opr::Dimshuffle>()) | |||||
| dimshuffle = o->output(0); | |||||
| }} | |||||
| .add(y1.node()->owner_opr()); | |||||
| ASSERT_EQ(reshapes.size(), 1); | |||||
| { | |||||
| gopt::SubGraph graph({y1}); | |||||
| gopt::UniqReaderCheck check(graph); | |||||
| EXPECT_TRUE(check(reshapes[0])); | |||||
| EXPECT_TRUE(dimshuffle); | |||||
| } | |||||
| auto y2 = SymbolVar(nchw4_to_nchw(x.node())); | |||||
| HostTensorND t1, t2; | |||||
| auto func1 = graph->compile({make_callback_copy(y1, t1)}); | |||||
| func1->execute(); | |||||
| auto func2 = graph->compile({make_callback_copy(y2, t2)}); | |||||
| func2->execute(); | |||||
| MGB_ASSERT_TENSOR_EQ(t1, t2); | |||||
| } | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||