From 395aafbbac873dfd0f7e4e4f8766f03a652f5707 Mon Sep 17 00:00:00 2001 From: dabaiji Date: Tue, 13 Apr 2021 09:52:20 +0800 Subject: [PATCH] support offset case in poly --- src/poly/scop_info.h | 1 + src/poly/tiling/tiling_analyzer.cc | 6 ++-- .../tiling/tiling_strategy_manager_gpu.cc | 28 +++++++++++++++++-- 3 files changed, 31 insertions(+), 4 deletions(-) diff --git a/src/poly/scop_info.h b/src/poly/scop_info.h index 9997416..154cfac 100644 --- a/src/poly/scop_info.h +++ b/src/poly/scop_info.h @@ -174,6 +174,7 @@ class UserConfig { void SetAttrs(const Map &attrs) { if (attrs.empty()) return; + mod_schedule_shift_ = target_ == TARGET_CUDA; ParseDynamicShapeAttr(attrs, "dynamic_shape", &dynamic_shape_); ParseIntAttr(attrs, "dynamic_shape_bound", &dynamic_shape_bound_); ParseBoolAttr(attrs, "pragma_tilesize_is_var", &tile_size_is_var_); diff --git a/src/poly/tiling/tiling_analyzer.cc b/src/poly/tiling/tiling_analyzer.cc index ac440b4..47b4371 100644 --- a/src/poly/tiling/tiling_analyzer.cc +++ b/src/poly/tiling/tiling_analyzer.cc @@ -1352,6 +1352,7 @@ void TilingAnalyzer::AddPostTilingConstraints() { if (scop_info_.user_config_.GetTarget() == TARGET_CUDA) { ReduceStrategy reduce_strategy(this); ModStrategy mod_strategy(this); + ModShiftAxisStrategy mod_shift_strategy(this); GemmStrategy gemm_strategy(this); GpuDmaAnalysisStrategy dma_analysis_strategy(this); CustomTilingStrategy custom_strategy(this); @@ -1364,6 +1365,7 @@ void TilingAnalyzer::AddPostTilingConstraints() { } else { actived_strategies.push_back(&reduce_strategy); actived_strategies.push_back(&mod_strategy); + actived_strategies.push_back(&mod_shift_strategy); actived_strategies.push_back(&gemm_strategy); } actived_strategies.push_back(&gpu_strategy); @@ -1411,7 +1413,7 @@ void TilingAnalyzer::AddTilingConstraints() { ReduceStrategy reduce_strategy(this); DmaAlignStrategy dma_align_stratgey(this); - + if (!scop_info_.user_config_.GetIsTuning()) { actived_strategies.push_back(&reduce_strategy); actived_strategies.push_back(&dma_align_stratgey); @@ -1448,7 +1450,7 @@ void TilingAnalyzer::AddTilingConstraints() { bool TilingAnalyzer::Prepare() { logger_ = std::unique_ptr(new (std::nothrow) TileLogger( - scop_info_.AddDumpDir("tiling.log"), !scop_info_.user_config_.GetDumpPolyDir().empty())); + scop_info_.AddDumpDir("tiling.log"), !scop_info_.user_config_.GetDumpPolyDir().empty())); CHECK(logger_) << "memory alloc fail."; // Stage 1: Analyze schedule tree. ScheduleTreeAnalyzer sch_ana(this, this->sch_); diff --git a/src/poly/tiling/tiling_strategy_manager_gpu.cc b/src/poly/tiling/tiling_strategy_manager_gpu.cc index c73d20a..c4c6990 100644 --- a/src/poly/tiling/tiling_strategy_manager_gpu.cc +++ b/src/poly/tiling/tiling_strategy_manager_gpu.cc @@ -29,6 +29,32 @@ void GpuDmaAnalysisStrategy::AddGpuConstraint() { void CastStrategy::AddGpuConstraint() { MarkDataSize(); } +void ModShiftAxisStrategy::AddGpuConstraint() { + auto interested_info = GetInterestedInfo(interested_attr_key); + for (auto it : interested_info) { + TileAxis *axis = it.first; + int64_t const_extent = axis->GetConstExtent(); + if (const_extent == -1) { + continue; + } + for (const auto &attr : it.second) { + axis->forbid_iso = true; + if (!attr.attr_value.empty()) { + auto share_time = static_cast(std::strtol(attr.attr_value.c_str(), nullptr, 10)); + auto whole_extent = const_extent * (share_time + 1); + axis->TileRestrainToSingleValue(CastInt64ToExpr(whole_extent), CACHE1); + axis->thread_constraints.map_min_ = 1; + axis->thread_constraints.map_extent_ = 1; + std::stringstream ss; + ss << "[MODSHIFT] axis " << axis->dim_axis << " tile: " << whole_extent + << ", map :" << axis->thread_constraints.map_extent_; + analyzer_->logger_->AppendLog(GPU_MAPPING, ss); + } + break; + } + } +} + void GemmStrategy::AddGpuConstraint() { if (!analyzer_->scop_info_.user_config_.GetEnableTensorCore()) { return; @@ -1395,8 +1421,6 @@ void DynamicBoundStrategy::AddGpuConstraint() {} void ShiftAxisStrategy::AddGpuConstraint() {} -void ModShiftAxisStrategy::AddGpuConstraint() {} - void ConvStrategy::AddGpuConstraint() {} // end of null constraint