From 7da00a1f103aa25197b3ded4ee0502c4f6f2275f Mon Sep 17 00:00:00 2001 From: dylangeng Date: Thu, 6 May 2021 10:16:57 +0800 Subject: [PATCH] support matmul transdata prefusion feature in AutoPoly Pass --- src/poly/dma_inject.cc | 29 +++++++++++++++++++++++++ src/poly/dsa_mgr_strategy.cc | 4 +--- src/poly/schedule_pass/reschedule.cc | 8 +++---- src/poly/schedule_pass/transfer_stmt.cc | 2 +- tests/common/test_run/matmul_run.py | 13 ++++++----- 5 files changed, 42 insertions(+), 14 deletions(-) diff --git a/src/poly/dma_inject.cc b/src/poly/dma_inject.cc index 544fdf9..aa332ce 100644 --- a/src/poly/dma_inject.cc +++ b/src/poly/dma_inject.cc @@ -1584,6 +1584,8 @@ void PlaceDataCopyBelowImplFakeReads(ScopInfo &scop_info, isl::schedule_node &tr } CHECK(node.isa()) << "must find a mark node." << std::endl; auto tag = node.as().get_id().get_name(); + // Realize_L1 mark + // id has _local_ key word if (tag == REALIZE_C1) { isl::map stmt_extension = read_extension.range().unwrap(); isl::id stmt_tensor_id = cluster_id; @@ -1612,6 +1614,20 @@ void PlaceDataCopyBelowImplFakeReads(ScopInfo &scop_info, isl::schedule_node &tr } } } +isl::schedule_node FindChildExtension(const isl::schedule_node &node) { + isl::schedule_node res = node; + while (res.has_children()) { + if (res.isa()) { + res = res.get_child(0); + if (res.isa()) { + res = res.get_child(0); + } + return res; + } + res = res.get_child(0); + } + return res; +} isl::schedule_node PlaceDataCopyBelowImpl(ScopInfo &scop_info, isl::schedule_node tree, const TensorFootprintCluster &cluster, const isl::map &footprint, @@ -1666,6 +1682,19 @@ isl::schedule_node PlaceDataCopyBelowImpl(ScopInfo &scop_info, isl::schedule_nod PlaceDataCopyBelowImplReadWrite(scop_info, tree, cluster, footprint, tensor_id, original_elements, exact_writes, read_extension, buffered_footprint, cluster_id, extension_map, read_id); + auto fake_copyin = scop_info.analysis_result_.GetFakeCopyin(); + bool in_copyin_map = false; + fake_copyin.foreach_map([&in_copyin_map, &tensor_id](const isl::map &fake_map) -> void { + fake_map.foreach_basic_map([&in_copyin_map, &tensor_id](const isl::basic_map &basic_map) -> void { + const isl::map m = basic_map; + if (m.range().tuple_id().name() == tensor_id.name()) { + in_copyin_map = true; + } + }); + }); + if (in_copyin_map && scop_info.mmu_info_.IsGemm()) { + tree = FindChildExtension(tree); + } PlaceDataCopyBelowImplFakeReads(scop_info, tree, cluster, read_extension, cluster_id, sch); return tree; diff --git a/src/poly/dsa_mgr_strategy.cc b/src/poly/dsa_mgr_strategy.cc index b43bdde..df48698 100644 --- a/src/poly/dsa_mgr_strategy.cc +++ b/src/poly/dsa_mgr_strategy.cc @@ -42,9 +42,7 @@ namespace akg { namespace ir { namespace poly { -void DsaMgrStrategy::RegisterTilingPasses() { - RegisterPass(std::make_shared(pass_info_, scop_info_)); -} +void DsaMgrStrategy::RegisterTilingPasses() { RegisterPass(std::make_shared(pass_info_, scop_info_)); } void DsaMgrStrategy::RegisterMemPromPasses() { RegisterPass(std::make_shared(scop_info_)); } diff --git a/src/poly/schedule_pass/reschedule.cc b/src/poly/schedule_pass/reschedule.cc index 5b1560f..a0b6b1e 100644 --- a/src/poly/schedule_pass/reschedule.cc +++ b/src/poly/schedule_pass/reschedule.cc @@ -26,7 +26,7 @@ namespace poly { bool Reschedule::IsL1OrUbMark(const isl::schedule_node &node) { if (node.isa()) { auto tag = node.as().get_id().get_name(); - if (tag == REALIZE_C1 || tag == REALIZE_BUF) return true; + if (tag == REALIZE_C1 || tag == REALIZE_BUF || tag == REALIZE_C1BUFC1) return true; } return false; } @@ -34,7 +34,7 @@ bool Reschedule::IsL1OrUbMark(const isl::schedule_node &node) { bool Reschedule::IsL0OrUbL0Mark(const isl::schedule_node &node) { if (node.isa()) { auto tag = node.as().get_id().get_name(); - if (tag == REALIZE_C0 || tag == REALIZE_BUFC0) return true; + if (tag == REALIZE_C0 || tag == REALIZE_BUFC0 || tag == REALIZE_BUFC1) return true; } return false; } @@ -57,10 +57,10 @@ void Reschedule::CollectTileBandData(const isl::schedule_node &node, struct Tile if (tile_band_data->mark.isa()) { auto marktag = tile_band_data->mark.as().get_id().get_name(); - if (marktag == REALIZE_C0 || marktag == REALIZE_BUFC0) { + if (marktag == REALIZE_C0 || marktag == REALIZE_BUFC0 || marktag == REALIZE_BUFC1) { tile_band_data->l0_tiled = true; l0_build_options_.push_back(tile_band_data->ast_build_options); - } else if (marktag == REALIZE_C1 || marktag == REALIZE_BUF) { + } else if (marktag == REALIZE_C1 || marktag == REALIZE_BUF || marktag == REALIZE_C1BUFC1) { l1_build_options_.push_back(tile_band_data->ast_build_options); } tile_band_data->gemm_mark = node.parent().parent(); diff --git a/src/poly/schedule_pass/transfer_stmt.cc b/src/poly/schedule_pass/transfer_stmt.cc index 4976226..8ab059b 100644 --- a/src/poly/schedule_pass/transfer_stmt.cc +++ b/src/poly/schedule_pass/transfer_stmt.cc @@ -38,9 +38,9 @@ isl::schedule TransferStmt::Run(isl::schedule curr_schedule) { filter = filter.subtract(pass_info_.transfer_stmt_); child = isl::manage(isl_schedule_node_filter_set_filter(child.copy(), filter.copy())); node = child.parent(); - return node.get_schedule(); } } + return node.get_schedule(); } return curr_schedule; } diff --git a/tests/common/test_run/matmul_run.py b/tests/common/test_run/matmul_run.py index dd4ee5a..8f1c59d 100644 --- a/tests/common/test_run/matmul_run.py +++ b/tests/common/test_run/matmul_run.py @@ -118,14 +118,16 @@ def np_matmul(matrix_a, matrix_b, batch_tuple, M, K, N, trans_data=False, trans_ res = out.reshape(out_shape).transpose(trans).copy() return res +def gen_data(batch_tuple, M, K, N, trans_data=False, trans_weight=False, + dtype="float16", bias_dtype="float16", out_dtype="float16", bias=0, left_format="zZ", right_format="nZ", output_format="zN"): + fractal_a, fractal_b, out, bias_data, _, _ = gen_data_all(batch_tuple, M, K, N, trans_data, trans_weight, dtype, bias_dtype, out_dtype, bias, left_format, right_format, output_format) + return fractal_a, fractal_b, out, bias_data -def genData(batch_tuple, M, K, N, trans_data=False, trans_weight=False, +def gen_data_all(batch_tuple, M, K, N, trans_data=False, trans_weight=False, dtype="float16", bias_dtype="float16", out_dtype="float16", bias=0, left_format="zZ", right_format="nZ", output_format="zN"): shape_x, shape_y = get_shapes(batch_tuple, M, K, N, trans_data, trans_weight) matrix_a = random_gaussian(shape_x, miu=0.1, sigma=0.01).astype(dtype) matrix_b = random_gaussian(shape_y, miu=0.1, sigma=0.01).astype(dtype) - # matrix_a = np.ones(shape_x, dtype=np.float16) - # matrix_b = np.ones(shape_y, dtype=np.float16) # this change is for gen data speed matrix_a_for_np = matrix_a.astype(np.float32) @@ -187,8 +189,7 @@ def genData(batch_tuple, M, K, N, trans_data=False, trans_weight=False, trans_x = tuple(range(batch_len)) + (batch_len + 2, batch_len + 0, batch_len + 1, batch_len + 3) fractal_a = matrix_b.reshape(batch_tuple + shape_y).transpose(trans_y).copy() fractal_b = matrix_a.reshape(batch_tuple + shape_x).transpose(trans_x).copy() - return fractal_a, fractal_b, out, bias_data - + return fractal_a, fractal_b, out, bias_data, matrix_a, matrix_b def matmul_data(batch_tuple, M, K, N, dtype, bias_dtype, out_dtype, bias, adj_x, adj_y, left_format=None, right_format=None, output_format=None, debug_logging=False): m_x = () @@ -197,7 +198,7 @@ def matmul_data(batch_tuple, M, K, N, dtype, bias_dtype, out_dtype, bias, adj_x, bias_data = () logging.debug("gen data start!") a = datetime.now() - m_x, m_y, bench_mark, bias_data = genData(batch_tuple, M, K, N, adj_x, adj_y, dtype, bias_dtype, out_dtype, bias, left_format, right_format, output_format) + m_x, m_y, bench_mark, bias_data = gen_data(batch_tuple, M, K, N, adj_x, adj_y, dtype, bias_dtype, out_dtype, bias, left_format, right_format, output_format) b = datetime.now() logging.debug((b - a).seconds) logging.debug("gen data end!")