Browse Source

!45 MindTrick: Constraints Injection for AKG Polyhedral Scheduler

From: @harenome
Reviewed-by: @dylangeng
Signed-off-by:
pull/45/MERGE
mindspore-ci-bot Gitee 5 years ago
parent
commit
f63dbff61b
64 changed files with 7478 additions and 77 deletions
  1. +1
    -1
      cmake/external_libs/isl.cmake
  2. +2
    -0
      src/api/api_pass.cc
  3. +7
    -0
      src/codegen/build_module.cc
  4. +1
    -0
      src/codegen/util.h
  5. +2
    -0
      src/include/ir_pass.h
  6. +730
    -0
      src/pass/swizzle_gpu.cc
  7. +1
    -0
      src/poly/dsa_mgr_strategy.cc
  8. +6
    -1
      src/poly/dump_log.cc
  9. +1
    -0
      src/poly/gpu_mgr_strategy.cc
  10. +891
    -0
      src/poly/isl_influence.cc
  11. +110
    -0
      src/poly/isl_influence.h
  12. +1019
    -0
      src/poly/isl_util.cc
  13. +184
    -0
      src/poly/isl_util.h
  14. +90
    -0
      src/poly/log_util.cc
  15. +135
    -0
      src/poly/log_util.h
  16. +6
    -0
      src/poly/pass_mgr_strategy.h
  17. +4
    -0
      src/poly/schedule_pass.h
  18. +454
    -0
      src/poly/schedule_pass/constrain_schedule.cc
  19. +115
    -0
      src/poly/schedule_pass/constrain_schedule.h
  20. +237
    -0
      src/poly/schedule_pass/constrain_schedule_checking.cc
  21. +1328
    -0
      src/poly/schedule_pass/scheduling_mind_trick.cc
  22. +303
    -0
      src/poly/schedule_pass/scheduling_mind_trick.h
  23. +342
    -0
      src/poly/schedule_pass/scheduling_mind_trick_matching.cc
  24. +6
    -0
      src/poly/schedule_pass_gpu/mapping_outer_band.cc
  25. +53
    -5
      src/poly/schedule_pass_gpu/shared_memory_manager.cc
  26. +5
    -1
      src/poly/schedule_pass_gpu/shared_memory_manager.h
  27. +14
    -0
      src/poly/schedule_pass_mgr.cc
  28. +27
    -0
      src/poly/scop_info.h
  29. +3
    -0
      tests/operators/gpu/.gitignore
  30. +58
    -52
      tests/operators/gpu/test_all.py
  31. +6
    -3
      tests/operators/gpu/test_fused_bn_update.py
  32. +6
    -3
      tests/operators/gpu/test_fused_l2loss_grad.py
  33. +5
    -2
      tests/operators/gpu/test_fused_relu_grad.py
  34. +14
    -0
      tests/st/composite/test_composite_json.py
  35. +1
    -0
      tests/st/ops/gpu/mind-trick_cases/operators/Fused_AddN_fusion_9584919353229493170.info
  36. +1
    -0
      tests/st/ops/gpu/mind-trick_cases/operators/Fused_Cast_BiasAdd_Gelu_fusion_7719078727474100806.info
  37. +1
    -0
      tests/st/ops/gpu/mind-trick_cases/operators/Fused_Cast_BiasAdd_GkDropout_tuple_getitem_TensorAdd_fusion_13282325956852925231.info
  38. +1
    -0
      tests/st/ops/gpu/mind-trick_cases/operators/Fused_Cast_RealDiv_Reshape_FusedAdamWeightDecay_fusion_1039082044534023692.info
  39. +1
    -0
      tests/st/ops/gpu/mind-trick_cases/operators/Fused_Cast_RealDiv_Reshape_FusedAdamWeightDecay_fusion_1545859458890067484.info
  40. +1
    -0
      tests/st/ops/gpu/mind-trick_cases/operators/Fused_Cast_RealDiv_Reshape_FusedAdamWeightDecay_fusion_1976850843332086880.info
  41. +1
    -0
      tests/st/ops/gpu/mind-trick_cases/operators/Fused_GkDropout_2353362030752466006.info
  42. +1
    -0
      tests/st/ops/gpu/mind-trick_cases/operators/Fused_Sub_Mul_Mul_split_9258187064108063363.info
  43. +1
    -0
      tests/st/ops/gpu/mind-trick_cases/operators/Fused_Transpose_split_18185609042134105765.info
  44. +20
    -0
      tests/st/ops/gpu/mind-trick_cases/tricks/Fused_AddN_fusion_9584919353229493170.json
  45. +22
    -0
      tests/st/ops/gpu/mind-trick_cases/tricks/Fused_Cast_BiasAdd_Gelu_fusion_7719078727474100806.json
  46. +22
    -0
      tests/st/ops/gpu/mind-trick_cases/tricks/Fused_Cast_BiasAdd_GkDropout_tuple_getitem_TensorAdd_fusion_13282325956852925231.json
  47. +57
    -0
      tests/st/ops/gpu/mind-trick_cases/tricks/Fused_Cast_RealDiv_Reshape_FusedAdamWeightDecay_fusion_1039082044534023692.json
  48. +57
    -0
      tests/st/ops/gpu/mind-trick_cases/tricks/Fused_Cast_RealDiv_Reshape_FusedAdamWeightDecay_fusion_1545859458890067484.json
  49. +57
    -0
      tests/st/ops/gpu/mind-trick_cases/tricks/Fused_Cast_RealDiv_Reshape_FusedAdamWeightDecay_fusion_1976850843332086880.json
  50. +19
    -0
      tests/st/ops/gpu/mind-trick_cases/tricks/Fused_GkDropout_2353362030752466006.json
  51. +37
    -0
      tests/st/ops/gpu/mind-trick_cases/tricks/Fused_Transpose_split_18185609042134105765.json
  52. +7
    -0
      tests/st/ops/gpu/mind-trick_cases/tricks/fused_bn_update.json
  53. +34
    -0
      tests/st/ops/gpu/mind-trick_cases/tricks/fused_relu_grad.json
  54. +9
    -0
      tests/st/ops/gpu/test_all.py
  55. +5
    -2
      tests/st/ops/gpu/test_fused_bn_update.py
  56. +5
    -2
      tests/st/ops/gpu/test_fused_relu_grad.py
  57. +159
    -0
      tests/st/ops/gpu/test_ms_mindtricks.py
  58. +2
    -2
      tests/test_env.sh
  59. +5
    -1
      third_party/incubator-tvm/include/tvm/ir.h
  60. +15
    -0
      third_party/incubator-tvm/python/tvm/contrib/nvcc.py
  61. +170
    -2
      third_party/incubator-tvm/src/codegen/codegen_cuda.cc
  62. +25
    -0
      third_party/incubator-tvm/src/codegen/codegen_cuda.h
  63. +3
    -0
      third_party/incubator-tvm/src/lang/ir.cc
  64. +573
    -0
      third_party/patch/isl/isl-influence.patch

+ 1
- 1
cmake/external_libs/isl.cmake View File

@@ -17,7 +17,7 @@ akg_add_pkg(isl
URL ${ISL_URL}
MD5 ${ISL_MD5}
CUSTOM_CMAKE ${AKG_SOURCE_DIR}/third_party/isl_wrap
PATCHES ${AKG_SOURCE_DIR}/third_party/patch/isl/isl.patch
PATCHES ${AKG_SOURCE_DIR}/third_party/patch/isl/isl.patch ${AKG_SOURCE_DIR}/third_party/patch/isl/isl-influence.patch
CMAKE_OPTION " ")
include_directories("${AKG_SOURCE_DIR}/third_party/isl_wrap/include")
include_directories("${isl_INC}/include")


+ 2
- 0
src/api/api_pass.cc View File

@@ -85,5 +85,7 @@ REGISTER_PASS(SubstituteDivVar);
REGISTER_PASS(UnrollNonConstantExtent)
REGISTER_PASS(ValueNumbering);
REGISTER_PASS(TensorAccessRewrite);
REGISTER_PASS(SwizzleGPU);

} // namespace ir
} // namespace akg

+ 7
- 0
src/codegen/build_module.cc View File

@@ -564,6 +564,13 @@ NodeRef LowerStmt(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeR
stmt = NEXT_PASS(InjectDoubleBuffer, stmt, config->double_buffer_split_loop,
global_attrs.GetBoolAttr(kEnableTransferBuffer, false));
stmt = NEXT_PASS(StorageRewrite, stmt);

if (target_platform->device_type == kDLGPU && polyhedral) {
if (global_attrs.GetBoolAttr(kEnableSwizzleGPU, true)) {
stmt = NEXT_PASS(SwizzleGPU, stmt, global_attrs);
}
}

stmt = NEXT_PASS(UnrollLoop, stmt, config->auto_unroll_max_step, config->auto_unroll_max_depth,
config->auto_unroll_max_extent, config->unroll_explicit);



+ 1
- 0
src/codegen/util.h View File

@@ -102,6 +102,7 @@ constexpr auto kAllocBits = "alloc_bits";
constexpr auto kEnablePolySch = "enable_poly_sch";
constexpr auto kEnableFuseAxis = "enable_fuse_axis";
constexpr auto kEnableAtomicAdd = "enable_atomic_add";
constexpr auto kEnableSwizzleGPU = "enable_swizzle_gpu";

static std::unordered_map<std::string, int> help_tiling_level = {
{"None", 0},


+ 2
- 0
src/include/ir_pass.h View File

@@ -206,6 +206,8 @@ Stmt ValueNumbering(Stmt stmt);

Stmt TensorAccessRewrite(const Stmt stmt);

Stmt SwizzleGPU(const Stmt &stmt, const Map<std::string, NodeRef> &attrs);

} // namespace ir
} // namespace akg



+ 730
- 0
src/pass/swizzle_gpu.cc View File

@@ -0,0 +1,730 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <utility>
#include <vector>
#include <map>
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <tvm.h>
#include <string>

#include "pass/utils.h"
#include "src/common/util.h"

namespace akg {
namespace ir {

class SwizzleFinder : public IRVisitor {
public:
SwizzleFinder() = default;

~SwizzleFinder() override = default;

void Visit_(const AttrStmt *op) final {
if (op->attr_key == air::ir::attr::thread_extent) {
if (auto value = op->value.as<IntImm>()) {
std::string name = op->node.as<IterVarNode>()->var->name_hint;
LOG(DEBUG) << "Thread extent (" << name << ") : " << value->value;
thread_extent[name] = value->value;
}
IRVisitor::Visit_(op);
} else if (op->attr_key == air::ir::attr::realize_scope) {
LOG(WARNING) << "Realize storage scope not implemented in swizzle pass (may not work as expected) : "
<< op->value.as<StringImm>()->value;
Visit(op->body);
} else {
IRVisitor::Visit_(op);
}
}

void Visit_(const For *op) final {
swizzle_length = GetExtent(op);
if (swizzle_length == 2 || swizzle_length == 4) {
LOG(INFO) << "Swizzle for loop ? var : " << op->loop_var->name_hint;

loop_var = op->loop_var;
swizzlable = true;
swizzle_candidate = true;
loop_loads = {};
loop_stores = {};
Visit(op->body);
if (!swizzle_candidate) {
// loop contains another loop
LOG(DEBUG) << op->loop_var->name_hint << " not swizzlable : loop contains another loop";
swizzlable = false;
return;
}
swizzle_candidate = false;
// get all load and store from this loop and test if their range is 2, 4 or 0 (const)
for (auto l : loop_loads) {
int ext = load_indexes[l].second - load_indexes[l].first;
if (ext != (int)swizzle_length && ext != 0) {
swizzlable = false;
LOG(DEBUG) << op->loop_var->name_hint
<< " not swizzlable : range of load is not 0, 2 or 4 : " << ExprToString(load_indexes[l].first);
break;
} else if (ext == 0) {
// do not swizzle variables that are constant inside loop
if (swizzle_temp_vars.count(l->buffer_var->name_hint) > 0 &&
set_temp_vars.count(l->buffer_var->name_hint) == 0) {
swizzle_temp_vars.erase(l->buffer_var->name_hint);
temp_vars.insert(l->buffer_var->name_hint);
}
}
}
for (auto s : loop_stores) {
int ext = store_indexes[s].second - store_indexes[s].first;
if (ext != (int)swizzle_length && ext != 0) {
swizzlable = false;
LOG(DEBUG) << op->loop_var->name_hint
<< " not swizzlable : range of store is not 0, 2 or 4 : " << ExprToString(store_indexes[s].first);
break;
} else if (ext == 0) {
// cannot swizzle if variable is an argument with range 0 (disable loop swizzle for safety)
if (swizzle_temp_vars.count(s->buffer_var->name_hint) == 0) {
swizzlable = false;
LOG(DEBUG) << op->loop_var->name_hint << " not swizzlable : range of store is 0 on variable : "
<< ExprToString(store_indexes[s].first);
break;
}

// check if store value contains var
// Warning : this check does not verify loop variables that are not in a load
if (!store_loads[s].empty() && swizzle_stores.count(s) == 0) {
// Variable value changes at each loop iteration
swizzle_stores.insert(s);
} else {
// remove from swizzle_temp_vars
swizzle_temp_vars.erase(s->buffer_var->name_hint);
temp_vars.insert(s->buffer_var->name_hint);
}
}
}
if (swizzlable) {
swizzle_loops.push_back(op);
}

} else {
Visit(op->body);
}
}

// check potential temp variable to convert, making sure it only contains the loop var
template <typename T>
void checkSwizzleVar(const T *op, air::DataType t) {
if (swizzle_temp_vars.count(op->buffer_var->name_hint) > 0) {
auto *v = op->index.template as<Variable>();
auto *i = op->index.template as<IntImm>();
if (((v == nullptr || v->name_hint != loop_var->name_hint) && (i == nullptr || i->value != 0)) ||
!((t.is_float() || t.is_int()) && t.bits() <= 32)) {
// this temp var does not look like a swizzle var, treat it as any regular temp var
LOG(DEBUG) << "Irregular potential swizzle var : " << op->buffer_var->name_hint;
temp_vars.insert(op->buffer_var->name_hint);
swizzle_temp_vars.erase(op->buffer_var->name_hint);
} else {
if (i && i->value == 0) {
if (var_size.find(op->buffer_var->name_hint) == var_size.end() ||
var_size[op->buffer_var->name_hint] < (int)swizzle_length) {
var_size[op->buffer_var->name_hint] = swizzle_length;
}
}
}
}
}

void Visit_(const Load *op) final {
if (swizzle_candidate && swizzlable) {
LOG(DEBUG) << "Load : " << op->buffer_var->name_hint << " index : " << ExprToString(op->index);
loop_loads.push_back(op);
compute_var = true;
current_min = 1;
current_max = 1;

checkSwizzleVar(op, op->type);

IRVisitor::Visit(op->index);
// add this load to array
compute_var = false;
if (current_min > current_max) std::swap(current_min, current_max);
load_indexes.insert(std::make_pair(op, std::make_pair(current_min, current_max)));
if (current_store != nullptr &&
((contains_iterator && (op->type.is_float() || op->type.is_int()) && op->type.bits() <= 32) ||
swizzle_temp_vars.count(op->buffer_var->name_hint) > 0)) {
LOG(DEBUG) << "Load contains iterator or swizzle temp var";
store_loads[current_store].insert(op);
}
contains_iterator = false;
LOG(DEBUG) << "End Load : " << op->buffer_var->name_hint << " range estimation : " << current_max - current_min;
IRVisitor::Visit(op->predicate);
}
IRVisitor::Visit_(op);
}

// visit each operator for each store and load
// compute value for each of its elements (a, b or array)
// add map<Store, pair> to evaluate extent inside each load/store
void Visit_(const Store *op) final {
if (swizzle_candidate && swizzlable) {
LOG(DEBUG) << "Store : " << op->buffer_var->name_hint << " index : " << ExprToString(op->index)
<< " value : " << ExprToString(op->value);
loop_stores.push_back(op);
compute_var = true;
contains_iterator = false;
current_store = op;
store_loads[op] = {};
current_min = 1;
current_max = 1;
checkSwizzleVar(op, op->value.type());
if (swizzle_temp_vars.count(op->buffer_var->name_hint) > 0) {
set_temp_vars.insert(op->buffer_var->name_hint);
}

IRVisitor::Visit(op->index);
// check for single Load (allows inplace swizzle)
const Load *ld = op->value.as<Load>();
if (ld && ld->type == op->value.type() && (ld->type.is_float() || ld->type.is_int()) && ld->type.bits() <= 32) {
LOG(DEBUG) << "Single store Value : " << op->value << std::endl
<< "ld : " << ld->buffer_var << "[" << ld->index << "]";
single_stores.insert(current_store);
}
// add this store to array
compute_var = false;
if (current_min > current_max) std::swap(current_min, current_max);
store_indexes.insert(std::make_pair(op, std::make_pair(current_min, current_max)));
if (contains_iterator && (op->value.type().is_float() || op->value.type().is_int()) &&
op->value.type().bits() <= 32) {
swizzle_stores.insert(op);
}
contains_iterator = false;

LOG(DEBUG) << "End Store index : " << op->buffer_var->name_hint
<< " range estimation : " << current_max - current_min;
}
IRVisitor::Visit_(op);
current_store = nullptr;
LOG(DEBUG) << "End Store " << op->buffer_var->name_hint;
}

void Visit_(const Allocate *op) final {
int size = op->constant_allocation_size();
LOG(DEBUG) << "Allocate : " << op->buffer_var->name_hint << " size " << size;
if (size == 1 || size == 2 || size == 4) {
swizzle_temp_vars.insert(op->buffer_var->name_hint);
var_size[op->buffer_var->name_hint] = size;
} else {
temp_vars.insert(op->buffer_var->name_hint);
}
IRVisitor::Visit_(op);
}

void Visit_(const FloatImm *op) final {
if (compute_var) {
current_min = (int)op->value;
current_max = current_min;
}
IRVisitor::Visit_(op);
}

void Visit_(const IntImm *op) final {
if (compute_var) {
current_min = op->value;
current_max = current_min;
}
IRVisitor::Visit_(op);
}

void Visit_(const Variable *op) final {
if (swizzle_candidate && swizzlable && compute_var) {
if (op->name_hint == loop_var->name_hint) {
contains_iterator = true;
current_min = 0;
current_max = (int)swizzle_length;
} else {
auto it = thread_extent.find(op->name_hint);
if (it != thread_extent.end()) {
current_min = it->second;
current_max = current_min;
} else {
// unknown variable value, consider it constant
current_min = 1;
current_max = 1;
}
}
}
IRVisitor::Visit_(op);
}

void Visit_(const Mul *op) final {
if (swizzle_candidate && swizzlable) {
if (compute_var) {
Visit(op->a);
int tmp_min = current_min, tmp_max = current_max;

Visit(op->b);

int tmp = std::min(std::min(current_min * tmp_min, current_max * tmp_min),
std::min(current_min * tmp_max, current_max * tmp_max));
current_max = std::max(std::max(current_min * tmp_min, current_max * tmp_min),
std::max(current_min * tmp_max, current_max * tmp_max));
current_min = tmp;

return;
}
}
IRVisitor::Visit_(op);
}

void Visit_(const Add *op) final {
if (swizzle_candidate && swizzlable && compute_var) {
Visit(op->a);
int tmp_min = current_min, tmp_max = current_max;

Visit(op->b);
int tmp = std::min(std::min(current_min + tmp_min, current_max + tmp_min),
std::min(current_min + tmp_max, current_max + tmp_max));
current_max = std::max(std::max(current_min + tmp_min, current_max + tmp_min),
std::max(current_min + tmp_max, current_max + tmp_max));
current_min = tmp;
return;
}
IRVisitor::Visit_(op);
}

void Visit_(const Sub *op) final {
if (swizzle_candidate && swizzlable && compute_var) {
Visit(op->b);
int tmp_min = current_min, tmp_max = current_max;

Visit(op->a);
int tmp = std::min(std::min(current_min - tmp_min, current_max - tmp_min),
std::min(current_min - tmp_max, current_max - tmp_max));
current_max = std::max(std::max(current_min - tmp_min, current_max - tmp_min),
std::max(current_min - tmp_max, current_max - tmp_max));
current_min = tmp;
return;
}
IRVisitor::Visit_(op);
}

void Visit_(const Div *op) final {
if (swizzle_candidate && swizzlable) {
if (compute_var) {
Visit(op->a);
int tmp_min = current_min, tmp_max = current_max;

Visit(op->b);

if (current_min == 0 || current_max == 0) {
swizzlable = false;
LOG(WARNING) << "Possible division by zero detected in " << getDebugInfo(op);
}
int tmp = std::min(std::min(tmp_min / current_min, tmp_min / current_max),
std::min(tmp_max / current_min, tmp_max / current_max));
current_max = std::max(std::max(tmp_min / current_min, tmp_min / current_max),
std::max(tmp_max / current_min, tmp_max / current_max));
current_min = tmp;

return;
}
}
IRVisitor::Visit_(op);
}

void Visit_(const IfThenElse *op) final {
if (swizzle_candidate) {
swizzlable = false;
}
IRVisitor::Visit_(op);
}

// returns the extent of the loop if it's a constant integer, otherwise return -1
static int GetExtent(const For *op) {
// constant folding.
Expr extent = Simplify(op->extent);
const auto *v1 = extent.as<IntImm>();
const auto *v2 = extent.as<UIntImm>();
int value = -1;
if (v1 != nullptr) {
value = static_cast<int>(v1->value);
}
if (v2 != nullptr) {
value = static_cast<int>(v2->value);
}
return value;
}

std::set<const Store *> single_stores{};
std::vector<const For *> swizzle_loops{};
std::set<std::basic_string<char>> temp_vars{};
std::set<std::basic_string<char>> swizzle_temp_vars{};
std::set<const Load *> input_arguments{};
std::set<const Store *> output_arguments{};
std::set<const Store *> swizzle_stores{};
std::set<const Load *> swizzle_loads{};
bool swizzlable{false};
std::map<const Store *, std::set<const Load *>> store_loads;
std::map<std::basic_string<char>, int> var_size{};

private:
// print a nice representation of what is happening inside the code
template <typename T>
std::string getDebugInfo(T *op) {
Expr a = (Expr)(op->a);
Expr b = (Expr)(op->b);
std::string a_str, b_str;
if (!a.defined()) {
a_str = a->GetTypeKey();
} else {
a_str = "(" + a->GetTypeKey() + ") " + ExprToString(a);
}
if (!b.defined()) {
b_str = b->GetTypeKey();
} else {
b_str = "(" + b->GetTypeKey() + ") " + ExprToString(b);
}
return a_str + " " + b_str;
}

const Store *current_store{};
bool compute_var = false;
int current_min = 0;
int current_max = 0;
bool swizzle_candidate{false};
bool contains_iterator{false};
unsigned int swizzle_length{0};
Var loop_var;
// temp vars that are set inside a loop
std::set<std::basic_string<char>> set_temp_vars{};
std::vector<const Load *> loop_loads;
std::vector<const Store *> loop_stores;
std::unordered_map<const Load *, std::pair<int, int>> load_indexes{};
std::unordered_map<const Store *, std::pair<int, int>> store_indexes{};
std::map<std::string, int64_t> thread_extent = {{"blockIdx.x", 0}, {"threadIdx.x", 0}};
};

class Swizzle : public IRMutator {
public:
explicit Swizzle(std::basic_string<char> name) : finder() { kernel_name = std::move(name); };

~Swizzle() override = default;

Stmt VisitAndMutate(Stmt stmt) {
LOG(DEBUG) << "Visit statement";
finder.Visit(stmt);
if (!finder.swizzle_loops.empty()) {
auto ret = Mutate(stmt);
if (!ret.same_as(stmt)) {
LOG(INFO) << "Total swizzled loops for " << kernel_name << " : " << finder.swizzle_loops.size();
return ret;
}
}
LOG(INFO) << "Total swizzled loops for " << kernel_name << " : 0";
return stmt;
}

Stmt Mutate_(const For *op, const Stmt &s) final {
auto f = std::find(std::begin(finder.swizzle_loops), std::end(finder.swizzle_loops), op);
if (f != std::end(finder.swizzle_loops)) {
swizzling = true;
LOG(DEBUG) << "Swizzle " << op->loop_var->name_hint;
Stmt s2 = swizzle_loop(op, s);
swizzling = false;
return s2;
}

auto body = Mutate(op->body);
// auto unroll loops
ForType t = op->for_type;
int ext = SwizzleFinder::GetExtent(op);
if (ext > 0 && t == ForType::Serial) t = ForType::Unrolled;
return For::make(op->loop_var, op->min, op->extent, t, op->device_api, body);
}

// modify loop to apply swizzle
Stmt swizzle_loop(const For *op, const Stmt &s) {
auto min = op->min.as<IntImm>();
loop_extent = SwizzleFinder::GetExtent(op);
if (min && loop_extent > 0) {
currentLoop = op;
auto body = Mutate(op->body);
// remove and unroll the loop
return For::make(op->loop_var, op->min, op->extent, ForType::Swizzled, op->device_api, body);
}

// something wrong happened during extent evaluation, we do not mutate
LOG(WARNING) << "Could not mutate loop (invalid loop extent)";
ForType t = op->for_type;
if (loop_extent > 0 && t == ForType::Serial) t = ForType::Unrolled;
return For::make(op->loop_var, op->min, op->extent, t, op->device_api, op->body);
}

Stmt Mutate_(const Store *op, const Stmt &s) final {
if (swizzling) {
loop_extent = SwizzleFinder::GetExtent(currentLoop);
if (std::find(finder.single_stores.begin(), finder.single_stores.end(), op) != finder.single_stores.end()) {
// this store has only one load attached, we can swizzle in place
LOG(DEBUG) << "Swizzle in place " << op->buffer_var->name_hint;

// modify store value
Expr value;
auto buffer_var = op->buffer_var;
auto v = air::ir::Substitute(op->value, {{Var{currentLoop->loop_var}, make_const(Int(32), 0)}});
v = Simplify_cce(v);
Array<Expr> value_args{make_const(Int(32), loop_extent), v};
// check if variable is a temp var or an output var
if (finder.swizzle_temp_vars.count(op->buffer_var->name_hint) > 0) {
// (swizzle temp var) sw_var ... ldg
// check temp variable is declared
CHECK(replace_vars.find(op->buffer_var->name_hint) != replace_vars.end());
buffer_var = replace_vars[op->buffer_var->name_hint];
value = Call::make(op->value.type(), Call::ldg, value_args, Call::Intrinsic);
Stmt s2 = Store::make(buffer_var, value, op->index, op->predicate);
return AttrStmt::make(s, "simple_store", Expr(0), s2);
} else if (finder.temp_vars.count(op->buffer_var->name_hint) > 0) {
// (temp var) reinterpret cast ... ldg
value = Call::make(op->value.type(), Call::ldg, value_args, Call::Intrinsic);
} else {
// (output var) reinterpret cast ... reinterpret cast
value = Call::make(op->value.type(), Call::reinterpret_cast_op, value_args, Call::Intrinsic);
}

auto index = air::ir::Substitute(op->index, {{Var{currentLoop->loop_var}, make_const(Int(32), 0)}});
index = Simplify_cce(index);
Stmt s2 = Store::make(buffer_var, value, index, op->predicate);
return AttrStmt::make(s, "reinterpret_store", Expr(loop_extent), s2);
} else {
// store with != 1 load
Expr value = Mutate(op->value);
LOG(DEBUG) << "check type: " << op->value.type() << " " << op->value;
CHECK(op->value.type().is_float() || op->value.type().is_int());
CHECK_LE(op->value.type().bits(), 32);
if (std::find(finder.swizzle_stores.begin(), finder.swizzle_stores.end(), op) != finder.swizzle_stores.end()) {
LOG(DEBUG) << "Mutate Store : index contains loop var " << currentLoop->loop_var->name_hint << std::endl
<< "Value :" << op->value;

Stmt new_stmt, end_stmt;
Expr index;
Array<Expr> value_args;

air::DataType t;
Var new_var;
if (replace_vars.find(op->buffer_var->name_hint) != replace_vars.end()) {
new_var = replace_vars[op->buffer_var->name_hint];
t = new_var.type();
} else {
if (op->value.type().is_int()) {
t = Int(op->value.type().bits(), loop_extent);
} else {
// Float(16, 4) -> half4
// 1. Generate the right DataType -> half4
t = Float(op->value.type().bits(), loop_extent);
}
new_var = Variable::make(t, "sw_" + op->buffer_var->name_hint);
}

LOG(DEBUG) << "(Vector store) replace previous buffer var : " << op->buffer_var
<< " type : " << op->buffer_var->type << " with " << new_var << " type : " << new_var->type;
Expr new_value = Broadcast::make(value, loop_extent);
index = Ramp::make(0, 1, loop_extent);
Expr predicate = Broadcast::make(Expr(1), loop_extent);

new_stmt = Store::make(new_var, new_value, index, predicate);
new_stmt = AttrStmt::make(s, "vec_store", Expr(currentLoop->loop_var), new_stmt);

// do not reinterpret cast if var is in swizzle_temp_vars
if (finder.swizzle_temp_vars.count(op->buffer_var->name_hint) == 0) {
// last statement (store) : set value to initial value
index = air::ir::Substitute(op->index, {{Var{currentLoop->loop_var}, make_const(Int(32), 0)}});
index = Simplify_cce(index);
value_args = {make_const(Int(32), 0), new_var};
Expr reinterpret = Call::make(op->value.type(), Call::reinterpret_cast_op, value_args, Call::Intrinsic);
// replace with new_var to remove second reinterpret (produces type error)
end_stmt = Store::make(op->buffer_var, reinterpret, index, op->predicate);
end_stmt = AttrStmt::make(new_stmt, "reinterpret_store", Expr(loop_extent), end_stmt);
}

// declare load variables
LOG(DEBUG) << "Mutate Load : index contains loop var " << currentLoop->loop_var->name_hint
<< ", Loop extent : " << loop_extent;
for (auto ld : finder.store_loads[op]) {
if (std::find(finder.temp_vars.begin(), finder.temp_vars.end(), ld->buffer_var->name_hint) ==
finder.temp_vars.end()) {
// declare swizzle variable

Var new_var2;
air::DataType t2;
if (replace_vars.find(ld->buffer_var->name_hint) == replace_vars.end()) {
if (ld->buffer_var->name_hint == op->buffer_var->name_hint) {
// this variable is both stored AND loaded for the first time
t2 = t;
new_var2 = new_var;
} else {
if (ld->type.is_int()) {
t2 = Int(ld->type.bits(), loop_extent);
} else {
// Float(16, 4) -> half4
// 1. Generate the right DataType -> half4
t2 = Float(ld->type.bits(), loop_extent);
}
new_var2 = Variable::make(t2, "sw_" + ld->buffer_var->name_hint);
}
replace_vars[ld->buffer_var->name_hint] = new_var2;
declared.insert(ld->buffer_var->name_hint);

// 2. declaration half4 sw_T_where = __ldg( ... )
index = air::ir::Substitute(ld->index, {{Var{currentLoop->loop_var}, make_const(Int(32), 0)}});
index = Simplify_cce(index);
Expr expr_ld = Load::make(ld->type, ld->buffer_var, index, ld->predicate);
value_args = {make_const(Int(32), loop_extent), expr_ld};
Expr ldg_input = Call::make(t2, Call::ldg, value_args, Call::Intrinsic);

LOG(DEBUG) << "(Vector load) replace previous buffer var : " << ld->buffer_var
<< " type : " << ld->buffer_var->type << " with " << new_var2
<< " type : " << new_var2->type;

new_stmt = LetStmt::make(new_var2, ldg_input, new_stmt);
new_stmt = AttrStmt::make(s, "vec_load", ld->buffer_var, new_stmt);
} else {
new_stmt = AttrStmt::make(s, "vec_load", ld->buffer_var, new_stmt);
}
}
}

// declare store variable
if (std::find(declared.begin(), declared.end(), op->buffer_var->name_hint) == declared.end()) {
// declare swizzle variable
declared.insert(op->buffer_var->name_hint);
new_stmt = LetStmt::make(new_var, Broadcast::make(make_const(op->value.type(), 0), loop_extent), new_stmt);
new_stmt = AttrStmt::make(s, "no_init_value", Expr(0), new_stmt);
}

if (replace_vars.find(op->buffer_var->name_hint) == replace_vars.end()) {
replace_vars[op->buffer_var->name_hint] = new_var;
}

if (finder.swizzle_temp_vars.count(op->buffer_var->name_hint) == 0) {
Stmt new_block = Block::make(new_stmt, end_stmt);
return new_block;
} else {
return new_stmt;
}
}
}
}
return IRMutator::Mutate_(op, s);
}

Stmt Mutate_(const Allocate *op, const Stmt &s) final {
if (finder.swizzle_temp_vars.count(op->buffer_var->name_hint) > 0) {
// replace var with its swizzle counterpart
int size = finder.var_size[op->buffer_var->name_hint];
air::DataType t;
if (op->type.is_int()) {
t = Int(op->type.bits(), size);
} else {
t = Float(op->type.bits(), size);
}
Var new_var = Variable::make(t, "sw_" + op->buffer_var->name_hint);

LOG(DEBUG) << "Allocate : replace previous buffer var : " << op->buffer_var << " type : " << op->type << " with "
<< new_var << " type : " << new_var->type;

replace_vars[op->buffer_var->name_hint] = new_var;
declared.insert(op->buffer_var->name_hint);
Stmt body = Mutate(op->body);
Stmt new_stmt;
Stmt let_stmt = LetStmt::make(new_var, Broadcast::make(make_const(op->type, 0), size), body);
Stmt attr = AttrStmt::make(s, "no_init_value", Expr(0), let_stmt);
Stmt new_allocate = Allocate::make(op->buffer_var, op->type, op->extents, const_false(), attr);

return new_allocate; // Block::make(attr, attr1);
}
return IRMutator::Mutate_(op, s);
}

private:
// int nb_loops;
SwizzleFinder finder;
const For *currentLoop{};
bool swizzling = false;
int loop_extent{};
std::set<std::basic_string<char>> declared{};
std::unordered_map<std::basic_string<char>, Var> replace_vars{};
std::basic_string<char> kernel_name;
};

static void ParseStringAttr(const Map<std::string, NodeRef> &attrs, const std::string &attr_name,
std::string *attr_to_set) {
CHECK(attr_to_set != nullptr);
if (attrs.count(attr_name) == 0) return;
const NodeRef &e = attrs.at(attr_name);
if (auto val = e.as<StringImm>()) {
*attr_to_set = val->value;
} else {
LOG(FATAL) << "Failed to parse attribute: " << attr_name << " = " << e << " as string";
}
}

static void ParseIntAttr(const Map<std::string, NodeRef> &attrs, const std::string &attr_name, int *attr_to_set) {
CHECK(attr_to_set != nullptr);
if (attrs.count(attr_name) == 0) return;
const NodeRef &e = attrs.at(attr_name);
if (auto i = e.as<IntImm>()) {
*attr_to_set = static_cast<int>(i->value);
} else if (auto ui = e.as<UIntImm>()) {
*attr_to_set = static_cast<int>(ui->value);
} else {
LOG(FATAL) << "Failed to parse attribute: " << attr_name << " = " << e << " as integer";
}
}

static void ParseBoolAttr(const Map<std::string, NodeRef> &attrs, const std::string &attr_name, bool *attr_to_set) {
const int invalid_value = -1;
int attr = invalid_value;
ParseIntAttr(attrs, attr_name, &attr);
if (attr != invalid_value) {
CHECK(attr == 0 || attr == 1) << "Bool attribute " << attr_name << " must be 0 or 1, but found "
<< attrs.at(attr_name);
*attr_to_set = static_cast<bool>(attr);
}
}

Stmt SwizzleGPU(const Stmt &stmt, const Map<std::string, NodeRef> &attrs) {
bool disable_swizzle = false;
ParseBoolAttr(attrs, "disable_swizzle", &disable_swizzle);
if (const char *env_p = std::getenv("MS_AKG_DISABLE_SWIZZLE"))
if (!strcmp(env_p, "1")) disable_swizzle = true;
if (disable_swizzle) {
LOG(INFO) << "SwizzleGPU pass disabled";
return stmt;
}
std::string kernel_name_;
ParseStringAttr(attrs, "kernel_name", &kernel_name_);
if (kernel_name_.empty())
LOG(WARNING) << "Kernel name not found !";
else
LOG(INFO) << "BEGIN_PASS SwizzleGPU on " << kernel_name_;
auto sw = Swizzle(kernel_name_);
Stmt s = sw.VisitAndMutate(stmt);

LOG(INFO) << "END_PASS";
return s;
}

} // namespace ir
} // namespace akg

