GitOrigin-RevId: 844b7e8d39
tags/v1.6.0-rc1
| @@ -19,6 +19,7 @@ using namespace gopt; | |||
| using Dimension = megdnn::Dimension; | |||
| using NamedTensorShape = megdnn::NamedTensorShape; | |||
| // =================== ModifyShapeMixin ====================*/ | |||
| ModifyShapeMixin::Pattern ModifyShapeMixin::mixin_analyze() const { | |||
| static constexpr uint32_t UNDETERMINED_EXTENT = | |||
| Dimension::UNDETERMINED_EXTENT; | |||
| @@ -50,7 +51,9 @@ ModifyShapeMixin::Pattern ModifyShapeMixin::mixin_analyze() const { | |||
| ModifyShapeMixin::Checker ModifyShapeMixin::mixin_emit_checker( | |||
| const Pattern& pattern) const { | |||
| auto src = m_src; | |||
| auto checker = [src, pattern](VarNode* var) { | |||
| auto checker = [src, pattern](const VarNodeArray& input) { | |||
| mgb_assert(input.size() >= 1); | |||
| const auto& var = input.front(); | |||
| const auto& shp = var->shape(); | |||
| if (shp.ndim != src.ndim) | |||
| return false; | |||
| @@ -73,10 +76,14 @@ ModifyShapeMixin::Checker ModifyShapeMixin::mixin_emit_checker( | |||
| return checker; | |||
| } | |||
| ReshapeEmitter::EmitResult ReshapeEmitter::emit() const { | |||
| // =================== MakeShapeEmitter ====================*/ | |||
| MakeShapeEmitter::EmitResult MakeShapeEmitter::emit() const { | |||
| auto pattern = mixin_analyze(); | |||
| auto builder = [pattern](VarNode* var) { | |||
| auto sym_var = SymbolVar(var); | |||
| auto builder = [pattern](const VarNodeArray& input) { | |||
| mgb_assert(input.size() == 1, | |||
| "number of input of MakeShapeBuilder should be 1(got:%zu)", | |||
| input.size()); | |||
| auto sym_var = SymbolVar(input.front()); | |||
| auto shp = opr::GetVarShape::make(sym_var); | |||
| auto cv = [&sym_var](int c) { return sym_var.make_scalar(c); }; | |||
| auto sub = [&shp, &cv](int ax) { | |||
| @@ -97,31 +104,59 @@ ReshapeEmitter::EmitResult ReshapeEmitter::emit() const { | |||
| } | |||
| } | |||
| auto tshp = opr::Concat::make(axs, 0); | |||
| auto ovar = opr::Reshape::make(sym_var, tshp); | |||
| return tshp.node(); | |||
| }; | |||
| auto checker = mixin_emit_checker(pattern); | |||
| return std::make_tuple(builder, checker); | |||
| } | |||
| // =================== ReshapeEmitter ====================*/ | |||
| ReshapeEmitter::EmitResult ReshapeEmitter::emit() const { | |||
| auto pattern = mixin_analyze(); | |||
| auto builder = [pattern](const VarNodeArray& input) { | |||
| mgb_assert(input.size() == 2, | |||
| "number of input of Reshape should be 2(got:%zu)", | |||
| input.size()); | |||
| auto ovar = opr::Reshape::make(input[0], input[1]); | |||
| return ovar.node(); | |||
| }; | |||
| auto checker = mixin_emit_checker(pattern); | |||
| return std::make_tuple(builder, checker); | |||
| } | |||
| // =================== DimshuffleEmitter ====================*/ | |||
| DimshuffleEmitter::EmitResult DimshuffleEmitter::emit() const { | |||
| auto&& pattern = m_pattern; | |||
| auto builder = [pattern](VarNode* var) { | |||
| auto sym_var = SymbolVar(var); | |||
| auto builder = [pattern](const VarNodeArray& input) { | |||
| mgb_assert(input.size() == 1, | |||
| "number of input of Dimshuffle should be 1(got:%zu)", | |||
| input.size()); | |||
| auto sym_var = SymbolVar(input.front()); | |||
| return opr::Dimshuffle::make(sym_var, pattern).node(); | |||
| }; | |||
| auto checker = [pattern](VarNode* var) { | |||
| return var->shape().ndim == pattern.size(); | |||
| auto checker = [pattern](const VarNodeArray& input) { | |||
| mgb_assert(input.size() == 1, | |||
| "number of input of Dimshuffle should be 1(got:%zu)", | |||
| input.size()); | |||
| return input.front()->shape().ndim == pattern.size(); | |||
| }; | |||
| return std::make_tuple(builder, checker); | |||
| } | |||
| // =================== ReformatEmitter ====================*/ | |||
| ReformatEmitter::EmitResult ReformatEmitter::emit() const { | |||
| auto ops = analyze(); | |||
| auto builder = [ops](VarNode* var) { | |||
| VarNode* ovar = var; | |||
| for (const auto& i : ops) { | |||
| ovar = i(ovar); | |||
| auto builders = analyze(); | |||
| auto builder = [builders](const VarNodeArray& input) { | |||
| VarNode *var, *ovar; | |||
| var = ovar = input.front(); | |||
| if (builders.make_shape1) { | |||
| auto shp1 = builders.make_shape1({var}); | |||
| ovar = builders.reshape1({ovar, shp1}); | |||
| } | |||
| ovar = builders.dimshuffle({ovar}); | |||
| if (builders.make_shape2) { | |||
| auto shp2 = builders.make_shape2({var}); | |||
| ovar = builders.reshape2({ovar, shp2}); | |||
| } | |||
| return ovar; | |||
| }; | |||
| @@ -130,7 +165,7 @@ ReformatEmitter::EmitResult ReformatEmitter::emit() const { | |||
| return std::make_tuple(builder, checker); | |||
| } | |||
| SmallVector<ReformatEmitter::Builder> ReformatEmitter::analyze() const { | |||
| ReformatEmitter::UnderlyingBuilders ReformatEmitter::analyze() const { | |||
| struct Dim { | |||
| Dimension dim; | |||
| int index; | |||
| @@ -196,12 +231,21 @@ SmallVector<ReformatEmitter::Builder> ReformatEmitter::analyze() const { | |||
| i1[i] = src_dims[src_perm[i]].dim; | |||
| i2[i] = src_dims[src_perm[permute[i]]].dim; | |||
| } | |||
| SmallVector<Builder> ops; | |||
| if (!m_src.eq_shape(i1)) | |||
| ops.emplace_back(std::get<0>(ReshapeEmitter(m_src, i1).emit())); | |||
| ops.emplace_back(std::get<0>(DimshuffleEmitter(permute).emit())); | |||
| if (!m_dest.eq_shape(i2)) | |||
| ops.emplace_back(std::get<0>(ReshapeEmitter(i2, m_dest).emit())); | |||
| return ops; | |||
| UnderlyingBuilders builders; | |||
| if (!m_src.eq_shape(i1)) { | |||
| builders.make_shape1 = | |||
| std::move(std::get<0>(MakeShapeEmitter(m_src, i1).emit())); | |||
| builders.reshape1 = | |||
| std::move(std::get<0>(ReshapeEmitter(m_src, i1).emit())); | |||
| } | |||
| builders.dimshuffle = | |||
| std::move(std::get<0>(DimshuffleEmitter(permute).emit())); | |||
| if (!m_dest.eq_shape(i2)) { | |||
| builders.make_shape2 = | |||
| std::move(std::get<0>(MakeShapeEmitter(m_src, m_dest).emit())); | |||
| builders.reshape2 = | |||
| std::move(std::get<0>(ReshapeEmitter(i2, m_dest).emit())); | |||
| } | |||
| return builders; | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -20,8 +20,8 @@ namespace gopt { | |||
| class Emitter { | |||
| public: | |||
| using Builder = thin_function<VarNode*(VarNode*)>; | |||
| using Checker = thin_function<bool(VarNode*)>; | |||
| using Builder = thin_function<VarNode*(const VarNodeArray&)>; | |||
| using Checker = thin_function<bool(const VarNodeArray&)>; | |||
| using EmitResult = std::tuple<Builder, Checker>; | |||
| virtual ~Emitter() = default; | |||
| virtual EmitResult emit() const = 0; | |||
| @@ -39,6 +39,14 @@ protected: | |||
| megdnn::NamedTensorShape m_src, m_dest; | |||
| }; | |||
| class MakeShapeEmitter final : public Emitter, ModifyShapeMixin { | |||
| public: | |||
| MakeShapeEmitter(const megdnn::NamedTensorShape& src, | |||
| const megdnn::NamedTensorShape& dest) | |||
| : ModifyShapeMixin(src, dest) {} | |||
| EmitResult emit() const override; | |||
| }; | |||
| class ReshapeEmitter final : public Emitter, ModifyShapeMixin { | |||
| public: | |||
| ReshapeEmitter(const megdnn::NamedTensorShape& src, | |||
| @@ -64,7 +72,10 @@ public: | |||
| EmitResult emit() const override; | |||
| private: | |||
| SmallVector<Builder> analyze() const; | |||
| struct UnderlyingBuilders { | |||
| Builder make_shape1, make_shape2, reshape1, reshape2, dimshuffle; | |||
| }; | |||
| UnderlyingBuilders analyze() const; | |||
| }; | |||
| } // namespace gopt | |||
| } // namespace mgb | |||
| @@ -21,10 +21,10 @@ TEST(TestReformatEmitter, Basic) { | |||
| constexpr size_t N = 12, C = 64, H = 7, W = 7; | |||
| HostTensorGenerator<> gen; | |||
| using NamedTensorShape = megdnn::NamedTensorShape; | |||
| auto dest = NamedTensorShape::make_named_tensor_shape( | |||
| NamedTensorShape::Format::NCHW4); | |||
| auto src = NamedTensorShape::make_named_tensor_shape( | |||
| NamedTensorShape::Format::NCHW32); | |||
| auto dest = NamedTensorShape::make_named_tensor_shape( | |||
| NamedTensorShape::Format::NCHW4); | |||
| auto&& tuple = gopt::ReformatEmitter(src, dest).emit(); | |||
| auto reformat = std::get<0>(tuple); | |||
| auto checker = std::get<1>(tuple); | |||
| @@ -53,10 +53,21 @@ TEST(TestReformatEmitter, Basic) { | |||
| return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); | |||
| }; | |||
| auto x = mkvar("x", {N, C / 32, H, W, 32}); | |||
| EXPECT_TRUE(checker(x.node())); | |||
| EXPECT_TRUE(checker({x.node()})); | |||
| auto x_ = mkvar("x", {N, H, W, C}); | |||
| EXPECT_FALSE(checker(x_.node())); | |||
| auto y1 = SymbolVar(reformat(x.node())); | |||
| EXPECT_FALSE(checker({x_.node()})); | |||
| auto y1 = SymbolVar(reformat({x.node()})); | |||
| size_t nr_shapeof = 0; | |||
| size_t nr_reshape = 0; | |||
| cg::DepOprIter{[&nr_shapeof, &nr_reshape](cg::OperatorNodeBase* o) { | |||
| if (o->same_type<opr::GetVarShape>()) | |||
| nr_shapeof++; | |||
| if (o->same_type<opr::Reshape>()) | |||
| nr_reshape++; | |||
| }} | |||
| .add(y1.node()->owner_opr()); | |||
| ASSERT_EQ(nr_shapeof, 1); | |||
| ASSERT_EQ(nr_reshape, 2); | |||
| auto y2 = SymbolVar(nchw32_to_nchw4(x.node())); | |||
| HostTensorND t1, t2; | |||
| auto func1 = graph->compile({make_callback_copy(y1, t1)}); | |||
| @@ -84,12 +95,116 @@ TEST(TestReformatEmitter, MoreComplicated) { | |||
| return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); | |||
| }; | |||
| auto x = mkvar("x", {N, C / 64, H, W, 64}); | |||
| EXPECT_TRUE(checker(x.node())); | |||
| EXPECT_TRUE(checker({x.node()})); | |||
| auto x_ = mkvar("x", {N, H, W, C}); | |||
| EXPECT_FALSE(checker(x_.node())); | |||
| auto y = SymbolVar(reformat(x.node())); | |||
| EXPECT_FALSE(checker({x_.node()})); | |||
| auto y = SymbolVar(reformat({x.node()})); | |||
| HostTensorND t; | |||
| auto func = graph->compile({make_callback_copy(y, t)}); | |||
| func->execute(); | |||
| } | |||
| TEST(TestReformatEmitter, EliminateRedudantReshape) { | |||
| constexpr size_t N = 16, C = 64, H = 7, W = 7; | |||
| HostTensorGenerator<> gen; | |||
| using NamedTensorShape = megdnn::NamedTensorShape; | |||
| auto src = NamedTensorShape::make_named_tensor_shape( | |||
| NamedTensorShape::Format::NCHW); | |||
| auto dest = NamedTensorShape::make_named_tensor_shape( | |||
| NamedTensorShape::Format::NHWC); | |||
| auto&& tuple = gopt::ReformatEmitter(src, dest).emit(); | |||
| auto reformat = std::get<0>(tuple); | |||
| auto checker = std::get<1>(tuple); | |||
| auto graph = ComputingGraph::make(); | |||
| graph->options().graph_opt_level = 0; | |||
| auto nchw_to_nhwc = [](VarNode* in) { | |||
| auto x = SymbolVar(in); | |||
| auto y = opr::Dimshuffle::make(x, {0, 2, 3, 1}); | |||
| return y.node(); | |||
| }; | |||
| auto mkvar = [&](const char* name, const TensorShape& shp) { | |||
| return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); | |||
| }; | |||
| auto x = mkvar("x", {N, C, H, W}); | |||
| EXPECT_TRUE(checker({x.node()})); | |||
| auto y1 = SymbolVar(reformat({x.node()})); | |||
| size_t nr_reshape = 0; | |||
| cg::DepOprIter{[&nr_reshape](cg::OperatorNodeBase* o) { | |||
| if (o->same_type<opr::Reshape>()) | |||
| nr_reshape++; | |||
| }} | |||
| .add(y1.node()->owner_opr()); | |||
| ASSERT_EQ(nr_reshape, 0); | |||
| HostTensorND t1, t2; | |||
| auto func1 = graph->compile({make_callback_copy(y1, t1)}); | |||
| func1->execute(); | |||
| auto y2 = SymbolVar(nchw_to_nhwc(x.node())); | |||
| auto func2 = graph->compile({make_callback_copy(y2, t2)}); | |||
| func2->execute(); | |||
| MGB_ASSERT_TENSOR_EQ(t1, t2); | |||
| } | |||
| TEST(TestReformatEmitter, Nchw4ToNchw) { | |||
| constexpr size_t N = 12, C = 64, H = 7, W = 7; | |||
| HostTensorGenerator<> gen; | |||
| using NamedTensorShape = megdnn::NamedTensorShape; | |||
| auto src = NamedTensorShape::make_named_tensor_shape( | |||
| NamedTensorShape::Format::NCHW4); | |||
| auto dest = NamedTensorShape::make_named_tensor_shape( | |||
| NamedTensorShape::Format::NCHW); | |||
| auto&& tuple = gopt::ReformatEmitter(src, dest).emit(); | |||
| auto reformat = std::get<0>(tuple); | |||
| auto checker = std::get<1>(tuple); | |||
| auto graph = ComputingGraph::make(); | |||
| graph->options().graph_opt_level = 0; | |||
| auto nchw4_to_nchw = [](VarNode* in) { | |||
| auto x = SymbolVar(in); | |||
| auto xshp = opr::GetVarShape::make(x); | |||
| auto cv = [&x](int v) { return x.make_scalar(v); }; | |||
| auto sub = [&xshp, &cv](int idx) { | |||
| return opr::IndexAt::make(xshp, {{0, cv(idx)}}); | |||
| }; | |||
| auto tshp = opr::Concat::make({sub(0), sub(1) * 4, sub(2), sub(3)}, 0); | |||
| auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3}); | |||
| auto y1 = opr::Reshape::make(y0, tshp); | |||
| return y1.node(); | |||
| }; | |||
| auto mkvar = [&](const char* name, const TensorShape& shp) { | |||
| return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); | |||
| }; | |||
| auto x = mkvar("x", {N, C / 4, H, W, 4}); | |||
| EXPECT_TRUE(checker({x.node()})); | |||
| auto y1 = SymbolVar(reformat({x.node()})); | |||
| SmallVector<VarNode*> reshapes; | |||
| VarNode* dimshuffle; | |||
| cg::DepOprIter{[&dimshuffle, &reshapes](cg::OperatorNodeBase* o) { | |||
| if (o->same_type<opr::Reshape>()) { | |||
| reshapes.push_back(o->output(0)); | |||
| } | |||
| if (o->same_type<opr::Dimshuffle>()) | |||
| dimshuffle = o->output(0); | |||
| }} | |||
| .add(y1.node()->owner_opr()); | |||
| ASSERT_EQ(reshapes.size(), 1); | |||
| { | |||
| gopt::SubGraph graph({y1}); | |||
| gopt::UniqReaderCheck check(graph); | |||
| EXPECT_TRUE(check(reshapes[0])); | |||
| EXPECT_TRUE(dimshuffle); | |||
| } | |||
| auto y2 = SymbolVar(nchw4_to_nchw(x.node())); | |||
| HostTensorND t1, t2; | |||
| auto func1 = graph->compile({make_callback_copy(y1, t1)}); | |||
| func1->execute(); | |||
| auto func2 = graph->compile({make_callback_copy(y2, t2)}); | |||
| func2->execute(); | |||
| MGB_ASSERT_TENSOR_EQ(t1, t2); | |||
| } | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||