Browse Source

support matmul transdata prefusion feature in AutoPoly Pass

pull/86/head
dylangeng 4 years ago
parent
commit
7da00a1f10
5 changed files with 42 additions and 14 deletions
  1. +29
    -0
      src/poly/dma_inject.cc
  2. +1
    -3
      src/poly/dsa_mgr_strategy.cc
  3. +4
    -4
      src/poly/schedule_pass/reschedule.cc
  4. +1
    -1
      src/poly/schedule_pass/transfer_stmt.cc
  5. +7
    -6
      tests/common/test_run/matmul_run.py

+ 29
- 0
src/poly/dma_inject.cc View File

@@ -1584,6 +1584,8 @@ void PlaceDataCopyBelowImplFakeReads(ScopInfo &scop_info, isl::schedule_node &tr
}
CHECK(node.isa<isl::schedule_node_mark>()) << "must find a mark node." << std::endl;
auto tag = node.as<isl::schedule_node_mark>().get_id().get_name();
// Realize_L1 mark
// id has _local_ key word
if (tag == REALIZE_C1) {
isl::map stmt_extension = read_extension.range().unwrap();
isl::id stmt_tensor_id = cluster_id;
@@ -1612,6 +1614,20 @@ void PlaceDataCopyBelowImplFakeReads(ScopInfo &scop_info, isl::schedule_node &tr
}
}
}
isl::schedule_node FindChildExtension(const isl::schedule_node &node) {
isl::schedule_node res = node;
while (res.has_children()) {
if (res.isa<isl::schedule_node_extension>()) {
res = res.get_child(0);
if (res.isa<isl::schedule_node_sequence>()) {
res = res.get_child(0);
}
return res;
}
res = res.get_child(0);
}
return res;
}