+ 1
- 0
src/poly/dsa_mgr_strategy.cc View File

@@ -51,6 +51,7 @@ void DsaMgrStrategy::RegisterMemPromPasses() { RegisterPass(std::make_shared<Mem
void DsaMgrStrategy::RegisterPasses() {
passes_.clear();
RegisterNormalizationPasses();
RegisterConstrainedScheduling();
if (!scop_info_.user_config_.GetDisableGroup()) {
RegisterPass(std::make_shared<GroupStatements>(pass_info_));
}


+ 6
- 1
src/poly/dump_log.cc View File

@@ -36,7 +36,7 @@ void DumpSchTreeToFile(std::FILE *fp, const isl::schedule &sch) {

CHECK(sch.get());

printer = isl_printer_to_file(isl_schedule_ctx(sch.get()), fp);
printer = isl_printer_to_file(isl_schedule_get_ctx(sch.get()), fp);
printer = isl_printer_set_yaml_style(printer, ISL_YAML_STYLE_BLOCK);
printer = isl_printer_print_schedule(printer, sch.get());

@@ -438,6 +438,11 @@ void ScopInfo::DumpScopDataAdvanced(std::ofstream &of) {
}

void UserConfig::DumpScopDataScheduleAttrs(std::ofstream &of) {
if (constrained_scheduling_output_ != "") {
PrintHeader(of, "constrained scheduling");
of << constrained_scheduling_output_ << std::endl;
}

PrintHeader(of, "schedule attrs");
of << "dump_poly_dir : " << GetDumpPolyDir() << std::endl;



+ 1
- 0
src/poly/gpu_mgr_strategy.cc View File

@@ -36,6 +36,7 @@ void GPUMgrStrategy::RegisterMemPromPasses() {
void GPUMgrStrategy::RegisterPasses() {
passes_.clear();
RegisterNormalizationPasses();
RegisterConstrainedScheduling();
RegisterSchedulingPasses();
RegisterPass(std::make_shared<GpuDmaAnalysis>(scop_info_));
RegisterTilingPasses();


+ 891
- 0
src/poly/isl_influence.cc View File

@@ -0,0 +1,891 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "poly/isl_influence.h"

// ISL
#include <isl/aff.h>
#include <isl/mat.h>
#include <isl/hash.h>
#include <isl/schedule.h>
#include <isl/schedule_node.h>
#include <isl/constraint.h>
#include <isl/map_to_basic_set.h>

// ISL private headers
#include <isl_int.h>
#include <isl_tab.h>
extern "C" {
#include <isl_aff_private.h>
#include <isl_mat_private.h>
#include <isl_schedule_constraints.h>
}

#include "poly/log_util.h"

struct isl_sched_graph {
isl_map_to_basic_set *intra_hmap;
isl_map_to_basic_set *intra_hmap_param;
isl_map_to_basic_set *inter_hmap;

struct isl_sched_node *node;
int n;
int maxvar;
int max_row;
int n_row;

int *sorted;

int n_total_row;
int band_start;

struct isl_sched_graph *root;

struct isl_sched_edge *edge;
int n_edge;
int max_edge[isl_edge_last + 1];
struct isl_hash_table *edge_table[isl_edge_last + 1];

struct isl_hash_table *node_table;
struct isl_trivial_region *region;

isl_basic_set *lp;

int src_scc;
int dst_scc;

int scc;
int weak;

int max_weight;
/* AKG isl_influence patch - start */
akg::ir::poly::isl_influence_list *inf_list;
akg::ir::poly::isl_influence_equal_list *inf_equal_list;
akg::ir::poly::isl_influence_sol_list *inf_sol_list;
/* AKG isl_influence patch - end */
};

///////////////////////////////////////////////////////////////////////////
// CODE
///////////////////////////////////////////////////////////////////////////

namespace akg {
namespace ir {
namespace poly {

static inline void akg_isl_influence_set_function_pointers(void) {
isl_influence_set_coef = akg_isl_influence_set_coef;
isl_influence_set_equal = akg_isl_influence_set_equal;
isl_influence_maxvar = akg_isl_influence_maxvar;
isl_influence_check_coincident = akg_isl_influence_check_coincident;
isl_influence_sol_list_free = akg_isl_influence_sol_list_free;
isl_influence_sol_add_elem = akg_isl_influence_sol_add_elem;
isl_influence_sol_get_elem = akg_isl_influence_sol_get_elem;
}

static inline void akg_isl_influence_unset_function_pointers(void) {
isl_influence_set_coef = 0;
isl_influence_set_equal = 0;
isl_influence_maxvar = 0;
isl_influence_check_coincident = 0;
isl_influence_sol_list_free = 0;
isl_influence_sol_add_elem = 0;
isl_influence_sol_get_elem = 0;
}

void akg_isl_influence_enable(isl_ctx *ctx) {
isl_options_set_akg_print_debug(ctx, 1);
isl_options_set_akg_influence_scheduler(ctx, 1);
akg_isl_influence_set_function_pointers();
}

void akg_isl_influence_disable(isl_ctx *ctx) {
isl_options_set_akg_print_debug(ctx, 0);
isl_options_set_akg_influence_scheduler(ctx, 0);
akg_isl_influence_unset_function_pointers();
}

static const char *inf_types_str[] = {
[isl_cst] = "isl_cst",
[isl_param] = "isl_param",
[isl_var] = "isl_var",
};

struct isl_sched_graph *akg_isl_influence_sol_list_free(isl_sched_graph *graph) {
log::Info(log::Verbosity::high, "Entering isl_influence_sol_list_free");

isl_influence_sol_list *inf = graph->inf_sol_list;
if (inf) {
log::Info(log::Verbosity::high, "inf-mem: " + std::to_string(inf->mem));
for (int i = 0; i < inf->mem; i++) {
isl_influence_sol *p = &inf->data[i];
p->mem = 0;
p->size = 0;
free(p->data);
p->data = NULL;
}
free(inf->data);
inf->data = NULL;
inf->size = 0;
inf->mem = 0;
free(inf);
}
graph->inf_sol_list = NULL;
log::Info(log::Verbosity::high, "Leaving isl_influence_sol_list_free");
return graph;
}
struct isl_sched_graph *akg_isl_influence_sol_add_elem(isl_vec *sol, struct isl_sched_graph *graph) {
log::Info(log::Verbosity::high, "Entering isl_influence_sol_add_elem");

if (!sol) {
return graph;
}

isl_influence_sol_list *inf = graph->inf_sol_list;
if (!inf) {
inf = (isl_influence_sol_list *)calloc(1, sizeof(isl_influence_sol_list));
if (!inf) {
log::Info(log::Verbosity::high,
"MEM ERROR: isl_influence_sol_add_elem could not allocate memory for isl_influence_sol_list");
return graph;
}
inf->mem = 0;
inf->size = 0;
inf->data = NULL;
}

if (inf->mem == 0) {
inf->mem = INF_BLOCK_SIZE;
inf->data = (isl_influence_sol *)calloc(inf->mem, sizeof(isl_influence_sol));
if (inf->data == NULL) {
log::Info(log::Verbosity::high,
"MEM ERROR: isl_influence_sol_add_elem could not allocate memory for isl_influence_sol_list");
return graph;
}
}

if (inf->mem == inf->size) {
inf->mem += INF_BLOCK_SIZE;
inf->data = (isl_influence_sol *)realloc(inf->data, inf->mem * sizeof(isl_influence_sol));
if (inf->data == NULL) {
log::Info(log::Verbosity::high,
"MEM ERROR: isl_influence_sol_add_elem could not reallocate memory isl_influence_sol");
return graph;
}
}

isl_influence_sol *p = &inf->data[inf->size];
const int sol_size = isl_inf_vec_get_size(sol);
p->data = (int *)calloc((size_t)sol_size, sizeof(int));

if (p->data == NULL) {
log::Info(log::Verbosity::high, "MEM ERROR: isl_influence_sol_add_elem could not allocate memory");
return graph;
}

p->mem = sol_size;
isl_ctx *ctx = isl_vec_get_ctx(sol);

for (int i = 0; i < p->mem; i++) {
isl_printer *printer = isl_printer_to_str(ctx);
isl_val *const val = isl_influence_vec_get_elem(sol, i);

printer = isl_printer_print_val(printer, val);
char *const str = isl_printer_get_str(printer);

const int value = atoi(isl_printer_get_str(printer));
p->data[i] = value;
p->size++;

free(str);
isl_printer_free(printer);
isl_val_free(val);
}

inf->size++;

log::Info(log::Verbosity::high, "inf->mem: " + std::to_string(inf->mem));
log::Info(log::Verbosity::high, "inf->size: " + std::to_string(inf->size));

log::Info(log::Verbosity::high, "Leaving isl_influence_sol_add_elem");
graph->inf_sol_list = inf;
return graph;
}

int akg_isl_influence_sol_get_elem(int dim, int pos, struct isl_sched_graph *graph) {
isl_influence_sol_list *inf = graph->inf_sol_list;

if (NULL == inf) {
return -1;
}

isl_influence_sol *p = NULL;
int retval;

if (dim < inf->size) {
p = &inf->data[dim];
} else {
std::string message = "ERROR: isl_influence_sol_get_elem index out of range isl_influence_sol_list size : ";
message += std::to_string(inf->size) + " < dim: " + std::to_string(dim);
log::Info(log::Verbosity::high, message);
retval = -1;
}

if (1 + pos < p->size) {
retval = p->data[1 + pos];
} else {
std::string message = "ERROR: isl_influence_sol_get_elem index out of range isl_ifnluence_sol size: ";
message += std::to_string(p->size) + " < pos: " + std::to_string(pos);
retval = -1;
}

return retval;
}

void print_basic_set(isl_basic_set *set, const char *str) {
isl_ctx *const ctx = isl_basic_set_get_ctx(set);
const bool print_debug = isl_options_get_akg_print_debug(ctx);
if (print_debug) {
isl_printer *printer = isl_printer_to_str(ctx);
printer = isl_printer_print_basic_set(printer, set);
char *const printed_str = isl_printer_get_str(printer);

log::Info(log::Verbosity::high, str);
log::Info(log::Verbosity::high, printed_str);

isl_printer_free(printer);
free(printed_str);
}
}

void *isl_influence_equal_list_free(isl_influence_equal_list *inf_equal_list) {
if (inf_equal_list->data) {
for (int i = 0; i < inf_equal_list->size; ++i) {
free(inf_equal_list->data[i].statement1);
free(inf_equal_list->data[i].statement2);
}
free(inf_equal_list->data);
inf_equal_list->data = NULL;
inf_equal_list->mem = 0;
inf_equal_list->size = 0;
}
free(inf_equal_list);
return NULL;
}

void *isl_influence_list_free(isl_influence_list *inf_list) {
if (inf_list->data) {
for (int i = 0; i < inf_list->size; ++i) {
free(inf_list->data[i].statement_name);
}
free(inf_list->data);
inf_list->data = NULL;
inf_list->mem = 0;
inf_list->size = 0;
}
free(inf_list);
return NULL;
}

isl_basic_set *hack_coefficients(isl_basic_set *coef, const char *msg, int pos, int lb, int ub) {
isl_ctx *const ctx = isl_basic_set_get_ctx(coef);

isl_val *const v_ub = isl_val_int_from_si(ctx, ub);
isl_val *const v_lb = isl_val_int_from_si(ctx, lb);

isl_size dim = isl_basic_set_n_dim(coef);
log::Info(log::Verbosity::high, "pos: " + std::to_string(pos));

if (pos < (int)dim) {
const bool influence_schedule = isl_options_get_akg_influence_scheduler(ctx);
if (influence_schedule) {
coef = isl_basic_set_upper_bound_val(coef, isl_dim_set, pos, isl_val_copy(v_ub));
coef = isl_basic_set_lower_bound_val(coef, isl_dim_set, pos, isl_val_copy(v_lb));

std::string message = " -> i" + std::to_string(pos);
message += " = (" + std::to_string(lb);
message += ", " + std::to_string(ub) + ")";
log::Info(log::Verbosity::high, message);
}
}

isl_val_free(v_ub);
isl_val_free(v_lb);

return coef;
}

isl_basic_set *create_constraint(isl_basic_set *coef, const char *msg, int pos1, int pos2) {
isl_ctx *ctx = isl_basic_set_get_ctx(coef);
const bool influence_schedule = isl_options_get_akg_influence_scheduler(ctx);
if (influence_schedule) {
isl_local_space *ls = isl_basic_set_get_local_space(coef);
isl_constraint *c = isl_constraint_alloc_equality(isl_local_space_copy(ls));
c = isl_constraint_set_coefficient_si(c, isl_dim_set, pos1, 1);
c = isl_constraint_set_coefficient_si(c, isl_dim_set, pos2, -1);
coef = isl_basic_set_add_constraint(coef, c);
print_basic_set(coef, msg);

log::Info(log::Verbosity::high, "--> i" + std::to_string(pos1) + " = i" + std::to_string(pos2));
isl_local_space_free(ls);
}

return coef;
}

isl_basic_set *graph_find_basic_set_by_statement_name(struct isl_sched_graph *graph, const char *name,
int *node_index) {
isl_basic_set *bset = NULL;

for (int i = 0; i < graph->n && bset == NULL; ++i) {
struct isl_sched_node *node = isl_sched_graph_get_node(graph, i);
isl_map *ma = isl_sched_node_extract_schedule(node);
const char *strstat = isl_map_get_tuple_name(ma, isl_dim_in);

if (strcmp(strstat, name) == 0) {
bset = graph->lp;
*node_index = i;
}

isl_map_free(ma);
}
return bset;
}

int get_pos_from_bset(isl_basic_set *bset, struct isl_sched_node *node, isl_influence_equal *inf, int coef_dim) {
int pos = -1;

switch (inf->type) {
case isl_cst:
pos = isl_sched_node_cst_coef_offset(node);
break;
case isl_param:
pos = isl_sched_node_par_coef_offset(node) - coef_dim;
break;
case isl_var:
pos = isl_sched_node_par_coef_offset(node) - isl_sched_node_get_nparam(node) - coef_dim * 2 - 1;
break;
default:
break;
}

log::Info(log::Verbosity::high,
"coefficient position for coef_dim=" + std::to_string(coef_dim) + ": " + std::to_string(pos));

return pos;
}

void print_influence(isl_influence_equal_list *equal, isl_influence_list *coef) {
log::Info(log::Verbosity::high, "start printing isl_influence hard constraints");
for (int i = 0; coef && i < coef->size; i++) {
isl_influence *inf = &coef->data[i];
log::Info(log::Verbosity::high,
"hard constraint\t\t[+" + std::to_string(i + 1) + ":" + std::to_string(coef->size) + "]");
log::Info(log::Verbosity::high, "statement:\t\t" + std::string(inf->statement_name));
log::Info(log::Verbosity::high, "type:\t\t\t" + std::string(inf_types_str[(int)inf->type]));
log::Info(log::Verbosity::high, "sched_dim:\t\t" + std::to_string(inf->sched_dim));
log::Info(log::Verbosity::high, "coef_dim:\t\t" + std::to_string(inf->coef_dim));
log::Info(log::Verbosity::high, "val:\t\t\t" + std::to_string(inf->val));
}
log::Info(log::Verbosity::high, "End printing isl_influence hard constraints");

log::Info(log::Verbosity::high, "Start printing isl_influence soft constraints");
for (int i = 0; equal && i < equal->size; i++) {
isl_influence_equal *inf_equal = &equal->data[i];
log::Info(log::Verbosity::high,
"soft constraint\t\t[" + std::to_string(i + 1) + ":" + std::to_string(equal->size) + "]");
log::Info(log::Verbosity::high, "statement1:\t\t" + std::string(inf_equal->statement1));
log::Info(log::Verbosity::high, "statement2:\t\t" + std::string(inf_equal->statement2));
log::Info(log::Verbosity::high, "sched_dim1:\t\t" + std::to_string(inf_equal->sched_dim1));
log::Info(log::Verbosity::high, "sched_dim2:\t\t" + std::to_string(inf_equal->sched_dim2));
log::Info(log::Verbosity::high, "type:\t\t\t" + std::string(inf_types_str[(int)inf_equal->type]));
log::Info(log::Verbosity::high, "coef_dim1:\t\t" + std::to_string(inf_equal->coef_dim1));
log::Info(log::Verbosity::high, "coef_dim2:\t\t" + std::to_string(inf_equal->coef_dim2));
}
log::Info(log::Verbosity::high, "End printing isl_influence soft constraints");
}

int report_influence(isl_influence_equal_list *equal, isl_influence_list *coef, isl_influence_sol_list *sol,
int maxvar) {
int bad_equal = 0;
int bad_coef = 0;
int result = 1;
for (int i = 0; i < coef->size; i++) {
isl_influence *inf = &coef->data[i];
if (inf->scanned != 1) {
log::Info(log::Verbosity::high,
"warning: influence hard constraint [" + std::to_string(i) + "] was not processed");
log::Info(log::Verbosity::high, "statement:\t\t" + std::string(inf->statement_name));
log::Info(log::Verbosity::high, "type:\t\t\t" + std::string(inf_types_str[(int)inf->type]));
log::Info(log::Verbosity::high, "sched_dim:\t\t" + std::to_string(inf->sched_dim));
log::Info(log::Verbosity::high, "coef_dim:\t\t" + std::to_string(inf->coef_dim));
log::Info(log::Verbosity::high, "val:\t\t\t" + std::to_string(inf->val));
bad_coef++;
result = 0;
}
}

for (int i = 0; i < equal->size; i++) {
isl_influence_equal *inf_equal = &equal->data[i];
if (inf_equal->scanned != 1) {
log::Info(log::Verbosity::high,
"warning: influence soft constraint [" + std::to_string(i) + "] was not processed");
log::Info(log::Verbosity::high, "statement1:\t\t" + std::string(inf_equal->statement1));
log::Info(log::Verbosity::high, "statement2:\t\t" + std::string(inf_equal->statement2));
log::Info(log::Verbosity::high, "sched_dim1:\t\t" + std::to_string(inf_equal->sched_dim1));
log::Info(log::Verbosity::high, "sched_dim2:\t\t" + std::to_string(inf_equal->sched_dim2));
log::Info(log::Verbosity::high, "type:\t\t\t" + std::string(inf_types_str[(int)inf_equal->type]));
log::Info(log::Verbosity::high, "coef_dim1:\t\t" + std::to_string(inf_equal->coef_dim1));
log::Info(log::Verbosity::high, "coef_dim2:\t\t" + std::to_string(inf_equal->coef_dim2));
bad_equal++;
result = 0;
}
}

if (bad_coef == 0)
log::Info(log::Verbosity::high, std::to_string(coef->size) + "influence hard coef constraints processed correctly");
else
log::Info(log::Verbosity::high, std::to_string(bad_coef) + "influence hard coef constraints were not processed");

if (bad_equal == 0)
log::Info(log::Verbosity::high, std::to_string(equal->size) + "influence equal constraints processed correctly");
else
log::Info(log::Verbosity::high, std::to_string(bad_equal) + "influence equal constraints were not processed");

if (!sol || sol->size != maxvar) {
log::Info(log::Verbosity::high, "isl influence could not find solution for all dimensions");
result = 0;
}

return result;
}

int set_params(isl_influence_equal *inf, int *sched_from, int *sched_to, int *coef_from, int *coef_to, int actual_dim) {
int retval = 0;

if (actual_dim == 0 && inf->sched_dim1 != inf->sched_dim2) {
return retval;
} else if (actual_dim == inf->sched_dim1) {
if (inf->sched_dim1 >= inf->sched_dim2) {
*sched_from = inf->sched_dim2;
*sched_to = inf->sched_dim1;
*coef_from = inf->coef_dim2;
*coef_to = inf->coef_dim1;
retval = 1;

} else {
log::Info(log::Verbosity::high,
"cannot set future coef for dimension: " + std::to_string(actual_dim) + " and inf_equal:");
log::Info(log::Verbosity::high, "inf->sched_dim1: " + std::to_string(inf->sched_dim1));
log::Info(log::Verbosity::high, "inf->sched_dim2: " + std::to_string(inf->sched_dim2));
}

} else if (actual_dim == inf->sched_dim2) {
if (inf->sched_dim2 >= inf->sched_dim1) {
*sched_from = inf->sched_dim1;
*sched_to = inf->sched_dim2;
*coef_from = inf->coef_dim1;
*coef_to = inf->coef_dim2;
retval = 1;
} else {
log::Info(log::Verbosity::high,
"cannot set future coef for dimension: " + std::to_string(actual_dim) + " and inf_equal:");
log::Info(log::Verbosity::high, "inf->sched_dim1: " + std::to_string(inf->sched_dim1));
log::Info(log::Verbosity::high, "inf->sched_dim2: " + std::to_string(inf->sched_dim2));
}
}

return retval;
}
isl_basic_set *akg_isl_influence_set_equal(isl_ctx *ctx, struct isl_sched_graph *graph, isl_basic_set *bset) {
// loop over iinfluence list equal
log::Info(log::Verbosity::high, "Enter isl_influence_set_equal for dimension --> " + std::to_string(graph->n_row));
isl_influence_equal_list *inf_list = graph->inf_equal_list;
for (int i = 0; i < inf_list->size; i++) {
isl_influence_equal *inf_equal = &inf_list->data[i];

if (inf_equal->sched_dim1 > inf_equal->sched_dim2) {
std::string message0 = "ERROR: isl cannot compute soft constraint";
message0 += "[" + std::to_string(i) + " from dimension " + std::to_string(inf_equal->sched_dim2) +
" to dimension " + std::to_string(inf_equal->sched_dim1);
std::string message1 = "Reason: destination coefficient unknown";
log::Info(log::Verbosity::high, message0);
log::Info(log::Verbosity::high, message1);
}
int sched_from;
int sched_to;
int coef_from;
int coef_to;
if (!set_params(inf_equal, &sched_from, &sched_to, &coef_from, &coef_to, graph->n_row)) continue;

isl_basic_set *bset_from;
isl_basic_set *bset_to;

int node_index_from = -1;
int node_index_to = -1;

bset_from = graph_find_basic_set_by_statement_name(graph, inf_equal->statement1, &node_index_from);
bset_to = graph_find_basic_set_by_statement_name(graph, inf_equal->statement2, &node_index_to);

if (bset_from != NULL && bset_to != NULL && node_index_from != -1 && node_index_to != -1) {
int pos_from;
int pos_to;
log::Info(log::Verbosity::high, "scanning equal constraint for influence equal constraint[" +
std::to_string(i + 1) + ":" + std::to_string(inf_list->size) + "]");
log::Info(log::Verbosity::high, "statement from:\t" + std::string(inf_equal->statement2));
log::Info(log::Verbosity::high, "statement to:\t" + std::string(inf_equal->statement1));
log::Info(log::Verbosity::high, "sched_dim from:\t" + std::to_string(sched_from));
log::Info(log::Verbosity::high, "sched_dim to:\t" + std::to_string(sched_to));
log::Info(log::Verbosity::high, "coef_dim from\t" + std::to_string(coef_from));
log::Info(log::Verbosity::high, "coef_dim to\t" + std::to_string(coef_to));
log::Info(log::Verbosity::high, "type:\t\t" + std::string(inf_types_str[(int)(inf_equal->type)]));
log::Info(log::Verbosity::high, "isl_influence_set_equal: copying coef from " +
std::string(inf_equal->statement2) + " to " +
std::string(inf_equal->statement1));
print_basic_set(bset_from, "bset from:");
print_basic_set(bset_to, "bset to:");
log::Info(log::Verbosity::high, "node from:\t" + std::to_string(node_index_from));
log::Info(log::Verbosity::high, "node to:\t" + std::to_string(node_index_to));
pos_from = get_pos_from_bset(bset_from, isl_sched_graph_get_node(graph, node_index_from), inf_equal, coef_from);
pos_to = get_pos_from_bset(bset_to, isl_sched_graph_get_node(graph, node_index_to), inf_equal, coef_to);

if (inf_equal->sched_dim2 == inf_equal->sched_dim1) {
bset = create_constraint(bset, "constraint created", pos_from, pos_to);
if (inf_equal->type == isl_var) bset = create_constraint(bset, "constraint created", pos_from - 1, pos_to - 1);
} else {
int val = isl_influence_sol_get_elem(sched_from, pos_from, graph);
log::Info(log::Verbosity::high, "val=" + std::to_string(val));
bset = hack_coefficients(bset, "isl_equal", pos_to, val, val);
if (inf_equal->type == isl_var) {
val = isl_influence_sol_get_elem(sched_from, pos_from - 1, graph);
bset = hack_coefficients(bset, "is_equal", pos_to - 1, val, val);
}
}
inf_equal->scanned = 1;
}
}
log::Info(log::Verbosity::high, "Leave isl_influence_set_equal");
return bset;
}

__isl_give isl_schedule *akg_isl_schedule_constraints_compute_schedule_influence(
__isl_take isl_schedule_constraints *sc, isl_influence_list *inf_coef, isl_influence_equal_list *inf_equal) {
isl_ctx *ctx = isl_schedule_constraints_get_ctx(sc);
struct isl_sched_graph graph = {0};
isl_schedule *sched;
isl_schedule_node *node;

isl_union_set *domain;
isl_size n;

log::Info(log::Verbosity::high, "isl_schedule_constraints_compute_schedule : start printing constraints");
isl_printer *p;
p = isl_printer_to_str(ctx);
p = isl_printer_set_yaml_style(p, ISL_YAML_STYLE_BLOCK);
p = isl_printer_print_schedule_constraints(p, sc);
char *log_string = isl_printer_get_str(p);

log::Info(log::Verbosity::high, std::string(log_string));

isl_printer_free(p);
free(log_string);

log::Info(log::Verbosity::high, "isl_schedule_constraints_compute_schedule : end printing constraints");

print_influence(inf_equal, inf_coef);

graph.inf_list = inf_coef;
graph.inf_equal_list = inf_equal;

sc = isl_schedule_constraints_align_params(sc);

domain = isl_schedule_constraints_get_domain(sc);
n = isl_union_set_n_set(domain);
if (n == 0) {
isl_schedule_constraints_free(sc);
return isl_schedule_from_domain(domain);
}

if (n < 0 || isl_sched_graph_init(&graph, sc) < 0) {
domain = isl_union_set_free(domain);
}

node = isl_schedule_node_from_domain(domain);
node = isl_schedule_node_child(node, 0);

if (graph.n > 0) {
node = isl_schedule_node_compute_schedule(node, &graph);
}
sched = isl_schedule_node_get_schedule(node);
int result = report_influence(inf_equal, inf_coef, graph.inf_sol_list, graph.maxvar);

isl_schedule_node_free(node);
isl_sched_graph_free(ctx, &graph);
isl_schedule_constraints_free(sc);

if (!result) {
log::Info(log::Verbosity::high, "isl_influence failed, will fallback to default isl");
isl_schedule_free(sched);
sched = NULL;
}
return sched;
}

isl_basic_set *akg_isl_influence_set_coef(isl_ctx *ctx, struct isl_sched_graph *graph, isl_basic_set *bset) {
log::Info(log::Verbosity::high, "Enter isl_influence_set_coef for dimension --> " + std::to_string(graph->n_row));

isl_influence_list *inf_list = graph->inf_list;
int dimension = graph->n_row;

for (int i = 0; i < graph->n; ++i) {
struct isl_sched_node *node = isl_sched_graph_get_node(graph, i);
isl_map *ma = isl_sched_node_extract_schedule(node);
log::Info(log::Verbosity::high, "statement:");
if (ma != NULL) {
isl_printer *p;
p = isl_printer_to_str(ctx);
p = isl_printer_print_map(p, ma);

char *log_str = isl_printer_get_str(p);
log::Info(log::Verbosity::high, std::string(log_str));

isl_printer_free(p);
free(log_str);
}
log::Info(log::Verbosity::high, "end statement");
const char *strstat = isl_map_get_tuple_name(ma, isl_dim_in);
isl_map_free(ma);

for (int j = 0; j < inf_list->size; j++) {
isl_influence *inf = &inf_list->data[j];

if (inf->sched_dim == dimension && strcmp(strstat, inf->statement_name) == 0) {
int pos;
int ub = inf->val;
int lb = inf->val;
log::Info(log::Verbosity::high, "scanning isl coefficients for influence[" + std::to_string(j + 1) + ":" +
std::to_string(inf_list->size) + "]:");
// S_0,S_1,...,S_n-1,S_n
log::Info(log::Verbosity::high, "statement_name:\t\t" + std::string(inf->statement_name));
// i,j,k.... (only apply for type=isl_var_plus | isl_var_minus
log::Info(log::Verbosity::high, "sched dim:\t\t" + std::to_string(inf->sched_dim));
// coefficient index
log::Info(log::Verbosity::high, "rank variable:\t\t" + std::to_string(inf->coef_dim));
int nparam = isl_sched_node_get_nparam(node);
int nvar = isl_sched_node_get_nvar(node);
log::Info(log::Verbosity::high, "statement variables:\t" + std::to_string(nvar));
// coefficient value
log::Info(log::Verbosity::high, "coefficient value:\t" + std::to_string(inf->val));
log::Info(log::Verbosity::high, "type:\t\t" + std::string(inf_types_str[(int)(inf->type)]));
print_basic_set(bset, "lp problem to influence:");
switch (inf->type) {
case isl_cst:
pos = isl_sched_node_cst_coef_offset(node);
bset = hack_coefficients(bset, "isl_cst", pos, ub, lb);
inf->scanned = 1;
break;
case isl_param:
log::Info(log::Verbosity::high, "Statement param (node->nparam): " + std::to_string(nparam));
pos = isl_sched_node_cst_coef_offset(node) - (nparam - inf->coef_dim);
bset = hack_coefficients(bset, "isl_param", pos, ub, lb);
inf->scanned = 1;
break;
case isl_var:
// dim_+ coefficient ;
if (inf->coef_dim <= nvar) {
pos = isl_sched_node_cst_coef_offset(node) - nparam - 1 - inf->coef_dim * 2;
if (inf->val > 0) {
bset = hack_coefficients(bset, "isl_var_+", pos, ub, lb);
// dim_- coeffiecient
pos--;
bset = hack_coefficients(bset, "isl_var_-", pos, 0, 0);
inf->scanned = 1;
} else if (inf->val == 0) {
// dim_+ coefficient
bset = hack_coefficients(bset, "isl_var_+", pos, 0, 0);
// dim_- coeffiecient
pos--;
bset = hack_coefficients(bset, "isl_var_-", pos, 0, 0);
inf->scanned = 1;

} else if (inf->val < 0) {
bset = hack_coefficients(bset, "isl_var_+", pos, 0, 0);
// dim_- coeffiecient
pos--;
bset = hack_coefficients(bset, "isl_var_-", pos, -ub, -lb);
inf->scanned = 1;
} else {
log::Info(log::Verbosity::high, "invalid inf->val: " + std::to_string(inf->val));
}
} else {
log::Info(log::Verbosity::high,
"Warning: dimension overflow --> dimension required: " + std::to_string(inf->coef_dim) +
" max dimensions: " + std::to_string(nvar));
}
break;
default:
log::Info(log::Verbosity::high, "unknown influence coef type");
break;
}
print_basic_set(bset, "lp influenced problem:");
}
}
}
log::Info(log::Verbosity::high, "Leave isl_influence_set_coef");
return bset;
}

int akg_isl_influence_maxvar(struct isl_sched_graph *graph) {
log::Info(log::Verbosity::high, "Entering akg_isl_influence_maxar");
int maxvar = 0;
int previous = maxvar;
int var;
isl_influence_list *inf_list = graph->inf_list;
isl_influence_equal_list *inf_equal_list = graph->inf_equal_list;

for (int i = 0; NULL != inf_list && i < inf_list->size; i++) {
isl_influence *inf = &inf_list->data[i];
var = inf->sched_dim + 1;
if (maxvar < var) {
maxvar = var;
}
}
for (int i = 0; NULL != inf_equal_list && i < inf_equal_list->size; i++) {
isl_influence_equal *inf_equal = &inf_equal_list->data[i];
var = inf_equal->sched_dim1 + 1;

if (maxvar < var) {
maxvar = var;
}

var = inf_equal->sched_dim2 + 1;

if (maxvar < var) {
maxvar = var;
}
}

log::Info(log::Verbosity::high, "isl_influence_maxvar : " + std::to_string(maxvar) +
" (previous maxvar: " + std::to_string(previous) + ")");
log::Info(log::Verbosity::high, "Leaving akg_isl_influence_maxvar");
return maxvar;
}

int akg_isl_influence_check_coincident(struct isl_sched_graph *graph, isl_vec *sol) {
int coincident = graph->n > 1 ? 1 : 0;
int pos = 0;

log::Info(log::Verbosity::high, "Entering isl_influene_check_coincidence for dimension " +
std::to_string(graph->n_row) +
", number of nodes (statements): " + std::to_string(graph->n));
int *vec = (int *)calloc(graph->n * 3, sizeof(int));

for (int i = 0; i < graph->n; ++i) {
isl_sched_node *node = isl_sched_graph_get_node(graph, i);
int cst_offset = isl_sched_node_cst_coef_offset(node);
int nparam = isl_sched_node_get_nparam(node);
int nvar = isl_sched_node_get_nvar(node);

log::Info(log::Verbosity::high, "graph node:\t" + std::to_string(i));
log::Info(log::Verbosity::high, "cst_offset:\t" + std::to_string(cst_offset));
log::Info(log::Verbosity::high, "nparam:\t\t" + std::to_string(nparam));
log::Info(log::Verbosity::high, "var:\t\t" + std::to_string(nvar));

vec[pos++] = cst_offset;
vec[pos++] = nparam;
vec[pos++] = nvar;
}

int is_equal;

for (int i = 0; i < graph->n; ++i)
for (int j = i; j < graph->n; ++j) {
if (j == i) continue;

log::Info(log::Verbosity::high,
"calculating coincidence for statement " + std::to_string(i) + " and " + std::to_string(j));

int cst_offset_S0 = vec[i * 3];
int cst_offset_S1 = vec[j * 3];

is_equal = isl_influence_int_eq(sol, 1 + cst_offset_S0, 1 + cst_offset_S1);

log::Info(log::Verbosity::high, "S_" + std::to_string(i) + "_i" + std::to_string(cst_offset_S0) + ", S_" +
std::to_string(j) + "_i" + std::to_string(cst_offset_S1) +
", cst coefficient equal: " + std::to_string(is_equal));

if (is_equal == 0) {
coincident = 0;
break;
}

log::Info(log::Verbosity::high, "param coefficient(s) equal:");

int nparam_S0 = vec[i * 3 + 1];
int nparam_S1 = vec[j * 3 + 1];

if (nparam_S0 != 0 && nparam_S1 != 0 && nparam_S0 == nparam_S1) {
int nparam_pos_0 = cst_offset_S0 - nparam_S0;
int nparam_pos_1 = cst_offset_S1 - nparam_S1;

for (int k = nparam_pos_0, l = nparam_pos_1; k < nparam_pos_0 + nparam_S0; k++, l++) {
is_equal = isl_influence_int_eq(sol, 1 + k, 1 + l);
log::Info(log::Verbosity::high, "S_" + std::to_string(i) + "_i" + std::to_string(k) + ", S_" +
std::to_string(j) + "_i" + std::to_string(l) +
"equal: " + std::to_string(is_equal));
if (is_equal == 0) {
coincident = 0;
break;
}
}
} else {
log::Info(log::Verbosity::high, " no parameteres found or number of parameters distinct for each statement.");
}

log::Info(log::Verbosity::high, "variable coefficients equal:");

int nvar_S0 = vec[i * 3 + 2];
int nvar_S1 = vec[j * 3 + 2];

if (nvar_S0 != nvar_S1) {
coincident = 0;
break;
}

int nvar_pos_0 = cst_offset_S0 - nparam_S0 - 2 * nvar_S0;
int nvar_pos_1 = cst_offset_S1 - nparam_S1 - 2 * nvar_S1;

for (int k = nvar_pos_0, l = nvar_pos_1; k < nvar_pos_0 + 2 * nvar_S0; k++, l++) {
is_equal = isl_influence_int_eq(sol, 1 + k, 1 + l);
log::Info(log::Verbosity::high, "S_" + std::to_string(i) + "_i" + std::to_string(k) + ", S_" +
std::to_string(j) + "_i" + std::to_string(l) +
" equal: " + std::to_string(is_equal));
if (is_equal == 0) {
coincident = 0;
break;
}
}
}

free(vec);

log::Info(log::Verbosity::high, "Leaving isl_check_coindicent result: " + std::to_string(coincident));

return coincident;
}

} // namespace poly
} // namespace ir
} // namespace akg

+ 110
- 0
src/poly/isl_influence.h View File

@@ -0,0 +1,110 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef POLY_ISL_INFLUENCE_H_
#define POLY_ISL_INFLUENCE_H_

// STL
#include <tuple>
#include <vector>

// ISL
#include <isl/cpp.h>

namespace akg {
namespace ir {
namespace poly {

///////////////////////////////////////////////////////////////////////////
// Typedefs for soft constraints
///////////////////////////////////////////////////////////////////////////

#define INF_BLOCK_SIZE 8
enum isl_influence_coeff_type { isl_cst, isl_param, isl_var };
struct isl_influence_equal {
char *statement1;
char *statement2;
int sched_dim;
int sched_dim1;
int sched_dim2;
isl_influence_coeff_type type;
int coef_dim1;
int coef_dim2;
int scanned;
};
typedef struct isl_influence_equal isl_influence_equal;

struct isl_influence_equal_list {
int size;
int mem;
struct isl_influence_equal *data;
};
typedef struct isl_influence_equal_list isl_influence_equal_list;

struct isl_influence {
char *statement_name;
isl_influence_coeff_type type;
int sched_dim;
int coef_dim;
int val;
int scanned;
};
typedef struct isl_influence isl_influence;

struct isl_influence_list {
int size;
int mem;
struct isl_influence *data;
};
typedef struct isl_influence_list isl_influence_list;

struct isl_influence_sol {
int size;
int mem;
int *data;
};
typedef struct isl_influence_sol isl_influence_sol;

struct isl_influence_sol_list {
int size;
int mem;
isl_influence_sol *data;
};
typedef struct isl_influence_sol_list isl_influence_sol_list;

void akg_isl_influence_enable(isl_ctx *ctx);
void akg_isl_influence_disable(isl_ctx *ctx);

int akg_isl_influence_maxvar(struct isl_sched_graph *graph);

isl_basic_set *akg_isl_influence_set_coef(isl_ctx *ctx, struct isl_sched_graph *graph, isl_basic_set *bset);
isl_basic_set *akg_isl_influence_set_equal(isl_ctx *ctx, struct isl_sched_graph *graph, isl_basic_set *bset);

int akg_isl_influence_check_coincident(struct isl_sched_graph *graph, isl_vec *sol);

__isl_give isl_schedule *akg_isl_schedule_constraints_compute_schedule_influence(
__isl_take isl_schedule_constraints *sc, isl_influence_list *inf_coef, isl_influence_equal_list *inf_equal);

void *isl_influence_list_free(isl_influence_list *inf_list);
void *isl_influence_equal_list_free(isl_influence_equal_list *inf_equal_list);
struct isl_sched_graph *akg_isl_influence_sol_list_free(struct isl_sched_graph *graph);
struct isl_sched_graph *akg_isl_influence_sol_add_elem(isl_vec *sol, struct isl_sched_graph *graph);
int akg_isl_influence_sol_get_elem(int sched, int pos, struct isl_sched_graph *graph);

} // namespace poly
} // namespace ir
} // namespace akg

#endif // POLY_ISL_INFLUENCE_H_

+ 1019
- 0
src/poly/isl_util.cc
File diff suppressed because it is too large
View File


+ 184
- 0
src/poly/isl_util.h View File

@@ -0,0 +1,184 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef POLY_ISL_UTIL_H_
#define POLY_ISL_UTIL_H_

#include <vector>

#include "isl/cpp.h"

#include "poly/pass_info.h"
#include "poly/scop_info.h"

// Hardcore isl functions

namespace akg {
namespace ir {
namespace poly {

////////////////////////////////////////////////////////////////////////////////
// Misc.
////////////////////////////////////////////////////////////////////////////////

long isl_set_plain_get_num_si(const isl::set &s, int dim);
long isl_set_plain_get_num_si(__isl_keep isl_set *const set, int dim);

std::vector<int> extract_upper_bounds(const isl::set &s, const std::vector<int> &dimensions);
std::vector<int> extract_upper_bounds(__isl_keep isl_set *const set, const std::vector<int> &dimensions);

// Combining isl_spaces
isl::space isl_space_set_cat(const isl::space &left, const isl::space &right);
__isl_give isl_space *isl_space_set_cat(__isl_take isl_space *left, __isl_take isl_space *right);

// Utilities for isl_multi_union_pw_aff*
isl::multi_union_pw_aff isl_multi_union_pw_aff_cat(const isl::multi_union_pw_aff &left,
const isl::multi_union_pw_aff &right);
__isl_give isl_multi_union_pw_aff *isl_multi_union_pw_aff_cat(__isl_take isl_multi_union_pw_aff *left,
__isl_take isl_multi_union_pw_aff *right);

isl::multi_union_pw_aff isl_multi_union_pw_aff_insert(const isl::multi_union_pw_aff &aff, unsigned pos,
const isl::union_pw_aff &el);
__isl_give isl_multi_union_pw_aff *isl_multi_union_pw_aff_insert(__isl_take isl_multi_union_pw_aff *aff, unsigned pos,
__isl_take isl_union_pw_aff *el);

////////////////////////////////////////////////////////////////////////////////
// isl_schedule_node utilities
////////////////////////////////////////////////////////////////////////////////

bool isl_schedule_node_is_band(const isl::schedule_node &node);
bool isl_schedule_node_is_sequence(const isl::schedule_node &node);
bool isl_schedule_node_has_single_child(const isl::schedule_node &node);
bool isl_schedule_node_band_can_unsplit(const isl::schedule_node &node);

isl_bool isl_schedule_node_is_band(__isl_keep isl_schedule_node *const node);
isl_bool isl_schedule_node_is_sequence(__isl_keep isl_schedule_node *const node);
isl_bool isl_schedule_node_has_single_child(__isl_keep isl_schedule_node *const node);
isl_bool isl_schedule_node_band_can_unsplit(__isl_keep isl_schedule_node *const band);

////////////////////////////////////////////////////////////////////////////////
// isl_schedule_node_band utilities
////////////////////////////////////////////////////////////////////////////////

__isl_give isl_bool *isl_schedule_node_band_get_coincidence(__isl_keep isl_schedule_node *const band);
__isl_give isl_schedule_node *isl_schedule_node_band_set_coincidence(__isl_take isl_schedule_node *band,
__isl_take isl_bool *const coincidence);
__isl_give isl_schedule_node *isl_schedule_node_band_copy_properties(__isl_take isl_schedule_node *band,
__isl_keep isl_schedule_node *const original);
__isl_give isl_schedule_node *isl_schedule_node_band_replace_partial_schedule(
__isl_take isl_schedule_node *band, __isl_take isl_multi_union_pw_aff *schedule, isl_bool keep_properties);
__isl_give isl_set *isl_schedule_node_band_lexmax(__isl_keep isl_schedule_node *const band);

////////////////////////////////////////////////////////////////////////////////
// isl_schedule_node_band transformations
////////////////////////////////////////////////////////////////////////////////

// These functions only preserve permutable/coincidence.
// AST options are not preserved.

// interchange dimensions 'first' and 'second'
isl::schedule_node isl_schedule_node_band_interchange(const isl::schedule_node &band, int first, int second);
__isl_give isl_schedule_node *isl_schedule_node_band_interchange(__isl_take isl_schedule_node *band, int first,
int second);

// strip mine only dimension 'dim'
isl::schedule_node isl_schedule_node_band_stripmine(const isl::schedule_node &band, int dim, int value);
isl::schedule_node isl_schedule_node_band_stripmine(const isl::schedule_node &band, int dim, const isl::val &value);

__isl_give isl_schedule_node *isl_schedule_node_band_stripmine(__isl_take isl_schedule_node *band, int dim, int value);
__isl_give isl_schedule_node *isl_schedule_node_band_stripmine(__isl_take isl_schedule_node *band, int dim,
__isl_take isl_val *value);

// collapse dimensions 'dim' and 'dim + 1'
isl::schedule_node isl_schedule_node_band_collapse(const isl::schedule_node &band, int dim);
__isl_give isl_schedule_node *isl_schedule_node_band_collapse(__isl_take isl_schedule_node *band, int dim);

// modulo/scale/scale down a given dimension of a target statement
isl::schedule_node isl_schedule_node_band_fine_mod(const isl::schedule_node &band, const std::string &name,
int dimension, int value);
isl::schedule_node isl_schedule_node_band_fine_scale(const isl::schedule_node &band, const std::string &name,
int dimension, int value);
isl::schedule_node isl_schedule_node_band_fine_scale_down(const isl::schedule_node &band, const std::string &name,
int dimension, int value);
isl::schedule_node isl_schedule_node_band_fine_mod(const isl::schedule_node &band, const std::string &name,
int dimension, const isl::val &value);
isl::schedule_node isl_schedule_node_band_fine_scale(const isl::schedule_node &band, const std::string &name,
int dimension, const isl::val &value);
isl::schedule_node isl_schedule_node_band_fine_scale_down(const isl::schedule_node &band, const std::string &name,
int dimension, const isl::val &value);

__isl_give isl_schedule_node *isl_schedule_node_band_fine_mod(__isl_take isl_schedule_node *band, const char *name,
int dimension, int value);
__isl_give isl_schedule_node *isl_schedule_node_band_fine_scale(__isl_take isl_schedule_node *band, const char *name,
int dimension, int value);
__isl_give isl_schedule_node *isl_schedule_node_band_fine_scale_down(__isl_take isl_schedule_node *band,
const char *name, int dimension, int value);
__isl_give isl_schedule_node *isl_schedule_node_band_fine_mod(__isl_take isl_schedule_node *band, const char *name,
int dimension, __isl_take isl_val *value);
__isl_give isl_schedule_node *isl_schedule_node_band_fine_scale(__isl_take isl_schedule_node *band, const char *name,
int dimension, __isl_take isl_val *value);
__isl_give isl_schedule_node *isl_schedule_node_band_fine_scale_down(__isl_take isl_schedule_node *band,
const char *name, int dimension,
__isl_take isl_val *value);

////////////////////////////////////////////////////////////////////////////////
// schedule tree transformations (on a schedule_node)
////////////////////////////////////////////////////////////////////////////////

// Merge two nested isl_schedule_node_band.
// Assuming the input schedule_node_band has only one child and this child is an isl_schedule_node_band.
isl::schedule_node isl_schedule_node_band_unsplit(const isl::schedule_node &band);
__isl_give isl_schedule_node *isl_schedule_node_band_unsplit(__isl_take isl_schedule_node *band);

// Call isl_schedule_node_band_unsplit() until it is not possible
isl::schedule_node isl_schedule_node_band_fully_unsplit(const isl::schedule_node &band);
__isl_give isl_schedule_node *isl_schedule_node_band_fully_unsplit(__isl_take isl_schedule_node *band);

// Call isl_schedule_node_band_split() until it is not possible
isl::schedule_node isl_schedule_node_band_fully_split(const isl::schedule_node &band);
__isl_give isl_schedule_node *isl_schedule_node_band_fully_split(__isl_take isl_schedule_node *band);

////////////////////////////////////////////////////////////////////////////////
// isl_schedule transformations
////////////////////////////////////////////////////////////////////////////////

isl::schedule isl_schedule_collapse(const isl::schedule &schedule, int first, int last);
__isl_give isl_schedule *isl_schedule_collapse(__isl_take isl_schedule *schedule, int first, int last);

////////////////////////////////////////////////////////////////////////////////
// "Readable" strings conversions
////////////////////////////////////////////////////////////////////////////////

// Simple C code
std::string to_c_code_string(const isl::schedule &sch);
std::string to_c_code_string(__isl_keep isl_schedule *const schedule);
std::string to_c_code_string(const isl::schedule_constraints &c);
std::string to_c_code_string(__isl_keep isl_schedule_constraints *const constraints);

// isl_*_to_str functions return isl formatted strings!
std::string to_block_string(const isl::schedule &s);
std::string to_block_string(__isl_keep isl_schedule *const schedule);

std::string to_block_string(const isl::union_map &map);
std::string to_block_string(__isl_keep isl_union_map *const map);

std::string to_block_string(const isl::schedule_constraints &constraints);
std::string to_block_string(__isl_keep isl_schedule_constraints *const constraints);

} // namespace poly
} // namespace ir
} // namespace akg

#endif // POLY_ISL_UTIL_H_

+ 90
- 0
src/poly/log_util.cc View File

@@ -0,0 +1,90 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "poly/log_util.h"

#include <dmlc/logging.h>

namespace akg {
namespace ir {
namespace poly {
namespace log {

////////////////////////////////////////////////////////////////////////////////
// Verbosity levels
////////////////////////////////////////////////////////////////////////////////

static Verbosity akg_poly_verbosity_level = Verbosity::silent;

Verbosity GetVerbosityLevel(void) { return akg_poly_verbosity_level; }

void SetVerbosityLevel(Verbosity level) { akg_poly_verbosity_level = level; }

////////////////////////////////////////////////////////////////////////////////
// Logging functions
////////////////////////////////////////////////////////////////////////////////

void Warn(const std::string &message) { LOG(WARNING) << text_yellow << message << text_reset; }

// Our errors should not be fatal so we still log as a warning.
void Error(const std::string &message) { LOG(WARNING) << text_red << message << text_reset; }

void Info(const std::string &message) { LOG(INFO) << message << text_reset; }

void Warn(const std::stringstream &stream) {
const std::string &message = stream.str();
Warn(message);
}

void Error(const std::stringstream &stream) {
const std::string &message = stream.str();
Error(message);
}

void Info(const std::stringstream &stream) {
const std::string &message = stream.str();
Info(message);
}

// clang-format off
#define _define_logging_wrappers(func) \
void func(const Verbosity level, const std::string &message) { \
if (akg_poly_verbosity_level >= level) { \
func(message); \
} \
} \
void func(const Verbosity level, const std::stringstream &stream) { \
if (akg_poly_verbosity_level >= level) { \
func(stream); \
} \
} \
void func(const int level, const std::string &message) { \
func(static_cast<Verbosity>(level), message); \
} \
void func(const int level, const std::stringstream &stream) { \
func(static_cast<Verbosity>(level), stream); \
}

_define_logging_wrappers(Info)
_define_logging_wrappers(Warn)
_define_logging_wrappers(Error)

#undef _declare_logging_functions_int_level
// clang-format on

} // namespace log
} // namespace poly
} // namespace ir
} // namespace akg

+ 135
- 0
src/poly/log_util.h View File

@@ -0,0 +1,135 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef POLY_LOG_UTIL_H_
#define POLY_LOG_UTIL_H_

#include <string>
#include <sstream>

namespace akg {
namespace ir {
namespace poly {
namespace log {

////////////////////////////////////////////////////////////////////////////////
// Verbosity levels
////////////////////////////////////////////////////////////////////////////////

enum class Verbosity {
silent = 0,
veryLow,
low,
medium,
high,
veryHigh,
};

Verbosity GetVerbosityLevel(void);
void SetVerbosityLevel(Verbosity level);

////////////////////////////////////////////////////////////////////////////////
// Log colors
////////////////////////////////////////////////////////////////////////////////

#ifdef AKG_POLY_LOG_WITH_COLORS
#define text_reset "\033[0m"
#define text_bold "\033[1m"
#define text_dim "\033[2m"
#define text_italic "\033[3m"
#define text_underline "\033[4m"
#define text_blink "\033[5m"
#define text_rapid_blink "\033[6m"
#define text_reverse "\033[7m"
#define text_conceal "\033[8m"

#define text_black "\033[30m"
#define text_red "\033[31m"
#define text_green "\033[32m"
#define text_yellow "\033[33m"
#define text_blue "\033[34m"
#define text_magenta "\033[35m"
#define text_cyan "\033[36m"
#define text_white "\033[37m"

#define text_bright_black "\033[90m"
#define text_bright_red "\033[91m"
#define text_bright_green "\033[92m"
#define text_bright_yellow "\033[93m"
#define text_bright_blue "\033[94m"
#define text_bright_magenta "\033[95m"
#define text_bright_cyan "\033[96m"
#define text_bright_white "\033[97m"
#else
#define text_reset ""
#define text_bold ""
#define text_dim ""
#define text_italic ""
#define text_underline ""
#define text_blink ""
#define text_rapid_blink ""
#define text_reverse ""
#define text_conceal ""

#define text_black ""
#define text_red ""
#define text_green ""
#define text_yellow ""
#define text_blue ""
#define text_magenta ""
#define text_cyan ""
#define text_white ""

#define text_bright_black ""
#define text_bright_red ""
#define text_bright_green ""
#define text_bright_yellow ""
#define text_bright_blue ""
#define text_bright_magenta ""
#define text_bright_cyan ""
#define text_bright_white ""
#endif

////////////////////////////////////////////////////////////////////////////////
// Local logging functions
////////////////////////////////////////////////////////////////////////////////

void Info(const std::string &message);
void Info(const std::stringstream &stream);
void Info(const int level, const std::string &message);
void Info(const int level, const std::stringstream &stream);
void Info(const Verbosity level, const std::string &message);
void Info(const Verbosity level, const std::stringstream &stream);

void Warn(const std::string &message);
void Warn(const std::stringstream &stream);
void Warn(const int level, const std::string &message);
void Warn(const int level, const std::stringstream &stream);
void Warn(const Verbosity level, const std::string &message);
void Warn(const Verbosity level, const std::stringstream &stream);

void Error(const std::string &message);
void Error(const std::stringstream &stream);
void Error(const int level, const std::string &message);
void Error(const int level, const std::stringstream &stream);
void Error(const Verbosity level, const std::string &message);
void Error(const Verbosity level, const std::stringstream &stream);

} // namespace log
} // namespace poly
} // namespace ir
} // namespace akg

#endif // POLY_LOG_UTIL_H_

+ 6
- 0
src/poly/pass_mgr_strategy.h View File

@@ -22,6 +22,7 @@

#include "poly/schedule_pass/init_schedule.h"
#include "poly/schedule_pass/compute_schedule.h"
#include "poly/schedule_pass/constrain_schedule.h"

namespace akg {
namespace ir {
@@ -37,6 +38,11 @@ class PassMgrStrategy {
}
void RegisterNormalizationPasses() { RegisterPass(std::make_shared<InitSchedule>(pass_info_, scop_info_)); }
void RegisterSchedulingPasses() { RegisterPass(std::make_shared<ComputeSchedule>(pass_info_, scop_info_)); }
void RegisterConstrainedScheduling() {
if (scop_info_.user_config_.GetEnableMindTrick()) {
RegisterPass(std::make_shared<ConstrainSchedule>(pass_info_, scop_info_));
}
}
virtual void RegisterTilingPasses() = 0; // each backend has different achievement
virtual void RegisterMemPromPasses() = 0; // each backend has different achievement
virtual void RegisterPasses() = 0;


+ 4
- 0
src/poly/schedule_pass.h View File

@@ -16,6 +16,8 @@
#ifndef POLY_PASS_H_
#define POLY_PASS_H_

#include <set>
#include <ostream>
#include "poly/isl.h"
#include "poly/scop_info.h"
#include "poly/pass_info.h"
@@ -34,6 +36,8 @@ class SchedulePass {
std::string GetPassName() { return pass_name_; }
std::string pass_name_;
bool restart_{false}; // triggers restart during runtime

std::set<std::string> disabled_passes_;
};

bool LoadScheduleTreeFromFile(const std::string &filename, isl::schedule &schedule);


+ 454
- 0
src/poly/schedule_pass/constrain_schedule.cc View File

@@ -0,0 +1,454 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "poly/schedule_pass/constrain_schedule.h"

#include <unistd.h>
#include <cstdio>
#include <cstdlib>

#include <isl/ctx.h>
#include <isl/schedule.h>

// opendir(), readdir, closedir()
#include <sys/types.h>
#include <dirent.h>

// TVM
#include <tvm/node/node.h>
#include <tvm/node/container.h>

// Local headers
#include "poly/schedule_pass/scheduling_mind_trick.h"
#include "poly/isl_util.h"
#include "poly/log_util.h"

namespace akg {
namespace ir {
namespace poly {

////////////////////////////////////////////////////////////////////////////////
// Constructors
////////////////////////////////////////////////////////////////////////////////

ConstrainSchedule::ConstrainSchedule(PassInfo &pass_info, ScopInfo &scop_info)
: pass_info_(pass_info), scop_info_(scop_info) {
pass_name_ = __FUNCTION__;

InitVerbosityLevel();
const log::Verbosity saved_verbosity = log::GetVerbosityLevel();
log::SetVerbosityLevel(static_cast<log::Verbosity>(verbosity_));
LoadMindTricks();
log::SetVerbosityLevel(saved_verbosity);
}

////////////////////////////////////////////////////////////////////////////////
// Setters
////////////////////////////////////////////////////////////////////////////////

void ConstrainSchedule::AddMindTrick(const std::shared_ptr<SchedulingMindTrick> &mind_trick) {
mind_tricks_.push_back(mind_trick);
}

std::vector<std::string> ConstrainSchedule::MindTricksDirectories(void) {
// We only build the list of directories, directory existence will be checked later on
std::vector<std::string> directories;

// We only add the directory specified via the environment.
const char *const user_directory = std::getenv(env_string_mind_tricks_dir_);
if (user_directory) {
directories.push_back(user_directory);
}
// Add other directories here if necessary...

return directories;
}

void ConstrainSchedule::LoadMindTrickFromFile(const std::string &filename) {
auto mind_trick = std::make_shared<SchedulingMindTrick>(pass_info_, scop_info_, verbosity_);

mind_trick->Load(filename);
// Alternative:
//
// std::ifstream stream(filename);
// stream >> *mind_trick;

if (*mind_trick) {
AddMindTrick(mind_trick);
} else {
Warn("something was wrong with mind_trick " + filename);
}
}

void ConstrainSchedule::LoadMindTricks(void) {
Info(log::Verbosity::low,
text_reverse text_bright_blue " ConstrainSchedule " text_reset text_bright_blue " LoadMindTricks()" text_reset);

// Try to load the user's mind_trick from CLI
const std::string user_mind_trick = scop_info_.user_config_.GetMindTrick();
if (user_mind_trick != "") {
auto mind_trick = std::make_shared<SchedulingMindTrick>(pass_info_, scop_info_, verbosity_);
mind_trick->Parse(user_mind_trick);
if (*mind_trick) {
AddMindTrick(mind_trick);
Info(log::Verbosity::medium, text_bright_magenta "User's mind_trick:\n" + mind_trick->str());
} else {
Warn(log::Verbosity::veryLow, "something was wrong with user's mind_trick");
}
}

// Look for mind_tricks in several directories.
std::vector<std::string> directories = MindTricksDirectories();
for (const std::string &directory_str : directories) {
log::Info(log::Verbosity::medium, "looking for mind tricks in " + directory_str);
DIR *const directory = opendir(directory_str.c_str());
if (directory) {
// We first store the strings in a vecttor because there is no guarantee
// on the order fo the files
std::vector<std::string> files;
for (struct dirent *entry = readdir(directory); entry; entry = readdir(directory)) {
const std::string &filename = std::string(entry->d_name);
if (filename.length() > 5 && filename.compare(filename.length() - 5, 5, ".json") == 0)
files.push_back(filename);
}
if (!files.empty()) {
std::sort(files.begin(), files.end());
for (const std::string &filename : files) {
const std::string &path = directory_str + "/" + filename;
LoadMindTrickFromFile(path);
}
}
closedir(directory);
} else {
log::Error(log::Verbosity::medium, "could not access directory " + directory_str);
}
}

std::stringstream summary;
summary << text_cyan << pass_name_ << " has " << mind_tricks_.size();
summary << (mind_tricks_.size() <= 1 ? " trick" : " tricks");
summary << "up its sleeve";
Info(log::Verbosity::low, summary);
}

////////////////////////////////////////////////////////////////////////////////
// Other
////////////////////////////////////////////////////////////////////////////////

static inline void RunInfo(const std::string &stage, const std::string &kernel_name, const isl::schedule &schedule) {
log::Info(log::Verbosity::low,
text_reverse text_bright_blue " ConstrainSchedule " text_reset text_bright_blue " Run() " + stage);
log::Info(log::Verbosity::low, text_bright_blue "name: " + kernel_name);
log::Info(log::Verbosity::low, text_bright_blue "schedule:\n" + to_block_string(schedule));
log::Info(log::Verbosity::medium, text_bright_blue "schedule (loop nest):\n" + to_c_code_string(schedule));
}

bool ConstrainSchedule::KernelIsEligible(const isl::schedule &sch) const {
const std::string &kernel_name = scop_info_.user_config_.GetKernelName();

// For now, the only criterion for eligibility is the existence of a blacklist exists and
// its containing the current kernel name
const char *const blacklist_path = std::getenv(env_string_mind_tricks_operator_blacklist_);
if (!blacklist_path) {
Info(log::Verbosity::high, env_string_mind_tricks_operator_blacklist_ + std::string(" not set"));
return true;
}

std::fstream file(blacklist_path, std::ios::in);
if (!file.is_open()) {
Warn(log::Verbosity::high, "could not open operator blacklist: " + std::string(blacklist_path));
return true;
}

// Very basic analysis of the lines (no support for extra spaces, etc.)
std::string line;
while (getline(file, line)) {
// Support for "comments" in the blacklist file
if (line[0] == '#') continue;

if (kernel_name == line) return false;
}

return true;
}

bool ConstrainSchedule::IsEnabled(void) {
// PassMgrStrategy already checks this and does not even register the pass
// if akg::ir::poly::UserConfig::GetEnableMindTrick() returns false.
// We check once more in case it was set to true at startup and the runtime
// later decided to change the value.
if (!scop_info_.user_config_.GetEnableMindTrick()) {
Info("ConstrainSchedule was disabled via akg::ir::poly::UserConfig");
return false;
}

const char *const env_mind_tricks = std::getenv(env_string_mind_tricks_enable_);
if (env_mind_tricks && std::string(env_mind_tricks) == "false") {
Info("ConstrainSchedule was disabled via environment variable " + std::string(env_string_mind_tricks_enable_));
return false;
}

return true;
}

isl::schedule ConstrainSchedule::Run(isl::schedule sch) {
if (!IsEnabled()) return sch;

const std::string &target = scop_info_.user_config_.GetTarget();
const std::string &kernel_name = scop_info_.user_config_.GetKernelName();

// We will want to restore the original value
const log::Verbosity saved_verbosity = log::GetVerbosityLevel();
log::SetVerbosityLevel(static_cast<log::Verbosity>(verbosity_));

isl::schedule result = sch;

// Make sure the constraints are available...
// We expect this pass to be right after InitSchedule: most information is
// initialized or computed in InitSchedule. However the constraints are
// usually computed in ComputeSchedule (which we may disable afterwards!).
pass_info_.constraints_ = MakeScheduleConstraints(sch, pass_info_);

// Check whether we want to use ConstrainSchedule on this kernel
const bool eligible = KernelIsEligible(sch);
if (!eligible) {
Warn("ConstrainSchedule will ignore this operator...");
return sch;
}

const std::size_t total = mind_tricks_.size();
RunInfo("input", kernel_name, sch);
std::stringstream summary;
summary << pass_name_ << " has " << total << " tricks up its sleeve";
Info(log::Verbosity::low, summary);

CreateMindTrickTemplate(sch);
if (target == TARGET_CUDA) {
GpuCompilerFlagsTempfileRemove();
}

size_t current = 0;
for (std::shared_ptr<SchedulingMindTrick> &mind_trick : mind_tricks_) {
const std::string name = mind_trick->GetName();
current++;

if (static_cast<log::Verbosity>(verbosity_) >= log::Verbosity::low) {
std::stringstream stream;
stream << text_reverse text_magenta " SchedulingMindTrick " text_reset text_magenta " ";
stream << "[" << current << "/" << total << "] ";
stream << name;

const std::string &str = stream.str();
Info(str, false);
}

const bool matches = mind_trick->Matches(sch);
if (!matches) {
Info(log::Verbosity::veryHigh, text_dim text_yellow + mind_trick->str());
continue;
}

const bool needs_check = mind_trick->NeedsScheduleCheck();
if (!needs_check) Info(log::Verbosity::veryLow, text_bright_yellow "MindTrick requests no schedule check!");

const isl::schedule &candidate = mind_trick->Apply(sch);
const bool has_schedule = mind_trick->HasSchedule();
if (!has_schedule) {
Warn(log::Verbosity::low, "'" + name + "': no schedule available");
continue;
}

const bool valid = !needs_check || CheckSchedule(candidate);
if (valid) {
if (needs_check) Info(log::Verbosity::low, text_green "schedule is valid!");
result = candidate;
ExtractMindTrickInfo(mind_trick);
LogMindTrick(mind_trick);
if (target == TARGET_CUDA) {
GpuCompilerFlagsTempfileCreate(mind_trick);
}
break;
} else {
Info(log::Verbosity::high, text_dim text_yellow + mind_trick->str());
}
}

RunInfo("output", kernel_name, result);
log::SetVerbosityLevel(saved_verbosity);

return result;
}

void ConstrainSchedule::ExtractMindTrickInfo(const std::shared_ptr<SchedulingMindTrick> &mind_trick) {
const std::string &target = scop_info_.user_config_.GetTarget();
if (target == TARGET_CUDA) {
ExtractGpuConfig(mind_trick);
}

ExtractDisabledPasses(mind_trick);
ExtractAttrs(mind_trick);
}

void ConstrainSchedule::LogMindTrick(const std::shared_ptr<SchedulingMindTrick> &mind_trick) {
const std::string kernel_name = scop_info_.user_config_.GetKernelName();
const std::string mind_trick_name = mind_trick->GetName();
const std::string &output = mind_trick->str();

Info(log::Verbosity::veryLow, text_reverse text_bright_blue " ConstrainSchedule ", false);
Info(log::Verbosity::veryLow, text_bright_blue "using schedule from \'" + mind_trick_name + "\'");
Info(log::Verbosity::medium, text_dim text_green + mind_trick->str());

scop_info_.user_config_.SetConstrainedSchedulingOutput(output);
}

void ConstrainSchedule::ExtractGpuConfig(const std::shared_ptr<SchedulingMindTrick> &mind_trick) {
const std::string blocks = mind_trick->GetGpuBlocks();
const std::string threads = mind_trick->GetGpuThreads();
if (blocks != "" && threads != "") {
scop_info_.user_config_.SetBlockConfig(blocks);
scop_info_.user_config_.SetThreadConfig(threads);
}
}

void ConstrainSchedule::ExtractDisabledPasses(const std::shared_ptr<SchedulingMindTrick> &mind_trick) {
// We always want to disable ComputeSchedule when using a mind_trick!
disabled_passes_.insert("ComputeSchedule");
// Then maybe disable other passes...
const std::set<std::string> &passes = mind_trick->GetDisabledPasses();
disabled_passes_.insert(passes.begin(), passes.end());
}

void ConstrainSchedule::CreateMindTrickTemplate(const isl::schedule &sch) {
const char *const env_templates = std::getenv(env_string_mind_tricks_templates_);
if (!env_templates || std::string(env_templates) != "true") {
return;
}

const std::string kernel_name = scop_info_.user_config_.GetKernelName();
const std::string &filename = "mindtrick-template_" + kernel_name + ".json";

std::ofstream output(filename);
output << SchedulingMindTrick::TemplateString(scop_info_, sch);
output.close();
}

void ConstrainSchedule::ExtractAttrs(const std::shared_ptr<SchedulingMindTrick> &mind_trick) {
const air::Map<std::string, air::NodeRef> &attrs = mind_trick->GetAttrs();
scop_info_.user_config_.SetAttrs(attrs);
}

void ConstrainSchedule::InitVerbosityLevel(void) {
#ifdef AKG_CONSTRAIN_SCHEDULE_VERBOSITY
{
constexpr int preprocessor_verbosity = AKG_CONSTRAIN_SCHEDULE_VERBOSITY;
if (preprocessor_verbosity >= 0) verbosity_ = preprocessor_verbosity;
}
#endif
{
const char *const env_verbosity_string = std::getenv(env_string_mind_tricks_verbosity_);
if (env_verbosity_string) {
const int env_verbosity = std::stoi(env_verbosity_string);
if (env_verbosity >= 0) verbosity_ = env_verbosity;
}
}
{
const int attrs_verbosity = scop_info_.user_config_.GetConstrainScheduleVerbosity();
if (attrs_verbosity >= 0) verbosity_ = attrs_verbosity;
}
}
////////////////////////////////////////////////////////////////////////////////
// GPU Compiler flags
////////////////////////////////////////////////////////////////////////////////

std::string ConstrainSchedule::GpuCompilerFlagsTempfileName(void) const {
std::stringstream filename_stream;
filename_stream << ".akg_gpu_compiler_flags_" << getpid();

const std::string &filename = filename_stream.str();
return filename;
}

void ConstrainSchedule::GpuCompilerFlagsTempfileRemove(void) {
const std::string &filename = GpuCompilerFlagsTempfileName();
std::remove(filename.c_str());
}

void ConstrainSchedule::GpuCompilerFlagsTempfileCreate(const std::shared_ptr<SchedulingMindTrick> &mind_trick) {
const std::vector<std::string> &flags = mind_trick->GetGpuCompilerFlags();
if (flags.empty()) {
return;
}

const std::string &filename = GpuCompilerFlagsTempfileName();
std::ofstream tempfile(filename);
for (const std::string &flag : flags) {
tempfile << flag << std::endl;
}
tempfile.close();
}

////////////////////////////////////////////////////////////////////////////////
// Logging
////////////////////////////////////////////////////////////////////////////////

std::string ConstrainSchedule::LogPrefixText(const bool prefix) const {
if (!prefix) {
return "";
}

const std::string &kernel_name = scop_info_.user_config_.GetKernelName();
const std::string &prefix_text = "'" + kernel_name + "': ";
return prefix_text;
}

// clang-format off
#define define_constrain_schedule_log_wrappers(func) \
void ConstrainSchedule::func(const std::string &message, const bool prefix) const { \
const std::string &prefix_text = LogPrefixText(prefix); \
log::func(prefix_text + message); \
} \
\
void ConstrainSchedule::func(const std::stringstream &stream, const bool prefix) const { \
const std::string &message = stream.str(); \
func(message, prefix); \
} \
\
void ConstrainSchedule::func(const int level, const std::string &message, const bool prefix) const { \
const std::string &prefix_text = LogPrefixText(prefix); \
log::func(level, prefix_text + message); \
} \
void ConstrainSchedule::func(const int level, const std::stringstream &stream, const bool prefix) const { \
const std::string &message = stream.str(); \
func(level, message, prefix); \
} \
void ConstrainSchedule::func(const log::Verbosity level, const std::string &message, const bool prefix) const { \
const std::string &prefix_text = LogPrefixText(prefix); \
log::func(level, prefix_text + message); \
} \
void ConstrainSchedule::func(const log::Verbosity level, const std::stringstream &stream, const bool prefix) const { \
const std::string &message = stream.str(); \
func(level, message, prefix); \
}

define_constrain_schedule_log_wrappers(Info)
define_constrain_schedule_log_wrappers(Warn)
define_constrain_schedule_log_wrappers(Error)

#undef define_constrain_schedule_log_wrappers
// clang-format on

} // namespace poly
} // namespace ir
} // namespace akg

+ 115
- 0
src/poly/schedule_pass/constrain_schedule.h View File

@@ -0,0 +1,115 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef POLY_CONSTRAIN_SCHEDULE_H_
#define POLY_CONSTRAIN_SCHEDULE_H_

// STL headers
#include <vector>
#include <memory>

// Other libraries
#include "isl/cpp.h"

// AKG headers
#include "poly/log_util.h"
#include "poly/schedule_pass.h"
#include "poly/schedule_pass/scheduling_mind_trick.h"

namespace akg {
namespace ir {
namespace poly {

class ConstrainSchedule : public SchedulePass {
public:
ConstrainSchedule(PassInfo &pass_info, ScopInfo &scop_info);
~ConstrainSchedule() {}

virtual isl::schedule Run(isl::schedule sch);
bool IsEnabled(void);

private:
bool CheckSchedule(const isl::schedule &) const;

bool KernelIsEligible(const isl::schedule &sch) const;

void LoadMindTricks(void);
void LoadMindTrickFromFile(const std::string &filename);

void AddMindTrick(const std::shared_ptr<SchedulingMindTrick> &mind_trick);
void ExtractMindTrickInfo(const std::shared_ptr<SchedulingMindTrick> &mind_trick);
void LogMindTrick(const std::shared_ptr<SchedulingMindTrick> &mind_trick);
void ExtractGpuConfig(const std::shared_ptr<SchedulingMindTrick> &mind_trick);
void ExtractDisabledPasses(const std::shared_ptr<SchedulingMindTrick> &mind_trick);
void ExtractAttrs(const std::shared_ptr<SchedulingMindTrick> &mind_trick);

void CreateMindTrickTemplate(const isl::schedule &sch);

void InitVerbosityLevel(void);

std::string GpuCompilerFlagsTempfileName(void) const;
void GpuCompilerFlagsTempfileRemove(void);
void GpuCompilerFlagsTempfileCreate(const std::shared_ptr<SchedulingMindTrick> &mind_trick);

PassInfo &pass_info_;
ScopInfo &scop_info_;
std::vector<std::shared_ptr<SchedulingMindTrick>> mind_tricks_;

///////////////////////////////////////////////////////////////////////////
// MindTrick paths
///////////////////////////////////////////////////////////////////////////

static std::vector<std::string> MindTricksDirectories(void);

///////////////////////////////////////////////////////////////////////////
// Supported environment variables
///////////////////////////////////////////////////////////////////////////

static constexpr const char *const env_string_mind_tricks_enable_ = "MS_AKG_MIND_TRICKS";
static constexpr const char *const env_string_mind_tricks_dir_ = "MS_AKG_MIND_TRICKS_DIR";
static constexpr const char *const env_string_mind_tricks_verbosity_ = "MS_AKG_MIND_TRICKS_VERBOSITY";
static constexpr const char *const env_string_mind_tricks_templates_ = "MS_AKG_MIND_TRICKS_TEMPLATES";
static constexpr const char *const env_string_mind_tricks_operator_blacklist_ =
"MS_AKG_MIND_TRICKS_OPERATOR_BLACKLIST";

///////////////////////////////////////////////////////////////////////////
// Logging
///////////////////////////////////////////////////////////////////////////

int verbosity_{0};

std::string LogPrefixText(const bool prefix = true) const;

// clang-format off
#define declare_constrain_schedule_log_wrappers(func) \
void func(const std::string &message, const bool prefix = true) const; \
void func(const std::stringstream &stream, const bool prefix = true) const; \
void func(const int level, const std::string &message, const bool prefix = true) const; \
void func(const int level, const std::stringstream &stream, const bool prefix = true) const; \
void func(const akg::ir::poly::log::Verbosity level, const std::string &message, const bool prefix = true) const; \
void func(const akg::ir::poly::log::Verbosity level, const std::stringstream &stream, const bool prefix = true) const;

declare_constrain_schedule_log_wrappers(Info) declare_constrain_schedule_log_wrappers(Warn)
declare_constrain_schedule_log_wrappers(Error)

#undef declare_constrain_schedule_log_wrappers
// clang-format on
};

} // namespace poly
} // namespace ir
} // namespace akg

#endif // POLY_CONSTRAIN_SCHEDULE_H_

+ 237
- 0
src/poly/schedule_pass/constrain_schedule_checking.cc View File

@@ -0,0 +1,237 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "poly/schedule_pass/constrain_schedule.h"

#include "poly/isl_util.h"
#include "poly/log_util.h"

namespace akg {
namespace ir {
namespace poly {

////////////////////////////////////////////////////////////////////////////////
// local declarations
////////////////////////////////////////////////////////////////////////////////

#ifndef _AKG_DUMMY_DATE_STRING
#define _AKG_DUMMY_DATE_STRING "_akg_dummy_date"
#endif

static const char *const _akg_dummy_date_string = _AKG_DUMMY_DATE_STRING;

// This should be changed later.
static int _verbosity = 2;

__isl_give isl_schedule *isl_schedule_constraints_silently_compute_schedule(
__isl_take isl_schedule_constraints *constraints);
__isl_give isl_stat map_statement_to_dummy_date(__isl_take isl_map *map, void *user);
__isl_give isl_union_map *isl_map_domain_to_dummy_date(__isl_take isl_union_set *domain);
__isl_give isl_union_map *isl_schedule_get_temporal_accesses(__isl_keep isl_schedule *schedule);
__isl_give isl_union_map *isl_schedule_get_temporal_dependences(__isl_keep isl_schedule *schedule);
__isl_give isl_schedule *isl_schedule_compute_verifying_schedule(
__isl_keep isl_schedule *const schedule, __isl_keep isl_schedule_constraints *const initial_constraints);
__isl_give isl_stat isl_schedule_check(__isl_keep isl_schedule *const schedule,
__isl_keep isl_schedule_constraints *const initial_constraints);

////////////////////////////////////////////////////////////////////////////////
// Schedule checking functions in ConstrainSchedule public API
////////////////////////////////////////////////////////////////////////////////

bool ConstrainSchedule::CheckSchedule(const isl::schedule &sch) const {
isl_schedule_constraints *const constraints = pass_info_.constraints_.get();
isl_schedule *const schedule = sch.get();

_verbosity = verbosity_;
const isl_stat status = isl_schedule_check(schedule, constraints);

return status == isl_stat_ok;
}

////////////////////////////////////////////////////////////////////////////////
// local definitions
////////////////////////////////////////////////////////////////////////////////

/**
* \brief Silent wrapper for isl_schedule_constraints_compute_schedule().
*/
__isl_give isl_schedule *isl_schedule_constraints_silently_compute_schedule(
__isl_take isl_schedule_constraints *const constraints) {
akg::ir::poly::log::Info(log::Verbosity::high, "silent constraints:\n" + to_block_string(constraints));

isl_ctx *const ctx = isl_schedule_constraints_get_ctx(constraints);
const int previous_behaviour = isl_options_get_on_error(ctx);
isl_options_set_on_error(ctx, ISL_ON_ERROR_CONTINUE);
isl_schedule *const result = isl_schedule_constraints_compute_schedule(constraints);
isl_options_set_on_error(ctx, previous_behaviour);

return result;
}

/**
* \brief Callback for isl_map_list_foreach(); sets the out dim to _dummy_date.
*/
__isl_give isl_stat map_statement_to_dummy_date(__isl_take isl_map *map, void *const user) {
// The output dimension in the input map is empty.
// We want it to be '_akg_dummy_date[0]':
map = isl_map_insert_dims(map, isl_dim_out, 0, 1);
map = isl_map_fix_si(map, isl_dim_out, 0, 0);
map = isl_map_set_tuple_name(map, isl_dim_out, _akg_dummy_date_string);

isl_union_map **result = (isl_union_map **)user;
if (!*result) {
*result = isl_union_map_from_map(map);
} else {
*result = isl_union_map_add_map(*result, map);
}

return isl_stat_ok;
}

/**
* \brief Map statements to dummy dates
*/
__isl_give isl_union_map *isl_map_domain_to_dummy_date(__isl_take isl_union_set *const domain) {
isl_union_map *start = isl_union_map_from_domain(domain);
isl_map_list *list = isl_union_map_get_map_list(start);
isl_union_map *result = NULL;
isl_map_list_foreach(list, map_statement_to_dummy_date, &result);

isl_union_map_free(start);
isl_map_list_free(list);

return result;
}

__isl_give isl_union_map *isl_schedule_get_temporal_accesses(__isl_keep isl_schedule *const schedule) {
isl_union_set *const domain = isl_schedule_get_domain(schedule);
return isl_map_domain_to_dummy_date(domain);
}

__isl_give isl_union_map *isl_schedule_get_temporal_dependences(__isl_keep isl_schedule *const schedule) {
/* Prepare inputs for the union_access_info. */
isl_union_map *const sink = isl_schedule_get_temporal_accesses(schedule);
isl_union_map *const source = isl_union_map_copy(sink);
isl_schedule *const schedule_copy = isl_schedule_copy(schedule);

/* Build the union_access_info. */
isl_union_access_info *accesses = isl_union_access_info_from_sink(sink);
accesses = isl_union_access_info_set_schedule(accesses, schedule_copy);
accesses = isl_union_access_info_set_must_source(accesses, isl_union_map_copy(source));
accesses = isl_union_access_info_set_kill(accesses, source);

/* Compute the flow. */
isl_union_flow *const flow = isl_union_access_info_compute_flow(accesses);

/* Extract must dependences. */
isl_union_map *const dependences = isl_union_flow_get_may_dependence(flow);
isl_union_flow_free(flow);

return dependences;

// isl_union_map* const dependences = isl_union_flow_get_full_may_dependence(flow);
// isl_union_flow_free(flow);

// /*
// * At this point, our maps (for instance, with S0 and S1) will be in form:
// * { S0[...] -> [S1[...] -> _akg_dummy_date[0]]: ...; ... }
// * We want maps in the form:
// * { S0[...] -> S1[...]: ...; ... }
// * We need to:
// * - uncurry the nested relation of the "range" of the map
// * (turning the "domain" of the map into a nested relation)
// * - unwrap the nested relation of the "domain" of the map
// */

// /* Uncurry, ditch the "new" range and unwrap our dummy dependences. */
// isl_union_map* const uncurried = isl_union_map_uncurry(dependences);
// isl_union_set* const domain = isl_union_map_domain(uncurried);
// isl_union_map* const unwrapped = isl_union_set_unwrap(domain);

// return unwrapped;
}

/**
* \brief Check a new schedule against previous constraints
*
* Bonus: can also be used to add permutable/coincident metadata and "reshape"
* a schedule tree (initial_constraints may be a null pointer).
*/
__isl_give isl_schedule *isl_schedule_compute_verifying_schedule(
__isl_keep isl_schedule *const schedule, __isl_keep isl_schedule_constraints *const initial_constraints) {
/*
* Principle:
* 1. extract "date" constraints from the proposed schedule
* 2. combine the date constraints with the initial constraints
* 3. attempt to schedule:
* - a schedule can be computed from the combined constraints:
* the proposed schedule is valid
* - no schedule can be computed from the combined constraints:
* the proposed schedule violates the initial constraints
*/
isl_union_map *const dates = isl_schedule_get_temporal_dependences(schedule);

/* Some logging. */
if (initial_constraints)
akg::ir::poly::log::Info(log::Verbosity::high, "initial constraints:\n" + to_block_string(initial_constraints));
else
akg::ir::poly::log::Info(log::Verbosity::high, "initial constraints: none");
akg::ir::poly::log::Info(log::Verbosity::high, "dates:\n" + to_block_string(dates));

isl_schedule_constraints *constraints = NULL;
if (!initial_constraints) {
/*
* Hack/Easter egg: the function will be used to "reshape" the schedule tree
*/

/* Build new schedule constraints. */
isl_union_set *const domain = isl_schedule_get_domain(schedule);
constraints = isl_schedule_constraints_on_domain(domain);
constraints = isl_schedule_constraints_set_validity(constraints, dates);
} else {
/* Combine the schedule constraints. */
constraints = isl_schedule_constraints_copy(initial_constraints);
isl_union_map *const validity = isl_schedule_constraints_get_validity(constraints);
isl_space *const space = isl_union_map_get_space(dates);
isl_union_map *const aligned = isl_union_map_align_params(validity, space);

isl_union_map *const restricted = isl_union_map_union(aligned, dates);
constraints = isl_schedule_constraints_set_validity(constraints, restricted);
}

/* Finally attempt to schedule. */
isl_schedule *const result = isl_schedule_constraints_silently_compute_schedule(constraints);

return result;
}

__isl_give isl_stat isl_schedule_check(__isl_keep isl_schedule *const schedule,
__isl_keep isl_schedule_constraints *const initial_constraints) {
isl_schedule *const result = isl_schedule_compute_verifying_schedule(schedule, initial_constraints);

/* Check whether we managed to schedule something. */
if (result) {
akg::ir::poly::log::Info(log::Verbosity::high, text_blue "schedule seems valid\n" + to_block_string(result));
isl_schedule_free(result);
return isl_stat_ok;
} else {
akg::ir::poly::log::Warn(log::Verbosity::veryLow, "schedule is invalid");
return isl_stat_error;
}
}

} // namespace poly
} // namespace ir
} // namespace akg

+ 1328
- 0
src/poly/schedule_pass/scheduling_mind_trick.cc
File diff suppressed because it is too large
View File


+ 303
- 0
src/poly/schedule_pass/scheduling_mind_trick.h View File

@@ -0,0 +1,303 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef POLY_SCHEDULING_MIND_TRICK_H_
#define POLY_SCHEDULING_MIND_TRICK_H_

// STL includes
#include <ostream>
#include <fstream>
#include <regex>

// External libraries
#include <picojson.h>
#include <isl/cpp.h>

// TVM
#include <tvm/node/container.h>
#include <tvm/node/node.h>

// Internal headers
#include "poly/isl.h"
#include "poly/log_util.h"
#include "poly/schedule_pass.h"
#include "poly/pass_info.h"
#include "poly/scop_info.h"
#include "poly/isl_influence.h"

namespace akg {
namespace ir {
namespace poly {

// Implementation notes
//
// 1. The class is non copyable.
// We use raw pointers to isl objects because the C++ wrapped objects are
// not practical in cases where the isl object is optional and std::optional
// curretly is not an option.
// Hence, we directly manage the isl objects. To avoid further complications
// with copy-constructors, move-constructors, isl_*_free(), isl_*_copy(),
// etc., the class is non copyable.

// data_value: tuple of <scheduling dim, coeff dim, coeff type, list of stmt name>
// using data_value = std::tuple<int, int, isl_influence_coeff_type, std::vector<std::string>>;

// single_data: tuple of <stmt name, scheduling dim, coeff dim, coeff type, value>
using single_data = std::tuple<std::string, int, int, isl_influence_coeff_type, int>;

// div_mod_data
// * if division: tuple of <stmt name, scheduling dim, divisor>
// * if modulo: tuple of <stmt name, scheduling dim, modulo value>
using div_mod_data = std::tuple<std::string, int, int>;

class GpuConfig {
public:
std::vector<int> block_sizes_;
std::vector<int> thread_sizes_;
std::vector<int> block_dimensions_;
std::vector<int> thread_dimensions_;
std::vector<std::string> compiler_flags_;
};

enum class MindTrickType {
none = 0,
manual,
};

std::string to_string(MindTrickType t);
MindTrickType MindTrickTypeFromString(const std::string &str);

class SchedulingMindTrick {
public:
///////////////////////////////////////////////////////////////////////////
// Constructors and similar methods
///////////////////////////////////////////////////////////////////////////

// Highly recommended to call the SchedulingMindTrick(PassInfo&, ScopInfo&)
// constructor in other constructors and in derived classes constructors!
SchedulingMindTrick(PassInfo &pass_info, ScopInfo &scop_info, int verbosity = -1);
~SchedulingMindTrick();

void Load(const std::string &filename);

// Parse JSON representation
void Parse(const picojson::value &json);
void Parse(const std::string &serialized_json);
std::istream &Parse(std::istream &streamed_json);

// Non copyable
SchedulingMindTrick(const SchedulingMindTrick &) = delete;
SchedulingMindTrick &operator=(const SchedulingMindTrick &) = delete;

///////////////////////////////////////////////////////////////////////////
// MindTrick state
///////////////////////////////////////////////////////////////////////////

explicit operator bool() const;

///////////////////////////////////////////////////////////////////////////
// MindTrick use
///////////////////////////////////////////////////////////////////////////

bool Matches(const isl::schedule &sch) const;
isl::schedule Apply(const isl::schedule &sch);

///////////////////////////////////////////////////////////////////////////
// I/O
///////////////////////////////////////////////////////////////////////////

std::string str(void) const;
std::ostream &Output(std::ostream &stream) const;
std::istream &Input(std::istream &stream);

friend std::ostream &operator<<(std::ostream &stream, const SchedulingMindTrick &mind_trick);
friend std::istream &operator>>(std::istream &stream, SchedulingMindTrick &mind_trick);
friend std::string to_string(const SchedulingMindTrick &mind_trick);

static std::string TemplateString(ScopInfo &scop_info, const isl::schedule &schedule,
MindTrickType type = MindTrickType::none);

///////////////////////////////////////////////////////////////////////////
// MindTrick metadata
///////////////////////////////////////////////////////////////////////////

void SetName(const std::string &name);
std::string GetName(void) const;
std::string GetTarget(void) const;

///////////////////////////////////////////////////////////////////////////
// Misc. attributes
///////////////////////////////////////////////////////////////////////////

bool HasSchedule(void) const;
bool NeedsScheduleCheck(void) const;
const air::Map<std::string, air::NodeRef> GetAttrs(void) const;

///////////////////////////////////////////////////////////////////////////
// GPU Mapping
///////////////////////////////////////////////////////////////////////////

void GuessGpuConfig(void);
std::string GetGpuBlocks(void) const;
std::string GetGpuThreads(void) const;
std::vector<std::string> GetGpuCompilerFlags(void) const;

///////////////////////////////////////////////////////////////////////////
// Pass toggling
///////////////////////////////////////////////////////////////////////////

const std::set<std::string> &GetDisabledPasses(void) const;

///////////////////////////////////////////////////////////////////////////
// Mind Trick Type
///////////////////////////////////////////////////////////////////////////

MindTrickType GetType(void) const;
void SetType(MindTrickType type);

protected:
///////////////////////////////////////////////////////////////////////////
// JSON parsing
///////////////////////////////////////////////////////////////////////////

static picojson::value maybe(const picojson::value &node, const std::string &key);
static picojson::value maybe(const picojson::object &node, const std::string &key);

static std::vector<int> to_int_vector(const picojson::value &node);
static std::vector<int> to_int_vector(const picojson::array &node);
static std::vector<std::string> to_string_vector(const picojson::value &node);
static std::vector<std::string> to_string_vector(const picojson::array &node);

void ParseName(const picojson::value &node);
void ParseDisabledPasses(const picojson::value &node);
void ParsePattern(const picojson::value &node);
void ParseOperatorName(const picojson::value &node);
void ParseDomain(const picojson::value &node);
void ParseSchedule(const picojson::value &node);
void ParseGpuInfo(const picojson::value &node);
void ParseExplicitTree(const picojson::value &node);
void ParseCheckSchedule(const picojson::value &node);
void ParseAttrs(const picojson::value &node);
void ParseVerbosity(const picojson::value &node);
void ParseSoftConstraints(const picojson::value &node);

///////////////////////////////////////////////////////////////////////////
// Schedule
///////////////////////////////////////////////////////////////////////////

isl::schedule GetSchedule(void) const;

///////////////////////////////////////////////////////////////////////////
// Schedule suggestion
///////////////////////////////////////////////////////////////////////////

void BuildSuggestedSchedule(void);

// Various helpers to build the suggested schedule
__isl_give isl_schedule *ComputeScheduleSuggestion(void);
__isl_give isl_schedule *PrepareMappingOuterBand(__isl_take isl_schedule *schedule, GpuConfig &info);

///////////////////////////////////////////////////////////////////////////
// Soft constraints
///////////////////////////////////////////////////////////////////////////

void CollectSoftConstraintsData(std::string stmt_name, unsigned int sched_dim, int coeff_dim,
isl_influence_coeff_type coeff_type, std::string coeff_vec_i);
void BuildSoftConstraints();
void BuildInfluenceList(std::vector<single_data> singles);
void BuildInfluenceEqualList(std::map<std::string, std::vector<single_data>> linked);

void BuildInfluencedSchedule(void);
void IslInfluenceToggle(bool toggle);

__isl_give isl_schedule *AdjustSchedule(__isl_take isl_schedule *schedule, const std::vector<div_mod_data> &modulos,
const std::vector<div_mod_data> &divisions);

// Misc helpers
std::vector<std::string> split_string(std::string str, std::string delim);

///////////////////////////////////////////////////////////////////////////
// Internal data
///////////////////////////////////////////////////////////////////////////

// AKG info
PassInfo &pass_info_;
ScopInfo &scop_info_;

bool correctly_parsed_{false};

std::string operator_{""};
std::string pattern_{""};
isl_schedule *explicit_tree_{nullptr};

isl_union_set *domain_{nullptr};
isl_schedule *suggested_schedule_{nullptr};
std::string suggested_schedule_string_{""};
std::vector<std::string> suggested_schedule_vector_;

std::vector<std::tuple<std::string, std::vector<int>>> post_transformations_;

std::vector<single_data> singles_;
std::map<std::string, std::vector<single_data>> linked_;
std::vector<div_mod_data> modulos_;
std::vector<div_mod_data> divisions_;
isl_influence_list *influence_list_{nullptr};
isl_influence_equal_list *influence_equal_list_{nullptr};
std::string parse_soft_constraints_log_str_{""};
isl_schedule *influenced_schedule_{nullptr};

bool check_schedule_{true};

std::string target_{""};
std::string name_{"unnamed mind_trick"};

GpuConfig gpu_info_;

air::Map<std::string, air::NodeRef> attrs_;

std::set<std::string> disabled_passes_;

MindTrickType type_{MindTrickType::manual};

int verbosity_{0};

std::string LogPrefixText(const bool prefix = true) const;

// clang-format off
#define declare_scheduling_mind_trick_log_wrappers(func) \
void func(const std::string &message, const bool prefix = true) const; \
void func(const std::stringstream &stream, const bool prefix = true) const; \
void func(const int level, const std::string &message, const bool prefix = true) const; \
void func(const int level, const std::stringstream &stream, const bool prefix = true) const; \
void func(const akg::ir::poly::log::Verbosity level, const std::string &message, const bool prefix = true) const; \
void func(const akg::ir::poly::log::Verbosity level, const std::stringstream &stream, const bool prefix = true) const;

declare_scheduling_mind_trick_log_wrappers(Info)
declare_scheduling_mind_trick_log_wrappers(Warn)
declare_scheduling_mind_trick_log_wrappers(Error)

#undef declare_scheduling_mind_trick_log_wrappers

private :
// Non copyable
SchedulingMindTrick();
// clang-format on
};

} // namespace poly
} // namespace ir
} // namespace akg

#endif // POLY_SCHEDULING_MIND_TRICK_H_

+ 342
- 0
src/poly/schedule_pass/scheduling_mind_trick_matching.cc View File

@@ -0,0 +1,342 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "poly/schedule_pass/scheduling_mind_trick.h"

#include "poly/log_util.h"

namespace akg {
namespace ir {
namespace poly {

////////////////////////////////////////////////////////////////////////////////
// Local declarations
////////////////////////////////////////////////////////////////////////////////

static inline isl_schedule_node_type GetScheduleNodeType(const isl::schedule_node &node);
static inline std::vector<isl::schedule_node> CollectScheduleNodes(const isl::schedule &sch, const int verbosity);
static inline bool DomainMatches(const isl::schedule &pattern, const isl::schedule &candidate, const std::string &name_,
const int verbosity);
static inline bool NodesMatch(const isl::schedule_node &pattern_node, const isl::schedule_node &candidate_node,
const std::string &name_, const int verbosity);
static inline bool PatternMatches(const isl::schedule &sch, const std::string &pattern_, const std::string &name_,
const int verbosity);

////////////////////////////////////////////////////////////////////////////////
// Patter-matching related in SchedulingMindTrick public API
////////////////////////////////////////////////////////////////////////////////

bool SchedulingMindTrick::Matches(const isl::schedule &sch) const {
if (operator_ != "" && pattern_ != "") {
Info(log::Verbosity::medium, "note that this mind trick contains both a pattern and an operator name...");
}

if (PatternMatches(sch, pattern_, name_, verbosity_)) {
Info(log::Verbosity::low, text_green "schedule pattern matches!");
return true;
} else if (operator_ == scop_info_.user_config_.GetKernelName()) {
Info(log::Verbosity::low, text_green "operator name matches!");
return true;
} else if (pattern_ == "" && operator_ == "") {
Warn(log::Verbosity::veryLow, "pattern-free and operator-free mind_trick: matches! (are you sure?)");
return true;
} else {
Info(log::Verbosity::medium, "schedule does not match!");
return false;
}
}

////////////////////////////////////////////////////////////////////////////////
// Local definitions
////////////////////////////////////////////////////////////////////////////////

isl_schedule_node_type GetScheduleNodeType(const isl::schedule_node &node) {
isl_schedule_node_type type = isl_schedule_node_get_type(node.get());
switch (type) {
case isl_schedule_node_band:
return isl_schedule_node_band;
case isl_schedule_node_context:
return isl_schedule_node_context;
case isl_schedule_node_domain:
return isl_schedule_node_domain;
case isl_schedule_node_extension:
return isl_schedule_node_extension;
case isl_schedule_node_filter:
return isl_schedule_node_filter;
case isl_schedule_node_guard:
return isl_schedule_node_guard;
case isl_schedule_node_mark:
return isl_schedule_node_mark;
case isl_schedule_node_leaf:
return isl_schedule_node_leaf;
case isl_schedule_node_sequence:
return isl_schedule_node_sequence;
case isl_schedule_node_set:
return isl_schedule_node_set;
default:
assert(false && "cannot convert the node type");
return isl_schedule_node_leaf;
}
}

std::vector<isl::schedule_node> CollectScheduleNodes(const isl::schedule &sch, const int verbosity) {
std::vector<isl::schedule_node> nodes;

isl::schedule_node root = sch.get_root();
if (GetScheduleNodeType(root.child(0)) == isl_schedule_node_leaf) {
log::Info(log::Verbosity::low, "Root node has no children; returning empty vector of schedule nodes");
return nodes;
} else {
auto collect = [&nodes](isl::schedule_node node) {
nodes.push_back(node);
return true;
};
root.foreach_descendant_top_down(collect);
}
return nodes;
}

bool DomainMatches(const isl::schedule &pattern, const isl::schedule &candidate, const std::string &name_,
const int verbosity) {
log::Info(log::Verbosity::high, name_ + ": checking if candidate schedule domain matches...");

if (pattern.get_domain().is_empty()) {
log::Warn(log::Verbosity::high,
name_ + ": pattern's domain is completely unspecified therefore any domain matches");
return true;
}
isl::set_list dom_patt_list = pattern.get_domain().get_set_list();
isl::set_list dom_can_list = candidate.get_domain().get_set_list();

if (dom_patt_list.size() != dom_can_list.size()) {
log::Warn(log::Verbosity::high,
name_ + ": candidate does not have the same number of statement as pattern. Does not match");
return false;
}

for (unsigned i = 0; i < dom_patt_list.size(); ++i) {
isl::set patt_i = dom_patt_list.get_at(i);
isl::id name_i = patt_i.get_tuple_id();
isl::space space_i = patt_i.get_space();

isl::set can_j;
isl::id name_j;
isl::space space_j;

bool exists = false;

for (unsigned j = 0; j < dom_can_list.size(); ++j) {
can_j = dom_can_list.get_at(j);
name_j = can_j.get_tuple_id();
space_j = can_j.get_space();

if ((name_j.get_name() == name_i.get_name()) && (space_i.is_equal(space_j))) {
exists = true;

log::Info(log::Verbosity::high,
name_ + ": statement " + space_i.to_str() + " exists; proceeding with further checks...");

isl::space space_i = patt_i.get_space();
unsigned dim_num_i = space_i.dim(isl_dim_set);

for (unsigned k = 0; k < dim_num_i; ++k) {
// There may be cases where underspecified statements appears in the candidate tree
// ex: domain : "{S_0[]}" (variable initizalization or reductions)
// We must distinguish such cases from that of underspecified patterns (for which we bypass
// isl_exceptions in order to still capture their specification.

// The previous tests on spaces should ensure that, at this point of the code,
// only underspecified patterns can be encountered.
// But to be very sure about that, the following check eliminates any possibility of attempting
// pattern matching on a candidate that contains any S[]
isl_bool dim_id_check = isl_set_has_dim_id(can_j.get(), static_cast<enum isl_dim_type>(isl_dim_set), k);

if (!dim_id_check) {
log::Warn(log::Verbosity::high,
name_ + ": statement " + name_i.get_name() +
" has no dim id. Cannot perform pattern matching on an underspecified candidate");
return false;
}

isl::id dim_name_j =
isl::manage(isl_set_get_dim_id(can_j.get(), static_cast<enum isl_dim_type>(isl_dim_set), k));

isl::pw_aff min_j = can_j.dim_min(k);
isl::pw_aff max_j = can_j.dim_max(k);

std::stringstream log_prefix_stream;
log_prefix_stream << name_ << ": candidate bounds for dimension " << k << " of " << name_i.get_name();
const std::string &log_prefix = log_prefix_stream.str();
try {
isl::id dim_name_i =
isl::manage(isl_set_get_dim_id(patt_i.get(), static_cast<enum isl_dim_type>(isl_dim_set), k));

isl::pw_aff min_i = patt_i.dim_min(k);
isl::pw_aff max_i = patt_i.dim_max(k);

if (dim_name_i.to_str() != "inf") {
if (!(min_i.is_equal(min_j) && max_i.is_equal(max_j))) {
log::Warn(log::Verbosity::high,
log_prefix + " are not equal to pattern bounds; domain does not match\n");
return false;
}
}
log::Info(log::Verbosity::high, log_prefix + " match");
} catch (isl::exception &) {
log::Info(log::Verbosity::high, log_prefix + " are undefined; any bound matches");
}
}
}
}
if (!exists) {
log::Warn(log::Verbosity::high,
name_ + ": statement " + space_i.to_str() + " does not exist; domain does not match");
return false;
}
}
log::Info(log::Verbosity::high, name_ + ": candidate domain matches\n");
return true;
}

bool NodesMatch(const isl::schedule_node &pattern_node, const isl::schedule_node &candidate_node,
const std::string &name_, const int verbosity) {
isl_schedule_node_type pnode_type = GetScheduleNodeType(pattern_node);
isl_schedule_node_type cnode_type = GetScheduleNodeType(candidate_node);

// At the first mismatch, return false
if (pnode_type != cnode_type) {
return false;
}

// If encountering schedule node bands or filters,
// ensure that their specification is equal
if (pnode_type == isl_schedule_node_band) {
isl::multi_union_pw_aff pnode_bsched = isl::manage(isl_schedule_node_band_get_partial_schedule(pattern_node.get()));
isl::multi_union_pw_aff cnode_bsched =
isl::manage(isl_schedule_node_band_get_partial_schedule(candidate_node.get()));

if (!cnode_bsched.plain_is_equal(pnode_bsched)) {
log::Warn(log::Verbosity::high, name_ + "\' : schedule bands are not equal. Schedule tree does not match");
return false;
}
}

if (pnode_type == isl_schedule_node_filter) {
isl::union_set pfilter_info = isl::manage(isl_schedule_node_filter_get_filter(pattern_node.get()));
isl::union_set cfilter_info = isl::manage(isl_schedule_node_filter_get_filter(candidate_node.get()));

if (!pfilter_info.is_empty()) {
isl::set_list plist = pfilter_info.get_set_list();
isl::set_list clist = cfilter_info.get_set_list();

for (unsigned i = 0; i < plist.size(); ++i) {
try {
isl::id pid = plist.get_at(i).get_tuple_id();

if (!pid.is_null()) {
bool exists = false;

for (unsigned j = 0; j < clist.size(); ++j) {
isl::id cid = clist.get_at(j).get_tuple_id();

if (!cid.is_null() && (pid.get_name() == cid.get_name())) {
exists = true;
}
}
if (!exists) {
log::Warn(
log::Verbosity::high,
"\'" + name_ + "\' candidate filter tuple id does not match with pattern (" + pid.get_name() + ")");
return false;
}
}
} catch (isl::exception_invalid &) {
log::Info(log::Verbosity::high, "\'" + name_ + "\' : no tuple id specified, therefore any tuple id matches");
}
}
}
}

// Return true if none of the previous tests failed
return true;
}

bool PatternMatches(const isl::schedule &sch, const std::string &pattern_, const std::string &name_,
const int verbosity) {
if (pattern_ != "") {
isl::schedule pattern = isl::schedule(sch.ctx(), pattern_);

if (!DomainMatches(pattern, sch, name_, verbosity)) {
return false;
}

log::Info(log::Verbosity::high, " \'" + name_ + "\' : collecting pattern nodes and candidate nodes...");

std::vector<isl::schedule_node> pattern_nodes = CollectScheduleNodes(pattern, verbosity);
std::vector<isl::schedule_node> candidate_nodes = CollectScheduleNodes(sch, verbosity);

// Iterate according to the number of nodes in the pattern to ensure that
// the pattern matches at least part of the candidate schedule tree
if (pattern_nodes.size() == 0) {
log::Info(
log::Verbosity::high,
" \'" + name_ + "\' : no schedule tree pattern is specified, therefore any candidate schedule tree matches");
return true;
}

// The pattern may not be found at the root of the candidate tree.
// We therefore first search for any nodes index that matches the first node
// of the pattern to capture starting points for checking if there is a match
std::vector<unsigned> entry_indexes;
for (unsigned i = 0; i < candidate_nodes.size(); ++i) {
if (NodesMatch(pattern_nodes[0], candidate_nodes[i], name_, verbosity)) {
entry_indexes.push_back(i);
}
}
if (entry_indexes.size() == 0) {
return false;
}

bool match = false;
for (unsigned k = 0; k < entry_indexes.size(); ++k) {
unsigned i;
for (i = 1; i < pattern_nodes.size() - 1; ++i) {
unsigned j = entry_indexes[k];
bool check = NodesMatch(pattern_nodes[i], candidate_nodes[j + i], name_, verbosity);
if (!check) break;
}

// With the way we expect patterns to be specified, for now, there
// can be only one possible match if any. so as soon as we match
// something, we quit the loop over k.
if (i == pattern_nodes.size() - 1) {
match = true;
}
}
if (match == false) {
log::Warn(log::Verbosity::high, " \'" + name_ + "\' : the candidate schedule tree does not match");
return false;
}

// If the code made it this far, then schedules match
return true;
} else {
return false;
}
}

} // namespace poly
} // namespace ir
} // namespace akg

+ 6
- 0
src/poly/schedule_pass_gpu/mapping_outer_band.cc View File

@@ -202,6 +202,12 @@ isl::schedule MappingOuterBand::DoThreadMapping(const isl::schedule &sch) {
return node;
}

if (node.has_parent() && node.parent().isa<isl::schedule_node_mark>()) {
const std::string &marker = node.parent().as<isl::schedule_node_mark>().get_id().get_name();
if (marker == "mind_trick_swizzle_marker")
return node;
}

size_t num_mapped_desc = NumMappedDescendant(thread_record, node);

if (CanBeMappedToThread(node, thread_record)) {


+ 53
- 5
src/poly/schedule_pass_gpu/shared_memory_manager.cc View File

@@ -32,6 +32,13 @@ isl::schedule SharedMemoryManager::Run(isl::schedule sch) {
}
schedule_ = sch;
auto root = sch.get_root();

// Update the variable/tensor to share
if (!scop_info_.user_config_.GetSharedTensors().empty()) {
configed_tensors_ = Split(scop_info_.user_config_.GetSharedTensors(), " ");
}

// Compute the depth where the shared memory have to be generated
UpdateDepth(root);
if (scop_info_.user_config_.GetSharedDepth() >= 0) {
depth_ = scop_info_.user_config_.GetSharedDepth();
@@ -39,6 +46,8 @@ isl::schedule SharedMemoryManager::Run(isl::schedule sch) {
}
CHECK_GE(depth_, 0) << "shared depth should be greater than or equal with zero!";
bank_conflict_ = scop_info_.user_config_.GetEnableBankConflict();
shared_inversed_thread_map_ = scop_info_.user_config_.GetSharedInversedThreadMap();
shared_vector_align_ = scop_info_.user_config_.GetSharedVectorAlign();

// collect all bands at the given depth in the schedule tree
size_t remain_memory = share_memory_size_;
@@ -171,6 +180,31 @@ isl::schedule_node SharedMemoryManager::MapCopiesToThreads(isl::schedule_node &r
has_split = true;
}

if (shared_inversed_thread_map_) {
// Pretille - To make a vectorize loop more apparent with only the information of the mapping
const auto &domain = band_node.as<isl::schedule_node_band>().get_partial_schedule().domain();
const isl::id &current_computing_id_shared = domain.unwrap().range().set_list().get_at(0).get_tuple_id();

std::vector<size_t> tensor_size;
for (BufferDefInfo &buffer_info : scop_info_.analysis_result_.buffer_def_infos_) {
if (current_computing_id_shared == buffer_info.dst_tensor_id) {
tensor_size = buffer_info.sizes;
}
}
// Reverse because thread is innermost map
std::reverse(tensor_size.begin(),tensor_size.end());

auto ctx = band_node.ctx();
const auto &space = band_node.as<isl::schedule_node_band>().get_space();
const auto n_member = band_node.as<isl::schedule_node_band>().n_member();
isl::multi_val tile_size = isl::multi_val::zero(space);
for (size_t i = 0; i < n_member; ++i) {
const size_t size = tensor_size[i]/thread_cfg->GetAt(i).second;
tile_size = tile_size.set_val(n_member-1 - i, isl::val(ctx, size!=0?size:1));
}
band_node = TileBand(band_node, tile_size);
}

auto mapping_cfg = thread_cfg;
if (scop_info_.user_config_.GetVectorLoadType() || scop_info_.user_config_.GetEnableTensorCoreUsePoly()) {
scop_info_.user_config_.SetEnableOneDimThread(true);
@@ -495,9 +529,7 @@ void SharedMemoryManager::GatherBufferFootprintDefInfo(const isl::schedule_node
sizes.back() += 16;
}

if (bank_conflict_) {
sizes = OptimizeBankConflict(sizes);
}
sizes = OptimizeSharedDimension(sizes);

isl::id tensor_id = tensor_info.tensor_id;
isl::id cluster_id = tensor_info.dst_tensor_id;
@@ -552,7 +584,7 @@ isl::schedule_node SharedMemoryManager::HoistClusters(const isl::schedule_node &
LOG(FATAL) << "Can not manage a scalar tensor";
}

box_sizes = OptimizeBankConflict(box_sizes);
box_sizes = OptimizeSharedDimension(box_sizes);

auto approximation_size = std::accumulate(box_sizes.begin(), box_sizes.end(), 1, std::multiplies<size_t>());
size_t byte = Bytes(id);
@@ -596,7 +628,7 @@ isl::schedule_node SharedMemoryManager::HoistToBlockThreadMemory(isl::schedule_n
isl::id dst_tensor_id = GpuDstId(type, tensor_id);
auto sizes = cluster.GetFixedBoxSizes();
if (force_last_extension_odd) {
sizes = OptimizeBankConflict(sizes);
sizes = OptimizeSharedDimension(sizes);
}

auto res_node = PlaceOuterDataCopyBelow(scop_info_, tree, cluster, tensor_id, dst_tensor_id, out_schedule,
@@ -734,6 +766,13 @@ size_t SharedMemoryManager::Bytes(const isl::id tensor_id) {
return static_cast<size_t>(type.bytes());
}

std::vector<size_t> SharedMemoryManager::OptimizeSharedDimension(std::vector<size_t> sizes) {
std::vector<size_t> res = sizes;
res = OptimizeBankConflict(res);
res = OptimizeVectorAlign(res);
return res;
}

std::vector<size_t> SharedMemoryManager::OptimizeBankConflict(std::vector<size_t> sizes) {
std::vector<size_t> res = sizes;
if (res.back() % 2 == 0) {
@@ -748,6 +787,15 @@ std::vector<size_t> SharedMemoryManager::OptimizeBankConflict(std::vector<size_t
return res;
}

std::vector<size_t> SharedMemoryManager::OptimizeVectorAlign(std::vector<size_t> sizes) {
std::vector<size_t> res = sizes;
if (shared_vector_align_ != 0) {
int padsize = res.back() % shared_vector_align_;
res.back() += padsize?(shared_vector_align_-padsize):0;
}
return res;
}

} // namespace poly
} // namespace ir
} // namespace akg

+ 5
- 1
src/poly/schedule_pass_gpu/shared_memory_manager.h View File

@@ -72,7 +72,9 @@ class SharedMemoryManager : public SchedulePass {

void UpdateDepth(const isl::schedule_node &root);

std::vector<size_t> OptimizeSharedDimension(std::vector<size_t> sizes);
std::vector<size_t> OptimizeBankConflict(std::vector<size_t> sizes);
std::vector<size_t> OptimizeVectorAlign(std::vector<size_t> sizes);
bool UnderThreadMarker(size_t depth);

std::string InAtomicTensors(isl::schedule_node &node);
@@ -102,10 +104,12 @@ class SharedMemoryManager : public SchedulePass {
std::string tensor_c_;
std::string tensor_a_;
std::string tensor_b_;
bool shared_inversed_thread_map_{false};
int shared_vector_align_{0};
};

} // namespace poly
} // namespace ir
} // namespace akg

#endif
#endif

+ 14
- 0
src/poly/schedule_pass_mgr.cc View File

@@ -44,7 +44,17 @@ isl::schedule SchedulePassMgr::Run(const isl::schedule &sch, const std::vector<s
auto replace_sch = sch;
need_restart_ = false;

std::set<std::string> disabled;
for (auto &pass : passes) {
const std::string &name = pass->GetPassName();
const bool disable = disabled.find(name) != disabled.end();
if (disable) {
LOG(INFO) << "Disabling poly pass " << name;
continue;
} else {
LOG(INFO) << "Running poly pass " << name;
}

if (LoadScheduleTreeFromFile(scop_info_.AddDumpDir(pass->GetPassName() + ".txt"), replace_sch)) {
if (!replace_sch.plain_is_equal(final_sch)) {
final_sch = replace_sch;
@@ -67,6 +77,10 @@ isl::schedule SchedulePassMgr::Run(const isl::schedule &sch, const std::vector<s
need_restart_ = true;
break;
}
if (!pass->disabled_passes_.empty()) {
disabled.insert(pass->disabled_passes_.begin(), pass->disabled_passes_.end());
LOG(INFO) << name << " requests to disable some subsequent passes";
}
}
return final_sch;
}


+ 27
- 0
src/poly/scop_info.h View File

@@ -179,6 +179,11 @@ class UserConfig {
ParseBoolAttr(attrs, "pragma_tilesize_is_var", &tile_size_is_var_);
ParseBoolAttr(attrs, "pragma_outerband_need_split", &outer_band_need_split_);

// Mind-trick pass
ParseIntAttr(attrs, "constrain_schedule_verbosity", &constrain_schedule_verbosity_);
ParseBoolAttr(attrs, "enable_mind_trick", &enable_mind_trick_);
ParseStringAttr(attrs, "mind_trick", &mind_trick_);

ParseStringAttr(attrs, "dim", &b_dim_);
ParseStringAttr(attrs, "bind_elem_per_thread", &elem_per_thread_);
ParseMappingCfgAttr(attrs, "bind_block", &block_cfg_);
@@ -241,6 +246,8 @@ class UserConfig {
ParseBoolAttr(attrs, "use_shared_memory", &use_shared_memory_);
ParseBoolAttr(attrs, "enable_bank_conflict_opt", &enable_bank_conflict_);
ParseBoolAttr(attrs, "enable_one_dim_thread", &enable_one_dim_thread_);
ParseBoolAttr(attrs, "shared_inversed_thread_map", &shared_inversed_thread_map_);
ParseIntAttr(attrs, "shared_vector_align", &shared_vector_align_);
ParseIntAttr(attrs, "register_memory_depth", &register_depth_);
ParseIntAttr(attrs, "shared_memory_depth", &shared_depth_);
ParseStringAttr(attrs, "shared_memory_tensors", &shared_tensors_);
@@ -305,6 +312,12 @@ class UserConfig {
void SetUnrollShared(const bool unroll_shared) { this->unroll_shared_ = unroll_shared; }
void SetDisableLoopFusion(const bool disable_loop_fusion) { this->disable_loop_fusion_ = disable_loop_fusion; }

// getter/setter for schedule_pass config
int GetConstrainScheduleVerbosity() const { return constrain_schedule_verbosity_; }
bool GetEnableMindTrick() const { return enable_mind_trick_; }
std::string GetMindTrick() const { return mind_trick_; }
void SetConstrainedSchedulingOutput(const std::string &output) { constrained_scheduling_output_ = output; }

// getter for schedule tree transform config
bool GetRemoveSelfDependence() const { return remove_self_dependence_; }
bool GetForceRemoveSelfDependence() const { return force_remove_self_dependence_; }
@@ -418,6 +431,12 @@ class UserConfig {
void SetEnableBankConflict(bool enable_bank_conflict) { enable_bank_conflict_ = enable_bank_conflict; }
bool GetEnableBankConflict() { return enable_bank_conflict_; }
int GetVectorLoadType() { return vector_load_type_; }
void SetSharedInversedThreadMap(bool shared_inversed_thread_map) {
shared_inversed_thread_map_ = shared_inversed_thread_map;
}
bool GetSharedInversedThreadMap() { return shared_inversed_thread_map_; }
void SetSharedVectorAlign(int shared_vector_align) { shared_vector_align_ = shared_vector_align; }
int GetSharedVectorAlign() { return shared_vector_align_; }

private:
// tools for parsing user config
@@ -576,6 +595,14 @@ class UserConfig {
int max_unroll_loop_{1};
bool unroll_shared_{false};
bool enable_bank_conflict_{false};
bool shared_inversed_thread_map_{false};
int shared_vector_align_{0};

// schedule_pass/mind_trick config
int constrain_schedule_verbosity_{-1};
bool enable_mind_trick_{true};
std::string mind_trick_{""};
std::string constrained_scheduling_output_{""};

// schedule tree transform config
bool remove_self_dependence_{true};


+ 3
- 0
tests/operators/gpu/.gitignore View File

@@ -0,0 +1,3 @@
opstest_*.log
*.cu
cuda_meta_*

+ 58
- 52
tests/operators/gpu/test_all.py View File

@@ -64,7 +64,7 @@ from tests.operators.gpu.test_fused_bn_update_grad import test_fused_bn_update_g
from tests.operators.gpu.test_fused_mul_div_rsqrt_mul_isfinite_red import test_fused_mul_div_rsqrt_mul_isfinite_red


def add(poly_sch, fuzz_shape=None):
def add(poly_sch, fuzz_shape=None, mind_trick_str=''):
if fuzz_shape:
test_ms_add(fuzz_shape, fuzz_shape, 'float32', poly_sch=poly_sch)
return
@@ -74,12 +74,12 @@ def add(poly_sch, fuzz_shape=None):
'float32', poly_sch=poly_sch)


def addn(poly_sch, fuzz_shape=None):
def addn(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_ms_addn((1, 1024, 1024), "float32", 2, poly_sch=poly_sch)
test_ms_addn((1, 1024, 1024), "float16", 2, poly_sch=poly_sch)


def bmm(poly_sch, fuzz_shape=None):
def bmm(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_ms_bmm((768, 768), (768, 768), 'float16', 'float16', layout1='NHDT', layout2='NHDT', layout_out='NHDT',
shape_bias=(1, ), add_bias=False, tensor_core=True, poly_sch=poly_sch,
dim="0 0 128 128 0 1 128 128 0 2 64 4", bind_block="6 6", bind_thread="32 4")
@@ -158,19 +158,19 @@ def bmm(poly_sch, fuzz_shape=None):
dim="0 0 128 128 0 1 128 128 0 2 64 4", bind_block="24 32", bind_thread="32 4")
"""

def cast(poly_sch, fuzz_shape=None):
def cast(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_ms_cast((32, 32, 14, 14, 16), "float16", "float32", poly_sch=poly_sch)
test_ms_cast((32, 32, 14, 14, 16), "float32", "float16", poly_sch=poly_sch)


def exp(poly_sch, fuzz_shape=None):
def exp(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_ms_exp((1024, 4096), 'float32', poly_sch=poly_sch)
test_ms_exp((1024, 4096), 'float16', poly_sch=poly_sch)
test_ms_exp((1024, 4095), 'float16', poly_sch=poly_sch)
test_ms_exp((1024, 799), 'float16', poly_sch=poly_sch)


def maximum(poly_sch, fuzz_shape=None):
def maximum(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_ms_maximum((32, 1024, 1024), (32, 1024, 1024),
'float32', poly_sch=poly_sch)
test_ms_maximum((32, 1024, 1024), (1, 1024, 1024),
@@ -179,7 +179,7 @@ def maximum(poly_sch, fuzz_shape=None):
'float16', poly_sch=poly_sch)


def minimum(poly_sch, fuzz_shape=None):
def minimum(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_ms_minimum((32, 1024, 1024), (32, 1024, 1024),
'float32', poly_sch=poly_sch)
test_ms_minimum((32, 1024, 1024), (1, 1024, 1024),
@@ -188,33 +188,33 @@ def minimum(poly_sch, fuzz_shape=None):
'float16', poly_sch=poly_sch)


def mul(poly_sch, fuzz_shape=None):
def mul(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_ms_mul((1024, 4096), 'float32', poly_sch=poly_sch)


def divide(poly_sch, fuzz_shape=None):
def divide(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_ms_divide((1024, 1024), 'float32', poly_sch=poly_sch)
test_ms_divide((1024, 1024), 'float16', poly_sch=poly_sch)


def reshape(poly_sch, fuzz_shape=None):
def reshape(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_ms_reshape("float32", (64, 128, 1024),
(8192, 1024), poly_sch=poly_sch)
test_ms_reshape("float16", (64, 128, 1024),
(8192, 1024), poly_sch=poly_sch)


def rsqrt(poly_sch, fuzz_shape=None):
def rsqrt(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_ms_rsqrt((32, 1024, 1024), 'float32', poly_sch=poly_sch)
test_ms_rsqrt((32, 1024, 1024), 'float16', poly_sch=poly_sch)


def sqrt(poly_sch, fuzz_shape=None):
def sqrt(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_ms_sqrt((1024, 1024), "float32", poly_sch=poly_sch)
test_ms_sqrt((1024, 1024), "float16", poly_sch=poly_sch)


def sub(poly_sch, fuzz_shape=None):
def sub(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_ms_sub((32, 1024, 1024), (32, 1024, 1024),
'float32', poly_sch=poly_sch)
test_ms_sub((32, 1024, 1024), (32, 1024, 1024),
@@ -224,36 +224,36 @@ def sub(poly_sch, fuzz_shape=None):
test_ms_sub((4, 4, 4), (1, 4, 4), 'float32', poly_sch=poly_sch)


def tile(poly_sch, fuzz_shape=None):
def tile(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_ms_tile((1024, 4096), (3,), 'float32', poly_sch=poly_sch)
test_ms_tile((1024, 4096), (3,), 'float16', poly_sch=poly_sch)


def one_hot(poly_sch, fuzz_shape=None):
def one_hot(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_ms_one_hot((1024,), 16, "int32", 1, 0, 0, poly_sch=poly_sch)
test_ms_one_hot((1024,), 16, "float32", 1, 0, 0, poly_sch=poly_sch)
test_ms_one_hot((32,), 16, "int32", 1, 0, 0, poly_sch=poly_sch)
test_ms_one_hot((32,), 16, "float32", 1, 0, 0, poly_sch=poly_sch)


def expand_dims(poly_sch, fuzz_shape=None):
def expand_dims(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_expand_dims((32, 1024, 1024), 1, 'float32', poly_sch=poly_sch)
test_expand_dims((32, 1024, 1024), 2, 'float16', poly_sch=poly_sch)


def trans_data(poly_sch, fuzz_shape=None):
def trans_data(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_ms_trans_data((8, 24, 38, 38), (0, 2, 1, 3),
'float32', poly_sch=poly_sch)
test_ms_trans_data((8, 24, 38, 38), (0, 2, 1, 3),
'float16', poly_sch=poly_sch)


def log(poly_sch, fuzz_shape=None):
def log(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_ms_log((9, 1024, 1024), 'float16', poly_sch=poly_sch)
test_ms_log((9, 1024, 1024), 'float32', poly_sch=poly_sch)


def pow(poly_sch, fuzz_shape=None):
def pow(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_ms_pow((9, 1024, 1024), (9, 1024, 1024), 'float32', poly_sch=poly_sch)
test_ms_pow((9, 1024, 1024), (9, 1024, 1), 'float32', poly_sch=poly_sch)
test_ms_pow((9, 1024, 1024), (9, 1, 1), 'float32', poly_sch=poly_sch)
@@ -264,7 +264,7 @@ def pow(poly_sch, fuzz_shape=None):
test_ms_pow((9, 1024, 1024), (1, 1, 1), 'float16', poly_sch=poly_sch)


def abs(poly_sch, fuzz_shape=None):
def abs(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_ms_abs((1024, 1024), "float32", poly_sch=poly_sch)
test_ms_abs((1024, 1024), "float16", poly_sch=poly_sch)
test_ms_abs((1, ), "float32", poly_sch=poly_sch)
@@ -273,7 +273,7 @@ def abs(poly_sch, fuzz_shape=None):
test_ms_abs((1, 1), "float16", poly_sch=poly_sch)


def neg(poly_sch, fuzz_shape=None):
def neg(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_ms_neg((1024, 1024), "float32", poly_sch=poly_sch)
test_ms_neg((1024, 1024), "float16", poly_sch=poly_sch)
test_ms_neg((1, ), "float32", poly_sch=poly_sch)
@@ -282,7 +282,7 @@ def neg(poly_sch, fuzz_shape=None):
test_ms_neg((1, 1), "float16", poly_sch=poly_sch)


def round(poly_sch, fuzz_shape=None):
def round(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_ms_round((1024, 1024), "float32", poly_sch=poly_sch)
test_ms_round((1024, 1024), "float16", poly_sch=poly_sch)
test_ms_round((1, ), "float32", poly_sch=poly_sch)
@@ -291,7 +291,7 @@ def round(poly_sch, fuzz_shape=None):
test_ms_round((1, 1), "float16", poly_sch=poly_sch)


def reduce_sum(poly_sch, fuzz_shape=None):
def reduce_sum(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_ms_reduce_sum((256, 256), 'float32', axis=(1,),
keepdims=True, poly_sch=poly_sch)
test_ms_reduce_sum((9, 1024, 1024), 'float32', axis=None,
@@ -304,31 +304,31 @@ def reduce_sum(poly_sch, fuzz_shape=None):
keepdims=True, poly_sch=poly_sch)


def select(poly_sch, fuzz_shape=None):
def select(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_ms_select((2, ), (2, 2, 2), "int8", "float16", poly_sch=poly_sch)


def equal(poly_sch, fuzz_shape=None):
def equal(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_ms_equal(((1, 1024), (1, 1024)), 'float16', poly_sch=poly_sch)
test_ms_equal(((1, 1024), (1, 1024)), 'float32', poly_sch=poly_sch)


def less_equal(poly_sch, fuzz_shape=None):
def less_equal(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_ms_less_equal((1, 1024), (1, 1024), 'float16', poly_sch=poly_sch)
test_ms_less_equal((1, 1024), (1, 1024), 'float32', poly_sch=poly_sch)


def greater_equal(poly_sch, fuzz_shape=None):
def greater_equal(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_ms_greater_equal((1, 1024), (1, 1024), 'float16', poly_sch=poly_sch)
test_ms_greater_equal((1, 1024), (1, 1024), 'float32', poly_sch=poly_sch)


def reciprocal(poly_sch, fuzz_shape=None):
def reciprocal(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_ms_reciprocal((1, 1024), 'float16', poly_sch=poly_sch)
test_ms_reciprocal((1, 1024), 'float32', poly_sch=poly_sch)


def reduce_min(poly_sch, fuzz_shape=None):
def reduce_min(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_ms_reduce_min((9, 1024, 1024), 'float32', axis=None,
keepdims=False, poly_sch=poly_sch)
test_ms_reduce_min((9, 1024, 1024), 'float16', axis=None,
@@ -339,7 +339,7 @@ def reduce_min(poly_sch, fuzz_shape=None):
keepdims=False, poly_sch=poly_sch)


def reduce_max(poly_sch, fuzz_shape=None):
def reduce_max(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_ms_reduce_max((9, 1024, 1024), 'float32', axis=None,
keepdims=False, poly_sch=poly_sch)
test_ms_reduce_max((9, 1024, 1024), 'float16', axis=None,
@@ -350,77 +350,77 @@ def reduce_max(poly_sch, fuzz_shape=None):
keepdims=False, poly_sch=poly_sch)


def fused_pad(poly_sch, fuzz_shape=None):
def fused_pad(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_fused_pad((7, 7, 3, 64), (0, 0, 0, 0), (0, 0, 1, 0),
layout='NHWC', pad_value=0.0, poly_sch=poly_sch)


def fused_bn_reduce(poly_sch, fuzz_shape=None):
def fused_bn_reduce(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_fused_bn_reduce((256, 7, 7, 2048), layout='NHWC', poly_sch=poly_sch)


def fused_bn_update(poly_sch, fuzz_shape=None):
test_fused_bn_update((2048,), poly_sch=poly_sch)
def fused_bn_update(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_fused_bn_update((2048,), poly_sch=poly_sch, mind_trick=mind_trick_str)


def fused_bn_follow_relu(poly_sch, fuzz_shape=None):
def fused_bn_follow_relu(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_fused_bn_follow_relu(
(256, 7, 7, 2048), layout='NHWC', poly_sch=poly_sch)


def fused_bn_follow_relu_avgpool(poly_sch, fuzz_shape=None):
def fused_bn_follow_relu_avgpool(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_fused_bn_follow_relu_avgpool(
(256, 7, 7, 2048), layout='NHWC', poly_sch=poly_sch)


def fused_bn_double_follow_relu(poly_sch, fuzz_shape=None):
def fused_bn_double_follow_relu(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_fused_bn_double_follow_relu(
(256, 7, 7, 2048), layout='NHWC', poly_sch=poly_sch)


def fused_bn_reduce_grad(poly_sch, fuzz_shape=None):
def fused_bn_reduce_grad(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_fused_bn_reduce_grad(
(256, 56, 56, 256), layout='NHWC', poly_sch=poly_sch)


def fused_relu_grad_bn_reduce_grad(poly_sch, fuzz_shape=None):
def fused_relu_grad_bn_reduce_grad(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_fused_relu_grad_bn_reduce_grad(
(64, ), (256, 112, 112, 64), layout='NHWC', poly_sch=poly_sch)


def fused_relu_grad_bn_double_reduce_grad(poly_sch, fuzz_shape=None):
def fused_relu_grad_bn_double_reduce_grad(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_fused_relu_grad_bn_double_reduce_grad(
(256,), (256, 56, 56, 256), layout="NHWC", poly_sch=poly_sch)


def fused_l2loss_grad(poly_sch, fuzz_shape=None):
test_fused_l2loss_grad((1, 1, 256, 1024), layout='NHWC', poly_sch=poly_sch)
def fused_l2loss_grad(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_fused_l2loss_grad((1, 1, 256, 1024), layout='NHWC', poly_sch=poly_sch, mind_trick=mind_trick_str)


def fused_is_finite(poly_sch, fuzz_shape=None):
def fused_is_finite(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_fused_is_finite((1, 1, 256, 1024), layout='NHWC', poly_sch=poly_sch)


def fused_relu_grad_bn_update_grad(poly_sch, fuzz_shape=None):
def fused_relu_grad_bn_update_grad(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_fused_relu_grad_bn_update_grad(
(256, 112, 112, 64), (64,), layout="NHWC", poly_sch=poly_sch)


def fused_relu_grad_bn_double_update_grad(poly_sch, fuzz_shape=None):
def fused_relu_grad_bn_double_update_grad(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_fused_relu_grad_bn_double_update_grad(
(256, 56, 56, 256), (256, ), layout='NHWC', poly_sch=poly_sch)


def fused_relu_grad(poly_sch, fuzz_shape=None):
test_fused_relu_grad((256, 56, 56, 256), poly_sch=poly_sch)
def fused_relu_grad(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_fused_relu_grad((256, 56, 56, 256), poly_sch=poly_sch, mind_trick=mind_trick_str)


def fused_bn_update_grad(poly_sch, fuzz_shape=None):
def fused_bn_update_grad(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_fused_bn_update_grad(
(256, 56, 56, 256), (256,), layout="NHWC", poly_sch=poly_sch)


def fused_mul_div_rsqrt_mul_isfinite_red(poly_sch, fuzz_shape=None):
def fused_mul_div_rsqrt_mul_isfinite_red(poly_sch, fuzz_shape=None, mind_trick_str=''):
test_fused_mul_div_rsqrt_mul_isfinite_red((64,), poly_sch=poly_sch)


@@ -490,11 +490,17 @@ if __name__ == '__main__':
sys.exit()

options, args = getopt.getopt(
sys.argv[1:], "f:", ["fuzz="])
sys.argv[1:], "f:", ["fuzz=", "mind-trick-string=", "mind-trick-file="])
mind_trick_str = ''
fuzz_dim = 0
for name, value in options:
if name in ("-f", "--fuzz"):
fuzz_dim = int(value)
if name == "--mind-trick-string":
mind_trick_str = value
if name == "--mind-trick-file":
with open(value, 'r') as f:
mind_trick_str = f.read()

fail_op_list = dict()
run_op_list = list()
@@ -520,7 +526,7 @@ if __name__ == '__main__':
print("Fuzz shape: {}".format(fuzz_shape))
try:
print("Time of auto schedule:")
op(poly_sch=True, fuzz_shape=fuzz_shape)
op(poly_sch=True, fuzz_shape=fuzz_shape, mind_trick_str=mind_trick_str)
except:
if op.__name__ in fail_op_list:
fail_op_list[op.__name__].extend(
@@ -535,4 +541,4 @@ if __name__ == '__main__':
for op, error_info in fail_op_list.items():
print("Run op %s error" % op)
for e in error_info:
print(e)
print(e)

+ 6
- 3
tests/operators/gpu/test_fused_bn_update.py View File

@@ -43,14 +43,17 @@ def compute_expect(input, c1, c2, c3, c4):

return [out1, out2, out3]

def test_fused_bn_update(shape, dtype="float32", c1=(1 / (256 * 7 * 7)), c2=1.001e-05, c3=1.00007975, c4=0.100000024, poly_sch=False):
def test_fused_bn_update(shape, dtype="float32", c1=(1 / (256 * 7 * 7)), c2=1.001e-05, c3=1.00007975, c4=0.100000024, poly_sch=False, mind_trick=''):
input = gen_data(shape, dtype)
expect = compute_expect(input, c1, c2, c3, c4)
attrs = [dtype, c1, c2, c3, c4]
shapes = [input[0].shape] * 4
dtypes = [dtype] * 4
if poly_sch:
mod = utils.op_build_test(fused_bn_update, shapes, dtypes, kernel_name="fused_bn_update", op_attrs=attrs, attrs={"target": "cuda"})
if mind_trick:
mod = utils.op_build_test(fused_bn_update, shapes, dtypes, kernel_name="fused_bn_update", op_attrs=attrs, attrs={"target": "cuda", "mind_trick": mind_trick})
else:
mod = utils.op_build_test(fused_bn_update, shapes, dtypes, kernel_name="fused_bn_update", op_attrs=attrs, attrs={"target": "cuda"})

outputs = [np.full(shape, np.nan, dtype)] * 3
attrs_list = input + outputs
@@ -64,4 +67,4 @@ def test_fused_bn_update(shape, dtype="float32", c1=(1 / (256 * 7 * 7)), c2=1.00

data = to_tvm_nd_array(input)
expect = to_tvm_nd_array(expect)
gpu_profiling(mod, *data, *expect, 400)
gpu_profiling(mod, *data, *expect, 400)

+ 6
- 3
tests/operators/gpu/test_fused_l2loss_grad.py View File

@@ -36,7 +36,7 @@ def compute_py(data_f16, data_f32, layout, fill_data):
output = np.full(np.shape(expect), np.nan, 'float32')
return expect, output
def test_fused_l2loss_grad(shape, layout, fill_data=4e-05, poly_sch=False):
def test_fused_l2loss_grad(shape, layout, fill_data=4e-05, poly_sch=False, mind_trick=''):
data_1 = gen_data(shape, 'float16')
data_2 = gen_data(shape, 'float32')
@@ -45,7 +45,10 @@ def test_fused_l2loss_grad(shape, layout, fill_data=4e-05, poly_sch=False):
dtype_list = ['float16', 'float32']
op_attrs = [layout, fill_data]
if poly_sch:
mod = utils.op_build_test(fused_l2loss_grad, input_list, dtype_list, kernel_name="fused_l2loss_grad", op_attrs=op_attrs, attrs={"target": "cuda"})
if mind_trick:
mod = utils.op_build_test(fused_l2loss_grad, input_list, dtype_list, kernel_name="fused_l2loss_grad", op_attrs=op_attrs, attrs={"target": "cuda", "mind_trick": mind_trick})
else:
mod = utils.op_build_test(fused_l2loss_grad, input_list, dtype_list, kernel_name="fused_l2loss_grad", op_attrs=op_attrs, attrs={"target": "cuda"})
args = [data_1, data_2, output]
output = utils.mod_launch(mod, args, expect = expect)
@@ -58,4 +61,4 @@ def test_fused_l2loss_grad(shape, layout, fill_data=4e-05, poly_sch=False):
data = to_tvm_nd_array([data_1, data_2])
expect = to_tvm_nd_array(expect)
gpu_profiling(mod, *data, expect, 400)
gpu_profiling(mod, *data, expect, 400)

+ 5
- 2
tests/operators/gpu/test_fused_relu_grad.py View File

@@ -38,7 +38,7 @@ def compute_expect(input, c1):

return np.where(cmp_zero, data_add, data_zero)

def test_fused_relu_grad(shape, c1=0, poly_sch=False):
def test_fused_relu_grad(shape, c1=0, poly_sch=False, mind_trick=''):
dtype='float16'
input = gen_data(shape, dtype)
expect = compute_expect(input, c1)
@@ -46,7 +46,10 @@ def test_fused_relu_grad(shape, c1=0, poly_sch=False):
dtypes = [dtype] * 3
attrs = [c1]
if poly_sch:
mod = utils.op_build_test(fused_relu_grad, shapes, dtypes, op_attrs=attrs, kernel_name="fused_relu_grad", attrs={"target": "cuda"})
if mind_trick:
mod = utils.op_build_test(fused_relu_grad, shapes, dtypes, op_attrs=attrs, kernel_name="fused_relu_grad", attrs={"target": "cuda", "mind_trick": mind_trick})
else:
mod = utils.op_build_test(fused_relu_grad, shapes, dtypes, op_attrs=attrs, kernel_name="fused_relu_grad", attrs={"target": "cuda"})

output = np.full(shape, np.nan, dtype)
output = utils.mod_launch(mod, (*input, output), expect=expect)


+ 14
- 0
tests/st/composite/test_composite_json.py View File

@@ -54,6 +54,12 @@ def print_usage():
logging.info(template.format("", "Set attribute of 'bind_block' when use '-f' command."))
logging.info(template.format("--bind_thread=<args>", ""))
logging.info(template.format("", "Set attribute of 'bind_thread' when use '-f' command."))
logging.info(template.format("--mind-trick-enable=<0|1>", ""))
logging.info(template.format("", "explicitly enable (--mind-trick-enable=1) or disable (--mind-trick-enable=0) mind tricks"))
logging.info(template.format("--mind-trick-file", ""))
logging.info(template.format("", "json mind trick file"))
logging.info(template.format("--mind-trick-string", ""))
logging.info(template.format("", "json mind-trick string"))
logging.info("\n")


@@ -212,6 +218,7 @@ def main(argv):
try:
options, args = getopt.getopt(argv, "adcf:mh", ["auto", "manual", "ci", "profile",
"enable_atomic_add=", "dim=", "bind_block=", "bind_thread=",
"mind-trick-enable=", "mind-trick-file=", "mind-trick-string=",
"help"])
poly = True
single_file = False
@@ -247,6 +254,13 @@ def main(argv):
attrs_list["bind_block"] = value
elif option == "--bind_thread":
attrs_list["bind_thread"] = value
elif option == "--mind-trick-enable":
attrs_list['enable_mind_trick'] = int(value)
elif option == "--mind-trick-file":
with open(value, 'r') as f:
attrs_list['mind_trick'] = f.read()
elif option == "--mind-trick-string":
attrs_list['mind_trick'] = value
except:
print_usage()
return


+ 1
- 0
tests/st/ops/gpu/mind-trick_cases/operators/Fused_AddN_fusion_9584919353229493170.info View File

@@ -0,0 +1 @@
{"composite":true,"composite_graph":"11380.22873","id":1054,"input_desc":[[{"data_type":"float16","format":"DefaultFormat","shape":[4096,768],"tensor_name":"input_0"}],[{"data_type":"float16","format":"DefaultFormat","shape":[4096,768],"tensor_name":"input_1"}],[{"data_type":"float16","format":"DefaultFormat","shape":[4096,768],"tensor_name":"input_2"}],[{"data_type":"float16","format":"DefaultFormat","shape":[4096,768],"tensor_name":"input_3"}]],"op":"Fused_AddN_fusion_9584919353229493170","op_desc":[{"attr":[{"data_type":"listInt","name":"dyn_input_sizes","value":[4]},{"data_type":"int","name":"n","value":4}],"impl_path":"","input_desc":[[{"data_type":"float16","format":"DefaultFormat","name":"input_0","shape":[4096,768],"tensor_name":"input_0"},{"data_type":"float16","format":"DefaultFormat","name":"input_0","shape":[4096,768],"tensor_name":"input_1"},{"data_type":"float16","format":"DefaultFormat","name":"input_0","shape":[4096,768],"tensor_name":"input_2"},{"data_type":"float16","format":"DefaultFormat","name":"input_0","shape":[4096,768],"tensor_name":"input_3"}]],"name":"AddN","output_desc":[{"data_type":"float16","format":"DefaultFormat","name":"output","shape":[4096,768],"tensor_name":"output_0_0"}]}],"output_desc":[{"data_type":"float16","format":"DefaultFormat","shape":[4096,768],"tensor_name":"output_0_0"}],"platform":"AKG","process":"cuda"}

+ 1
- 0
tests/st/ops/gpu/mind-trick_cases/operators/Fused_Cast_BiasAdd_Gelu_fusion_7719078727474100806.info
File diff suppressed because it is too large
View File


+ 1
- 0
tests/st/ops/gpu/mind-trick_cases/operators/Fused_Cast_BiasAdd_GkDropout_tuple_getitem_TensorAdd_fusion_13282325956852925231.info View File

@@ -0,0 +1 @@
{"composite":true,"composite_graph":"11380.21701","id":929,"input_desc":[[{"data_type":"float32","format":"DefaultFormat","shape":[768],"tensor_name":"input_4"}],[{"data_type":"float16","format":"DefaultFormat","shape":[4096,768],"tensor_name":"input_5"}],[{"data_type":"float32","format":"DefaultFormat","shape":[4096,768],"tensor_name":"input_0"}],[{"data_type":"float16","format":"DefaultFormat","shape":[4096,768],"tensor_name":"input_11"}]],"op":"Fused_Cast_BiasAdd_GkDropout_tuple_getitem_TensorAdd_fusion_13282325956852925231","op_desc":[{"attr":[{"data_type":"str","name":"dst_type","value":"float16"},{"data_type":"bool","name":"is_backed_cast","value":false}],"impl_path":"","input_desc":[[{"data_type":"float32","format":"DefaultFormat","name":"input_0","shape":[4096,768],"tensor_name":"input_0"}]],"name":"Cast","output_desc":[{"data_type":"float16","format":"DefaultFormat","name":"output","shape":[4096,768],"tensor_name":"output_0_0"}]},{"attr":null,"impl_path":"","input_desc":[[{"data_type":"float16","format":"DefaultFormat","name":"input_0","shape":[4096,768],"tensor_name":"output_0_0"}],[{"data_type":"float16","format":"DefaultFormat","name":"input_1","shape":[1],"tensor_name":"input_2","value":0.89990234375}]],"name":"LessEqual","output_desc":[{"data_type":"float16","format":"DefaultFormat","name":"output","shape":[4096,768],"tensor_name":"output_0_1"}]},{"attr":[{"data_type":"str","name":"dst_type","value":"float16"},{"data_type":"bool","name":"is_backed_cast","value":false}],"impl_path":"","input_desc":[[{"data_type":"float16","format":"DefaultFormat","name":"input_0","shape":[4096,768],"tensor_name":"output_0_1"}]],"name":"Cast","output_desc":[{"data_type":"float16","format":"DefaultFormat","name":"output","shape":[4096,768],"tensor_name":"output_0_2"}]},{"attr":[{"data_type":"str","name":"dst_type","value":"float16"},{"data_type":"bool","name":"is_backed_cast","value":false}],"impl_path":"","input_desc":[[{"data_type":"float32","format":"DefaultFormat","name":"input_0","shape":[768],"tensor_name":"input_4"}]],"name":"Cast","output_desc":[{"data_type":"float16","format":"DefaultFormat","name":"output","shape":[768],"tensor_name":"output_0_3"}]},{"attr":null,"impl_path":"","input_desc":[[{"data_type":"float16","format":"DefaultFormat","name":"input_0","shape":[4096,768],"tensor_name":"input_5"}],[{"data_type":"float16","format":"DefaultFormat","name":"input_1","shape":[768],"tensor_name":"output_0_3"}]],"name":"TensorAdd","output_desc":[{"data_type":"float16","format":"DefaultFormat","name":"output","shape":[4096,768],"tensor_name":"output_0_4"}]},{"attr":null,"impl_path":"","input_desc":[[{"data_type":"float16","format":"DefaultFormat","name":"input_0","shape":[1],"tensor_name":"input_7","value":1.111328125}],[{"data_type":"float16","format":"DefaultFormat","name":"input_1","shape":[4096,768],"tensor_name":"output_0_4"}]],"name":"Mul","output_desc":[{"data_type":"float16","format":"DefaultFormat","name":"output","shape":[4096,768],"tensor_name":"output_0_5"}]},{"attr":null,"impl_path":"","input_desc":[[{"data_type":"float16","format":"DefaultFormat","name":"input_0","shape":[4096,768],"tensor_name":"output_0_5"}],[{"data_type":"float16","format":"DefaultFormat","name":"input_1","shape":[4096,768],"tensor_name":"output_0_2"}]],"name":"Mul","output_desc":[{"data_type":"float16","format":"DefaultFormat","name":"output","shape":[4096,768],"tensor_name":"output_0_6"}]},{"attr":null,"impl_path":"","input_desc":[[{"data_type":"float16","format":"DefaultFormat","name":"input_0","shape":[4096,768],"tensor_name":"input_11"}],[{"data_type":"float16","format":"DefaultFormat","name":"input_1","shape":[4096,768],"tensor_name":"output_0_6"}]],"name":"TensorAdd","output_desc":[{"data_type":"float16","format":"DefaultFormat","name":"output","shape":[4096,768],"tensor_name":"output_0_7"}]}],"output_desc":[{"data_type":"float16","format":"DefaultFormat","shape":[4096,768],"tensor_name":"output_0_2"},{"data_type":"float16","format":"DefaultFormat","shape":[4096,768],"tensor_name":"output_0_7"}],"platform":"AKG","process":"cuda"}

+ 1
- 0
tests/st/ops/gpu/mind-trick_cases/operators/Fused_Cast_RealDiv_Reshape_FusedAdamWeightDecay_fusion_1039082044534023692.info
File diff suppressed because it is too large
View File


+ 1
- 0
tests/st/ops/gpu/mind-trick_cases/operators/Fused_Cast_RealDiv_Reshape_FusedAdamWeightDecay_fusion_1545859458890067484.info
File diff suppressed because it is too large
View File


+ 1
- 0
tests/st/ops/gpu/mind-trick_cases/operators/Fused_Cast_RealDiv_Reshape_FusedAdamWeightDecay_fusion_1976850843332086880.info
File diff suppressed because it is too large
View File


+ 1
- 0
tests/st/ops/gpu/mind-trick_cases/operators/Fused_GkDropout_2353362030752466006.info View File

@@ -0,0 +1 @@
{"composite":true,"composite_graph":"akg_decode_3954_20083.20131","id":925,"input_desc":[[{"data_type":"float16","format":"DefaultFormat","shape":[32,12,128,128],"tensor_name":"input_5"}],[{"data_type":"float32","format":"DefaultFormat","shape":[32,12,128,128],"tensor_name":"input_0"}]],"op":"Fused_GkDropout_2353362030752466006","op_desc":[{"attr":[{"data_type":"str","name":"dst_type","value":"float16"},{"data_type":"bool","name":"is_backed_cast","value":false}],"impl_path":"","input_desc":[[{"data_type":"float32","format":"DefaultFormat","name":"input_0","shape":[32,12,128,128],"tensor_name":"input_0"}]],"name":"Cast","output_desc":[{"data_type":"float16","format":"DefaultFormat","name":"output","shape":[32,12,128,128],"tensor_name":"output_0_0"}]},{"attr":null,"impl_path":"","input_desc":[[{"data_type":"float16","format":"DefaultFormat","name":"input_0","shape":[32,12,128,128],"tensor_name":"output_0_0"}],[{"data_type":"float16","format":"DefaultFormat","name":"input_1","shape":[1],"tensor_name":"input_2","value":0.89990234375}]],"name":"LessEqual","output_desc":[{"data_type":"float16","format":"DefaultFormat","name":"output","shape":[32,12,128,128],"tensor_name":"output_0_1"}]},{"attr":[{"data_type":"str","name":"dst_type","value":"float16"},{"data_type":"bool","name":"is_backed_cast","value":false}],"impl_path":"","input_desc":[[{"data_type":"float16","format":"DefaultFormat","name":"input_0","shape":[32,12,128,128],"tensor_name":"output_0_1"}]],"name":"Cast","output_desc":[{"data_type":"float16","format":"DefaultFormat","name":"output","shape":[32,12,128,128],"tensor_name":"output_0_2"}]},{"attr":null,"impl_path":"","input_desc":[[{"data_type":"float16","format":"DefaultFormat","name":"input_0","shape":[1],"tensor_name":"input_4","value":1.111328125}],[{"data_type":"float16","format":"DefaultFormat","name":"input_1","shape":[32,12,128,128],"tensor_name":"input_5"}]],"name":"Mul","output_desc":[{"data_type":"float16","format":"DefaultFormat","name":"output","shape":[32,12,128,128],"tensor_name":"output_0_3"}]},{"attr":null,"impl_path":"","input_desc":[[{"data_type":"float16","format":"DefaultFormat","name":"input_0","shape":[32,12,128,128],"tensor_name":"output_0_3"}],[{"data_type":"float16","format":"DefaultFormat","name":"input_1","shape":[32,12,128,128],"tensor_name":"output_0_2"}]],"name":"Mul","output_desc":[{"data_type":"float16","format":"DefaultFormat","name":"output","shape":[32,12,128,128],"tensor_name":"output_0_4"}]}],"output_desc":[{"data_type":"float16","format":"DefaultFormat","shape":[32,12,128,128],"tensor_name":"output_0_4"},{"data_type":"float16","format":"DefaultFormat","shape":[32,12,128,128],"tensor_name":"output_0_2"}],"platform":"AKG","process":"cuda"}

+ 1
- 0
tests/st/ops/gpu/mind-trick_cases/operators/Fused_Sub_Mul_Mul_split_9258187064108063363.info View File

@@ -0,0 +1 @@
{"composite":true,"composite_graph":"25320.25320","id":696,"input_desc":[[{"data_type":"float32","format":"DefaultFormat","shape":[32,12,128,1],"tensor_name":"input_1"}],[{"data_type":"float32","format":"DefaultFormat","shape":[32,12,128,128],"tensor_name":"input_2"}],[{"data_type":"float32","format":"DefaultFormat","shape":[32,12,128,128],"tensor_name":"input_0"}]],"op":"Fused_Sub_Mul_Mul_split_9258187064108063363","op_desc":[{"attr":null,"impl_path":"","input_desc":[[{"data_type":"float32","format":"DefaultFormat","name":"input_0","shape":[32,12,128,128],"tensor_name":"input_0"}],[{"data_type":"float32","format":"DefaultFormat","name":"input_1","shape":[32,12,128,1],"tensor_name":"input_1"}]],"name":"Sub","output_desc":[{"data_type":"float32","format":"DefaultFormat","name":"output","shape":[32,12,128,128],"tensor_name":"output_0_0"}]},{"attr":null,"impl_path":"","input_desc":[[{"data_type":"float32","format":"DefaultFormat","name":"input_0","shape":[32,12,128,128],"tensor_name":"input_2"}],[{"data_type":"float32","format":"DefaultFormat","name":"input_1","shape":[32,12,128,128],"tensor_name":"output_0_0"}]],"name":"Mul","output_desc":[{"data_type":"float32","format":"DefaultFormat","name":"output","shape":[32,12,128,128],"tensor_name":"output_0_1"}]},{"attr":null,"impl_path":"","input_desc":[[{"data_type":"float32","format":"DefaultFormat","name":"input_0","shape":[1],"tensor_name":"input_4","value":0.125}],[{"data_type":"float32","format":"DefaultFormat","name":"input_1","shape":[32,12,128,128],"tensor_name":"output_0_1"}]],"name":"Mul","output_desc":[{"data_type":"float32","format":"DefaultFormat","name":"output","shape":[32,12,128,128],"tensor_name":"output_0_2"}]}],"output_desc":[{"data_type":"float32","format":"DefaultFormat","shape":[32,12,128,128],"tensor_name":"output_0_2"}],"platform":"AKG","process":"cuda"}

+ 1
- 0
tests/st/ops/gpu/mind-trick_cases/operators/Fused_Transpose_split_18185609042134105765.info View File

@@ -0,0 +1 @@
{"composite":true,"composite_graph":"27184.27184","id":926,"input_desc":[[{"data_type":"float16","format":"DefaultFormat","shape":[32,12,128,64],"tensor_name":"input_0"}]],"op":"Fused_Transpose_split_18185609042134105765","op_desc":[{"attr":[{"data_type":"listInt","name":"perm","value":[0,2,1,3]}],"impl_path":"","input_desc":[[{"data_type":"float16","format":"DefaultFormat","name":"input_0","shape":[32,12,128,64],"tensor_name":"input_0"}]],"name":"Transpose","output_desc":[{"data_type":"float16","format":"DefaultFormat","name":"output","shape":[32,128,12,64],"tensor_name":"output_0_0"}]}],"output_desc":[{"data_type":"float16","format":"DefaultFormat","shape":[32,128,12,64],"tensor_name":"output_0_0"}],"platform":"AKG","process":"cuda"}

+ 20
- 0
tests/st/ops/gpu/mind-trick_cases/tricks/Fused_AddN_fusion_9584919353229493170.json View File

@@ -0,0 +1,20 @@
{
"name": "Fused_AddN_fusion_9584919353229493170_0",
"operator": "Fused_AddN_fusion_9584919353229493170_0",
"domain": "{ S_0[cc0] : 0 <= cc0 <= 3145727 }",
"schedule": [
"{ S_0[cc0] -> [cc0/4096] }",
"{ S_0[cc0] -> [(cc0 mod 4096)/4] }",
"{ S_0[cc0] -> [(cc0 mod 4096) mod 4] }",
"{ S_0[cc0] -> [0] }"
],
"gpu": {
"blocks": [0],
"threads": [1],
"compiler flags": ["--use_fast_math"]
},
"disable": [
"GpuDmaAnalysis",
"TileOuterBand"
]
}

+ 22
- 0
tests/st/ops/gpu/mind-trick_cases/tricks/Fused_Cast_BiasAdd_Gelu_fusion_7719078727474100806.json View File

@@ -0,0 +1,22 @@
{
"name": "Fused_Cast_BiasAdd_Gelu_fusion_7719078727474100806_0",
"operator": "Fused_Cast_BiasAdd_Gelu_fusion_7719078727474100806_0",
"domain": [
"{ S_0[ax0_ax1_fused] : 0 <= ax0_ax1_fused <= 12582911 }",
"{ S_1[ax0_ax1_fused] : 0 <= ax0_ax1_fused <= 12582911 }"
],
"schedule": [
"{ S_0[ax0_ax1_fused] -> [ax0_ax1_fused/3072]; S_1[ax0_ax1_fused] -> [ax0_ax1_fused/3072] }",
"{ S_0[ax0_ax1_fused] -> [(ax0_ax1_fused mod 3072)/4]; S_1[ax0_ax1_fused] -> [(ax0_ax1_fused mod 3072)/4] }",
"{ S_0[ax0_ax1_fused] -> [(ax0_ax1_fused mod 3072) mod 4]; S_1[ax0_ax1_fused] -> [(ax0_ax1_fused mod 3072) mod 4] }",
"{ S_0[ax0_ax1_fused] -> [0]; S_1[ax0_ax1_fused] -> [1] }"
],
"gpu": {
"blocks": [0],
"threads": [1]
},
"disable": [
"GpuDmaAnalysis",
"TileOuterBand"
]
}

+ 22
- 0
tests/st/ops/gpu/mind-trick_cases/tricks/Fused_Cast_BiasAdd_GkDropout_tuple_getitem_TensorAdd_fusion_13282325956852925231.json View File

@@ -0,0 +1,22 @@
{
"name": "Fused_Cast_BiasAdd_GkDropout_tuple_getitem_TensorAdd_fusion_13282325956852925231_0",
"operator": "Fused_Cast_BiasAdd_GkDropout_tuple_getitem_TensorAdd_fusion_13282325956852925231_0",
"domain": [
"{ S_0[ax0_ax1_fused] : 0 <= ax0_ax1_fused <= 3145727 }",
"{ S_1[ax0_ax1_fused] : 0 <= ax0_ax1_fused <= 3145727 }"
],
"schedule": [
"{ S_0[ax0_ax1_fused] -> [ax0_ax1_fused/768]; S_1[ax0_ax1_fused] -> [ax0_ax1_fused/768] }",
"{ S_0[ax0_ax1_fused] -> [(ax0_ax1_fused mod 768)/4]; S_1[ax0_ax1_fused] -> [(ax0_ax1_fused mod 768)/4] }",
"{ S_0[ax0_ax1_fused] -> [(ax0_ax1_fused mod 768) mod 4]; S_1[ax0_ax1_fused] -> [(ax0_ax1_fused mod 768) mod 4] }",
"{ S_0[ax0_ax1_fused] -> [0]; S_1[ax0_ax1_fused] -> [1] }"
],
"gpu": {
"blocks": [0],
"threads": [1]
},
"disable": [
"GpuDmaAnalysis",
"TileOuterBand"
]
}

+ 57
- 0
tests/st/ops/gpu/mind-trick_cases/tricks/Fused_Cast_RealDiv_Reshape_FusedAdamWeightDecay_fusion_1039082044534023692.json View File

@@ -0,0 +1,57 @@
{
"name": "Fused_Cast_RealDiv_Reshape_FusedAdamWeightDecay_fusion_1039082044534023692_0",
"operator": "Fused_Cast_RealDiv_Reshape_FusedAdamWeightDecay_fusion_1039082044534023692_0",
"domain": [
"{ S_0[ax0, ax1] : 0 <= ax0 <= 767 and 0 <= ax1 <= 3071 }",
"{ S_1[ax0, ax1] : 0 <= ax0 <= 767 and 0 <= ax1 <= 3071 }",
"{ S_2[ax0, ax1] : 0 <= ax0 <= 767 and 0 <= ax1 <= 3071 }"
],
"pattern": "{ domain: \"{ S_0[ax0, ax1] : 0 <= ax0 <= 767 and 0 <= ax1 <= 3071; S_2[ax0, ax1] : 0 <= ax0 <= 767 and 0 <= ax1 <= 3071; S_1[ax0, ax1] : 0 <= ax0 <= 767 and 0 <= ax1 <= 3071 }\", child: { sequence: [ { filter: \"{ S_0[ax0, ax1] }\", child: { schedule: \"[{ S_0[ax0, ax1] -> [(ax0)] }]\", child: { schedule: \"[{ S_0[ax0, ax1] -> [(ax1)] }]\" } } }, { filter: \"{ S_1[ax0, ax1] }\", child: { schedule: \"[{ S_1[ax0, ax1] -> [(ax0)] }]\", child: { schedule: \"[{ S_1[ax0, ax1] -> [(ax1)] }]\" } } }, { filter: \"{ S_2[ax0, ax1] }\", child: { schedule: \"[{ S_2[ax0, ax1] -> [(ax0)] }]\", child: { schedule: \"[{ S_2[ax0, ax1] -> [(ax1)] }]\" } } } ] } }",
"schedule": [
"{ S_0[ax0, ax1] -> [ax0]; S_1[ax0, ax1] -> [ax0]; S_2[ax0, ax1] -> [ax0] }",
"{ S_0[ax0, ax1] -> [ax1/4]; S_1[ax0, ax1] -> [ax1/4]; S_2[ax0, ax1] -> [ax1/4] }",
"{ S_0[ax0, ax1] -> [ax1 mod 4]; S_1[ax0, ax1] -> [ax1 mod 4]; S_2[ax0, ax1] -> [ax1 mod 4] }",
"{ S_0[ax0, ax1] -> [0]; S_1[ax0, ax1] -> [1]; S_2[ax0, ax1] -> [2] }"
],
"soft constraints": [
{
"statement": "S_0",
"meta": [2, 0],
"coefficients": [
"[?1, 0, ?3]",
"[?1, 1, ?3] (/4)",
"[?1, 1, ?3] (%4)",
"[0, 0, 0]"
]
},
{
"statement": "S_1",
"meta": [2, 0],
"coefficients": [
"[?1, 0, ?3]",
"[?1, 1, ?3] (/4)",
"[?1, 1, ?3] (%4)",
"[0, 0, 1]"
]
},
{
"statement": "S_2",
"meta": [2, 0],
"coefficients": [
"[?1, 0, ?3]",
"[?1, 1, ?3] (/4)",
"[?1, 1, ?3] (%4)",
"[0, 0, 2]"
]
}
],
"gpu": {
"blocks": [0],
"threads": [1],
"compiler flags": ["--use_fast_math"]
},
"disable": [
"GpuDmaAnalysis",
"TileOuterBand"
]
}

+ 57
- 0
tests/st/ops/gpu/mind-trick_cases/tricks/Fused_Cast_RealDiv_Reshape_FusedAdamWeightDecay_fusion_1545859458890067484.json View File

@@ -0,0 +1,57 @@
{
"name": "Fused_Cast_RealDiv_Reshape_FusedAdamWeightDecay_fusion_1545859458890067484_0",
"operator": "Fused_Cast_RealDiv_Reshape_FusedAdamWeightDecay_fusion_1545859458890067484_0",
"domain": [
"{ S_0[ax0, ax1] : 0 <= ax0 <= 767 and 0 <= ax1 <= 767 }",
"{ S_1[ax0, ax1] : 0 <= ax0 <= 767 and 0 <= ax1 <= 767 }",
"{ S_2[ax0, ax1] : 0 <= ax0 <= 767 and 0 <= ax1 <= 767 }"
],
"pattern": "{ domain: \"{ S_0[ax0, ax1] : 0 <= ax0 <= 767 and 0 <= ax1 <= 767; S_2[ax0, ax1] : 0 <= ax0 <= 767 and 0 <= ax1 <= 767; S_1[ax0, ax1] : 0 <= ax0 <= 767 and 0 <= ax1 <= 767 }\", child: { sequence: [ { filter: \"{ S_0[ax0, ax1] }\", child: { schedule: \"[{ S_0[ax0, ax1] -> [(ax0)] }]\", child: { schedule: \"[{ S_0[ax0, ax1] -> [(ax1)] }]\" } } }, { filter: \"{ S_1[ax0, ax1] }\", child: { schedule: \"[{ S_1[ax0, ax1] -> [(ax0)] }]\", child: { schedule: \"[{ S_1[ax0, ax1] -> [(ax1)] }]\" } } }, { filter: \"{ S_2[ax0, ax1] }\", child: { schedule: \"[{ S_2[ax0, ax1] -> [(ax0)] }]\", child: { schedule: \"[{ S_2[ax0, ax1] -> [(ax1)] }]\" } } } ] } }",
"schedule": [
"{ S_0[ax0, ax1] -> [ax0]; S_1[ax0, ax1] -> [ax0]; S_2[ax0, ax1] -> [ax0] }",
"{ S_0[ax0, ax1] -> [ax1/4]; S_1[ax0, ax1] -> [ax1/4]; S_2[ax0, ax1] -> [ax1/4] }",
"{ S_0[ax0, ax1] -> [ax1 mod 4]; S_1[ax0, ax1] -> [ax1 mod 4]; S_2[ax0, ax1] -> [ax1 mod 4] }",
"{ S_0[ax0, ax1] -> [0]; S_1[ax0, ax1] -> [1]; S_2[ax0, ax1] -> [2] }"
],
"soft constraints": [
{
"statement": "S_0",
"meta": [2, 0],
"coefficients": [
"[?1, 0, ?3]",
"[?1, 1, ?3] (/4)",
"[?1, 1, ?3] (%4)",
"[0, 0, 0]"
]
},
{
"statement": "S_1",
"meta": [2, 0],
"coefficients": [
"[?1, 0, ?3]",
"[?1, 1, ?3] (/4)",
"[?1, 1, ?3] (%4)",
"[0, 0, 1]"
]
},
{
"statement": "S_2",
"meta": [2, 0],
"coefficients": [
"[?1, 0, ?3]",
"[?1, 1, ?3] (/4)",
"[?1, 1, ?3] (%4)",
"[0, 0, 2]"
]
}
],
"gpu": {
"blocks": [0],
"threads": [1],
"compiler flags": ["--use_fast_math"]
},
"disable": [
"GpuDmaAnalysis",
"TileOuterBand"
]
}

+ 57
- 0
tests/st/ops/gpu/mind-trick_cases/tricks/Fused_Cast_RealDiv_Reshape_FusedAdamWeightDecay_fusion_1976850843332086880.json View File

@@ -0,0 +1,57 @@
{
"name": "Fused_Cast_RealDiv_Reshape_FusedAdamWeightDecay_fusion_1976850843332086880_0",
"operator": "Fused_Cast_RealDiv_Reshape_FusedAdamWeightDecay_fusion_1976850843332086880_0",
"domain": [
"{ S_0[ax0, ax1] : 0 <= ax0 <= 3071 and 0 <= ax1 <= 767 }",
"{ S_1[ax0, ax1] : 0 <= ax0 <= 3071 and 0 <= ax1 <= 767 }",
"{ S_2[ax0, ax1] : 0 <= ax0 <= 3071 and 0 <= ax1 <= 767 }"
],
"pattern": "{ domain: \"{ S_0[ax0, ax1] : 0 <= ax0 <= 3071 and 0 <= ax1 <= 767; S_2[ax0, ax1] : 0 <= ax0 <= 3071 and 0 <= ax1 <= 767; S_1[ax0, ax1] : 0 <= ax0 <= 3071 and 0 <= ax1 <= 767 }\", child: { sequence: [ { filter: \"{ S_0[ax0, ax1] }\", child: { schedule: \"[{ S_0[ax0, ax1] -> [(ax0)] }]\", child: { schedule: \"[{ S_0[ax0, ax1] -> [(ax1)] }]\" } } }, { filter: \"{ S_1[ax0, ax1] }\", child: { schedule: \"[{ S_1[ax0, ax1] -> [(ax0)] }]\", child: { schedule: \"[{ S_1[ax0, ax1] -> [(ax1)] }]\" } } }, { filter: \"{ S_2[ax0, ax1] }\", child: { schedule: \"[{ S_2[ax0, ax1] -> [(ax0)] }]\", child: { schedule: \"[{ S_2[ax0, ax1] -> [(ax1)] }]\" } } } ] } }",
"schedule": [
"{ S_0[ax0, ax1] -> [ax0]; S_1[ax0, ax1] -> [ax0]; S_2[ax0, ax1] -> [ax0] }",
"{ S_0[ax0, ax1] -> [ax1/4]; S_1[ax0, ax1] -> [ax1/4]; S_2[ax0, ax1] -> [ax1/4] }",
"{ S_0[ax0, ax1] -> [ax1 mod 4]; S_1[ax0, ax1] -> [ax1 mod 4]; S_2[ax0, ax1] -> [ax1 mod 4] }",
"{ S_0[ax0, ax1] -> [0]; S_1[ax0, ax1] -> [1]; S_2[ax0, ax1] -> [2] }"
],
"soft constraints": [
{
"statement": "S_0",
"meta": [2, 0],
"coefficients": [
"[?1, 0, ?3]",
"[?1, 1, ?3] (/4)",
"[?1, 1, ?3] (%4)",
"[0, 0, 0]"
]
},
{
"statement": "S_1",
"meta": [2, 0],
"coefficients": [
"[?1, 0, ?3]",
"[?1, 1, ?3] (/4)",
"[?1, 1, ?3] (%4)",
"[0, 0, 1]"
]
},
{
"statement": "S_2",
"meta": [2, 0],
"coefficients": [
"[?1, 0, ?3]",
"[?1, 1, ?3] (/4)",
"[?1, 1, ?3] (%4)",
"[0, 0, 2]"
]
}
],
"gpu": {
"blocks": [0],
"threads": [1],
"compiler flags": ["--use_fast_math"]
},
"disable": [
"GpuDmaAnalysis",
"TileOuterBand"
]
}

+ 19
- 0
tests/st/ops/gpu/mind-trick_cases/tricks/Fused_GkDropout_2353362030752466006.json View File

@@ -0,0 +1,19 @@
{
"name": "Fused_GkDropout_2353362030752466006_0",
"operator": "Fused_GkDropout_2353362030752466006_0",
"domain": "{ S_0[cc0] : 0 <= cc0 <= 6291455; S_1[cc0] : 0 <= cc0 <= 6291455 }",
"schedule": [
"{ S_0[cc0] -> [cc0/3072]; S_1[cc0] -> [cc0/3072] }",
"{ S_0[cc0] -> [(cc0 mod 3072)/4]; S_1[cc0] -> [(cc0 mod 3072)/4] }",
"{ S_0[cc0] -> [(cc0 mod 3072) mod 4]; S_1[cc0] -> [(cc0 mod 3072) mod 4] }",
"{ S_0[cc0] -> [0]; S_1[cc0] -> [1] }"
],
"gpu": {
"blocks": [0],
"threads": [1]
},
"disable": [
"GpuDmaAnalysis",
"TileOuterBand"
]
}

+ 37
- 0
tests/st/ops/gpu/mind-trick_cases/tricks/Fused_Transpose_split_18185609042134105765.json View File

@@ -0,0 +1,37 @@
{
"name": "Fused_Transpose_split_18185609042134105765_0",
"operator": "Fused_Transpose_split_18185609042134105765_0",
"domain": "{ S_0[ax0, ax1, ax2, ax3] : 0 <= ax0 <= 31 and 0 <= ax1 <= 127 and 0 <= ax2 <= 11 and 0 <= ax3 <= 63 }",
"pattern": "{ domain: \"{ S_0[ax0, ax1, ax2, ax3] : 0 <= ax0 <= 31 and 0 <= ax1 <= 127 and 0 <= ax2 <= 11 and 0 <= ax3 <= 63 }\", child: { schedule: \"[{ S_0[ax0, ax1, ax2, ax3] -> [(ax0)] }]\", child: { schedule: \"[{ S_0[ax0, ax1, ax2, ax3] -> [(ax1)] }]\", child: { schedule: \"[{ S_0[ax0, ax1, ax2, ax3] -> [(ax2)] }]\", child: { schedule: \"[{ S_0[ax0, ax1, ax2, ax3] -> [(ax3)] }]\" } } } } }",
"schedule": [
"{ S_0[ax0, ax1, ax2, ax3] -> [ax0] }",
"{ S_0[ax0, ax1, ax2, ax3] -> [ax1] }",
"{ S_0[ax0, ax1, ax2, ax3] -> [ax2] }",
"{ S_0[ax0, ax1, ax2, ax3] -> [ax3/4] }",
"{ S_0[ax0, ax1, ax2, ax3] -> [ax3 mod 4] }",
"{ S_0[ax0, ax1, ax2, ax3] -> [0] }"
],
"soft constraints": [
{
"statement": "S_0",
"meta": [4, 0],
"coefficients": [
"[?, ?, ?, 0, ?]",
"[?, ?, ?, 0, ?]",
"[?, ?, ?, 0, ?]",
"[?, ?, ?, 1, ?] (/4)",
"[?, ?, ?, 1, ?] (%4)",
"[0, 0, 0, 0, 0]"
]
}
],
"gpu": {
"blocks": [0, 1],
"threads": [2, 3],
"compiler flags": ["--use_fast_math"]
},
"disable": [
"GpuDmaAnalysis",
"TileOuterBand"
]
}

+ 7
- 0
tests/st/ops/gpu/mind-trick_cases/tricks/fused_bn_update.json View File

@@ -0,0 +1,7 @@
{
"name": "fused_bn_update",
"operator": "fused_bn_update_auto_float32_2048_float32_2048_float32_2048_float32_2048_float32_7_971938775510203e_05_1_001e_05_1_00007975_0_100000024",
"pattern": "{ domain: \"{ S_2[ax0] : 0 <= ax0 <= 2047; S_0[i0] : 0 <= i0 <= 2047; S_1[ax0] : 0 <= ax0 <= 2047 }\", child: { sequence: [ { filter: \"{ S_0[i0] }\", child: { schedule: \"[{ S_0[i0] -> [(i0)] }]\" } }, { filter: \"{ S_1[ax0] }\", child: { schedule: \"[{ S_1[ax0] -> [(ax0)] }]\"} }, { filter: \"{ S_2[ax0] }\", child: { schedule: \"[{ S_2[ax0] -> [(ax0)] }]\" } } ] } }",
"tree": "{ \"domain\": \"{ S_2[ax0] : 0 <= ax0 <= 2047; S_0[i0] : 0 <= i0 <= 2047; S_1[ax0] : 0 <= ax0 <= 2047 }\", \"child\": { \"context\": \"[b0, t0] -> {[]: 0 <= b0 < 2 and 0 <= t0 < 256 }\", \"child\": { \"mark\": \"block_marker\", \"child\": { \"filter\": \"[b0] -> { S_0[i]: b0 = floor(i / 1024); S_1[i]: b0 = floor(i / 1024); S_2[i]: b0 = floor(i / 1024) }\", \"child\": { \"schedule\": \"[{ S_0[i] -> [i/1024]; S_1[i] -> [i/1024]; S_2[i] -> [i/1024] }]\", \"child\": { \"mark\": \"thread_marker\", \"child\": { \"filter\": \"[t0] -> { S_0[i]: t0 = floor((i / 4) mod 256); S_1[i]: t0 = floor((i / 4) mod 256); S_2[i]: t0 = floor((i / 4) mod 256) }\", \"child\": { \"schedule\": \"[{ S_0[i] -> [(i/4) mod 256]; S_1[i] -> [(i/4) mod 256]; S_2[i] -> [(i/4) mod 256] }]\", \"child\": { \"schedule\": \"[{S_0[i] -> [i mod 4]; S_1[i] -> [i mod 4]; S_2[i] -> [i mod 4] }]\", \"child\": { \"schedule\": \"[{ S_0[i] -> [2]; S_1[i] -> [1]; S_2[i] -> [0] }]\" } } } } } } } } } }",
"disable": [ "GpuDmaAnalysis", "TileOuterBand", "MappingOuterBand" ]
}

+ 34
- 0
tests/st/ops/gpu/mind-trick_cases/tricks/fused_relu_grad.json View File

@@ -0,0 +1,34 @@
{
"name": "fused_relu_grad",
"operator": "fused_relu_grad_auto_float16_256_56_56_256_float16_256_56_56_256_float16_256_56_56_256_0",
"pattern": "{ domain: \"{ S_0[ax0, ax1, ax2, ax3] : 0 <= ax0 <= 255 and 0 <= ax1 <= 55 and 0 <= ax2 <= 55 and 0 <= ax3 <= 255 }\", child: { schedule: \"[{ S_0[ax0, ax1, ax2, ax3] -> [(ax0)] }]\", child: { schedule: \"[{ S_0[ax0, ax1, ax2, ax3] -> [(ax1)] }]\", child: { schedule: \"[{ S_0[ax0, ax1, ax2, ax3] -> [(ax2)] }]\", child: { schedule: \"[{ S_0[ax0, ax1, ax2, ax3] -> [(ax3)] }]\" } } } } }",
"disable": [ "GpuDmaAnalysis", "TileOuterBand" ],
"gpu": {
"blocks": [0],
"threads": [1]
},
"domain": "{ S_0[i0, i1, i2, i3] : 0 <= i0 <= 255 and 0 <= i1 <= 55 and 0 <= i2 <= 55 and 0 <= i3 <= 255 }",
"schedule": [
"{ S_0[i0, i1, i2, i3] -> [(i0 * 56 + i1) * 14 + i2 / 4] }",
"{ S_0[i0, i1, i2, i3] -> [(i2 mod 4) * 64 + (i3 / 4)] }",
"{ S_0[i0, i1, i2, i3] -> [i3 mod 4] }"
],
"soft constraints": [
{
"statement": "S_0",
"meta": [4, 0],
"coefficients": [
"[1, 0, 0, 0, 0]",
"[0, 1, 0, 0, 0]",
"[0, 0, 1, 0, 0] (/4)",
"[0, 0, 1, 0, 0] (%4)",
"[0, 0, 0, 1, 0] (/4)",
"[0, 0, 0, 1, 0] (%4)"
]
}
],
"post": [
{ "collapse": [3, 4] },
{ "collapse": [0, 2] }
]
}

+ 9
- 0
tests/st/ops/gpu/test_all.py View File

@@ -62,6 +62,7 @@ from tests.st.ops.gpu.test_fused_relu_grad import test_fused_relu_grad
from tests.st.ops.gpu.test_fused_bn_update_grad import test_fused_bn_update_grad
from tests.st.ops.gpu.test_fused_mul_div_rsqrt_mul_isfinite_red import test_fused_mul_div_rsqrt_mul_isfinite_red
from tests.st.ops.gpu.test_ms_composite_stitch import test_composite_stitch
from tests.st.ops.gpu.test_ms_mindtricks import test_mindtricks


@pytest.mark.level0
@@ -505,6 +506,14 @@ def test_ms_composite_buffer_stitch():
return True


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_ms_mindtricks():
test_mindtricks()
return True


class Logger(object):
def __init__(self, filename, stream):
self.terminal = stream


+ 5
- 2
tests/st/ops/gpu/test_fused_bn_update.py View File

@@ -44,15 +44,18 @@ def compute_expect(input, c1, c2, c3, c4):


def test_fused_bn_update(shape, dtype="float32", c1=(1 / (256 * 7 * 7)), c2=1.001e-05, c3=1.00007975, c4=0.100000024,
poly_sch=False):
poly_sch=False, mind_trick=''):
input = gen_data(shape, dtype)
expect = compute_expect(input, c1, c2, c3, c4)
attrs = [dtype, c1, c2, c3, c4]
shapes = [input[0].shape] * 4
dtypes = [dtype] * 4
if poly_sch:
mod_attrs = { "target": "cuda" }
if mind_trick:
mod_attrs["mind_trick"] = mind_trick
mod = utils.op_build_test(fused_bn_update, shapes, dtypes, kernel_name="fused_bn_update", op_attrs=attrs,
attrs={"target": "cuda"})
attrs=mod_attrs)

outputs = [np.full(shape, np.nan, dtype)] * 3
attrs_list = input + outputs


+ 5
- 2
tests/st/ops/gpu/test_fused_relu_grad.py View File

@@ -38,7 +38,7 @@ def compute_expect(input, c1):
return np.where(cmp_zero, data_add, data_zero)


def test_fused_relu_grad(shape, c1=0, poly_sch=False):
def test_fused_relu_grad(shape, c1=0, poly_sch=False, mind_trick=''):
dtype = 'float16'
input = gen_data(shape, dtype)
expect = compute_expect(input, c1)
@@ -46,8 +46,11 @@ def test_fused_relu_grad(shape, c1=0, poly_sch=False):
dtypes = [dtype] * 3
attrs = [c1]
if poly_sch:
mod_attrs = { "target": "cuda" }
if mind_trick:
mod_attrs["mind_trick"] = mind_trick
mod = utils.op_build_test(fused_relu_grad, shapes, dtypes, op_attrs=attrs, kernel_name="fused_relu_grad",
attrs={"target": "cuda"})
attrs=mod_attrs)

output = np.full(shape, np.nan, dtype)
output = utils.mod_launch(mod, (*input, output), expect=expect)


+ 159
- 0
tests/st/ops/gpu/test_ms_mindtricks.py View File

@@ -0,0 +1,159 @@
#!/usr/bin/env python3
# coding: utf-8
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License


########################################################################################################################

# Basic libraries
import os
import sys
import json
import pytest
import logging

# For composite cases
from akg import composite
from akg.utils import kernel_exec as utils
from akg.utils.result_analysis import gpu_profiling
from akg.utils.format_transform import to_tvm_nd_array
from tests.common.gen_json_data import gen_json_data
from tests.common.base import get_rtol_atol
from tests.common.tensorio import compare_tensor

# Specific GPU tests
from tests.st.ops.gpu.test_fused_bn_update import test_fused_bn_update
from tests.st.ops.gpu.test_fused_relu_grad import test_fused_relu_grad

########################################################################################################################

mind_trick_cases_dir = "./mind-trick_cases"

# Note: no need to hardcode trick paths for composite_operators unless the trick's name differs from the operator name
tricks_dir = mind_trick_cases_dir + "/tricks"
tricks_filenames = {
"fused_bn_update": "fused_bn_update.json",
"fused_relu_grad": "fused_relu_grad.json",
}

composite_operators_dir = mind_trick_cases_dir + "/operators"
composite_operators = [
"Fused_AddN_fusion_9584919353229493170",
"Fused_Cast_BiasAdd_Gelu_fusion_7719078727474100806",
"Fused_Cast_BiasAdd_GkDropout_tuple_getitem_TensorAdd_fusion_13282325956852925231",
"Fused_Cast_RealDiv_Reshape_FusedAdamWeightDecay_fusion_1039082044534023692",
"Fused_Cast_RealDiv_Reshape_FusedAdamWeightDecay_fusion_1545859458890067484",
"Fused_Cast_RealDiv_Reshape_FusedAdamWeightDecay_fusion_1976850843332086880",
"Fused_GkDropout_2353362030752466006",
"Fused_Transpose_split_18185609042134105765",
]

########################################################################################################################

def _get_json_dict(desc):
return json.loads(desc) if isinstance(desc, str) else desc

def _get_backend(desc):
json_obj = _get_json_dict(desc)
if "process" not in json_obj.keys():
logging.info("Can't identify the backend.")
return None
return json_obj["process"]

def _compare_func(output, expect):
rtol, atol = get_rtol_atol("FUSED", str(output.dtype))
return compare_tensor(output, expect, rtol=rtol, atol=atol)

def get_result(desc, poly, attrs=None):
backend = _get_backend(desc)
if attrs is None:
attrs = {}

build_attrs = attrs if attrs else None
mod = composite.build(desc, build_attrs, poly=poly)

input_for_mod, expect, output_indexes = gen_json_data(desc)
output = utils.mod_launch(mod, input_for_mod, output_indexes)

if not all(map(_compare_func, output if isinstance(output, (list, tuple)) else [output],
expect if isinstance(expect, (list, tuple)) else [expect])):
logging.info(mod.imported_modules[0].get_source())
return False
if backend == "cuda":
inputs = to_tvm_nd_array(input_for_mod)
expect = to_tvm_nd_array(expect)
gpu_profiling(mod, *inputs, *expect, repeat_time=400)
return True

def test_gpu_cases():
# Check some GPU cases
with open(tricks_dir + "/" + tricks_filenames["fused_bn_update"], "r") as trick:
test_fused_bn_update((2048,), poly_sch=True, mind_trick=trick.read())
with open(tricks_dir + "/" + tricks_filenames["fused_relu_grad"], "r") as trick:
test_fused_relu_grad((256, 56, 56, 256), poly_sch=True, mind_trick=trick.read())

return True

def test_single_composite_file(input_file, attrs, poly):
with open(input_file, 'r') as f:
desc = f.read()
if get_result(desc, poly, attrs):
logging.info("Run Pass!")
else:
logging.info("Precision Error")
raise ValueError("Precision Error")

def test_composite_cases(operators=composite_operators):
for operator in operators:
operator_path = ""
if os.path.isfile(composite_operators_dir + "/" + operator + ".info"):
operator_path = composite_operators_dir + "/" + operator + ".info"
elif os.path.isfile(composite_operators_dir + "/" + operator + ".json"):
operator_path = composite_operators_dir + "/" + operator + ".json"
else:
logging.info("could not find desc for operator: " + operator)

trick_path = ""
if operator in tricks_filenames:
trick_path = tricks_dir + "/" + tricks_filenames[operator]
elif os.path.isfile(tricks_dir + "/" + operator + ".json"):
trick_path = tricks_dir + "/" + operator + ".json"
else:
logging.info("could not find trick for operator: " + operator)

if os.path.isfile(operator_path) and os.path.isfile(trick_path):
trick = open(trick_path, "r")
attrs = {
"target": "cuda",
"mind_trick": trick.read(),
}
test_single_composite_file(operator_path, attrs, poly=True)

return True

########################################################################################################################

def test_mindtricks(cases=["gpu", "composite"]):
if "gpu" in cases:
test_gpu_cases()
if "composite" in cases:
test_composite_cases(composite_operators)

return True

########################################################################################################################

if __name__ == '__main__':
test_mindtricks()

+ 2
- 2
tests/test_env.sh View File

@@ -20,8 +20,8 @@ if [ $# -eq 1 ] && [ $1 = "gpu_ci" ]; then
echo "Argument gpu_ci is used in CI and will be deprecated."
else
CUR_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
AKG_DIR="${CUR_DIR}/.."
AKG_BUILD_DIR="${AKG_DIR}/build"
AKG_DIR="${AKG_DIR:-${CUR_DIR}/..}"
AKG_BUILD_DIR="${AKG_BUILD_DIR:-${AKG_DIR}/build}"
TVM_ROOT="${AKG_DIR}/third_party/incubator-tvm"

export LD_LIBRARY_PATH=${AKG_BUILD_DIR}:${LD_LIBRARY_PATH}


+ 5
- 1
third_party/incubator-tvm/include/tvm/ir.h View File

@@ -583,6 +583,8 @@ class Call : public ExprNode {
static constexpr const char* glsl_texture_store = "glsl_texture_store";
static constexpr const char* prefetch = "prefetch";
static constexpr const char* isnan = "isnan";
static constexpr const char* reinterpret_cast_op = "reinterpret_cast_op";
static constexpr const char* ldg = "ldg";

/*! \brief Vectorizable intrinsic list. */
static const char* vectorizable_intrinsics[];
@@ -1107,7 +1109,9 @@ enum class ForType : int {
/*! \brief Vector SIMD loop annotaion. */
Vectorized = 2,
/*! \brief Unroll annotation. */
Unrolled = 3
Unrolled = 3,
/*! \brief Swizzle loop on GPU. */
Swizzled = 4
};

// Kevice api of for loop


+ 15
- 0
third_party/incubator-tvm/python/tvm/contrib/nvcc.py View File

@@ -92,6 +92,21 @@ def compile_cuda(code,
file_target = path_target if path_target else temp_target
cmd = ["nvcc"]
cmd += ["--%s" % target, "-O3"]

mind_trick_gpu_flags = ".akg_gpu_compiler_flags_{}".format(os.getpid())
if os.path.isfile(mind_trick_gpu_flags):
with open(mind_trick_gpu_flags, 'r') as f:
# This file should contain one flag per line
for line in f:
# Ensure we only use flags we explicitely allow and avoid arbitrary flags users may try to inject
flag = line.rstrip()
if flag == "-use_fast_math":
cmd += flag
try:
os.remove(mind_trick_gpu_flags)
except:
warnings.warn("Could not delete file {}".format(mind_trick_gpu_flags))

if isinstance(arch, list):
cmd += arch
else:


+ 170
- 2
third_party/incubator-tvm/src/codegen/codegen_cuda.cc View File

@@ -60,6 +60,7 @@
#include <cmath>
#include <vector>
#include <string>
#include <tvm/ir_pass.h>
#include "literal/cuda_half_t.h"
#include "codegen_cuda.h"

@@ -130,6 +131,33 @@ std::string CodeGenCUDA::Finish() {
}
}

// TODO add condition
decl_stream << "// built-in for half swizzle\n"
"#include <cuda_fp16.h>\n"
"struct __device_builtin__ __align__(8) half4 { half x, y, z, w; };\n"
"\n"
"#if defined(__CUDACC_RTC__)\n"
"#define __CUDA_FP16_DECL__ __host__ __device__\n"
"#else\n"
"#define __CUDA_FP16_DECL__ static __device__ __inline__\n"
"#endif\n"
"\n"
"// half4 ldg function support\n"
"#if (defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__) || defined(__CUDACC_RTC__)\n"
"#define __LDG_PTR \"l\"\n"
"#else\n"
"// not sure about this one, it was copied from the half2 ldg() function\n"
"#define __LDG_PTR \"r\"\n"
"#endif /*(defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__) || defined(__CUDACC_RTC__)*/\n"
"\n"
"#define __HALF4_TO_UI(var) *(reinterpret_cast<unsigned long *>(&(var)))\n"
"__CUDA_FP16_DECL__ half4 __ldg(const half4 *ptr)\n"
"{\n"
" half4 ret;\n"
" asm (\"ld.global.nc.b64 %0, [%1];\" : \"=l\"(__HALF4_TO_UI(ret)) : __LDG_PTR(ptr));\n"
" return ret;\n"
"}\n\n";

return CodeGenC::Finish();
}

@@ -138,6 +166,11 @@ void CodeGenCUDA::VisitStmt_(const ir::For* op) {
PrintIndent();
stream << "#pragma unroll\n";
}
else if (op->for_type == ir::ForType::Swizzled) {
// remove this loop
PrintStmt(op->body);
return;
}
CodeGenC::VisitStmt_(op);
}

@@ -163,7 +196,11 @@ void CodeGenCUDA::PrintType(Type t, std::ostream& os) { // NOLINT(*)
os << "half";
} else if (lanes <= 8) {
CHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
os << "float" << lanes / 2;
if (lanes <= 4) { // added for swizzle
os << "half" << lanes;
} else {
os << "float" << lanes / 2;
}
} else {
fail = true;
}
@@ -291,6 +328,7 @@ void CodeGenCUDA::PrintVecElemLoad(
}
}


void CodeGenCUDA::PrintVecElemStore(
const std::string& vec, Type t, int i, const std::string& value) {
this->PrintIndent();
@@ -516,11 +554,119 @@ void CodeGenCUDA::VisitExpr_(const Call *op, std::ostream& os) {
return;
}
CodeGenC::VisitExpr_(op, os);
} else if (op->is_intrinsic(Call::reinterpret_cast_op)) {
os << "*(reinterpret_cast<";
PrintType(op->args[1].type(), os);
if (op->args[0].as<IntImm>()->value > 0) os << op->args[0];
os << "*>(&";
auto ld = op->args[1].as<Load>();
if (ld) {
auto var_name = ld->buffer_var->name_hint;
if (std::find_if(vec_loads.begin(), vec_loads.end(),
[var_name](const Variable *v) { return (v->name_hint == var_name); }) != vec_loads.end()) {
os << "sw_" + var_name;
} else this->PrintExpr(op->args[1], os);
} else this->PrintExpr(op->args[1], os);
os << "))";
} else if (op->is_intrinsic(Call::ldg)) {
os << "__ldg((";
PrintType(op->args[1].type(), os);
os << this->PrintExpr(op->args[0]) << " *)(&" << this->PrintExpr(op->args[1]) << "))";
} else {
CodeGenC::VisitExpr_(op, os);
}
}

void CodeGenCUDA::VisitStmt_(const LetStmt* op) {
if (no_init_value){
no_init_value = false;
PrintIndent();
if (op->var.type() == Handle() &&
handle_data_type_.count(op->var.get())) {
PrintType(handle_data_type_.at(op->var.get()), stream);
stream << "* "
<< AllocVarID(op->var.get())
<< ";\n";
} else {
PrintType(op->var.type(), this->stream);
this->stream << ' '
<< AllocVarID(op->var.get())
<< ";\n";
}
PrintStmt(op->body);
} else {
CodeGenC::VisitStmt_(op);
}
}

void CodeGenCUDA::VisitExpr_(const Variable* op, std::ostream& os) {
// replace cce var with const index
if(replace_cce && loop_var == op){
os << current_index;
}
else {
CodeGenC::VisitExpr_(op, os);
}
}

void CodeGenCUDA::VisitExpr_(const Load* op, std::ostream& os) {
int lanes = op->type.lanes();
if (vec_store) {
static const char access[] = {'x', 'y', 'z', 'w', 'a', 'b', 'c', 'd'};
if (lanes == 2 || lanes == 4) {
os << op->buffer_var->name_hint << "." << access[current_index];
} else if(std::find_if(vec_loads.begin(), vec_loads.end(),
[op] (const Variable* v) { return (v->name_hint==op->buffer_var->name_hint); }) != vec_loads.end()){
os << "sw_" << op->buffer_var->name_hint << "." << access[current_index];
} else{
// temp variable
os << op->buffer_var->name_hint << "[";
PrintExpr(op->index, os);
os << "]";
}
} else {
CodeGenC::VisitExpr_(op, os);
}
}

void CodeGenCUDA::VisitStmt_(const Store* op) {
Type t = op->value.type();
if (is_reinterpret && t.lanes() == 1) {
is_reinterpret = false;

std::string value = this->PrintExpr(op->value);
std::string ref = this->GetBufferRef(t, op->buffer_var.get(), op->index);
this->PrintIndent();

Type elem_type = t.element_of();
stream << "*(reinterpret_cast<";
PrintType(elem_type, stream);
stream << loop_extent << "*>(&" << ref << ")) = " << value << ";\n";

} else if (vec_store) {
replace_cce = true;
static const char access[] = {'x', 'y', 'z', 'w', 'a', 'b', 'c', 'd'};
int lanes = op->buffer_var.type().lanes();
loop_extent = lanes;
for (int i = 0; i < lanes; i++){
this->PrintIndent();
current_index = i;
stream << op->buffer_var->name_hint << "." << access[i] << " = ";
PrintExpr(op->value, stream);
stream << ";\n";
}
replace_cce = false;
vec_store = false;
} else if (simple_store){
simple_store = false;
std::string value = this->PrintExpr(op->value);
this->PrintIndent();
stream << op->buffer_var->name_hint << " = " << value << ";\n";
} else{
CodeGenC::VisitStmt_(op);
}
}

void CodeGenCUDA::VisitStmt_(const AttrStmt* op) {
if (op->attr_key == attr::wmma_scope) {
const StringImm* scope_str = op->value.as<StringImm>();
@@ -548,12 +694,30 @@ void CodeGenCUDA::VisitStmt_(const AttrStmt* op) {
const Variable* buffer = op->node.as<Variable>();
int offset = op->value.as<IntImm>()->value;
sm_offsets[buffer] = offset;
} else if (op->attr_key == "reinterpret_store") {
loop_extent = op->value;
// mark next store statement to be a reinterpret cast
is_reinterpret = true;
} else if (op->attr_key == "vec_store") {
loop_var = op->value.as<Variable>();
// mark next store statement to be a vector store
vec_store = true;
} else if (op->attr_key == "simple_store") {
// mark next store statement to be a basic store of type a = b
simple_store = true;
} else if (op->attr_key == "vec_load") {
vec_loads.insert(op->value.as<Variable>());
} else if (op->attr_key == "no_init_value") {
// mark next let statement to be a simple, empty declaration
no_init_value = true;
}
CodeGenC::VisitStmt_(op);
}

void CodeGenCUDA::VisitStmt_(const Allocate* op) {
CHECK(!is_zero(op->condition));
if (is_zero(op->condition)) {
stream << "// ";
}
std::string vid = AllocVarID(op->buffer_var.get());
if (op->new_expr.defined()) {
// Prefer global static allocation for the program
@@ -637,6 +801,10 @@ void CodeGenCUDA::VisitExpr_(const Ramp* op, std::ostream& os) {
}

void CodeGenCUDA::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*)
if (vec_store) {
PrintExpr(op->value, os);
return;
}
if (op->type.is_int() && op->type.bits() == 8 && op->lanes == 4) {
// make_int8x4
const int64_t *p = as_const_int(op->value);


+ 25
- 0
third_party/incubator-tvm/src/codegen/codegen_cuda.h View File

@@ -94,6 +94,10 @@ class CodeGenCUDA final : public CodeGenC {
void VisitStmt_(const Evaluate *op) final;
void VisitStmt_(const Allocate *op) final;
void VisitStmt_(const AttrStmt *op) final;
void VisitStmt_(const LetStmt *op) final;
void VisitExpr_(const Variable *op, std::ostream &os) final;
void VisitExpr_(const Load *op, std::ostream &os) final;
void VisitStmt_(const Store *op) final;

private:
// Handle volatile loads.
@@ -120,6 +124,27 @@ class CodeGenCUDA final : public CodeGenC {
// whether need mma.h
bool need_mma_h_{false};

// whether next store will be a reinterpret_cast
bool is_reinterpret{false};
// the extent of the swizzled loop
Expr loop_extent{0};
// whether next store is not an array
bool simple_store{false};
// whether store uses vector format
bool vec_store{false};
// var from Loads to modify with vector load
std::set<const Variable*> vec_loads;
// whether replace cce variable with constant in vec_store
bool replace_cce{false};
// variable to replace
const Variable* loop_var;
// index to replace cce with
int current_index;
// do not set value to next LetStmt if true
bool no_init_value{false};
// ignore next allocate stmt if true (trick to bypass some tests)
// bool ignore_next_allocate{false};

// add for TensorCore
// warp tile size for TensorCore interface
Expr warp_tile_m = IntImm::make(Int(32), 1);


+ 3
- 0
third_party/incubator-tvm/src/lang/ir.cc View File

@@ -923,6 +923,9 @@ std::ostream &operator<<(std::ostream& out, ForType type) { // NOLINT(*)
case ForType::Vectorized:
out << "vectorized";
break;
case ForType::Swizzled:
out << "swizzled";
break;
}
return out;
}


+ 573
- 0
third_party/patch/isl/isl-influence.patch View File

@@ -0,0 +1,573 @@
unchanged:
--- isl_0.22/include/isl/options.h 2021-01-26 13:29:02.345994411 +0100
+++ isl/include/isl/options.h 2021-01-26 13:25:07.713990422 +0100
@@ -49,6 +49,13 @@ int isl_options_get_coalesce_bounded_wra
isl_stat isl_options_set_coalesce_preserve_locals(isl_ctx *ctx, int val);
int isl_options_get_coalesce_preserve_locals(isl_ctx *ctx);
+/* ====================== AKG influence patch -- start ====================== */
+isl_stat isl_options_set_akg_print_debug(isl_ctx *ctx, int val);
+int isl_options_get_akg_print_debug(isl_ctx *ctx);
+isl_stat isl_options_set_akg_influence_scheduler(isl_ctx *ctx, int val);
+int isl_options_get_akg_influence_scheduler(isl_ctx *ctx);
+/* ======================= AKG influence patch -- end ======================= */
+
#if defined(__cplusplus)
}
#endif
unchanged:
--- isl_0.22/include/isl/schedule.h 2021-01-26 13:29:02.345994411 +0100
+++ isl/include/isl/schedule.h 2021-01-26 13:25:07.713990422 +0100
@@ -9,7 +9,7 @@
#include <isl/set_type.h>
#include <isl/list.h>
#include <isl/printer_type.h>
-
+#include <isl/vec.h>
#if defined(__cplusplus)
extern "C" {
#endif
@@ -209,6 +209,57 @@ __isl_give isl_printer *isl_printer_prin
void isl_schedule_dump(__isl_keep isl_schedule *schedule);
__isl_give char *isl_schedule_to_str(__isl_keep isl_schedule *schedule);
+/* ====================== AKG influence patch -- start ====================== */
+
+/* --- Exports --- */
+
+/* export hidden types */
+struct isl_sched_node;
+typedef struct isl_sched_node isl_sched_node;
+struct isl_sched_edge;
+typedef struct isl_sched_edge isl_sched_edge;
+struct isl_sched_graph;
+typedef struct isl_sched_graph isl_sched_graph;
+
+/* isl_sched_node functions */
+int isl_sched_node_par_coef_offset(struct isl_sched_node *node);
+int isl_sched_node_cst_coef_offset(struct isl_sched_node *node);
+__isl_give isl_map *isl_sched_node_extract_schedule(struct isl_sched_node *node);
+
+/* isl vec functions */
+int isl_inf_vec_get_size( isl_vec* vec);
+/* export isl_sched_edge functions */
+/* export isl_sched_graph functions */
+isl_stat isl_sched_graph_init(struct isl_sched_graph *graph,
+ __isl_keep isl_schedule_constraints *sc);
+void isl_sched_graph_free(isl_ctx *ctx, struct isl_sched_graph *graph);
+__isl_give isl_schedule_node *isl_schedule_node_compute_schedule(isl_schedule_node *node,
+ struct isl_sched_graph *graph);
+
+/* --- New functions --- */
+
+/* isl_sched_node functions */
+int isl_sched_node_get_nparam(const struct isl_sched_node *node);
+int isl_sched_node_get_nvar(const struct isl_sched_node *node);
+
+/* isl_vec functions */
+int isl_influence_int_eq(isl_vec* v, int pos1, int pos2);
+isl_val* isl_influence_vec_get_elem(isl_vec*, int pos);
+/* isl_sched_graph functions */
+struct isl_sched_node* isl_sched_graph_get_node(struct isl_sched_graph *graph, int i);
+
+/* --- AKG influence --- */
+extern isl_basic_set* (*isl_influence_set_coef)(
+ isl_ctx *ctx, struct isl_sched_graph *graph, isl_basic_set *bset);
+extern isl_basic_set* (*isl_influence_set_equal)(
+ isl_ctx *ctx, struct isl_sched_graph *graph, isl_basic_set *bset);
+extern int (*isl_influence_maxvar)(struct isl_sched_graph* graph);
+extern int (*isl_influence_check_coincident)(struct isl_sched_graph* graph, isl_vec* sol);
+extern struct isl_sched_graph* (*isl_influence_sol_list_free)(struct isl_sched_graph* graph);
+extern struct isl_sched_graph* (*isl_influence_sol_add_elem)(isl_vec* sol, struct isl_sched_graph* graph);
+extern int(*isl_influence_sol_get_elem)(int sched, int pos, struct isl_sched_graph* graph);
+/* ======================= AKG influence patch -- end ======================= */
+
#if defined(__cplusplus)
}
#endif
unchanged:
--- isl_0.22/isl_options.c 2021-01-26 13:29:02.313994410 +0100
+++ isl/isl_options.c 2021-01-26 13:25:07.713990422 +0100
@@ -228,6 +228,12 @@ ISL_ARG_BOOL(struct isl_options, print_s
"print statistics for every isl_ctx")
ISL_ARG_ULONG(struct isl_options, max_operations, 0,
"max-operations", 0, "default number of maximal operations per isl_ctx")
+/* ====================== AKG influence patch -- start ====================== */
+ISL_ARG_BOOL(struct isl_options, akg_print_debug, 0, "print-debug", 0,
+ "print debug info")
+ISL_ARG_BOOL(struct isl_options, akg_influence_scheduler, 0, "influence-schedule", 0,
+ "update scheduler coefficients")
+/* ======================= AKG influence patch -- end ======================= */
ISL_ARG_VERSION(print_version)
ISL_ARGS_END
@@ -402,3 +408,14 @@ ISL_CTX_SET_BOOL_DEF(isl_options, struct
ast_build_allow_or)
ISL_CTX_GET_BOOL_DEF(isl_options, struct isl_options, isl_options_args,
ast_build_allow_or)
+
+/* ====================== AKG influence patch -- start ====================== */
+ISL_CTX_SET_BOOL_DEF(isl_options, struct isl_options, isl_options_args,
+ akg_print_debug)
+ISL_CTX_GET_BOOL_DEF(isl_options, struct isl_options, isl_options_args,
+ akg_print_debug)
+ISL_CTX_SET_BOOL_DEF(isl_options, struct isl_options, isl_options_args,
+ akg_influence_scheduler)
+ISL_CTX_GET_BOOL_DEF(isl_options, struct isl_options, isl_options_args,
+ akg_influence_scheduler)
+/* ======================= AKG influence patch -- end ======================= */
unchanged:
--- isl_0.22/isl_options_private.h 2021-01-26 13:29:02.317994410 +0100
+++ isl/isl_options_private.h 2021-01-26 13:25:07.713990422 +0100
@@ -71,6 +71,10 @@ struct isl_options {
int print_stats;
unsigned long max_operations;
+/* ====================== AKG influence patch -- start ====================== */
+ int akg_print_debug;
+ int akg_influence_scheduler;
+/* ======================= AKG influence patch -- end ======================= */
};
#endif
diff -u isl/isl_scheduler.c isl/isl_scheduler.c
--- isl/isl_scheduler.c 2021-02-05 09:56:15.186697283 +0100
+++ isl/isl_scheduler.c 2021-03-31 14:49:38.568015437 +0200
@@ -39,7 +39,6 @@
#include <isl_morph.h>
#include <isl/ilp.h>
#include <isl_val_private.h>
-
/*
* The scheduling algorithm implemented in this file was inspired by
* Bondhugula et al., "Automatic Transformations for Communication-Minimized
@@ -303,6 +302,12 @@
return is_condition(edge) || is_conditional_validity(edge);
}
+/* ====================== AKG influence patch -- start ====================== */
+struct isl_influence_list;
+struct isl_influence_equal_list;
+struct isl_influence_sol_list;
+/* ======================= AKG influence patch -- end ======================= */
+
/* Internal information about the dependence graph used during
* the construction of the schedule.
*
@@ -395,6 +400,11 @@
int weak;
int max_weight;
+/* ====================== AKG influence patch -- start ====================== */
+ struct isl_influence_list *inf_list;
+ struct isl_influence_equal_list *inf_equal_list;
+ struct isl_influenec_sol_list* inf_sol_list;
+/* ======================= AKG influence patch -- end ======================= */
};
/* Initialize node_table based on the list of nodes.
@@ -757,6 +767,8 @@
isl_hash_table_free(ctx, graph->edge_table[i]);
isl_hash_table_free(ctx, graph->node_table);
isl_basic_set_free(graph->lp);
+ if(isl_influence_sol_list_free)
+ graph=isl_influence_sol_list_free(graph);
}
/* For each "set" on which this function is called, increment
@@ -1233,7 +1245,6 @@
return isl_stat_error;
if (compressed && (!hull || !compress || !decompress))
return isl_stat_error;
-
return isl_stat_ok;
error:
isl_set_free(set);
@@ -3222,6 +3233,7 @@
* In particular, the non-triviality region enforces that at least
* one of the linear combinations in the rows of node->indep is non-zero.
*/
+
static __isl_give isl_vec *solve_lp(isl_ctx *ctx, struct isl_sched_graph *graph)
{
int i;
@@ -3237,16 +3249,68 @@
trivial = construct_trivial(node, node->indep);
else
trivial = isl_mat_zero(ctx, 0, 0);
+
graph->region[i].trivial = trivial;
}
lp = isl_basic_set_copy(graph->lp);
+/* ====================== AKG influence patch -- start ====================== */
+ isl_basic_set *lp_backup;
+ isl_basic_set *lp_backup_inf;
+ const int akg_influence = isl_options_get_akg_influence_scheduler(ctx);
+ const int akg_debug = isl_options_get_akg_print_debug(ctx);
+ if (akg_influence)
+ {
+ lp_backup = isl_basic_set_copy(graph->lp);
+ lp = isl_influence_set_coef(ctx, graph, lp);
+ lp = isl_influence_set_equal(ctx, graph, lp);
+ lp_backup_inf = isl_basic_set_copy(lp);
+ }
+/* ======================= AKG influence patch -- end ======================= */
sol = isl_tab_basic_set_non_trivial_lexmin(lp, 2, graph->n,
graph->region, &check_conflict, graph);
+/* ====================== AKG influence patch -- start ====================== */
+ if (akg_influence)
+ {
+ if(sol->size == 0)
+ {
+ if (akg_debug >= 1) fprintf(stderr, "\ninfluence schedule did not find solution, relax region and try again::\n");
+ isl_vec_free(sol);
+
+ for(int i=0; i< graph->n;i++)
+ {
+ isl_mat_free(graph->region[i].trivial);
+ graph->region[i].trivial = isl_mat_zero(ctx,0,0);
+ }
+
+ sol = isl_tab_basic_set_non_trivial_lexmin(lp_backup_inf, 2, graph->n, graph->region, &check_conflict, graph);
+
+ if (sol->size == 0)
+ {
+ if (akg_debug >= 1) fprintf(stderr, "\ninfluence schedule did not find solution, another try with use_coincidence=0:\n");
+ } else
+ {
+ isl_basic_set_free(lp_backup);
+ }
+
+ }
+ else
+ {
+ isl_basic_set_free(lp_backup);
+ isl_basic_set_free(lp_backup_inf);
+ }
+ }
+
+ if (!sol->size)
+ {
+ if (akg_debug >= 1)
+ fprintf(stderr, "solve_lp did not find solution for dimension: %d\n",graph->n_row);
+ }
+
+/* ======================= AKG influence patch -- end ======================= */
for (i = 0; i < graph->n; ++i)
isl_mat_free(graph->region[i].trivial);
return sol;
}
-
/* Extract the coefficients for the variables of "node" from "sol".
*
* Each schedule coefficient c_i_x is represented as the difference
@@ -3271,6 +3335,12 @@
if (!csol)
return NULL;
+/* ====================== AKG influence patch -- start ====================== */
+ isl_ctx* const ctx = isl_vec_get_ctx(sol);
+ const int akg_influence = isl_options_get_akg_influence_scheduler(ctx);
+ const int akg_debug = isl_options_get_akg_print_debug(ctx);
+
+/* ======================= AKG influence patch -- end ======================= */
for (i = 0; i < node->nvar; ++i) {
pos = 1 + node_var_coef_pos(node, i);
if (node->nonneg)
@@ -3304,10 +3374,15 @@
if (sol->size == 0)
isl_die(sol->ctx, isl_error_internal,
"no solution found", goto error);
- if (graph->n_total_row >= graph->max_row)
+/* ====================== AKG influence patch -- start ====================== */
+ isl_ctx* const ctx = isl_vec_get_ctx(sol);
+ const int akg_influence = isl_options_get_akg_influence_scheduler(ctx);
+ const int akg_debug = isl_options_get_akg_print_debug(ctx);
+
+ if (!akg_influence && graph->n_total_row >= graph->max_row)
isl_die(sol->ctx, isl_error_internal,
"too many schedule rows", goto error);
-
+ /* ======================= AKG influence patch -- end ======================= */
for (i = 0; i < graph->n; ++i) {
struct isl_sched_node *node = &graph->node[i];
int pos;
@@ -3325,11 +3400,13 @@
goto error;
pos = node_cst_coef_offset(node);
node->sched = isl_mat_set_element(node->sched,
- row, 0, sol->el[1 + pos]);
+ row, 0, sol->el[1 + pos]);
pos = node_par_coef_offset(node);
for (j = 0; j < node->nparam; ++j)
+ {
node->sched = isl_mat_set_element(node->sched,
row, 1 + j, sol->el[1 + pos + j]);
+ }
for (j = 0; j < node->nvar; ++j)
node->sched = isl_mat_set_element(node->sched,
row, 1 + node->nparam + j, csol->el[j]);
@@ -3340,6 +3417,14 @@
graph->n_row++;
graph->n_total_row++;
+/* ====================== AKG influence patch -- start ====================== */
+ if (akg_influence && graph->n_row == graph->maxvar)
+ {
+ if(graph->n_row < isl_influence_maxvar(graph)){
+ graph->maxvar++;
+ }
+ }
+/* ======================= AKG influence patch -- end ======================= */
return 0;
error:
@@ -3917,14 +4002,16 @@
static int compute_maxvar(struct isl_sched_graph *graph)
{
int i;
-
graph->maxvar = 0;
for (i = 0; i < graph->n; ++i) {
struct isl_sched_node *node = &graph->node[i];
int nvar;
- if (node_update_vmap(node) < 0)
+/* ====================== AKG influence patch -- start ====================== */
+ if (node_update_vmap(node) < 0) {
return -1;
+ }
+/* ======================= AKG influence patch -- end ======================= */
nvar = node->nvar + graph->n_row - node->rank;
if (nvar > graph->maxvar)
graph->maxvar = nvar;
@@ -3968,6 +4055,10 @@
sub->max_row = graph->max_row;
sub->n_total_row = graph->n_total_row;
sub->band_start = graph->band_start;
+/* ====================== AKG influence patch -- start ====================== */
+ sub->inf_list = graph->inf_list;
+ sub->inf_equal_list = graph->inf_equal_list;
+/* ======================= AKG influence patch -- end ======================= */
return isl_stat_ok;
}
@@ -3997,6 +4088,11 @@
{
struct isl_sched_graph split = { 0 };
+/* ====================== AKG influence patch -- start ====================== */
+ const int akg_influence = isl_options_get_akg_influence_scheduler(ctx);
+ const int akg_debug = isl_options_get_akg_print_debug(ctx);
+/* ======================= AKG influence patch -- end ======================= */
+
if (extract_sub_graph(ctx, graph, node_pred, edge_pred, data,
&split) < 0)
goto error;
@@ -5396,6 +5492,30 @@
return NULL;
lp = isl_basic_set_copy(graph->lp);
+/* ====================== AKG influence patch -- start ====================== */
+ const int akg_influence = isl_options_get_akg_influence_scheduler(ctx);
+ const int akg_debug = isl_options_get_akg_print_debug(ctx);
+ if (akg_influence)
+ {
+ isl_basic_set *lp_backup = isl_basic_set_copy(graph->lp);
+ lp = isl_influence_set_coef(ctx, graph, lp);
+ lp = isl_influence_set_equal(ctx, graph, lp);
+ isl_vec *sol = non_neg_lexmin(graph, lp, n_edge, want_integral);
+ if (sol->size == 0)
+ {
+ if (akg_debug >= 1) {
+ fprintf(stderr, "\ncarry_lp hack failed, restoring previous lp problem\n");
+ }
+ sol = non_neg_lexmin(graph, lp_backup, n_edge, want_integral);
+ }
+ else
+ {
+ isl_basic_set_free(lp_backup);
+ }
+
+ return sol;
+ }
+/* ======================= AKG influence patch -- end ======================= */
return non_neg_lexmin(graph, lp, n_edge, want_integral);
}
@@ -5991,6 +6111,11 @@
int use_coincidence;
int force_coincidence = 0;
int check_conditional;
+ int coincidence_relaxed=0;
+/* ====================== AKG influence patch -- start ====================== */
+ const int akg_influence = isl_options_get_akg_influence_scheduler(ctx);
+ const int akg_debug = isl_options_get_akg_print_debug(ctx);
+/* ======================= AKG influence patch -- end ======================= */
if (sort_sccs(graph) < 0)
return isl_stat_error;
@@ -5998,10 +6123,17 @@
clear_local_edges(graph);
check_conditional = need_condition_check(graph);
has_coincidence = has_any_coincidence(graph);
-
if (ctx->opt->schedule_outer_coincidence)
force_coincidence = 1;
+
+ if(akg_influence){
+ int previous_coincidence = has_coincidence;
+ int isl_maxvar = isl_influence_maxvar(graph);
+ if(graph->maxvar > isl_maxvar)
+ graph->maxvar = isl_maxvar;
+ }
+
use_coincidence = has_coincidence;
while (graph->n_row < graph->maxvar) {
isl_vec *sol;
@@ -6014,19 +6146,34 @@
if (setup_lp(ctx, graph, use_coincidence) < 0)
return isl_stat_error;
sol = solve_lp(ctx, graph);
- if (!sol)
+ if(akg_influence && sol->size)
+ graph=isl_influence_sol_add_elem(sol,graph);
+ if(!sol)
+ {
return isl_stat_error;
+ }
if (sol->size == 0) {
+ /* ====================== AKG influence patch -- start ====================== */
+ if (akg_influence && akg_debug >= 1) {
+ fprintf(stderr, "No solution!\n");
+ }
+ /* ======================= AKG influence patch -- end ======================= */
int empty = graph->n_total_row == graph->band_start;
isl_vec_free(sol);
if (use_coincidence && (!force_coincidence || !empty)) {
use_coincidence = 0;
+ coincidence_relaxed=1;
continue;
}
return isl_stat_ok;
}
coincident = !has_coincidence || use_coincidence;
+
+ if(akg_influence && graph->n > 1){
+ coincident=isl_influence_check_coincident(graph,sol);
+ }
+
if (update_schedule(graph, sol, coincident) < 0)
return isl_stat_error;
@@ -7688,6 +7835,10 @@
if (graph->scc <= 1 || isl_options_get_schedule_whole_component(ctx))
return compute_schedule_wcc_whole(node, graph);
+/* ====================== AKG influence patch -- start ====================== */
+ else if (isl_options_get_akg_influence_scheduler(ctx))
+ return compute_schedule_wcc_whole(node, graph);
+/* ======================= AKG influence patch -- end ======================= */
else
return compute_schedule_wcc_clustering(node, graph);
}
@@ -7735,6 +7886,10 @@
else
node = isl_schedule_node_insert_sequence(node, filters);
+/* ====================== AKG influence patch -- start ====================== */
+ const int akg_influence = isl_options_get_akg_influence_scheduler(ctx);
+ const int akg_debug = isl_options_get_akg_print_debug(ctx);
+/* ======================= AKG influence patch -- end ======================= */
for (component = 0; component < graph->scc; ++component) {
node = isl_schedule_node_child(node, component);
node = isl_schedule_node_child(node, 0);
@@ -7775,7 +7930,10 @@
return isl_schedule_node_free(node);
}
- if (graph->scc > 1)
+/* ====================== AKG influence patch -- start ====================== */
+ const int akg_influence = isl_options_get_akg_influence_scheduler(ctx);
+ if (graph->scc > 1 && akg_influence == 0)
+/* ======================= AKG influence patch -- end ======================= */
return compute_component_schedule(node, graph, 1);
return compute_schedule_wcc(node, graph);
@@ -7806,6 +7964,11 @@
isl_union_set *domain;
isl_size n;
+/* ====================== AKG influence patch -- start ====================== */
+ const int akg_influence = isl_options_get_akg_influence_scheduler(ctx);
+ const int akg_debug = isl_options_get_akg_print_debug(ctx);
+/* ======================= AKG influence patch -- end ======================= */
+
sc = isl_schedule_constraints_align_params(sc);
domain = isl_schedule_constraints_get_domain(sc);
@@ -7852,0 +8016,62 @@
+
+/* ====================== AKG influence patch -- start ====================== */
+isl_basic_set* (*isl_influence_set_coef)(
+ isl_ctx *ctx, struct isl_sched_graph *graph, isl_basic_set *bset) = { 0 };
+isl_basic_set* (*isl_influence_set_equal)(
+ isl_ctx *ctx, struct isl_sched_graph *graph, isl_basic_set *bset) = { 0 };
+int (*isl_influence_maxvar)(struct isl_sched_graph* graph) = { 0 };
+int (*isl_influence_check_coincident)(struct isl_sched_graph *graph,isl_vec* sol) = { 0 };
+struct isl_sched_graph* (*isl_influence_sol_list_free)(struct isl_sched_graph* graph) = { 0 } ;
+struct isl_sched_graph* (*isl_influence_sol_add_elem)(isl_vec* sol, struct isl_sched_graph* graph) = { 0 };
+int (*isl_influence_sol_get_elem)(int sched, int pos, struct isl_sched_graph* graph) = { 0 };
+int isl_sched_node_par_coef_offset(struct isl_sched_node *node) {
+ return node_par_coef_offset(node);
+}
+
+int isl_sched_node_cst_coef_offset(struct isl_sched_node *node) {
+ return node_cst_coef_offset(node);
+}
+
+__isl_give isl_map *isl_sched_node_extract_schedule(struct isl_sched_node *node) {
+ return node_extract_schedule(node);
+}
+
+int isl_sched_node_get_nparam(const struct isl_sched_node *node) {
+ return node->nparam;
+}
+
+int isl_sched_node_get_nvar(const struct isl_sched_node *node) {
+ return node->nvar;
+}
+
+int isl_influence_int_eq(isl_vec* v, int pos1, int pos2)
+{
+ return isl_int_eq(v->el[pos1],v->el[pos2]);
+}
+
+isl_val* isl_influence_vec_get_elem(isl_vec* v, int pos)
+{
+ return isl_vec_get_element_val(v,pos);
+}
+isl_stat isl_sched_graph_init(struct isl_sched_graph *graph,
+ __isl_keep isl_schedule_constraints *sc) {
+ return graph_init(graph, sc);
+}
+
+void isl_sched_graph_free(isl_ctx *ctx, struct isl_sched_graph *graph) {
+ graph_free(ctx, graph);
+}
+
+__isl_give isl_schedule_node *isl_schedule_node_compute_schedule(isl_schedule_node *node,
+ struct isl_sched_graph *graph) {
+ return compute_schedule(node, graph);
+}
+
+struct isl_sched_node* isl_sched_graph_get_node(struct isl_sched_graph *graph, int i) {
+ return &graph->node[i];
+}
+
+int isl_inf_vec_get_size(isl_vec* vec){
+ return(int) vec->size;
+}
+/* ======================= AKG influence patch -- end ======================= */

Loading…
Cancel
Save