isl::schedule_node PlaceDataCopyBelowImpl(ScopInfo &scop_info, isl::schedule_node tree,
const TensorFootprintCluster &cluster, const isl::map &footprint,
@@ -1666,6 +1682,19 @@ isl::schedule_node PlaceDataCopyBelowImpl(ScopInfo &scop_info, isl::schedule_nod
PlaceDataCopyBelowImplReadWrite(scop_info, tree, cluster, footprint, tensor_id, original_elements, exact_writes,
read_extension, buffered_footprint, cluster_id, extension_map, read_id);

auto fake_copyin = scop_info.analysis_result_.GetFakeCopyin();
bool in_copyin_map = false;
fake_copyin.foreach_map([&in_copyin_map, &tensor_id](const isl::map &fake_map) -> void {
fake_map.foreach_basic_map([&in_copyin_map, &tensor_id](const isl::basic_map &basic_map) -> void {
const isl::map m = basic_map;
if (m.range().tuple_id().name() == tensor_id.name()) {
in_copyin_map = true;
}
});
});
if (in_copyin_map && scop_info.mmu_info_.IsGemm()) {
tree = FindChildExtension(tree);
}
PlaceDataCopyBelowImplFakeReads(scop_info, tree, cluster, read_extension, cluster_id, sch);

return tree;


+ 1
- 3
src/poly/dsa_mgr_strategy.cc View File

@@ -42,9 +42,7 @@ namespace akg {
namespace ir {
namespace poly {

void DsaMgrStrategy::RegisterTilingPasses() {
RegisterPass(std::make_shared<TileOuterBand>(pass_info_, scop_info_));
}
void DsaMgrStrategy::RegisterTilingPasses() { RegisterPass(std::make_shared<TileOuterBand>(pass_info_, scop_info_)); }

void DsaMgrStrategy::RegisterMemPromPasses() { RegisterPass(std::make_shared<MemoryManager>(scop_info_)); }



+ 4
- 4
src/poly/schedule_pass/reschedule.cc View File

@@ -26,7 +26,7 @@ namespace poly {
bool Reschedule::IsL1OrUbMark(const isl::schedule_node &node) {
if (node.isa<isl::schedule_node_mark>()) {
auto tag = node.as<isl::schedule_node_mark>().get_id().get_name();
if (tag == REALIZE_C1 || tag == REALIZE_BUF) return true;
if (tag == REALIZE_C1 || tag == REALIZE_BUF || tag == REALIZE_C1BUFC1) return true;
}
return false;
}
@@ -34,7 +34,7 @@ bool Reschedule::IsL1OrUbMark(const isl::schedule_node &node) {
bool Reschedule::IsL0OrUbL0Mark(const isl::schedule_node &node) {
if (node.isa<isl::schedule_node_mark>()) {
auto tag = node.as<isl::schedule_node_mark>().get_id().get_name();
if (tag == REALIZE_C0 || tag == REALIZE_BUFC0) return true;
if (tag == REALIZE_C0 || tag == REALIZE_BUFC0 || tag == REALIZE_BUFC1) return true;
}
return false;
}
@@ -57,10 +57,10 @@ void Reschedule::CollectTileBandData(const isl::schedule_node &node, struct Tile

if (tile_band_data->mark.isa<isl::schedule_node_mark>()) {
auto marktag = tile_band_data->mark.as<isl::schedule_node_mark>().get_id().get_name();
if (marktag == REALIZE_C0 || marktag == REALIZE_BUFC0) {
if (marktag == REALIZE_C0 || marktag == REALIZE_BUFC0 || marktag == REALIZE_BUFC1) {
tile_band_data->l0_tiled = true;
l0_build_options_.push_back(tile_band_data->ast_build_options);
} else if (marktag == REALIZE_C1 || marktag == REALIZE_BUF) {
} else if (marktag == REALIZE_C1 || marktag == REALIZE_BUF || marktag == REALIZE_C1BUFC1) {
l1_build_options_.push_back(tile_band_data->ast_build_options);
}
tile_band_data->gemm_mark = node.parent().parent();


+ 1
- 1
src/poly/schedule_pass/transfer_stmt.cc View File

@@ -38,9 +38,9 @@ isl::schedule TransferStmt::Run(isl::schedule curr_schedule) {
filter = filter.subtract(pass_info_.transfer_stmt_);
child = isl::manage(isl_schedule_node_filter_set_filter(child.copy(), filter.copy()));
node = child.parent();
return node.get_schedule();
}
}
return node.get_schedule();
}
return curr_schedule;
}


+ 7
- 6
tests/common/test_run/matmul_run.py View File

@@ -118,14 +118,16 @@ def np_matmul(matrix_a, matrix_b, batch_tuple, M, K, N, trans_data=False, trans_
res = out.reshape(out_shape).transpose(trans).copy()
return res

def gen_data(batch_tuple, M, K, N, trans_data=False, trans_weight=False,
dtype="float16", bias_dtype="float16", out_dtype="float16", bias=0, left_format="zZ", right_format="nZ", output_format="zN"):
fractal_a, fractal_b, out, bias_data, _, _ = gen_data_all(batch_tuple, M, K, N, trans_data, trans_weight, dtype, bias_dtype, out_dtype, bias, left_format, right_format, output_format)
return fractal_a, fractal_b, out, bias_data

def genData(batch_tuple, M, K, N, trans_data=False, trans_weight=False,
def gen_data_all(batch_tuple, M, K, N, trans_data=False, trans_weight=False,
dtype="float16", bias_dtype="float16", out_dtype="float16", bias=0, left_format="zZ", right_format="nZ", output_format="zN"):
shape_x, shape_y = get_shapes(batch_tuple, M, K, N, trans_data, trans_weight)
matrix_a = random_gaussian(shape_x, miu=0.1, sigma=0.01).astype(dtype)
matrix_b = random_gaussian(shape_y, miu=0.1, sigma=0.01).astype(dtype)
# matrix_a = np.ones(shape_x, dtype=np.float16)
# matrix_b = np.ones(shape_y, dtype=np.float16)

# this change is for gen data speed
matrix_a_for_np = matrix_a.astype(np.float32)
@@ -187,8 +189,7 @@ def genData(batch_tuple, M, K, N, trans_data=False, trans_weight=False,
trans_x = tuple(range(batch_len)) + (batch_len + 2, batch_len + 0, batch_len + 1, batch_len + 3)
fractal_a = matrix_b.reshape(batch_tuple + shape_y).transpose(trans_y).copy()
fractal_b = matrix_a.reshape(batch_tuple + shape_x).transpose(trans_x).copy()
return fractal_a, fractal_b, out, bias_data

return fractal_a, fractal_b, out, bias_data, matrix_a, matrix_b

def matmul_data(batch_tuple, M, K, N, dtype, bias_dtype, out_dtype, bias, adj_x, adj_y, left_format=None, right_format=None, output_format=None, debug_logging=False):
m_x = ()
@@ -197,7 +198,7 @@ def matmul_data(batch_tuple, M, K, N, dtype, bias_dtype, out_dtype, bias, adj_x,
bias_data = ()
logging.debug("gen data start!")
a = datetime.now()
m_x, m_y, bench_mark, bias_data = genData(batch_tuple, M, K, N, adj_x, adj_y, dtype, bias_dtype, out_dtype, bias, left_format, right_format, output_format)
m_x, m_y, bench_mark, bias_data = gen_data(batch_tuple, M, K, N, adj_x, adj_y, dtype, bias_dtype, out_dtype, bias, left_format, right_format, output_format)
b = datetime.now()
logging.debug((b - a).seconds)
logging.debug("gen data end!")


Loading…
Cancel
Save