From: @harenome Reviewed-by: @dylangeng Signed-off-by:pull/45/MERGE
| @@ -17,7 +17,7 @@ akg_add_pkg(isl | |||
| URL ${ISL_URL} | |||
| MD5 ${ISL_MD5} | |||
| CUSTOM_CMAKE ${AKG_SOURCE_DIR}/third_party/isl_wrap | |||
| PATCHES ${AKG_SOURCE_DIR}/third_party/patch/isl/isl.patch | |||
| PATCHES ${AKG_SOURCE_DIR}/third_party/patch/isl/isl.patch ${AKG_SOURCE_DIR}/third_party/patch/isl/isl-influence.patch | |||
| CMAKE_OPTION " ") | |||
| include_directories("${AKG_SOURCE_DIR}/third_party/isl_wrap/include") | |||
| include_directories("${isl_INC}/include") | |||
| @@ -85,5 +85,7 @@ REGISTER_PASS(SubstituteDivVar); | |||
| REGISTER_PASS(UnrollNonConstantExtent) | |||
| REGISTER_PASS(ValueNumbering); | |||
| REGISTER_PASS(TensorAccessRewrite); | |||
| REGISTER_PASS(SwizzleGPU); | |||
| } // namespace ir | |||
| } // namespace akg | |||
| @@ -564,6 +564,13 @@ NodeRef LowerStmt(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeR | |||
| stmt = NEXT_PASS(InjectDoubleBuffer, stmt, config->double_buffer_split_loop, | |||
| global_attrs.GetBoolAttr(kEnableTransferBuffer, false)); | |||
| stmt = NEXT_PASS(StorageRewrite, stmt); | |||
| if (target_platform->device_type == kDLGPU && polyhedral) { | |||
| if (global_attrs.GetBoolAttr(kEnableSwizzleGPU, true)) { | |||
| stmt = NEXT_PASS(SwizzleGPU, stmt, global_attrs); | |||
| } | |||
| } | |||
| stmt = NEXT_PASS(UnrollLoop, stmt, config->auto_unroll_max_step, config->auto_unroll_max_depth, | |||
| config->auto_unroll_max_extent, config->unroll_explicit); | |||
| @@ -102,6 +102,7 @@ constexpr auto kAllocBits = "alloc_bits"; | |||
| constexpr auto kEnablePolySch = "enable_poly_sch"; | |||
| constexpr auto kEnableFuseAxis = "enable_fuse_axis"; | |||
| constexpr auto kEnableAtomicAdd = "enable_atomic_add"; | |||
| constexpr auto kEnableSwizzleGPU = "enable_swizzle_gpu"; | |||
| static std::unordered_map<std::string, int> help_tiling_level = { | |||
| {"None", 0}, | |||
| @@ -206,6 +206,8 @@ Stmt ValueNumbering(Stmt stmt); | |||
| Stmt TensorAccessRewrite(const Stmt stmt); | |||
| Stmt SwizzleGPU(const Stmt &stmt, const Map<std::string, NodeRef> &attrs); | |||
| } // namespace ir | |||
| } // namespace akg | |||
| @@ -0,0 +1,730 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <utility> | |||
| #include <vector> | |||
| #include <map> | |||
| #include <tvm/ir.h> | |||
| #include <tvm/ir_mutator.h> | |||
| #include <tvm/ir_pass.h> | |||
| #include <tvm.h> | |||
| #include <string> | |||
| #include "pass/utils.h" | |||
| #include "src/common/util.h" | |||
| namespace akg { | |||
| namespace ir { | |||
| class SwizzleFinder : public IRVisitor { | |||
| public: | |||
| SwizzleFinder() = default; | |||
| ~SwizzleFinder() override = default; | |||
| void Visit_(const AttrStmt *op) final { | |||
| if (op->attr_key == air::ir::attr::thread_extent) { | |||
| if (auto value = op->value.as<IntImm>()) { | |||
| std::string name = op->node.as<IterVarNode>()->var->name_hint; | |||
| LOG(DEBUG) << "Thread extent (" << name << ") : " << value->value; | |||
| thread_extent[name] = value->value; | |||
| } | |||
| IRVisitor::Visit_(op); | |||
| } else if (op->attr_key == air::ir::attr::realize_scope) { | |||
| LOG(WARNING) << "Realize storage scope not implemented in swizzle pass (may not work as expected) : " | |||
| << op->value.as<StringImm>()->value; | |||
| Visit(op->body); | |||
| } else { | |||
| IRVisitor::Visit_(op); | |||
| } | |||
| } | |||
| void Visit_(const For *op) final { | |||
| swizzle_length = GetExtent(op); | |||
| if (swizzle_length == 2 || swizzle_length == 4) { | |||
| LOG(INFO) << "Swizzle for loop ? var : " << op->loop_var->name_hint; | |||
| loop_var = op->loop_var; | |||
| swizzlable = true; | |||
| swizzle_candidate = true; | |||
| loop_loads = {}; | |||
| loop_stores = {}; | |||
| Visit(op->body); | |||
| if (!swizzle_candidate) { | |||
| // loop contains another loop | |||
| LOG(DEBUG) << op->loop_var->name_hint << " not swizzlable : loop contains another loop"; | |||
| swizzlable = false; | |||
| return; | |||
| } | |||
| swizzle_candidate = false; | |||
| // get all load and store from this loop and test if their range is 2, 4 or 0 (const) | |||
| for (auto l : loop_loads) { | |||
| int ext = load_indexes[l].second - load_indexes[l].first; | |||
| if (ext != (int)swizzle_length && ext != 0) { | |||
| swizzlable = false; | |||
| LOG(DEBUG) << op->loop_var->name_hint | |||
| << " not swizzlable : range of load is not 0, 2 or 4 : " << ExprToString(load_indexes[l].first); | |||
| break; | |||
| } else if (ext == 0) { | |||
| // do not swizzle variables that are constant inside loop | |||
| if (swizzle_temp_vars.count(l->buffer_var->name_hint) > 0 && | |||
| set_temp_vars.count(l->buffer_var->name_hint) == 0) { | |||
| swizzle_temp_vars.erase(l->buffer_var->name_hint); | |||
| temp_vars.insert(l->buffer_var->name_hint); | |||
| } | |||
| } | |||
| } | |||
| for (auto s : loop_stores) { | |||
| int ext = store_indexes[s].second - store_indexes[s].first; | |||
| if (ext != (int)swizzle_length && ext != 0) { | |||
| swizzlable = false; | |||
| LOG(DEBUG) << op->loop_var->name_hint | |||
| << " not swizzlable : range of store is not 0, 2 or 4 : " << ExprToString(store_indexes[s].first); | |||
| break; | |||
| } else if (ext == 0) { | |||
| // cannot swizzle if variable is an argument with range 0 (disable loop swizzle for safety) | |||
| if (swizzle_temp_vars.count(s->buffer_var->name_hint) == 0) { | |||
| swizzlable = false; | |||
| LOG(DEBUG) << op->loop_var->name_hint << " not swizzlable : range of store is 0 on variable : " | |||
| << ExprToString(store_indexes[s].first); | |||
| break; | |||
| } | |||
| // check if store value contains var | |||
| // Warning : this check does not verify loop variables that are not in a load | |||
| if (!store_loads[s].empty() && swizzle_stores.count(s) == 0) { | |||
| // Variable value changes at each loop iteration | |||
| swizzle_stores.insert(s); | |||
| } else { | |||
| // remove from swizzle_temp_vars | |||
| swizzle_temp_vars.erase(s->buffer_var->name_hint); | |||
| temp_vars.insert(s->buffer_var->name_hint); | |||
| } | |||
| } | |||
| } | |||
| if (swizzlable) { | |||
| swizzle_loops.push_back(op); | |||
| } | |||
| } else { | |||
| Visit(op->body); | |||
| } | |||
| } | |||
| // check potential temp variable to convert, making sure it only contains the loop var | |||
| template <typename T> | |||
| void checkSwizzleVar(const T *op, air::DataType t) { | |||
| if (swizzle_temp_vars.count(op->buffer_var->name_hint) > 0) { | |||
| auto *v = op->index.template as<Variable>(); | |||
| auto *i = op->index.template as<IntImm>(); | |||
| if (((v == nullptr || v->name_hint != loop_var->name_hint) && (i == nullptr || i->value != 0)) || | |||
| !((t.is_float() || t.is_int()) && t.bits() <= 32)) { | |||
| // this temp var does not look like a swizzle var, treat it as any regular temp var | |||
| LOG(DEBUG) << "Irregular potential swizzle var : " << op->buffer_var->name_hint; | |||
| temp_vars.insert(op->buffer_var->name_hint); | |||
| swizzle_temp_vars.erase(op->buffer_var->name_hint); | |||
| } else { | |||
| if (i && i->value == 0) { | |||
| if (var_size.find(op->buffer_var->name_hint) == var_size.end() || | |||
| var_size[op->buffer_var->name_hint] < (int)swizzle_length) { | |||
| var_size[op->buffer_var->name_hint] = swizzle_length; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void Visit_(const Load *op) final { | |||
| if (swizzle_candidate && swizzlable) { | |||
| LOG(DEBUG) << "Load : " << op->buffer_var->name_hint << " index : " << ExprToString(op->index); | |||
| loop_loads.push_back(op); | |||
| compute_var = true; | |||
| current_min = 1; | |||
| current_max = 1; | |||
| checkSwizzleVar(op, op->type); | |||
| IRVisitor::Visit(op->index); | |||
| // add this load to array | |||
| compute_var = false; | |||
| if (current_min > current_max) std::swap(current_min, current_max); | |||
| load_indexes.insert(std::make_pair(op, std::make_pair(current_min, current_max))); | |||
| if (current_store != nullptr && | |||
| ((contains_iterator && (op->type.is_float() || op->type.is_int()) && op->type.bits() <= 32) || | |||
| swizzle_temp_vars.count(op->buffer_var->name_hint) > 0)) { | |||
| LOG(DEBUG) << "Load contains iterator or swizzle temp var"; | |||
| store_loads[current_store].insert(op); | |||
| } | |||
| contains_iterator = false; | |||
| LOG(DEBUG) << "End Load : " << op->buffer_var->name_hint << " range estimation : " << current_max - current_min; | |||
| IRVisitor::Visit(op->predicate); | |||
| } | |||
| IRVisitor::Visit_(op); | |||
| } | |||
| // visit each operator for each store and load | |||
| // compute value for each of its elements (a, b or array) | |||
| // add map<Store, pair> to evaluate extent inside each load/store | |||
| void Visit_(const Store *op) final { | |||
| if (swizzle_candidate && swizzlable) { | |||
| LOG(DEBUG) << "Store : " << op->buffer_var->name_hint << " index : " << ExprToString(op->index) | |||
| << " value : " << ExprToString(op->value); | |||
| loop_stores.push_back(op); | |||
| compute_var = true; | |||
| contains_iterator = false; | |||
| current_store = op; | |||
| store_loads[op] = {}; | |||
| current_min = 1; | |||
| current_max = 1; | |||
| checkSwizzleVar(op, op->value.type()); | |||
| if (swizzle_temp_vars.count(op->buffer_var->name_hint) > 0) { | |||
| set_temp_vars.insert(op->buffer_var->name_hint); | |||
| } | |||
| IRVisitor::Visit(op->index); | |||
| // check for single Load (allows inplace swizzle) | |||
| const Load *ld = op->value.as<Load>(); | |||
| if (ld && ld->type == op->value.type() && (ld->type.is_float() || ld->type.is_int()) && ld->type.bits() <= 32) { | |||
| LOG(DEBUG) << "Single store Value : " << op->value << std::endl | |||
| << "ld : " << ld->buffer_var << "[" << ld->index << "]"; | |||
| single_stores.insert(current_store); | |||
| } | |||
| // add this store to array | |||
| compute_var = false; | |||
| if (current_min > current_max) std::swap(current_min, current_max); | |||
| store_indexes.insert(std::make_pair(op, std::make_pair(current_min, current_max))); | |||
| if (contains_iterator && (op->value.type().is_float() || op->value.type().is_int()) && | |||
| op->value.type().bits() <= 32) { | |||
| swizzle_stores.insert(op); | |||
| } | |||
| contains_iterator = false; | |||
| LOG(DEBUG) << "End Store index : " << op->buffer_var->name_hint | |||
| << " range estimation : " << current_max - current_min; | |||
| } | |||
| IRVisitor::Visit_(op); | |||
| current_store = nullptr; | |||
| LOG(DEBUG) << "End Store " << op->buffer_var->name_hint; | |||
| } | |||
| void Visit_(const Allocate *op) final { | |||
| int size = op->constant_allocation_size(); | |||
| LOG(DEBUG) << "Allocate : " << op->buffer_var->name_hint << " size " << size; | |||
| if (size == 1 || size == 2 || size == 4) { | |||
| swizzle_temp_vars.insert(op->buffer_var->name_hint); | |||
| var_size[op->buffer_var->name_hint] = size; | |||
| } else { | |||
| temp_vars.insert(op->buffer_var->name_hint); | |||
| } | |||
| IRVisitor::Visit_(op); | |||
| } | |||
| void Visit_(const FloatImm *op) final { | |||
| if (compute_var) { | |||
| current_min = (int)op->value; | |||
| current_max = current_min; | |||
| } | |||
| IRVisitor::Visit_(op); | |||
| } | |||
| void Visit_(const IntImm *op) final { | |||
| if (compute_var) { | |||
| current_min = op->value; | |||
| current_max = current_min; | |||
| } | |||
| IRVisitor::Visit_(op); | |||
| } | |||
| void Visit_(const Variable *op) final { | |||
| if (swizzle_candidate && swizzlable && compute_var) { | |||
| if (op->name_hint == loop_var->name_hint) { | |||
| contains_iterator = true; | |||
| current_min = 0; | |||
| current_max = (int)swizzle_length; | |||
| } else { | |||
| auto it = thread_extent.find(op->name_hint); | |||
| if (it != thread_extent.end()) { | |||
| current_min = it->second; | |||
| current_max = current_min; | |||
| } else { | |||
| // unknown variable value, consider it constant | |||
| current_min = 1; | |||
| current_max = 1; | |||
| } | |||
| } | |||
| } | |||
| IRVisitor::Visit_(op); | |||
| } | |||
| void Visit_(const Mul *op) final { | |||
| if (swizzle_candidate && swizzlable) { | |||
| if (compute_var) { | |||
| Visit(op->a); | |||
| int tmp_min = current_min, tmp_max = current_max; | |||
| Visit(op->b); | |||
| int tmp = std::min(std::min(current_min * tmp_min, current_max * tmp_min), | |||
| std::min(current_min * tmp_max, current_max * tmp_max)); | |||
| current_max = std::max(std::max(current_min * tmp_min, current_max * tmp_min), | |||
| std::max(current_min * tmp_max, current_max * tmp_max)); | |||
| current_min = tmp; | |||
| return; | |||
| } | |||
| } | |||
| IRVisitor::Visit_(op); | |||
| } | |||
| void Visit_(const Add *op) final { | |||
| if (swizzle_candidate && swizzlable && compute_var) { | |||
| Visit(op->a); | |||
| int tmp_min = current_min, tmp_max = current_max; | |||
| Visit(op->b); | |||
| int tmp = std::min(std::min(current_min + tmp_min, current_max + tmp_min), | |||
| std::min(current_min + tmp_max, current_max + tmp_max)); | |||
| current_max = std::max(std::max(current_min + tmp_min, current_max + tmp_min), | |||
| std::max(current_min + tmp_max, current_max + tmp_max)); | |||
| current_min = tmp; | |||
| return; | |||
| } | |||
| IRVisitor::Visit_(op); | |||
| } | |||
| void Visit_(const Sub *op) final { | |||
| if (swizzle_candidate && swizzlable && compute_var) { | |||
| Visit(op->b); | |||
| int tmp_min = current_min, tmp_max = current_max; | |||
| Visit(op->a); | |||
| int tmp = std::min(std::min(current_min - tmp_min, current_max - tmp_min), | |||
| std::min(current_min - tmp_max, current_max - tmp_max)); | |||
| current_max = std::max(std::max(current_min - tmp_min, current_max - tmp_min), | |||
| std::max(current_min - tmp_max, current_max - tmp_max)); | |||
| current_min = tmp; | |||
| return; | |||
| } | |||
| IRVisitor::Visit_(op); | |||
| } | |||
| void Visit_(const Div *op) final { | |||
| if (swizzle_candidate && swizzlable) { | |||
| if (compute_var) { | |||
| Visit(op->a); | |||
| int tmp_min = current_min, tmp_max = current_max; | |||
| Visit(op->b); | |||
| if (current_min == 0 || current_max == 0) { | |||
| swizzlable = false; | |||
| LOG(WARNING) << "Possible division by zero detected in " << getDebugInfo(op); | |||
| } | |||
| int tmp = std::min(std::min(tmp_min / current_min, tmp_min / current_max), | |||
| std::min(tmp_max / current_min, tmp_max / current_max)); | |||
| current_max = std::max(std::max(tmp_min / current_min, tmp_min / current_max), | |||
| std::max(tmp_max / current_min, tmp_max / current_max)); | |||
| current_min = tmp; | |||
| return; | |||
| } | |||
| } | |||
| IRVisitor::Visit_(op); | |||
| } | |||
| void Visit_(const IfThenElse *op) final { | |||
| if (swizzle_candidate) { | |||
| swizzlable = false; | |||
| } | |||
| IRVisitor::Visit_(op); | |||
| } | |||
| // returns the extent of the loop if it's a constant integer, otherwise return -1 | |||
| static int GetExtent(const For *op) { | |||
| // constant folding. | |||
| Expr extent = Simplify(op->extent); | |||
| const auto *v1 = extent.as<IntImm>(); | |||
| const auto *v2 = extent.as<UIntImm>(); | |||
| int value = -1; | |||
| if (v1 != nullptr) { | |||
| value = static_cast<int>(v1->value); | |||
| } | |||
| if (v2 != nullptr) { | |||
| value = static_cast<int>(v2->value); | |||
| } | |||
| return value; | |||
| } | |||
| std::set<const Store *> single_stores{}; | |||
| std::vector<const For *> swizzle_loops{}; | |||
| std::set<std::basic_string<char>> temp_vars{}; | |||
| std::set<std::basic_string<char>> swizzle_temp_vars{}; | |||
| std::set<const Load *> input_arguments{}; | |||
| std::set<const Store *> output_arguments{}; | |||
| std::set<const Store *> swizzle_stores{}; | |||
| std::set<const Load *> swizzle_loads{}; | |||
| bool swizzlable{false}; | |||
| std::map<const Store *, std::set<const Load *>> store_loads; | |||
| std::map<std::basic_string<char>, int> var_size{}; | |||
| private: | |||
| // print a nice representation of what is happening inside the code | |||
| template <typename T> | |||
| std::string getDebugInfo(T *op) { | |||
| Expr a = (Expr)(op->a); | |||
| Expr b = (Expr)(op->b); | |||
| std::string a_str, b_str; | |||
| if (!a.defined()) { | |||
| a_str = a->GetTypeKey(); | |||
| } else { | |||
| a_str = "(" + a->GetTypeKey() + ") " + ExprToString(a); | |||
| } | |||
| if (!b.defined()) { | |||
| b_str = b->GetTypeKey(); | |||
| } else { | |||
| b_str = "(" + b->GetTypeKey() + ") " + ExprToString(b); | |||
| } | |||
| return a_str + " " + b_str; | |||
| } | |||
| const Store *current_store{}; | |||
| bool compute_var = false; | |||
| int current_min = 0; | |||
| int current_max = 0; | |||
| bool swizzle_candidate{false}; | |||
| bool contains_iterator{false}; | |||
| unsigned int swizzle_length{0}; | |||
| Var loop_var; | |||
| // temp vars that are set inside a loop | |||
| std::set<std::basic_string<char>> set_temp_vars{}; | |||
| std::vector<const Load *> loop_loads; | |||
| std::vector<const Store *> loop_stores; | |||
| std::unordered_map<const Load *, std::pair<int, int>> load_indexes{}; | |||
| std::unordered_map<const Store *, std::pair<int, int>> store_indexes{}; | |||
| std::map<std::string, int64_t> thread_extent = {{"blockIdx.x", 0}, {"threadIdx.x", 0}}; | |||
| }; | |||
| class Swizzle : public IRMutator { | |||
| public: | |||
| explicit Swizzle(std::basic_string<char> name) : finder() { kernel_name = std::move(name); }; | |||
| ~Swizzle() override = default; | |||
| Stmt VisitAndMutate(Stmt stmt) { | |||
| LOG(DEBUG) << "Visit statement"; | |||
| finder.Visit(stmt); | |||
| if (!finder.swizzle_loops.empty()) { | |||
| auto ret = Mutate(stmt); | |||
| if (!ret.same_as(stmt)) { | |||
| LOG(INFO) << "Total swizzled loops for " << kernel_name << " : " << finder.swizzle_loops.size(); | |||
| return ret; | |||
| } | |||
| } | |||
| LOG(INFO) << "Total swizzled loops for " << kernel_name << " : 0"; | |||
| return stmt; | |||
| } | |||
| Stmt Mutate_(const For *op, const Stmt &s) final { | |||
| auto f = std::find(std::begin(finder.swizzle_loops), std::end(finder.swizzle_loops), op); | |||
| if (f != std::end(finder.swizzle_loops)) { | |||
| swizzling = true; | |||
| LOG(DEBUG) << "Swizzle " << op->loop_var->name_hint; | |||
| Stmt s2 = swizzle_loop(op, s); | |||
| swizzling = false; | |||
| return s2; | |||
| } | |||
| auto body = Mutate(op->body); | |||
| // auto unroll loops | |||
| ForType t = op->for_type; | |||
| int ext = SwizzleFinder::GetExtent(op); | |||
| if (ext > 0 && t == ForType::Serial) t = ForType::Unrolled; | |||
| return For::make(op->loop_var, op->min, op->extent, t, op->device_api, body); | |||
| } | |||
| // modify loop to apply swizzle | |||
| Stmt swizzle_loop(const For *op, const Stmt &s) { | |||
| auto min = op->min.as<IntImm>(); | |||
| loop_extent = SwizzleFinder::GetExtent(op); | |||
| if (min && loop_extent > 0) { | |||
| currentLoop = op; | |||
| auto body = Mutate(op->body); | |||
| // remove and unroll the loop | |||
| return For::make(op->loop_var, op->min, op->extent, ForType::Swizzled, op->device_api, body); | |||
| } | |||
| // something wrong happened during extent evaluation, we do not mutate | |||
| LOG(WARNING) << "Could not mutate loop (invalid loop extent)"; | |||
| ForType t = op->for_type; | |||
| if (loop_extent > 0 && t == ForType::Serial) t = ForType::Unrolled; | |||
| return For::make(op->loop_var, op->min, op->extent, t, op->device_api, op->body); | |||
| } | |||
| Stmt Mutate_(const Store *op, const Stmt &s) final { | |||
| if (swizzling) { | |||
| loop_extent = SwizzleFinder::GetExtent(currentLoop); | |||
| if (std::find(finder.single_stores.begin(), finder.single_stores.end(), op) != finder.single_stores.end()) { | |||
| // this store has only one load attached, we can swizzle in place | |||
| LOG(DEBUG) << "Swizzle in place " << op->buffer_var->name_hint; | |||
| // modify store value | |||
| Expr value; | |||
| auto buffer_var = op->buffer_var; | |||
| auto v = air::ir::Substitute(op->value, {{Var{currentLoop->loop_var}, make_const(Int(32), 0)}}); | |||
| v = Simplify_cce(v); | |||
| Array<Expr> value_args{make_const(Int(32), loop_extent), v}; | |||
| // check if variable is a temp var or an output var | |||
| if (finder.swizzle_temp_vars.count(op->buffer_var->name_hint) > 0) { | |||
| // (swizzle temp var) sw_var ... ldg | |||
| // check temp variable is declared | |||
| CHECK(replace_vars.find(op->buffer_var->name_hint) != replace_vars.end()); | |||
| buffer_var = replace_vars[op->buffer_var->name_hint]; | |||
| value = Call::make(op->value.type(), Call::ldg, value_args, Call::Intrinsic); | |||
| Stmt s2 = Store::make(buffer_var, value, op->index, op->predicate); | |||
| return AttrStmt::make(s, "simple_store", Expr(0), s2); | |||
| } else if (finder.temp_vars.count(op->buffer_var->name_hint) > 0) { | |||
| // (temp var) reinterpret cast ... ldg | |||
| value = Call::make(op->value.type(), Call::ldg, value_args, Call::Intrinsic); | |||
| } else { | |||
| // (output var) reinterpret cast ... reinterpret cast | |||
| value = Call::make(op->value.type(), Call::reinterpret_cast_op, value_args, Call::Intrinsic); | |||
| } | |||
| auto index = air::ir::Substitute(op->index, {{Var{currentLoop->loop_var}, make_const(Int(32), 0)}}); | |||
| index = Simplify_cce(index); | |||
| Stmt s2 = Store::make(buffer_var, value, index, op->predicate); | |||
| return AttrStmt::make(s, "reinterpret_store", Expr(loop_extent), s2); | |||
| } else { | |||
| // store with != 1 load | |||
| Expr value = Mutate(op->value); | |||
| LOG(DEBUG) << "check type: " << op->value.type() << " " << op->value; | |||
| CHECK(op->value.type().is_float() || op->value.type().is_int()); | |||
| CHECK_LE(op->value.type().bits(), 32); | |||
| if (std::find(finder.swizzle_stores.begin(), finder.swizzle_stores.end(), op) != finder.swizzle_stores.end()) { | |||
| LOG(DEBUG) << "Mutate Store : index contains loop var " << currentLoop->loop_var->name_hint << std::endl | |||
| << "Value :" << op->value; | |||
| Stmt new_stmt, end_stmt; | |||
| Expr index; | |||
| Array<Expr> value_args; | |||
| air::DataType t; | |||
| Var new_var; | |||
| if (replace_vars.find(op->buffer_var->name_hint) != replace_vars.end()) { | |||
| new_var = replace_vars[op->buffer_var->name_hint]; | |||
| t = new_var.type(); | |||
| } else { | |||
| if (op->value.type().is_int()) { | |||
| t = Int(op->value.type().bits(), loop_extent); | |||
| } else { | |||
| // Float(16, 4) -> half4 | |||
| // 1. Generate the right DataType -> half4 | |||
| t = Float(op->value.type().bits(), loop_extent); | |||
| } | |||
| new_var = Variable::make(t, "sw_" + op->buffer_var->name_hint); | |||
| } | |||
| LOG(DEBUG) << "(Vector store) replace previous buffer var : " << op->buffer_var | |||
| << " type : " << op->buffer_var->type << " with " << new_var << " type : " << new_var->type; | |||
| Expr new_value = Broadcast::make(value, loop_extent); | |||
| index = Ramp::make(0, 1, loop_extent); | |||
| Expr predicate = Broadcast::make(Expr(1), loop_extent); | |||
| new_stmt = Store::make(new_var, new_value, index, predicate); | |||
| new_stmt = AttrStmt::make(s, "vec_store", Expr(currentLoop->loop_var), new_stmt); | |||
| // do not reinterpret cast if var is in swizzle_temp_vars | |||
| if (finder.swizzle_temp_vars.count(op->buffer_var->name_hint) == 0) { | |||
| // last statement (store) : set value to initial value | |||
| index = air::ir::Substitute(op->index, {{Var{currentLoop->loop_var}, make_const(Int(32), 0)}}); | |||
| index = Simplify_cce(index); | |||
| value_args = {make_const(Int(32), 0), new_var}; | |||
| Expr reinterpret = Call::make(op->value.type(), Call::reinterpret_cast_op, value_args, Call::Intrinsic); | |||
| // replace with new_var to remove second reinterpret (produces type error) | |||
| end_stmt = Store::make(op->buffer_var, reinterpret, index, op->predicate); | |||
| end_stmt = AttrStmt::make(new_stmt, "reinterpret_store", Expr(loop_extent), end_stmt); | |||
| } | |||
| // declare load variables | |||
| LOG(DEBUG) << "Mutate Load : index contains loop var " << currentLoop->loop_var->name_hint | |||
| << ", Loop extent : " << loop_extent; | |||
| for (auto ld : finder.store_loads[op]) { | |||
| if (std::find(finder.temp_vars.begin(), finder.temp_vars.end(), ld->buffer_var->name_hint) == | |||
| finder.temp_vars.end()) { | |||
| // declare swizzle variable | |||
| Var new_var2; | |||
| air::DataType t2; | |||
| if (replace_vars.find(ld->buffer_var->name_hint) == replace_vars.end()) { | |||
| if (ld->buffer_var->name_hint == op->buffer_var->name_hint) { | |||
| // this variable is both stored AND loaded for the first time | |||
| t2 = t; | |||
| new_var2 = new_var; | |||
| } else { | |||
| if (ld->type.is_int()) { | |||
| t2 = Int(ld->type.bits(), loop_extent); | |||
| } else { | |||
| // Float(16, 4) -> half4 | |||
| // 1. Generate the right DataType -> half4 | |||
| t2 = Float(ld->type.bits(), loop_extent); | |||
| } | |||
| new_var2 = Variable::make(t2, "sw_" + ld->buffer_var->name_hint); | |||
| } | |||
| replace_vars[ld->buffer_var->name_hint] = new_var2; | |||
| declared.insert(ld->buffer_var->name_hint); | |||
| // 2. declaration half4 sw_T_where = __ldg( ... ) | |||
| index = air::ir::Substitute(ld->index, {{Var{currentLoop->loop_var}, make_const(Int(32), 0)}}); | |||
| index = Simplify_cce(index); | |||
| Expr expr_ld = Load::make(ld->type, ld->buffer_var, index, ld->predicate); | |||
| value_args = {make_const(Int(32), loop_extent), expr_ld}; | |||
| Expr ldg_input = Call::make(t2, Call::ldg, value_args, Call::Intrinsic); | |||
| LOG(DEBUG) << "(Vector load) replace previous buffer var : " << ld->buffer_var | |||
| << " type : " << ld->buffer_var->type << " with " << new_var2 | |||
| << " type : " << new_var2->type; | |||
| new_stmt = LetStmt::make(new_var2, ldg_input, new_stmt); | |||
| new_stmt = AttrStmt::make(s, "vec_load", ld->buffer_var, new_stmt); | |||
| } else { | |||
| new_stmt = AttrStmt::make(s, "vec_load", ld->buffer_var, new_stmt); | |||
| } | |||
| } | |||
| } | |||
| // declare store variable | |||
| if (std::find(declared.begin(), declared.end(), op->buffer_var->name_hint) == declared.end()) { | |||
| // declare swizzle variable | |||
| declared.insert(op->buffer_var->name_hint); | |||
| new_stmt = LetStmt::make(new_var, Broadcast::make(make_const(op->value.type(), 0), loop_extent), new_stmt); | |||
| new_stmt = AttrStmt::make(s, "no_init_value", Expr(0), new_stmt); | |||
| } | |||
| if (replace_vars.find(op->buffer_var->name_hint) == replace_vars.end()) { | |||
| replace_vars[op->buffer_var->name_hint] = new_var; | |||
| } | |||
| if (finder.swizzle_temp_vars.count(op->buffer_var->name_hint) == 0) { | |||
| Stmt new_block = Block::make(new_stmt, end_stmt); | |||
| return new_block; | |||
| } else { | |||
| return new_stmt; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return IRMutator::Mutate_(op, s); | |||
| } | |||
| Stmt Mutate_(const Allocate *op, const Stmt &s) final { | |||
| if (finder.swizzle_temp_vars.count(op->buffer_var->name_hint) > 0) { | |||
| // replace var with its swizzle counterpart | |||
| int size = finder.var_size[op->buffer_var->name_hint]; | |||
| air::DataType t; | |||
| if (op->type.is_int()) { | |||
| t = Int(op->type.bits(), size); | |||
| } else { | |||
| t = Float(op->type.bits(), size); | |||
| } | |||
| Var new_var = Variable::make(t, "sw_" + op->buffer_var->name_hint); | |||
| LOG(DEBUG) << "Allocate : replace previous buffer var : " << op->buffer_var << " type : " << op->type << " with " | |||
| << new_var << " type : " << new_var->type; | |||
| replace_vars[op->buffer_var->name_hint] = new_var; | |||
| declared.insert(op->buffer_var->name_hint); | |||
| Stmt body = Mutate(op->body); | |||
| Stmt new_stmt; | |||
| Stmt let_stmt = LetStmt::make(new_var, Broadcast::make(make_const(op->type, 0), size), body); | |||
| Stmt attr = AttrStmt::make(s, "no_init_value", Expr(0), let_stmt); | |||
| Stmt new_allocate = Allocate::make(op->buffer_var, op->type, op->extents, const_false(), attr); | |||
| return new_allocate; // Block::make(attr, attr1); | |||
| } | |||
| return IRMutator::Mutate_(op, s); | |||
| } | |||
| private: | |||
| // int nb_loops; | |||
| SwizzleFinder finder; | |||
| const For *currentLoop{}; | |||
| bool swizzling = false; | |||
| int loop_extent{}; | |||
| std::set<std::basic_string<char>> declared{}; | |||
| std::unordered_map<std::basic_string<char>, Var> replace_vars{}; | |||
| std::basic_string<char> kernel_name; | |||
| }; | |||
| static void ParseStringAttr(const Map<std::string, NodeRef> &attrs, const std::string &attr_name, | |||
| std::string *attr_to_set) { | |||
| CHECK(attr_to_set != nullptr); | |||
| if (attrs.count(attr_name) == 0) return; | |||
| const NodeRef &e = attrs.at(attr_name); | |||
| if (auto val = e.as<StringImm>()) { | |||
| *attr_to_set = val->value; | |||
| } else { | |||
| LOG(FATAL) << "Failed to parse attribute: " << attr_name << " = " << e << " as string"; | |||
| } | |||
| } | |||
| static void ParseIntAttr(const Map<std::string, NodeRef> &attrs, const std::string &attr_name, int *attr_to_set) { | |||
| CHECK(attr_to_set != nullptr); | |||
| if (attrs.count(attr_name) == 0) return; | |||
| const NodeRef &e = attrs.at(attr_name); | |||
| if (auto i = e.as<IntImm>()) { | |||
| *attr_to_set = static_cast<int>(i->value); | |||
| } else if (auto ui = e.as<UIntImm>()) { | |||
| *attr_to_set = static_cast<int>(ui->value); | |||
| } else { | |||
| LOG(FATAL) << "Failed to parse attribute: " << attr_name << " = " << e << " as integer"; | |||
| } | |||
| } | |||
| static void ParseBoolAttr(const Map<std::string, NodeRef> &attrs, const std::string &attr_name, bool *attr_to_set) { | |||
| const int invalid_value = -1; | |||
| int attr = invalid_value; | |||
| ParseIntAttr(attrs, attr_name, &attr); | |||
| if (attr != invalid_value) { | |||
| CHECK(attr == 0 || attr == 1) << "Bool attribute " << attr_name << " must be 0 or 1, but found " | |||
| << attrs.at(attr_name); | |||
| *attr_to_set = static_cast<bool>(attr); | |||
| } | |||
| } | |||
| Stmt SwizzleGPU(const Stmt &stmt, const Map<std::string, NodeRef> &attrs) { | |||
| bool disable_swizzle = false; | |||
| ParseBoolAttr(attrs, "disable_swizzle", &disable_swizzle); | |||
| if (const char *env_p = std::getenv("MS_AKG_DISABLE_SWIZZLE")) | |||
| if (!strcmp(env_p, "1")) disable_swizzle = true; | |||
| if (disable_swizzle) { | |||
| LOG(INFO) << "SwizzleGPU pass disabled"; | |||
| return stmt; | |||
| } | |||
| std::string kernel_name_; | |||
| ParseStringAttr(attrs, "kernel_name", &kernel_name_); | |||
| if (kernel_name_.empty()) | |||
| LOG(WARNING) << "Kernel name not found !"; | |||
| else | |||
| LOG(INFO) << "BEGIN_PASS SwizzleGPU on " << kernel_name_; | |||
| auto sw = Swizzle(kernel_name_); | |||
| Stmt s = sw.VisitAndMutate(stmt); | |||
| LOG(INFO) << "END_PASS"; | |||
| return s; | |||
| } | |||
| } // namespace ir | |||
| } // namespace akg | |||
| @@ -51,6 +51,7 @@ void DsaMgrStrategy::RegisterMemPromPasses() { RegisterPass(std::make_shared<Mem | |||
| void DsaMgrStrategy::RegisterPasses() { | |||
| passes_.clear(); | |||
| RegisterNormalizationPasses(); | |||
| RegisterConstrainedScheduling(); | |||
| if (!scop_info_.user_config_.GetDisableGroup()) { | |||
| RegisterPass(std::make_shared<GroupStatements>(pass_info_)); | |||
| } | |||
| @@ -36,7 +36,7 @@ void DumpSchTreeToFile(std::FILE *fp, const isl::schedule &sch) { | |||
| CHECK(sch.get()); | |||
| printer = isl_printer_to_file(isl_schedule_ctx(sch.get()), fp); | |||
| printer = isl_printer_to_file(isl_schedule_get_ctx(sch.get()), fp); | |||
| printer = isl_printer_set_yaml_style(printer, ISL_YAML_STYLE_BLOCK); | |||
| printer = isl_printer_print_schedule(printer, sch.get()); | |||
| @@ -438,6 +438,11 @@ void ScopInfo::DumpScopDataAdvanced(std::ofstream &of) { | |||
| } | |||
| void UserConfig::DumpScopDataScheduleAttrs(std::ofstream &of) { | |||
| if (constrained_scheduling_output_ != "") { | |||
| PrintHeader(of, "constrained scheduling"); | |||
| of << constrained_scheduling_output_ << std::endl; | |||
| } | |||
| PrintHeader(of, "schedule attrs"); | |||
| of << "dump_poly_dir : " << GetDumpPolyDir() << std::endl; | |||
| @@ -36,6 +36,7 @@ void GPUMgrStrategy::RegisterMemPromPasses() { | |||
| void GPUMgrStrategy::RegisterPasses() { | |||
| passes_.clear(); | |||
| RegisterNormalizationPasses(); | |||
| RegisterConstrainedScheduling(); | |||
| RegisterSchedulingPasses(); | |||
| RegisterPass(std::make_shared<GpuDmaAnalysis>(scop_info_)); | |||
| RegisterTilingPasses(); | |||
| @@ -0,0 +1,891 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "poly/isl_influence.h" | |||
| // ISL | |||
| #include <isl/aff.h> | |||
| #include <isl/mat.h> | |||
| #include <isl/hash.h> | |||
| #include <isl/schedule.h> | |||
| #include <isl/schedule_node.h> | |||
| #include <isl/constraint.h> | |||
| #include <isl/map_to_basic_set.h> | |||
| // ISL private headers | |||
| #include <isl_int.h> | |||
| #include <isl_tab.h> | |||
| extern "C" { | |||
| #include <isl_aff_private.h> | |||
| #include <isl_mat_private.h> | |||
| #include <isl_schedule_constraints.h> | |||
| } | |||
| #include "poly/log_util.h" | |||
| struct isl_sched_graph { | |||
| isl_map_to_basic_set *intra_hmap; | |||
| isl_map_to_basic_set *intra_hmap_param; | |||
| isl_map_to_basic_set *inter_hmap; | |||
| struct isl_sched_node *node; | |||
| int n; | |||
| int maxvar; | |||
| int max_row; | |||
| int n_row; | |||
| int *sorted; | |||
| int n_total_row; | |||
| int band_start; | |||
| struct isl_sched_graph *root; | |||
| struct isl_sched_edge *edge; | |||
| int n_edge; | |||
| int max_edge[isl_edge_last + 1]; | |||
| struct isl_hash_table *edge_table[isl_edge_last + 1]; | |||
| struct isl_hash_table *node_table; | |||
| struct isl_trivial_region *region; | |||
| isl_basic_set *lp; | |||
| int src_scc; | |||
| int dst_scc; | |||
| int scc; | |||
| int weak; | |||
| int max_weight; | |||
| /* AKG isl_influence patch - start */ | |||
| akg::ir::poly::isl_influence_list *inf_list; | |||
| akg::ir::poly::isl_influence_equal_list *inf_equal_list; | |||
| akg::ir::poly::isl_influence_sol_list *inf_sol_list; | |||
| /* AKG isl_influence patch - end */ | |||
| }; | |||
| /////////////////////////////////////////////////////////////////////////// | |||
| // CODE | |||
| /////////////////////////////////////////////////////////////////////////// | |||
| namespace akg { | |||
| namespace ir { | |||
| namespace poly { | |||
| static inline void akg_isl_influence_set_function_pointers(void) { | |||
| isl_influence_set_coef = akg_isl_influence_set_coef; | |||
| isl_influence_set_equal = akg_isl_influence_set_equal; | |||
| isl_influence_maxvar = akg_isl_influence_maxvar; | |||
| isl_influence_check_coincident = akg_isl_influence_check_coincident; | |||
| isl_influence_sol_list_free = akg_isl_influence_sol_list_free; | |||
| isl_influence_sol_add_elem = akg_isl_influence_sol_add_elem; | |||
| isl_influence_sol_get_elem = akg_isl_influence_sol_get_elem; | |||
| } | |||
| static inline void akg_isl_influence_unset_function_pointers(void) { | |||
| isl_influence_set_coef = 0; | |||
| isl_influence_set_equal = 0; | |||
| isl_influence_maxvar = 0; | |||
| isl_influence_check_coincident = 0; | |||
| isl_influence_sol_list_free = 0; | |||
| isl_influence_sol_add_elem = 0; | |||
| isl_influence_sol_get_elem = 0; | |||
| } | |||
| void akg_isl_influence_enable(isl_ctx *ctx) { | |||
| isl_options_set_akg_print_debug(ctx, 1); | |||
| isl_options_set_akg_influence_scheduler(ctx, 1); | |||
| akg_isl_influence_set_function_pointers(); | |||
| } | |||
| void akg_isl_influence_disable(isl_ctx *ctx) { | |||
| isl_options_set_akg_print_debug(ctx, 0); | |||
| isl_options_set_akg_influence_scheduler(ctx, 0); | |||
| akg_isl_influence_unset_function_pointers(); | |||
| } | |||
| static const char *inf_types_str[] = { | |||
| [isl_cst] = "isl_cst", | |||
| [isl_param] = "isl_param", | |||
| [isl_var] = "isl_var", | |||
| }; | |||
| struct isl_sched_graph *akg_isl_influence_sol_list_free(isl_sched_graph *graph) { | |||
| log::Info(log::Verbosity::high, "Entering isl_influence_sol_list_free"); | |||
| isl_influence_sol_list *inf = graph->inf_sol_list; | |||
| if (inf) { | |||
| log::Info(log::Verbosity::high, "inf-mem: " + std::to_string(inf->mem)); | |||
| for (int i = 0; i < inf->mem; i++) { | |||
| isl_influence_sol *p = &inf->data[i]; | |||
| p->mem = 0; | |||
| p->size = 0; | |||
| free(p->data); | |||
| p->data = NULL; | |||
| } | |||
| free(inf->data); | |||
| inf->data = NULL; | |||
| inf->size = 0; | |||
| inf->mem = 0; | |||
| free(inf); | |||
| } | |||
| graph->inf_sol_list = NULL; | |||
| log::Info(log::Verbosity::high, "Leaving isl_influence_sol_list_free"); | |||
| return graph; | |||
| } | |||
| struct isl_sched_graph *akg_isl_influence_sol_add_elem(isl_vec *sol, struct isl_sched_graph *graph) { | |||
| log::Info(log::Verbosity::high, "Entering isl_influence_sol_add_elem"); | |||
| if (!sol) { | |||
| return graph; | |||
| } | |||
| isl_influence_sol_list *inf = graph->inf_sol_list; | |||
| if (!inf) { | |||
| inf = (isl_influence_sol_list *)calloc(1, sizeof(isl_influence_sol_list)); | |||
| if (!inf) { | |||
| log::Info(log::Verbosity::high, | |||
| "MEM ERROR: isl_influence_sol_add_elem could not allocate memory for isl_influence_sol_list"); | |||
| return graph; | |||
| } | |||
| inf->mem = 0; | |||
| inf->size = 0; | |||
| inf->data = NULL; | |||
| } | |||
| if (inf->mem == 0) { | |||
| inf->mem = INF_BLOCK_SIZE; | |||
| inf->data = (isl_influence_sol *)calloc(inf->mem, sizeof(isl_influence_sol)); | |||
| if (inf->data == NULL) { | |||
| log::Info(log::Verbosity::high, | |||
| "MEM ERROR: isl_influence_sol_add_elem could not allocate memory for isl_influence_sol_list"); | |||
| return graph; | |||
| } | |||
| } | |||
| if (inf->mem == inf->size) { | |||
| inf->mem += INF_BLOCK_SIZE; | |||
| inf->data = (isl_influence_sol *)realloc(inf->data, inf->mem * sizeof(isl_influence_sol)); | |||
| if (inf->data == NULL) { | |||
| log::Info(log::Verbosity::high, | |||
| "MEM ERROR: isl_influence_sol_add_elem could not reallocate memory isl_influence_sol"); | |||
| return graph; | |||
| } | |||
| } | |||
| isl_influence_sol *p = &inf->data[inf->size]; | |||
| const int sol_size = isl_inf_vec_get_size(sol); | |||
| p->data = (int *)calloc((size_t)sol_size, sizeof(int)); | |||
| if (p->data == NULL) { | |||
| log::Info(log::Verbosity::high, "MEM ERROR: isl_influence_sol_add_elem could not allocate memory"); | |||
| return graph; | |||
| } | |||
| p->mem = sol_size; | |||
| isl_ctx *ctx = isl_vec_get_ctx(sol); | |||
| for (int i = 0; i < p->mem; i++) { | |||
| isl_printer *printer = isl_printer_to_str(ctx); | |||
| isl_val *const val = isl_influence_vec_get_elem(sol, i); | |||
| printer = isl_printer_print_val(printer, val); | |||
| char *const str = isl_printer_get_str(printer); | |||
| const int value = atoi(isl_printer_get_str(printer)); | |||
| p->data[i] = value; | |||
| p->size++; | |||
| free(str); | |||
| isl_printer_free(printer); | |||
| isl_val_free(val); | |||
| } | |||
| inf->size++; | |||
| log::Info(log::Verbosity::high, "inf->mem: " + std::to_string(inf->mem)); | |||
| log::Info(log::Verbosity::high, "inf->size: " + std::to_string(inf->size)); | |||
| log::Info(log::Verbosity::high, "Leaving isl_influence_sol_add_elem"); | |||
| graph->inf_sol_list = inf; | |||
| return graph; | |||
| } | |||
| int akg_isl_influence_sol_get_elem(int dim, int pos, struct isl_sched_graph *graph) { | |||
| isl_influence_sol_list *inf = graph->inf_sol_list; | |||
| if (NULL == inf) { | |||
| return -1; | |||
| } | |||
| isl_influence_sol *p = NULL; | |||
| int retval; | |||
| if (dim < inf->size) { | |||
| p = &inf->data[dim]; | |||
| } else { | |||
| std::string message = "ERROR: isl_influence_sol_get_elem index out of range isl_influence_sol_list size : "; | |||
| message += std::to_string(inf->size) + " < dim: " + std::to_string(dim); | |||
| log::Info(log::Verbosity::high, message); | |||
| retval = -1; | |||
| } | |||
| if (1 + pos < p->size) { | |||
| retval = p->data[1 + pos]; | |||
| } else { | |||
| std::string message = "ERROR: isl_influence_sol_get_elem index out of range isl_ifnluence_sol size: "; | |||
| message += std::to_string(p->size) + " < pos: " + std::to_string(pos); | |||
| retval = -1; | |||
| } | |||
| return retval; | |||
| } | |||
| void print_basic_set(isl_basic_set *set, const char *str) { | |||
| isl_ctx *const ctx = isl_basic_set_get_ctx(set); | |||
| const bool print_debug = isl_options_get_akg_print_debug(ctx); | |||
| if (print_debug) { | |||
| isl_printer *printer = isl_printer_to_str(ctx); | |||
| printer = isl_printer_print_basic_set(printer, set); | |||
| char *const printed_str = isl_printer_get_str(printer); | |||
| log::Info(log::Verbosity::high, str); | |||
| log::Info(log::Verbosity::high, printed_str); | |||
| isl_printer_free(printer); | |||
| free(printed_str); | |||
| } | |||
| } | |||
| void *isl_influence_equal_list_free(isl_influence_equal_list *inf_equal_list) { | |||
| if (inf_equal_list->data) { | |||
| for (int i = 0; i < inf_equal_list->size; ++i) { | |||
| free(inf_equal_list->data[i].statement1); | |||
| free(inf_equal_list->data[i].statement2); | |||
| } | |||
| free(inf_equal_list->data); | |||
| inf_equal_list->data = NULL; | |||
| inf_equal_list->mem = 0; | |||
| inf_equal_list->size = 0; | |||
| } | |||
| free(inf_equal_list); | |||
| return NULL; | |||
| } | |||
| void *isl_influence_list_free(isl_influence_list *inf_list) { | |||
| if (inf_list->data) { | |||
| for (int i = 0; i < inf_list->size; ++i) { | |||
| free(inf_list->data[i].statement_name); | |||
| } | |||
| free(inf_list->data); | |||
| inf_list->data = NULL; | |||
| inf_list->mem = 0; | |||
| inf_list->size = 0; | |||
| } | |||
| free(inf_list); | |||
| return NULL; | |||
| } | |||
| isl_basic_set *hack_coefficients(isl_basic_set *coef, const char *msg, int pos, int lb, int ub) { | |||
| isl_ctx *const ctx = isl_basic_set_get_ctx(coef); | |||
| isl_val *const v_ub = isl_val_int_from_si(ctx, ub); | |||
| isl_val *const v_lb = isl_val_int_from_si(ctx, lb); | |||
| isl_size dim = isl_basic_set_n_dim(coef); | |||
| log::Info(log::Verbosity::high, "pos: " + std::to_string(pos)); | |||
| if (pos < (int)dim) { | |||
| const bool influence_schedule = isl_options_get_akg_influence_scheduler(ctx); | |||
| if (influence_schedule) { | |||
| coef = isl_basic_set_upper_bound_val(coef, isl_dim_set, pos, isl_val_copy(v_ub)); | |||
| coef = isl_basic_set_lower_bound_val(coef, isl_dim_set, pos, isl_val_copy(v_lb)); | |||
| std::string message = " -> i" + std::to_string(pos); | |||
| message += " = (" + std::to_string(lb); | |||
| message += ", " + std::to_string(ub) + ")"; | |||
| log::Info(log::Verbosity::high, message); | |||
| } | |||
| } | |||
| isl_val_free(v_ub); | |||
| isl_val_free(v_lb); | |||
| return coef; | |||
| } | |||
| isl_basic_set *create_constraint(isl_basic_set *coef, const char *msg, int pos1, int pos2) { | |||
| isl_ctx *ctx = isl_basic_set_get_ctx(coef); | |||
| const bool influence_schedule = isl_options_get_akg_influence_scheduler(ctx); | |||
| if (influence_schedule) { | |||
| isl_local_space *ls = isl_basic_set_get_local_space(coef); | |||
| isl_constraint *c = isl_constraint_alloc_equality(isl_local_space_copy(ls)); | |||
| c = isl_constraint_set_coefficient_si(c, isl_dim_set, pos1, 1); | |||
| c = isl_constraint_set_coefficient_si(c, isl_dim_set, pos2, -1); | |||
| coef = isl_basic_set_add_constraint(coef, c); | |||
| print_basic_set(coef, msg); | |||
| log::Info(log::Verbosity::high, "--> i" + std::to_string(pos1) + " = i" + std::to_string(pos2)); | |||
| isl_local_space_free(ls); | |||
| } | |||
| return coef; | |||
| } | |||
| isl_basic_set *graph_find_basic_set_by_statement_name(struct isl_sched_graph *graph, const char *name, | |||
| int *node_index) { | |||
| isl_basic_set *bset = NULL; | |||
| for (int i = 0; i < graph->n && bset == NULL; ++i) { | |||
| struct isl_sched_node *node = isl_sched_graph_get_node(graph, i); | |||
| isl_map *ma = isl_sched_node_extract_schedule(node); | |||
| const char *strstat = isl_map_get_tuple_name(ma, isl_dim_in); | |||
| if (strcmp(strstat, name) == 0) { | |||
| bset = graph->lp; | |||
| *node_index = i; | |||
| } | |||
| isl_map_free(ma); | |||
| } | |||
| return bset; | |||
| } | |||
| int get_pos_from_bset(isl_basic_set *bset, struct isl_sched_node *node, isl_influence_equal *inf, int coef_dim) { | |||
| int pos = -1; | |||
| switch (inf->type) { | |||
| case isl_cst: | |||
| pos = isl_sched_node_cst_coef_offset(node); | |||
| break; | |||
| case isl_param: | |||
| pos = isl_sched_node_par_coef_offset(node) - coef_dim; | |||
| break; | |||
| case isl_var: | |||
| pos = isl_sched_node_par_coef_offset(node) - isl_sched_node_get_nparam(node) - coef_dim * 2 - 1; | |||
| break; | |||
| default: | |||
| break; | |||
| } | |||
| log::Info(log::Verbosity::high, | |||
| "coefficient position for coef_dim=" + std::to_string(coef_dim) + ": " + std::to_string(pos)); | |||
| return pos; | |||
| } | |||
| void print_influence(isl_influence_equal_list *equal, isl_influence_list *coef) { | |||
| log::Info(log::Verbosity::high, "start printing isl_influence hard constraints"); | |||
| for (int i = 0; coef && i < coef->size; i++) { | |||
| isl_influence *inf = &coef->data[i]; | |||
| log::Info(log::Verbosity::high, | |||
| "hard constraint\t\t[+" + std::to_string(i + 1) + ":" + std::to_string(coef->size) + "]"); | |||
| log::Info(log::Verbosity::high, "statement:\t\t" + std::string(inf->statement_name)); | |||
| log::Info(log::Verbosity::high, "type:\t\t\t" + std::string(inf_types_str[(int)inf->type])); | |||
| log::Info(log::Verbosity::high, "sched_dim:\t\t" + std::to_string(inf->sched_dim)); | |||
| log::Info(log::Verbosity::high, "coef_dim:\t\t" + std::to_string(inf->coef_dim)); | |||
| log::Info(log::Verbosity::high, "val:\t\t\t" + std::to_string(inf->val)); | |||
| } | |||
| log::Info(log::Verbosity::high, "End printing isl_influence hard constraints"); | |||
| log::Info(log::Verbosity::high, "Start printing isl_influence soft constraints"); | |||
| for (int i = 0; equal && i < equal->size; i++) { | |||
| isl_influence_equal *inf_equal = &equal->data[i]; | |||
| log::Info(log::Verbosity::high, | |||
| "soft constraint\t\t[" + std::to_string(i + 1) + ":" + std::to_string(equal->size) + "]"); | |||
| log::Info(log::Verbosity::high, "statement1:\t\t" + std::string(inf_equal->statement1)); | |||
| log::Info(log::Verbosity::high, "statement2:\t\t" + std::string(inf_equal->statement2)); | |||
| log::Info(log::Verbosity::high, "sched_dim1:\t\t" + std::to_string(inf_equal->sched_dim1)); | |||
| log::Info(log::Verbosity::high, "sched_dim2:\t\t" + std::to_string(inf_equal->sched_dim2)); | |||
| log::Info(log::Verbosity::high, "type:\t\t\t" + std::string(inf_types_str[(int)inf_equal->type])); | |||
| log::Info(log::Verbosity::high, "coef_dim1:\t\t" + std::to_string(inf_equal->coef_dim1)); | |||
| log::Info(log::Verbosity::high, "coef_dim2:\t\t" + std::to_string(inf_equal->coef_dim2)); | |||
| } | |||
| log::Info(log::Verbosity::high, "End printing isl_influence soft constraints"); | |||
| } | |||
| int report_influence(isl_influence_equal_list *equal, isl_influence_list *coef, isl_influence_sol_list *sol, | |||
| int maxvar) { | |||
| int bad_equal = 0; | |||
| int bad_coef = 0; | |||
| int result = 1; | |||
| for (int i = 0; i < coef->size; i++) { | |||
| isl_influence *inf = &coef->data[i]; | |||
| if (inf->scanned != 1) { | |||
| log::Info(log::Verbosity::high, | |||
| "warning: influence hard constraint [" + std::to_string(i) + "] was not processed"); | |||
| log::Info(log::Verbosity::high, "statement:\t\t" + std::string(inf->statement_name)); | |||
| log::Info(log::Verbosity::high, "type:\t\t\t" + std::string(inf_types_str[(int)inf->type])); | |||
| log::Info(log::Verbosity::high, "sched_dim:\t\t" + std::to_string(inf->sched_dim)); | |||
| log::Info(log::Verbosity::high, "coef_dim:\t\t" + std::to_string(inf->coef_dim)); | |||
| log::Info(log::Verbosity::high, "val:\t\t\t" + std::to_string(inf->val)); | |||
| bad_coef++; | |||
| result = 0; | |||
| } | |||
| } | |||
| for (int i = 0; i < equal->size; i++) { | |||
| isl_influence_equal *inf_equal = &equal->data[i]; | |||
| if (inf_equal->scanned != 1) { | |||
| log::Info(log::Verbosity::high, | |||
| "warning: influence soft constraint [" + std::to_string(i) + "] was not processed"); | |||
| log::Info(log::Verbosity::high, "statement1:\t\t" + std::string(inf_equal->statement1)); | |||
| log::Info(log::Verbosity::high, "statement2:\t\t" + std::string(inf_equal->statement2)); | |||
| log::Info(log::Verbosity::high, "sched_dim1:\t\t" + std::to_string(inf_equal->sched_dim1)); | |||
| log::Info(log::Verbosity::high, "sched_dim2:\t\t" + std::to_string(inf_equal->sched_dim2)); | |||
| log::Info(log::Verbosity::high, "type:\t\t\t" + std::string(inf_types_str[(int)inf_equal->type])); | |||
| log::Info(log::Verbosity::high, "coef_dim1:\t\t" + std::to_string(inf_equal->coef_dim1)); | |||
| log::Info(log::Verbosity::high, "coef_dim2:\t\t" + std::to_string(inf_equal->coef_dim2)); | |||
| bad_equal++; | |||
| result = 0; | |||
| } | |||
| } | |||
| if (bad_coef == 0) | |||
| log::Info(log::Verbosity::high, std::to_string(coef->size) + "influence hard coef constraints processed correctly"); | |||
| else | |||
| log::Info(log::Verbosity::high, std::to_string(bad_coef) + "influence hard coef constraints were not processed"); | |||
| if (bad_equal == 0) | |||
| log::Info(log::Verbosity::high, std::to_string(equal->size) + "influence equal constraints processed correctly"); | |||
| else | |||
| log::Info(log::Verbosity::high, std::to_string(bad_equal) + "influence equal constraints were not processed"); | |||
| if (!sol || sol->size != maxvar) { | |||
| log::Info(log::Verbosity::high, "isl influence could not find solution for all dimensions"); | |||
| result = 0; | |||
| } | |||
| return result; | |||
| } | |||
| int set_params(isl_influence_equal *inf, int *sched_from, int *sched_to, int *coef_from, int *coef_to, int actual_dim) { | |||
| int retval = 0; | |||
| if (actual_dim == 0 && inf->sched_dim1 != inf->sched_dim2) { | |||
| return retval; | |||
| } else if (actual_dim == inf->sched_dim1) { | |||
| if (inf->sched_dim1 >= inf->sched_dim2) { | |||
| *sched_from = inf->sched_dim2; | |||
| *sched_to = inf->sched_dim1; | |||
| *coef_from = inf->coef_dim2; | |||
| *coef_to = inf->coef_dim1; | |||
| retval = 1; | |||
| } else { | |||
| log::Info(log::Verbosity::high, | |||
| "cannot set future coef for dimension: " + std::to_string(actual_dim) + " and inf_equal:"); | |||
| log::Info(log::Verbosity::high, "inf->sched_dim1: " + std::to_string(inf->sched_dim1)); | |||
| log::Info(log::Verbosity::high, "inf->sched_dim2: " + std::to_string(inf->sched_dim2)); | |||
| } | |||
| } else if (actual_dim == inf->sched_dim2) { | |||
| if (inf->sched_dim2 >= inf->sched_dim1) { | |||
| *sched_from = inf->sched_dim1; | |||
| *sched_to = inf->sched_dim2; | |||
| *coef_from = inf->coef_dim1; | |||
| *coef_to = inf->coef_dim2; | |||
| retval = 1; | |||
| } else { | |||
| log::Info(log::Verbosity::high, | |||
| "cannot set future coef for dimension: " + std::to_string(actual_dim) + " and inf_equal:"); | |||
| log::Info(log::Verbosity::high, "inf->sched_dim1: " + std::to_string(inf->sched_dim1)); | |||
| log::Info(log::Verbosity::high, "inf->sched_dim2: " + std::to_string(inf->sched_dim2)); | |||
| } | |||
| } | |||
| return retval; | |||
| } | |||
| isl_basic_set *akg_isl_influence_set_equal(isl_ctx *ctx, struct isl_sched_graph *graph, isl_basic_set *bset) { | |||
| // loop over iinfluence list equal | |||
| log::Info(log::Verbosity::high, "Enter isl_influence_set_equal for dimension --> " + std::to_string(graph->n_row)); | |||
| isl_influence_equal_list *inf_list = graph->inf_equal_list; | |||
| for (int i = 0; i < inf_list->size; i++) { | |||
| isl_influence_equal *inf_equal = &inf_list->data[i]; | |||
| if (inf_equal->sched_dim1 > inf_equal->sched_dim2) { | |||
| std::string message0 = "ERROR: isl cannot compute soft constraint"; | |||
| message0 += "[" + std::to_string(i) + " from dimension " + std::to_string(inf_equal->sched_dim2) + | |||
| " to dimension " + std::to_string(inf_equal->sched_dim1); | |||
| std::string message1 = "Reason: destination coefficient unknown"; | |||
| log::Info(log::Verbosity::high, message0); | |||
| log::Info(log::Verbosity::high, message1); | |||
| } | |||
| int sched_from; | |||
| int sched_to; | |||
| int coef_from; | |||
| int coef_to; | |||
| if (!set_params(inf_equal, &sched_from, &sched_to, &coef_from, &coef_to, graph->n_row)) continue; | |||
| isl_basic_set *bset_from; | |||
| isl_basic_set *bset_to; | |||
| int node_index_from = -1; | |||
| int node_index_to = -1; | |||
| bset_from = graph_find_basic_set_by_statement_name(graph, inf_equal->statement1, &node_index_from); | |||
| bset_to = graph_find_basic_set_by_statement_name(graph, inf_equal->statement2, &node_index_to); | |||
| if (bset_from != NULL && bset_to != NULL && node_index_from != -1 && node_index_to != -1) { | |||
| int pos_from; | |||
| int pos_to; | |||
| log::Info(log::Verbosity::high, "scanning equal constraint for influence equal constraint[" + | |||
| std::to_string(i + 1) + ":" + std::to_string(inf_list->size) + "]"); | |||
| log::Info(log::Verbosity::high, "statement from:\t" + std::string(inf_equal->statement2)); | |||
| log::Info(log::Verbosity::high, "statement to:\t" + std::string(inf_equal->statement1)); | |||
| log::Info(log::Verbosity::high, "sched_dim from:\t" + std::to_string(sched_from)); | |||
| log::Info(log::Verbosity::high, "sched_dim to:\t" + std::to_string(sched_to)); | |||
| log::Info(log::Verbosity::high, "coef_dim from\t" + std::to_string(coef_from)); | |||
| log::Info(log::Verbosity::high, "coef_dim to\t" + std::to_string(coef_to)); | |||
| log::Info(log::Verbosity::high, "type:\t\t" + std::string(inf_types_str[(int)(inf_equal->type)])); | |||
| log::Info(log::Verbosity::high, "isl_influence_set_equal: copying coef from " + | |||
| std::string(inf_equal->statement2) + " to " + | |||
| std::string(inf_equal->statement1)); | |||
| print_basic_set(bset_from, "bset from:"); | |||
| print_basic_set(bset_to, "bset to:"); | |||
| log::Info(log::Verbosity::high, "node from:\t" + std::to_string(node_index_from)); | |||
| log::Info(log::Verbosity::high, "node to:\t" + std::to_string(node_index_to)); | |||
| pos_from = get_pos_from_bset(bset_from, isl_sched_graph_get_node(graph, node_index_from), inf_equal, coef_from); | |||
| pos_to = get_pos_from_bset(bset_to, isl_sched_graph_get_node(graph, node_index_to), inf_equal, coef_to); | |||
| if (inf_equal->sched_dim2 == inf_equal->sched_dim1) { | |||
| bset = create_constraint(bset, "constraint created", pos_from, pos_to); | |||
| if (inf_equal->type == isl_var) bset = create_constraint(bset, "constraint created", pos_from - 1, pos_to - 1); | |||
| } else { | |||
| int val = isl_influence_sol_get_elem(sched_from, pos_from, graph); | |||
| log::Info(log::Verbosity::high, "val=" + std::to_string(val)); | |||
| bset = hack_coefficients(bset, "isl_equal", pos_to, val, val); | |||
| if (inf_equal->type == isl_var) { | |||
| val = isl_influence_sol_get_elem(sched_from, pos_from - 1, graph); | |||
| bset = hack_coefficients(bset, "is_equal", pos_to - 1, val, val); | |||
| } | |||
| } | |||
| inf_equal->scanned = 1; | |||
| } | |||
| } | |||
| log::Info(log::Verbosity::high, "Leave isl_influence_set_equal"); | |||
| return bset; | |||
| } | |||
| __isl_give isl_schedule *akg_isl_schedule_constraints_compute_schedule_influence( | |||
| __isl_take isl_schedule_constraints *sc, isl_influence_list *inf_coef, isl_influence_equal_list *inf_equal) { | |||
| isl_ctx *ctx = isl_schedule_constraints_get_ctx(sc); | |||
| struct isl_sched_graph graph = {0}; | |||
| isl_schedule *sched; | |||
| isl_schedule_node *node; | |||
| isl_union_set *domain; | |||
| isl_size n; | |||
| log::Info(log::Verbosity::high, "isl_schedule_constraints_compute_schedule : start printing constraints"); | |||
| isl_printer *p; | |||
| p = isl_printer_to_str(ctx); | |||
| p = isl_printer_set_yaml_style(p, ISL_YAML_STYLE_BLOCK); | |||
| p = isl_printer_print_schedule_constraints(p, sc); | |||
| char *log_string = isl_printer_get_str(p); | |||
| log::Info(log::Verbosity::high, std::string(log_string)); | |||
| isl_printer_free(p); | |||
| free(log_string); | |||
| log::Info(log::Verbosity::high, "isl_schedule_constraints_compute_schedule : end printing constraints"); | |||
| print_influence(inf_equal, inf_coef); | |||
| graph.inf_list = inf_coef; | |||
| graph.inf_equal_list = inf_equal; | |||
| sc = isl_schedule_constraints_align_params(sc); | |||
| domain = isl_schedule_constraints_get_domain(sc); | |||
| n = isl_union_set_n_set(domain); | |||
| if (n == 0) { | |||
| isl_schedule_constraints_free(sc); | |||
| return isl_schedule_from_domain(domain); | |||
| } | |||
| if (n < 0 || isl_sched_graph_init(&graph, sc) < 0) { | |||
| domain = isl_union_set_free(domain); | |||
| } | |||
| node = isl_schedule_node_from_domain(domain); | |||
| node = isl_schedule_node_child(node, 0); | |||
| if (graph.n > 0) { | |||
| node = isl_schedule_node_compute_schedule(node, &graph); | |||
| } | |||
| sched = isl_schedule_node_get_schedule(node); | |||
| int result = report_influence(inf_equal, inf_coef, graph.inf_sol_list, graph.maxvar); | |||
| isl_schedule_node_free(node); | |||
| isl_sched_graph_free(ctx, &graph); | |||
| isl_schedule_constraints_free(sc); | |||
| if (!result) { | |||
| log::Info(log::Verbosity::high, "isl_influence failed, will fallback to default isl"); | |||
| isl_schedule_free(sched); | |||
| sched = NULL; | |||
| } | |||
| return sched; | |||
| } | |||
| isl_basic_set *akg_isl_influence_set_coef(isl_ctx *ctx, struct isl_sched_graph *graph, isl_basic_set *bset) { | |||
| log::Info(log::Verbosity::high, "Enter isl_influence_set_coef for dimension --> " + std::to_string(graph->n_row)); | |||
| isl_influence_list *inf_list = graph->inf_list; | |||
| int dimension = graph->n_row; | |||
| for (int i = 0; i < graph->n; ++i) { | |||
| struct isl_sched_node *node = isl_sched_graph_get_node(graph, i); | |||
| isl_map *ma = isl_sched_node_extract_schedule(node); | |||
| log::Info(log::Verbosity::high, "statement:"); | |||
| if (ma != NULL) { | |||
| isl_printer *p; | |||
| p = isl_printer_to_str(ctx); | |||
| p = isl_printer_print_map(p, ma); | |||
| char *log_str = isl_printer_get_str(p); | |||
| log::Info(log::Verbosity::high, std::string(log_str)); | |||
| isl_printer_free(p); | |||
| free(log_str); | |||
| } | |||
| log::Info(log::Verbosity::high, "end statement"); | |||
| const char *strstat = isl_map_get_tuple_name(ma, isl_dim_in); | |||
| isl_map_free(ma); | |||
| for (int j = 0; j < inf_list->size; j++) { | |||
| isl_influence *inf = &inf_list->data[j]; | |||
| if (inf->sched_dim == dimension && strcmp(strstat, inf->statement_name) == 0) { | |||
| int pos; | |||
| int ub = inf->val; | |||
| int lb = inf->val; | |||
| log::Info(log::Verbosity::high, "scanning isl coefficients for influence[" + std::to_string(j + 1) + ":" + | |||
| std::to_string(inf_list->size) + "]:"); | |||
| // S_0,S_1,...,S_n-1,S_n | |||
| log::Info(log::Verbosity::high, "statement_name:\t\t" + std::string(inf->statement_name)); | |||
| // i,j,k.... (only apply for type=isl_var_plus | isl_var_minus | |||
| log::Info(log::Verbosity::high, "sched dim:\t\t" + std::to_string(inf->sched_dim)); | |||
| // coefficient index | |||
| log::Info(log::Verbosity::high, "rank variable:\t\t" + std::to_string(inf->coef_dim)); | |||
| int nparam = isl_sched_node_get_nparam(node); | |||
| int nvar = isl_sched_node_get_nvar(node); | |||
| log::Info(log::Verbosity::high, "statement variables:\t" + std::to_string(nvar)); | |||
| // coefficient value | |||
| log::Info(log::Verbosity::high, "coefficient value:\t" + std::to_string(inf->val)); | |||
| log::Info(log::Verbosity::high, "type:\t\t" + std::string(inf_types_str[(int)(inf->type)])); | |||
| print_basic_set(bset, "lp problem to influence:"); | |||
| switch (inf->type) { | |||
| case isl_cst: | |||
| pos = isl_sched_node_cst_coef_offset(node); | |||
| bset = hack_coefficients(bset, "isl_cst", pos, ub, lb); | |||
| inf->scanned = 1; | |||
| break; | |||
| case isl_param: | |||
| log::Info(log::Verbosity::high, "Statement param (node->nparam): " + std::to_string(nparam)); | |||
| pos = isl_sched_node_cst_coef_offset(node) - (nparam - inf->coef_dim); | |||
| bset = hack_coefficients(bset, "isl_param", pos, ub, lb); | |||
| inf->scanned = 1; | |||
| break; | |||
| case isl_var: | |||
| // dim_+ coefficient ; | |||
| if (inf->coef_dim <= nvar) { | |||
| pos = isl_sched_node_cst_coef_offset(node) - nparam - 1 - inf->coef_dim * 2; | |||
| if (inf->val > 0) { | |||
| bset = hack_coefficients(bset, "isl_var_+", pos, ub, lb); | |||
| // dim_- coeffiecient | |||
| pos--; | |||
| bset = hack_coefficients(bset, "isl_var_-", pos, 0, 0); | |||
| inf->scanned = 1; | |||
| } else if (inf->val == 0) { | |||
| // dim_+ coefficient | |||
| bset = hack_coefficients(bset, "isl_var_+", pos, 0, 0); | |||
| // dim_- coeffiecient | |||
| pos--; | |||
| bset = hack_coefficients(bset, "isl_var_-", pos, 0, 0); | |||
| inf->scanned = 1; | |||
| } else if (inf->val < 0) { | |||
| bset = hack_coefficients(bset, "isl_var_+", pos, 0, 0); | |||
| // dim_- coeffiecient | |||
| pos--; | |||
| bset = hack_coefficients(bset, "isl_var_-", pos, -ub, -lb); | |||
| inf->scanned = 1; | |||
| } else { | |||
| log::Info(log::Verbosity::high, "invalid inf->val: " + std::to_string(inf->val)); | |||
| } | |||
| } else { | |||
| log::Info(log::Verbosity::high, | |||
| "Warning: dimension overflow --> dimension required: " + std::to_string(inf->coef_dim) + | |||
| " max dimensions: " + std::to_string(nvar)); | |||
| } | |||
| break; | |||
| default: | |||
| log::Info(log::Verbosity::high, "unknown influence coef type"); | |||
| break; | |||
| } | |||
| print_basic_set(bset, "lp influenced problem:"); | |||
| } | |||
| } | |||
| } | |||
| log::Info(log::Verbosity::high, "Leave isl_influence_set_coef"); | |||
| return bset; | |||
| } | |||
| int akg_isl_influence_maxvar(struct isl_sched_graph *graph) { | |||
| log::Info(log::Verbosity::high, "Entering akg_isl_influence_maxar"); | |||
| int maxvar = 0; | |||
| int previous = maxvar; | |||
| int var; | |||
| isl_influence_list *inf_list = graph->inf_list; | |||
| isl_influence_equal_list *inf_equal_list = graph->inf_equal_list; | |||
| for (int i = 0; NULL != inf_list && i < inf_list->size; i++) { | |||
| isl_influence *inf = &inf_list->data[i]; | |||
| var = inf->sched_dim + 1; | |||
| if (maxvar < var) { | |||
| maxvar = var; | |||
| } | |||
| } | |||
| for (int i = 0; NULL != inf_equal_list && i < inf_equal_list->size; i++) { | |||
| isl_influence_equal *inf_equal = &inf_equal_list->data[i]; | |||
| var = inf_equal->sched_dim1 + 1; | |||
| if (maxvar < var) { | |||
| maxvar = var; | |||
| } | |||
| var = inf_equal->sched_dim2 + 1; | |||
| if (maxvar < var) { | |||
| maxvar = var; | |||
| } | |||
| } | |||
| log::Info(log::Verbosity::high, "isl_influence_maxvar : " + std::to_string(maxvar) + | |||
| " (previous maxvar: " + std::to_string(previous) + ")"); | |||
| log::Info(log::Verbosity::high, "Leaving akg_isl_influence_maxvar"); | |||
| return maxvar; | |||
| } | |||
| int akg_isl_influence_check_coincident(struct isl_sched_graph *graph, isl_vec *sol) { | |||
| int coincident = graph->n > 1 ? 1 : 0; | |||
| int pos = 0; | |||
| log::Info(log::Verbosity::high, "Entering isl_influene_check_coincidence for dimension " + | |||
| std::to_string(graph->n_row) + | |||
| ", number of nodes (statements): " + std::to_string(graph->n)); | |||
| int *vec = (int *)calloc(graph->n * 3, sizeof(int)); | |||
| for (int i = 0; i < graph->n; ++i) { | |||
| isl_sched_node *node = isl_sched_graph_get_node(graph, i); | |||
| int cst_offset = isl_sched_node_cst_coef_offset(node); | |||
| int nparam = isl_sched_node_get_nparam(node); | |||
| int nvar = isl_sched_node_get_nvar(node); | |||
| log::Info(log::Verbosity::high, "graph node:\t" + std::to_string(i)); | |||
| log::Info(log::Verbosity::high, "cst_offset:\t" + std::to_string(cst_offset)); | |||
| log::Info(log::Verbosity::high, "nparam:\t\t" + std::to_string(nparam)); | |||
| log::Info(log::Verbosity::high, "var:\t\t" + std::to_string(nvar)); | |||
| vec[pos++] = cst_offset; | |||
| vec[pos++] = nparam; | |||
| vec[pos++] = nvar; | |||
| } | |||
| int is_equal; | |||
| for (int i = 0; i < graph->n; ++i) | |||
| for (int j = i; j < graph->n; ++j) { | |||
| if (j == i) continue; | |||
| log::Info(log::Verbosity::high, | |||
| "calculating coincidence for statement " + std::to_string(i) + " and " + std::to_string(j)); | |||
| int cst_offset_S0 = vec[i * 3]; | |||
| int cst_offset_S1 = vec[j * 3]; | |||
| is_equal = isl_influence_int_eq(sol, 1 + cst_offset_S0, 1 + cst_offset_S1); | |||
| log::Info(log::Verbosity::high, "S_" + std::to_string(i) + "_i" + std::to_string(cst_offset_S0) + ", S_" + | |||
| std::to_string(j) + "_i" + std::to_string(cst_offset_S1) + | |||
| ", cst coefficient equal: " + std::to_string(is_equal)); | |||
| if (is_equal == 0) { | |||
| coincident = 0; | |||
| break; | |||
| } | |||
| log::Info(log::Verbosity::high, "param coefficient(s) equal:"); | |||
| int nparam_S0 = vec[i * 3 + 1]; | |||
| int nparam_S1 = vec[j * 3 + 1]; | |||
| if (nparam_S0 != 0 && nparam_S1 != 0 && nparam_S0 == nparam_S1) { | |||
| int nparam_pos_0 = cst_offset_S0 - nparam_S0; | |||
| int nparam_pos_1 = cst_offset_S1 - nparam_S1; | |||
| for (int k = nparam_pos_0, l = nparam_pos_1; k < nparam_pos_0 + nparam_S0; k++, l++) { | |||
| is_equal = isl_influence_int_eq(sol, 1 + k, 1 + l); | |||
| log::Info(log::Verbosity::high, "S_" + std::to_string(i) + "_i" + std::to_string(k) + ", S_" + | |||
| std::to_string(j) + "_i" + std::to_string(l) + | |||
| "equal: " + std::to_string(is_equal)); | |||
| if (is_equal == 0) { | |||
| coincident = 0; | |||
| break; | |||
| } | |||
| } | |||
| } else { | |||
| log::Info(log::Verbosity::high, " no parameteres found or number of parameters distinct for each statement."); | |||
| } | |||
| log::Info(log::Verbosity::high, "variable coefficients equal:"); | |||
| int nvar_S0 = vec[i * 3 + 2]; | |||
| int nvar_S1 = vec[j * 3 + 2]; | |||
| if (nvar_S0 != nvar_S1) { | |||
| coincident = 0; | |||
| break; | |||
| } | |||
| int nvar_pos_0 = cst_offset_S0 - nparam_S0 - 2 * nvar_S0; | |||
| int nvar_pos_1 = cst_offset_S1 - nparam_S1 - 2 * nvar_S1; | |||
| for (int k = nvar_pos_0, l = nvar_pos_1; k < nvar_pos_0 + 2 * nvar_S0; k++, l++) { | |||
| is_equal = isl_influence_int_eq(sol, 1 + k, 1 + l); | |||
| log::Info(log::Verbosity::high, "S_" + std::to_string(i) + "_i" + std::to_string(k) + ", S_" + | |||
| std::to_string(j) + "_i" + std::to_string(l) + | |||
| " equal: " + std::to_string(is_equal)); | |||
| if (is_equal == 0) { | |||
| coincident = 0; | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| free(vec); | |||
| log::Info(log::Verbosity::high, "Leaving isl_check_coindicent result: " + std::to_string(coincident)); | |||
| return coincident; | |||
| } | |||
| } // namespace poly | |||
| } // namespace ir | |||
| } // namespace akg | |||
| @@ -0,0 +1,110 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef POLY_ISL_INFLUENCE_H_ | |||
| #define POLY_ISL_INFLUENCE_H_ | |||
| // STL | |||
| #include <tuple> | |||
| #include <vector> | |||
| // ISL | |||
| #include <isl/cpp.h> | |||
| namespace akg { | |||
| namespace ir { | |||
| namespace poly { | |||
| /////////////////////////////////////////////////////////////////////////// | |||
| // Typedefs for soft constraints | |||
| /////////////////////////////////////////////////////////////////////////// | |||
| #define INF_BLOCK_SIZE 8 | |||
| enum isl_influence_coeff_type { isl_cst, isl_param, isl_var }; | |||
| struct isl_influence_equal { | |||
| char *statement1; | |||
| char *statement2; | |||
| int sched_dim; | |||
| int sched_dim1; | |||
| int sched_dim2; | |||
| isl_influence_coeff_type type; | |||
| int coef_dim1; | |||
| int coef_dim2; | |||
| int scanned; | |||
| }; | |||
| typedef struct isl_influence_equal isl_influence_equal; | |||
| struct isl_influence_equal_list { | |||
| int size; | |||
| int mem; | |||
| struct isl_influence_equal *data; | |||
| }; | |||
| typedef struct isl_influence_equal_list isl_influence_equal_list; | |||
| struct isl_influence { | |||
| char *statement_name; | |||
| isl_influence_coeff_type type; | |||
| int sched_dim; | |||
| int coef_dim; | |||
| int val; | |||
| int scanned; | |||
| }; | |||
| typedef struct isl_influence isl_influence; | |||
| struct isl_influence_list { | |||
| int size; | |||
| int mem; | |||
| struct isl_influence *data; | |||
| }; | |||
| typedef struct isl_influence_list isl_influence_list; | |||
| struct isl_influence_sol { | |||
| int size; | |||
| int mem; | |||
| int *data; | |||
| }; | |||
| typedef struct isl_influence_sol isl_influence_sol; | |||
| struct isl_influence_sol_list { | |||
| int size; | |||
| int mem; | |||
| isl_influence_sol *data; | |||
| }; | |||
| typedef struct isl_influence_sol_list isl_influence_sol_list; | |||
| void akg_isl_influence_enable(isl_ctx *ctx); | |||
| void akg_isl_influence_disable(isl_ctx *ctx); | |||
| int akg_isl_influence_maxvar(struct isl_sched_graph *graph); | |||
| isl_basic_set *akg_isl_influence_set_coef(isl_ctx *ctx, struct isl_sched_graph *graph, isl_basic_set *bset); | |||
| isl_basic_set *akg_isl_influence_set_equal(isl_ctx *ctx, struct isl_sched_graph *graph, isl_basic_set *bset); | |||
| int akg_isl_influence_check_coincident(struct isl_sched_graph *graph, isl_vec *sol); | |||
| __isl_give isl_schedule *akg_isl_schedule_constraints_compute_schedule_influence( | |||
| __isl_take isl_schedule_constraints *sc, isl_influence_list *inf_coef, isl_influence_equal_list *inf_equal); | |||
| void *isl_influence_list_free(isl_influence_list *inf_list); | |||
| void *isl_influence_equal_list_free(isl_influence_equal_list *inf_equal_list); | |||
| struct isl_sched_graph *akg_isl_influence_sol_list_free(struct isl_sched_graph *graph); | |||
| struct isl_sched_graph *akg_isl_influence_sol_add_elem(isl_vec *sol, struct isl_sched_graph *graph); | |||
| int akg_isl_influence_sol_get_elem(int sched, int pos, struct isl_sched_graph *graph); | |||
| } // namespace poly | |||
| } // namespace ir | |||
| } // namespace akg | |||
| #endif // POLY_ISL_INFLUENCE_H_ | |||
| @@ -0,0 +1,184 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef POLY_ISL_UTIL_H_ | |||
| #define POLY_ISL_UTIL_H_ | |||
| #include <vector> | |||
| #include "isl/cpp.h" | |||
| #include "poly/pass_info.h" | |||
| #include "poly/scop_info.h" | |||
| // Hardcore isl functions | |||
| namespace akg { | |||
| namespace ir { | |||
| namespace poly { | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| // Misc. | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| long isl_set_plain_get_num_si(const isl::set &s, int dim); | |||
| long isl_set_plain_get_num_si(__isl_keep isl_set *const set, int dim); | |||
| std::vector<int> extract_upper_bounds(const isl::set &s, const std::vector<int> &dimensions); | |||
| std::vector<int> extract_upper_bounds(__isl_keep isl_set *const set, const std::vector<int> &dimensions); | |||
| // Combining isl_spaces | |||
| isl::space isl_space_set_cat(const isl::space &left, const isl::space &right); | |||
| __isl_give isl_space *isl_space_set_cat(__isl_take isl_space *left, __isl_take isl_space *right); | |||
| // Utilities for isl_multi_union_pw_aff* | |||
| isl::multi_union_pw_aff isl_multi_union_pw_aff_cat(const isl::multi_union_pw_aff &left, | |||
| const isl::multi_union_pw_aff &right); | |||
| __isl_give isl_multi_union_pw_aff *isl_multi_union_pw_aff_cat(__isl_take isl_multi_union_pw_aff *left, | |||
| __isl_take isl_multi_union_pw_aff *right); | |||
| isl::multi_union_pw_aff isl_multi_union_pw_aff_insert(const isl::multi_union_pw_aff &aff, unsigned pos, | |||
| const isl::union_pw_aff &el); | |||
| __isl_give isl_multi_union_pw_aff *isl_multi_union_pw_aff_insert(__isl_take isl_multi_union_pw_aff *aff, unsigned pos, | |||
| __isl_take isl_union_pw_aff *el); | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| // isl_schedule_node utilities | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| bool isl_schedule_node_is_band(const isl::schedule_node &node); | |||
| bool isl_schedule_node_is_sequence(const isl::schedule_node &node); | |||
| bool isl_schedule_node_has_single_child(const isl::schedule_node &node); | |||
| bool isl_schedule_node_band_can_unsplit(const isl::schedule_node &node); | |||
| isl_bool isl_schedule_node_is_band(__isl_keep isl_schedule_node *const node); | |||
| isl_bool isl_schedule_node_is_sequence(__isl_keep isl_schedule_node *const node); | |||
| isl_bool isl_schedule_node_has_single_child(__isl_keep isl_schedule_node *const node); | |||
| isl_bool isl_schedule_node_band_can_unsplit(__isl_keep isl_schedule_node *const band); | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| // isl_schedule_node_band utilities | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| __isl_give isl_bool *isl_schedule_node_band_get_coincidence(__isl_keep isl_schedule_node *const band); | |||
| __isl_give isl_schedule_node *isl_schedule_node_band_set_coincidence(__isl_take isl_schedule_node *band, | |||
| __isl_take isl_bool *const coincidence); | |||
| __isl_give isl_schedule_node *isl_schedule_node_band_copy_properties(__isl_take isl_schedule_node *band, | |||
| __isl_keep isl_schedule_node *const original); | |||
| __isl_give isl_schedule_node *isl_schedule_node_band_replace_partial_schedule( | |||
| __isl_take isl_schedule_node *band, __isl_take isl_multi_union_pw_aff *schedule, isl_bool keep_properties); | |||
| __isl_give isl_set *isl_schedule_node_band_lexmax(__isl_keep isl_schedule_node *const band); | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| // isl_schedule_node_band transformations | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| // These functions only preserve permutable/coincidence. | |||
| // AST options are not preserved. | |||
| // interchange dimensions 'first' and 'second' | |||
| isl::schedule_node isl_schedule_node_band_interchange(const isl::schedule_node &band, int first, int second); | |||
| __isl_give isl_schedule_node *isl_schedule_node_band_interchange(__isl_take isl_schedule_node *band, int first, | |||
| int second); | |||
| // strip mine only dimension 'dim' | |||
| isl::schedule_node isl_schedule_node_band_stripmine(const isl::schedule_node &band, int dim, int value); | |||
| isl::schedule_node isl_schedule_node_band_stripmine(const isl::schedule_node &band, int dim, const isl::val &value); | |||
| __isl_give isl_schedule_node *isl_schedule_node_band_stripmine(__isl_take isl_schedule_node *band, int dim, int value); | |||
| __isl_give isl_schedule_node *isl_schedule_node_band_stripmine(__isl_take isl_schedule_node *band, int dim, | |||
| __isl_take isl_val *value); | |||
| // collapse dimensions 'dim' and 'dim + 1' | |||
| isl::schedule_node isl_schedule_node_band_collapse(const isl::schedule_node &band, int dim); | |||
| __isl_give isl_schedule_node *isl_schedule_node_band_collapse(__isl_take isl_schedule_node *band, int dim); | |||
| // modulo/scale/scale down a given dimension of a target statement | |||
| isl::schedule_node isl_schedule_node_band_fine_mod(const isl::schedule_node &band, const std::string &name, | |||
| int dimension, int value); | |||
| isl::schedule_node isl_schedule_node_band_fine_scale(const isl::schedule_node &band, const std::string &name, | |||
| int dimension, int value); | |||
| isl::schedule_node isl_schedule_node_band_fine_scale_down(const isl::schedule_node &band, const std::string &name, | |||
| int dimension, int value); | |||
| isl::schedule_node isl_schedule_node_band_fine_mod(const isl::schedule_node &band, const std::string &name, | |||
| int dimension, const isl::val &value); | |||
| isl::schedule_node isl_schedule_node_band_fine_scale(const isl::schedule_node &band, const std::string &name, | |||
| int dimension, const isl::val &value); | |||
| isl::schedule_node isl_schedule_node_band_fine_scale_down(const isl::schedule_node &band, const std::string &name, | |||
| int dimension, const isl::val &value); | |||
| __isl_give isl_schedule_node *isl_schedule_node_band_fine_mod(__isl_take isl_schedule_node *band, const char *name, | |||
| int dimension, int value); | |||
| __isl_give isl_schedule_node *isl_schedule_node_band_fine_scale(__isl_take isl_schedule_node *band, const char *name, | |||
| int dimension, int value); | |||
| __isl_give isl_schedule_node *isl_schedule_node_band_fine_scale_down(__isl_take isl_schedule_node *band, | |||
| const char *name, int dimension, int value); | |||
| __isl_give isl_schedule_node *isl_schedule_node_band_fine_mod(__isl_take isl_schedule_node *band, const char *name, | |||
| int dimension, __isl_take isl_val *value); | |||
| __isl_give isl_schedule_node *isl_schedule_node_band_fine_scale(__isl_take isl_schedule_node *band, const char *name, | |||
| int dimension, __isl_take isl_val *value); | |||
| __isl_give isl_schedule_node *isl_schedule_node_band_fine_scale_down(__isl_take isl_schedule_node *band, | |||
| const char *name, int dimension, | |||
| __isl_take isl_val *value); | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| // schedule tree transformations (on a schedule_node) | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| // Merge two nested isl_schedule_node_band. | |||
| // Assuming the input schedule_node_band has only one child and this child is an isl_schedule_node_band. | |||
| isl::schedule_node isl_schedule_node_band_unsplit(const isl::schedule_node &band); | |||
| __isl_give isl_schedule_node *isl_schedule_node_band_unsplit(__isl_take isl_schedule_node *band); | |||
| // Call isl_schedule_node_band_unsplit() until it is not possible | |||
| isl::schedule_node isl_schedule_node_band_fully_unsplit(const isl::schedule_node &band); | |||
| __isl_give isl_schedule_node *isl_schedule_node_band_fully_unsplit(__isl_take isl_schedule_node *band); | |||
| // Call isl_schedule_node_band_split() until it is not possible | |||
| isl::schedule_node isl_schedule_node_band_fully_split(const isl::schedule_node &band); | |||
| __isl_give isl_schedule_node *isl_schedule_node_band_fully_split(__isl_take isl_schedule_node *band); | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| // isl_schedule transformations | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| isl::schedule isl_schedule_collapse(const isl::schedule &schedule, int first, int last); | |||
| __isl_give isl_schedule *isl_schedule_collapse(__isl_take isl_schedule *schedule, int first, int last); | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| // "Readable" strings conversions | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| // Simple C code | |||
| std::string to_c_code_string(const isl::schedule &sch); | |||
| std::string to_c_code_string(__isl_keep isl_schedule *const schedule); | |||
| std::string to_c_code_string(const isl::schedule_constraints &c); | |||
| std::string to_c_code_string(__isl_keep isl_schedule_constraints *const constraints); | |||
| // isl_*_to_str functions return isl formatted strings! | |||
| std::string to_block_string(const isl::schedule &s); | |||
| std::string to_block_string(__isl_keep isl_schedule *const schedule); | |||
| std::string to_block_string(const isl::union_map &map); | |||
| std::string to_block_string(__isl_keep isl_union_map *const map); | |||
| std::string to_block_string(const isl::schedule_constraints &constraints); | |||
| std::string to_block_string(__isl_keep isl_schedule_constraints *const constraints); | |||
| } // namespace poly | |||
| } // namespace ir | |||
| } // namespace akg | |||
| #endif // POLY_ISL_UTIL_H_ | |||
| @@ -0,0 +1,90 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "poly/log_util.h" | |||
| #include <dmlc/logging.h> | |||
| namespace akg { | |||
| namespace ir { | |||
| namespace poly { | |||
| namespace log { | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| // Verbosity levels | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| static Verbosity akg_poly_verbosity_level = Verbosity::silent; | |||
| Verbosity GetVerbosityLevel(void) { return akg_poly_verbosity_level; } | |||
| void SetVerbosityLevel(Verbosity level) { akg_poly_verbosity_level = level; } | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| // Logging functions | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| void Warn(const std::string &message) { LOG(WARNING) << text_yellow << message << text_reset; } | |||
| // Our errors should not be fatal so we still log as a warning. | |||
| void Error(const std::string &message) { LOG(WARNING) << text_red << message << text_reset; } | |||
| void Info(const std::string &message) { LOG(INFO) << message << text_reset; } | |||
| void Warn(const std::stringstream &stream) { | |||
| const std::string &message = stream.str(); | |||
| Warn(message); | |||
| } | |||
| void Error(const std::stringstream &stream) { | |||
| const std::string &message = stream.str(); | |||
| Error(message); | |||
| } | |||
| void Info(const std::stringstream &stream) { | |||
| const std::string &message = stream.str(); | |||
| Info(message); | |||
| } | |||
| // clang-format off | |||
| #define _define_logging_wrappers(func) \ | |||
| void func(const Verbosity level, const std::string &message) { \ | |||
| if (akg_poly_verbosity_level >= level) { \ | |||
| func(message); \ | |||
| } \ | |||
| } \ | |||
| void func(const Verbosity level, const std::stringstream &stream) { \ | |||
| if (akg_poly_verbosity_level >= level) { \ | |||
| func(stream); \ | |||
| } \ | |||
| } \ | |||
| void func(const int level, const std::string &message) { \ | |||
| func(static_cast<Verbosity>(level), message); \ | |||
| } \ | |||
| void func(const int level, const std::stringstream &stream) { \ | |||
| func(static_cast<Verbosity>(level), stream); \ | |||
| } | |||
| _define_logging_wrappers(Info) | |||
| _define_logging_wrappers(Warn) | |||
| _define_logging_wrappers(Error) | |||
| #undef _declare_logging_functions_int_level | |||
| // clang-format on | |||
| } // namespace log | |||
| } // namespace poly | |||
| } // namespace ir | |||
| } // namespace akg | |||
| @@ -0,0 +1,135 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef POLY_LOG_UTIL_H_ | |||
| #define POLY_LOG_UTIL_H_ | |||
| #include <string> | |||
| #include <sstream> | |||
| namespace akg { | |||
| namespace ir { | |||
| namespace poly { | |||
| namespace log { | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| // Verbosity levels | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| enum class Verbosity { | |||
| silent = 0, | |||
| veryLow, | |||
| low, | |||
| medium, | |||
| high, | |||
| veryHigh, | |||
| }; | |||
| Verbosity GetVerbosityLevel(void); | |||
| void SetVerbosityLevel(Verbosity level); | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| // Log colors | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| #ifdef AKG_POLY_LOG_WITH_COLORS | |||
| #define text_reset "\033[0m" | |||
| #define text_bold "\033[1m" | |||
| #define text_dim "\033[2m" | |||
| #define text_italic "\033[3m" | |||
| #define text_underline "\033[4m" | |||
| #define text_blink "\033[5m" | |||
| #define text_rapid_blink "\033[6m" | |||
| #define text_reverse "\033[7m" | |||
| #define text_conceal "\033[8m" | |||
| #define text_black "\033[30m" | |||
| #define text_red "\033[31m" | |||
| #define text_green "\033[32m" | |||
| #define text_yellow "\033[33m" | |||
| #define text_blue "\033[34m" | |||
| #define text_magenta "\033[35m" | |||
| #define text_cyan "\033[36m" | |||
| #define text_white "\033[37m" | |||
| #define text_bright_black "\033[90m" | |||
| #define text_bright_red "\033[91m" | |||
| #define text_bright_green "\033[92m" | |||
| #define text_bright_yellow "\033[93m" | |||
| #define text_bright_blue "\033[94m" | |||
| #define text_bright_magenta "\033[95m" | |||
| #define text_bright_cyan "\033[96m" | |||
| #define text_bright_white "\033[97m" | |||
| #else | |||
| #define text_reset "" | |||
| #define text_bold "" | |||
| #define text_dim "" | |||
| #define text_italic "" | |||
| #define text_underline "" | |||
| #define text_blink "" | |||
| #define text_rapid_blink "" | |||
| #define text_reverse "" | |||
| #define text_conceal "" | |||
| #define text_black "" | |||
| #define text_red "" | |||
| #define text_green "" | |||
| #define text_yellow "" | |||
| #define text_blue "" | |||
| #define text_magenta "" | |||
| #define text_cyan "" | |||
| #define text_white "" | |||
| #define text_bright_black "" | |||
| #define text_bright_red "" | |||
| #define text_bright_green "" | |||
| #define text_bright_yellow "" | |||
| #define text_bright_blue "" | |||
| #define text_bright_magenta "" | |||
| #define text_bright_cyan "" | |||
| #define text_bright_white "" | |||
| #endif | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| // Local logging functions | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| void Info(const std::string &message); | |||
| void Info(const std::stringstream &stream); | |||
| void Info(const int level, const std::string &message); | |||
| void Info(const int level, const std::stringstream &stream); | |||
| void Info(const Verbosity level, const std::string &message); | |||
| void Info(const Verbosity level, const std::stringstream &stream); | |||
| void Warn(const std::string &message); | |||
| void Warn(const std::stringstream &stream); | |||
| void Warn(const int level, const std::string &message); | |||
| void Warn(const int level, const std::stringstream &stream); | |||
| void Warn(const Verbosity level, const std::string &message); | |||
| void Warn(const Verbosity level, const std::stringstream &stream); | |||
| void Error(const std::string &message); | |||
| void Error(const std::stringstream &stream); | |||
| void Error(const int level, const std::string &message); | |||
| void Error(const int level, const std::stringstream &stream); | |||
| void Error(const Verbosity level, const std::string &message); | |||
| void Error(const Verbosity level, const std::stringstream &stream); | |||
| } // namespace log | |||
| } // namespace poly | |||
| } // namespace ir | |||
| } // namespace akg | |||
| #endif // POLY_LOG_UTIL_H_ | |||
| @@ -22,6 +22,7 @@ | |||
| #include "poly/schedule_pass/init_schedule.h" | |||
| #include "poly/schedule_pass/compute_schedule.h" | |||
| #include "poly/schedule_pass/constrain_schedule.h" | |||
| namespace akg { | |||
| namespace ir { | |||
| @@ -37,6 +38,11 @@ class PassMgrStrategy { | |||
| } | |||
| void RegisterNormalizationPasses() { RegisterPass(std::make_shared<InitSchedule>(pass_info_, scop_info_)); } | |||
| void RegisterSchedulingPasses() { RegisterPass(std::make_shared<ComputeSchedule>(pass_info_, scop_info_)); } | |||
| void RegisterConstrainedScheduling() { | |||
| if (scop_info_.user_config_.GetEnableMindTrick()) { | |||
| RegisterPass(std::make_shared<ConstrainSchedule>(pass_info_, scop_info_)); | |||
| } | |||
| } | |||
| virtual void RegisterTilingPasses() = 0; // each backend has different achievement | |||
| virtual void RegisterMemPromPasses() = 0; // each backend has different achievement | |||
| virtual void RegisterPasses() = 0; | |||
| @@ -16,6 +16,8 @@ | |||
| #ifndef POLY_PASS_H_ | |||
| #define POLY_PASS_H_ | |||
| #include <set> | |||
| #include <ostream> | |||
| #include "poly/isl.h" | |||
| #include "poly/scop_info.h" | |||
| #include "poly/pass_info.h" | |||
| @@ -34,6 +36,8 @@ class SchedulePass { | |||
| std::string GetPassName() { return pass_name_; } | |||
| std::string pass_name_; | |||
| bool restart_{false}; // triggers restart during runtime | |||
| std::set<std::string> disabled_passes_; | |||
| }; | |||
| bool LoadScheduleTreeFromFile(const std::string &filename, isl::schedule &schedule); | |||
| @@ -0,0 +1,454 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "poly/schedule_pass/constrain_schedule.h" | |||
| #include <unistd.h> | |||
| #include <cstdio> | |||
| #include <cstdlib> | |||
| #include <isl/ctx.h> | |||
| #include <isl/schedule.h> | |||
| // opendir(), readdir, closedir() | |||
| #include <sys/types.h> | |||
| #include <dirent.h> | |||
| // TVM | |||
| #include <tvm/node/node.h> | |||
| #include <tvm/node/container.h> | |||
| // Local headers | |||
| #include "poly/schedule_pass/scheduling_mind_trick.h" | |||
| #include "poly/isl_util.h" | |||
| #include "poly/log_util.h" | |||
| namespace akg { | |||
| namespace ir { | |||
| namespace poly { | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| // Constructors | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| ConstrainSchedule::ConstrainSchedule(PassInfo &pass_info, ScopInfo &scop_info) | |||
| : pass_info_(pass_info), scop_info_(scop_info) { | |||
| pass_name_ = __FUNCTION__; | |||
| InitVerbosityLevel(); | |||
| const log::Verbosity saved_verbosity = log::GetVerbosityLevel(); | |||
| log::SetVerbosityLevel(static_cast<log::Verbosity>(verbosity_)); | |||
| LoadMindTricks(); | |||
| log::SetVerbosityLevel(saved_verbosity); | |||
| } | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| // Setters | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| void ConstrainSchedule::AddMindTrick(const std::shared_ptr<SchedulingMindTrick> &mind_trick) { | |||
| mind_tricks_.push_back(mind_trick); | |||
| } | |||
| std::vector<std::string> ConstrainSchedule::MindTricksDirectories(void) { | |||
| // We only build the list of directories, directory existence will be checked later on | |||
| std::vector<std::string> directories; | |||
| // We only add the directory specified via the environment. | |||
| const char *const user_directory = std::getenv(env_string_mind_tricks_dir_); | |||
| if (user_directory) { | |||
| directories.push_back(user_directory); | |||
| } | |||
| // Add other directories here if necessary... | |||
| return directories; | |||
| } | |||
| void ConstrainSchedule::LoadMindTrickFromFile(const std::string &filename) { | |||
| auto mind_trick = std::make_shared<SchedulingMindTrick>(pass_info_, scop_info_, verbosity_); | |||
| mind_trick->Load(filename); | |||
| // Alternative: | |||
| // | |||
| // std::ifstream stream(filename); | |||
| // stream >> *mind_trick; | |||
| if (*mind_trick) { | |||
| AddMindTrick(mind_trick); | |||
| } else { | |||
| Warn("something was wrong with mind_trick " + filename); | |||
| } | |||
| } | |||
| void ConstrainSchedule::LoadMindTricks(void) { | |||
| Info(log::Verbosity::low, | |||
| text_reverse text_bright_blue " ConstrainSchedule " text_reset text_bright_blue " LoadMindTricks()" text_reset); | |||
| // Try to load the user's mind_trick from CLI | |||
| const std::string user_mind_trick = scop_info_.user_config_.GetMindTrick(); | |||
| if (user_mind_trick != "") { | |||
| auto mind_trick = std::make_shared<SchedulingMindTrick>(pass_info_, scop_info_, verbosity_); | |||
| mind_trick->Parse(user_mind_trick); | |||
| if (*mind_trick) { | |||
| AddMindTrick(mind_trick); | |||
| Info(log::Verbosity::medium, text_bright_magenta "User's mind_trick:\n" + mind_trick->str()); | |||
| } else { | |||
| Warn(log::Verbosity::veryLow, "something was wrong with user's mind_trick"); | |||
| } | |||
| } | |||
| // Look for mind_tricks in several directories. | |||
| std::vector<std::string> directories = MindTricksDirectories(); | |||
| for (const std::string &directory_str : directories) { | |||
| log::Info(log::Verbosity::medium, "looking for mind tricks in " + directory_str); | |||
| DIR *const directory = opendir(directory_str.c_str()); | |||
| if (directory) { | |||
| // We first store the strings in a vecttor because there is no guarantee | |||
| // on the order fo the files | |||
| std::vector<std::string> files; | |||
| for (struct dirent *entry = readdir(directory); entry; entry = readdir(directory)) { | |||
| const std::string &filename = std::string(entry->d_name); | |||
| if (filename.length() > 5 && filename.compare(filename.length() - 5, 5, ".json") == 0) | |||
| files.push_back(filename); | |||
| } | |||
| if (!files.empty()) { | |||
| std::sort(files.begin(), files.end()); | |||
| for (const std::string &filename : files) { | |||
| const std::string &path = directory_str + "/" + filename; | |||
| LoadMindTrickFromFile(path); | |||
| } | |||
| } | |||
| closedir(directory); | |||
| } else { | |||
| log::Error(log::Verbosity::medium, "could not access directory " + directory_str); | |||
| } | |||
| } | |||
| std::stringstream summary; | |||
| summary << text_cyan << pass_name_ << " has " << mind_tricks_.size(); | |||
| summary << (mind_tricks_.size() <= 1 ? " trick" : " tricks"); | |||
| summary << "up its sleeve"; | |||
| Info(log::Verbosity::low, summary); | |||
| } | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| // Other | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| static inline void RunInfo(const std::string &stage, const std::string &kernel_name, const isl::schedule &schedule) { | |||
| log::Info(log::Verbosity::low, | |||
| text_reverse text_bright_blue " ConstrainSchedule " text_reset text_bright_blue " Run() " + stage); | |||
| log::Info(log::Verbosity::low, text_bright_blue "name: " + kernel_name); | |||
| log::Info(log::Verbosity::low, text_bright_blue "schedule:\n" + to_block_string(schedule)); | |||
| log::Info(log::Verbosity::medium, text_bright_blue "schedule (loop nest):\n" + to_c_code_string(schedule)); | |||
| } | |||
| bool ConstrainSchedule::KernelIsEligible(const isl::schedule &sch) const { | |||
| const std::string &kernel_name = scop_info_.user_config_.GetKernelName(); | |||
| // For now, the only criterion for eligibility is the existence of a blacklist exists and | |||
| // its containing the current kernel name | |||
| const char *const blacklist_path = std::getenv(env_string_mind_tricks_operator_blacklist_); | |||
| if (!blacklist_path) { | |||
| Info(log::Verbosity::high, env_string_mind_tricks_operator_blacklist_ + std::string(" not set")); | |||
| return true; | |||
| } | |||
| std::fstream file(blacklist_path, std::ios::in); | |||
| if (!file.is_open()) { | |||
| Warn(log::Verbosity::high, "could not open operator blacklist: " + std::string(blacklist_path)); | |||
| return true; | |||
| } | |||
| // Very basic analysis of the lines (no support for extra spaces, etc.) | |||
| std::string line; | |||
| while (getline(file, line)) { | |||
| // Support for "comments" in the blacklist file | |||
| if (line[0] == '#') continue; | |||
| if (kernel_name == line) return false; | |||
| } | |||
| return true; | |||
| } | |||
| bool ConstrainSchedule::IsEnabled(void) { | |||
| // PassMgrStrategy already checks this and does not even register the pass | |||
| // if akg::ir::poly::UserConfig::GetEnableMindTrick() returns false. | |||
| // We check once more in case it was set to true at startup and the runtime | |||
| // later decided to change the value. | |||
| if (!scop_info_.user_config_.GetEnableMindTrick()) { | |||
| Info("ConstrainSchedule was disabled via akg::ir::poly::UserConfig"); | |||
| return false; | |||
| } | |||
| const char *const env_mind_tricks = std::getenv(env_string_mind_tricks_enable_); | |||
| if (env_mind_tricks && std::string(env_mind_tricks) == "false") { | |||
| Info("ConstrainSchedule was disabled via environment variable " + std::string(env_string_mind_tricks_enable_)); | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| isl::schedule ConstrainSchedule::Run(isl::schedule sch) { | |||
| if (!IsEnabled()) return sch; | |||
| const std::string &target = scop_info_.user_config_.GetTarget(); | |||
| const std::string &kernel_name = scop_info_.user_config_.GetKernelName(); | |||
| // We will want to restore the original value | |||
| const log::Verbosity saved_verbosity = log::GetVerbosityLevel(); | |||
| log::SetVerbosityLevel(static_cast<log::Verbosity>(verbosity_)); | |||
| isl::schedule result = sch; | |||
| // Make sure the constraints are available... | |||
| // We expect this pass to be right after InitSchedule: most information is | |||
| // initialized or computed in InitSchedule. However the constraints are | |||
| // usually computed in ComputeSchedule (which we may disable afterwards!). | |||
| pass_info_.constraints_ = MakeScheduleConstraints(sch, pass_info_); | |||
| // Check whether we want to use ConstrainSchedule on this kernel | |||
| const bool eligible = KernelIsEligible(sch); | |||
| if (!eligible) { | |||
| Warn("ConstrainSchedule will ignore this operator..."); | |||
| return sch; | |||
| } | |||
| const std::size_t total = mind_tricks_.size(); | |||
| RunInfo("input", kernel_name, sch); | |||
| std::stringstream summary; | |||
| summary << pass_name_ << " has " << total << " tricks up its sleeve"; | |||
| Info(log::Verbosity::low, summary); | |||
| CreateMindTrickTemplate(sch); | |||
| if (target == TARGET_CUDA) { | |||
| GpuCompilerFlagsTempfileRemove(); | |||
| } | |||
| size_t current = 0; | |||
| for (std::shared_ptr<SchedulingMindTrick> &mind_trick : mind_tricks_) { | |||
| const std::string name = mind_trick->GetName(); | |||
| current++; | |||
| if (static_cast<log::Verbosity>(verbosity_) >= log::Verbosity::low) { | |||
| std::stringstream stream; | |||
| stream << text_reverse text_magenta " SchedulingMindTrick " text_reset text_magenta " "; | |||
| stream << "[" << current << "/" << total << "] "; | |||
| stream << name; | |||
| const std::string &str = stream.str(); | |||
| Info(str, false); | |||
| } | |||
| const bool matches = mind_trick->Matches(sch); | |||
| if (!matches) { | |||
| Info(log::Verbosity::veryHigh, text_dim text_yellow + mind_trick->str()); | |||
| continue; | |||
| } | |||
| const bool needs_check = mind_trick->NeedsScheduleCheck(); | |||
| if (!needs_check) Info(log::Verbosity::veryLow, text_bright_yellow "MindTrick requests no schedule check!"); | |||
| const isl::schedule &candidate = mind_trick->Apply(sch); | |||
| const bool has_schedule = mind_trick->HasSchedule(); | |||
| if (!has_schedule) { | |||
| Warn(log::Verbosity::low, "'" + name + "': no schedule available"); | |||
| continue; | |||
| } | |||
| const bool valid = !needs_check || CheckSchedule(candidate); | |||
| if (valid) { | |||
| if (needs_check) Info(log::Verbosity::low, text_green "schedule is valid!"); | |||
| result = candidate; | |||
| ExtractMindTrickInfo(mind_trick); | |||
| LogMindTrick(mind_trick); | |||
| if (target == TARGET_CUDA) { | |||
| GpuCompilerFlagsTempfileCreate(mind_trick); | |||
| } | |||
| break; | |||
| } else { | |||
| Info(log::Verbosity::high, text_dim text_yellow + mind_trick->str()); | |||
| } | |||
| } | |||
| RunInfo("output", kernel_name, result); | |||
| log::SetVerbosityLevel(saved_verbosity); | |||
| return result; | |||
| } | |||
| void ConstrainSchedule::ExtractMindTrickInfo(const std::shared_ptr<SchedulingMindTrick> &mind_trick) { | |||
| const std::string &target = scop_info_.user_config_.GetTarget(); | |||
| if (target == TARGET_CUDA) { | |||
| ExtractGpuConfig(mind_trick); | |||
| } | |||
| ExtractDisabledPasses(mind_trick); | |||
| ExtractAttrs(mind_trick); | |||
| } | |||
| void ConstrainSchedule::LogMindTrick(const std::shared_ptr<SchedulingMindTrick> &mind_trick) { | |||
| const std::string kernel_name = scop_info_.user_config_.GetKernelName(); | |||
| const std::string mind_trick_name = mind_trick->GetName(); | |||
| const std::string &output = mind_trick->str(); | |||
| Info(log::Verbosity::veryLow, text_reverse text_bright_blue " ConstrainSchedule ", false); | |||
| Info(log::Verbosity::veryLow, text_bright_blue "using schedule from \'" + mind_trick_name + "\'"); | |||
| Info(log::Verbosity::medium, text_dim text_green + mind_trick->str()); | |||
| scop_info_.user_config_.SetConstrainedSchedulingOutput(output); | |||
| } | |||
| void ConstrainSchedule::ExtractGpuConfig(const std::shared_ptr<SchedulingMindTrick> &mind_trick) { | |||
| const std::string blocks = mind_trick->GetGpuBlocks(); | |||
| const std::string threads = mind_trick->GetGpuThreads(); | |||
| if (blocks != "" && threads != "") { | |||
| scop_info_.user_config_.SetBlockConfig(blocks); | |||
| scop_info_.user_config_.SetThreadConfig(threads); | |||
| } | |||
| } | |||
| void ConstrainSchedule::ExtractDisabledPasses(const std::shared_ptr<SchedulingMindTrick> &mind_trick) { | |||
| // We always want to disable ComputeSchedule when using a mind_trick! | |||
| disabled_passes_.insert("ComputeSchedule"); | |||
| // Then maybe disable other passes... | |||
| const std::set<std::string> &passes = mind_trick->GetDisabledPasses(); | |||
| disabled_passes_.insert(passes.begin(), passes.end()); | |||
| } | |||
| void ConstrainSchedule::CreateMindTrickTemplate(const isl::schedule &sch) { | |||
| const char *const env_templates = std::getenv(env_string_mind_tricks_templates_); | |||
| if (!env_templates || std::string(env_templates) != "true") { | |||
| return; | |||
| } | |||
| const std::string kernel_name = scop_info_.user_config_.GetKernelName(); | |||
| const std::string &filename = "mindtrick-template_" + kernel_name + ".json"; | |||
| std::ofstream output(filename); | |||
| output << SchedulingMindTrick::TemplateString(scop_info_, sch); | |||
| output.close(); | |||
| } | |||
| void ConstrainSchedule::ExtractAttrs(const std::shared_ptr<SchedulingMindTrick> &mind_trick) { | |||
| const air::Map<std::string, air::NodeRef> &attrs = mind_trick->GetAttrs(); | |||
| scop_info_.user_config_.SetAttrs(attrs); | |||
| } | |||
| void ConstrainSchedule::InitVerbosityLevel(void) { | |||
| #ifdef AKG_CONSTRAIN_SCHEDULE_VERBOSITY | |||
| { | |||
| constexpr int preprocessor_verbosity = AKG_CONSTRAIN_SCHEDULE_VERBOSITY; | |||
| if (preprocessor_verbosity >= 0) verbosity_ = preprocessor_verbosity; | |||
| } | |||
| #endif | |||
| { | |||
| const char *const env_verbosity_string = std::getenv(env_string_mind_tricks_verbosity_); | |||
| if (env_verbosity_string) { | |||
| const int env_verbosity = std::stoi(env_verbosity_string); | |||
| if (env_verbosity >= 0) verbosity_ = env_verbosity; | |||
| } | |||
| } | |||
| { | |||
| const int attrs_verbosity = scop_info_.user_config_.GetConstrainScheduleVerbosity(); | |||
| if (attrs_verbosity >= 0) verbosity_ = attrs_verbosity; | |||
| } | |||
| } | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| // GPU Compiler flags | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| std::string ConstrainSchedule::GpuCompilerFlagsTempfileName(void) const { | |||
| std::stringstream filename_stream; | |||
| filename_stream << ".akg_gpu_compiler_flags_" << getpid(); | |||
| const std::string &filename = filename_stream.str(); | |||
| return filename; | |||
| } | |||
| void ConstrainSchedule::GpuCompilerFlagsTempfileRemove(void) { | |||
| const std::string &filename = GpuCompilerFlagsTempfileName(); | |||
| std::remove(filename.c_str()); | |||
| } | |||
| void ConstrainSchedule::GpuCompilerFlagsTempfileCreate(const std::shared_ptr<SchedulingMindTrick> &mind_trick) { | |||
| const std::vector<std::string> &flags = mind_trick->GetGpuCompilerFlags(); | |||
| if (flags.empty()) { | |||
| return; | |||
| } | |||
| const std::string &filename = GpuCompilerFlagsTempfileName(); | |||
| std::ofstream tempfile(filename); | |||
| for (const std::string &flag : flags) { | |||
| tempfile << flag << std::endl; | |||
| } | |||
| tempfile.close(); | |||
| } | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| // Logging | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| std::string ConstrainSchedule::LogPrefixText(const bool prefix) const { | |||
| if (!prefix) { | |||
| return ""; | |||
| } | |||
| const std::string &kernel_name = scop_info_.user_config_.GetKernelName(); | |||
| const std::string &prefix_text = "'" + kernel_name + "': "; | |||
| return prefix_text; | |||
| } | |||
| // clang-format off | |||
| #define define_constrain_schedule_log_wrappers(func) \ | |||
| void ConstrainSchedule::func(const std::string &message, const bool prefix) const { \ | |||
| const std::string &prefix_text = LogPrefixText(prefix); \ | |||
| log::func(prefix_text + message); \ | |||
| } \ | |||
| \ | |||
| void ConstrainSchedule::func(const std::stringstream &stream, const bool prefix) const { \ | |||
| const std::string &message = stream.str(); \ | |||
| func(message, prefix); \ | |||
| } \ | |||
| \ | |||
| void ConstrainSchedule::func(const int level, const std::string &message, const bool prefix) const { \ | |||
| const std::string &prefix_text = LogPrefixText(prefix); \ | |||
| log::func(level, prefix_text + message); \ | |||
| } \ | |||
| void ConstrainSchedule::func(const int level, const std::stringstream &stream, const bool prefix) const { \ | |||
| const std::string &message = stream.str(); \ | |||
| func(level, message, prefix); \ | |||
| } \ | |||
| void ConstrainSchedule::func(const log::Verbosity level, const std::string &message, const bool prefix) const { \ | |||
| const std::string &prefix_text = LogPrefixText(prefix); \ | |||
| log::func(level, prefix_text + message); \ | |||
| } \ | |||
| void ConstrainSchedule::func(const log::Verbosity level, const std::stringstream &stream, const bool prefix) const { \ | |||
| const std::string &message = stream.str(); \ | |||
| func(level, message, prefix); \ | |||
| } | |||
| define_constrain_schedule_log_wrappers(Info) | |||
| define_constrain_schedule_log_wrappers(Warn) | |||
| define_constrain_schedule_log_wrappers(Error) | |||
| #undef define_constrain_schedule_log_wrappers | |||
| // clang-format on | |||
| } // namespace poly | |||
| } // namespace ir | |||
| } // namespace akg | |||
| @@ -0,0 +1,115 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef POLY_CONSTRAIN_SCHEDULE_H_ | |||
| #define POLY_CONSTRAIN_SCHEDULE_H_ | |||
| // STL headers | |||
| #include <vector> | |||
| #include <memory> | |||
| // Other libraries | |||
| #include "isl/cpp.h" | |||
| // AKG headers | |||
| #include "poly/log_util.h" | |||
| #include "poly/schedule_pass.h" | |||
| #include "poly/schedule_pass/scheduling_mind_trick.h" | |||
| namespace akg { | |||
| namespace ir { | |||
| namespace poly { | |||
| class ConstrainSchedule : public SchedulePass { | |||
| public: | |||
| ConstrainSchedule(PassInfo &pass_info, ScopInfo &scop_info); | |||
| ~ConstrainSchedule() {} | |||
| virtual isl::schedule Run(isl::schedule sch); | |||
| bool IsEnabled(void); | |||
| private: | |||
| bool CheckSchedule(const isl::schedule &) const; | |||
| bool KernelIsEligible(const isl::schedule &sch) const; | |||
| void LoadMindTricks(void); | |||
| void LoadMindTrickFromFile(const std::string &filename); | |||
| void AddMindTrick(const std::shared_ptr<SchedulingMindTrick> &mind_trick); | |||
| void ExtractMindTrickInfo(const std::shared_ptr<SchedulingMindTrick> &mind_trick); | |||
| void LogMindTrick(const std::shared_ptr<SchedulingMindTrick> &mind_trick); | |||
| void ExtractGpuConfig(const std::shared_ptr<SchedulingMindTrick> &mind_trick); | |||
| void ExtractDisabledPasses(const std::shared_ptr<SchedulingMindTrick> &mind_trick); | |||
| void ExtractAttrs(const std::shared_ptr<SchedulingMindTrick> &mind_trick); | |||
| void CreateMindTrickTemplate(const isl::schedule &sch); | |||
| void InitVerbosityLevel(void); | |||
| std::string GpuCompilerFlagsTempfileName(void) const; | |||
| void GpuCompilerFlagsTempfileRemove(void); | |||
| void GpuCompilerFlagsTempfileCreate(const std::shared_ptr<SchedulingMindTrick> &mind_trick); | |||
| PassInfo &pass_info_; | |||
| ScopInfo &scop_info_; | |||
| std::vector<std::shared_ptr<SchedulingMindTrick>> mind_tricks_; | |||
| /////////////////////////////////////////////////////////////////////////// | |||
| // MindTrick paths | |||
| /////////////////////////////////////////////////////////////////////////// | |||
| static std::vector<std::string> MindTricksDirectories(void); | |||
| /////////////////////////////////////////////////////////////////////////// | |||
| // Supported environment variables | |||
| /////////////////////////////////////////////////////////////////////////// | |||
| static constexpr const char *const env_string_mind_tricks_enable_ = "MS_AKG_MIND_TRICKS"; | |||
| static constexpr const char *const env_string_mind_tricks_dir_ = "MS_AKG_MIND_TRICKS_DIR"; | |||
| static constexpr const char *const env_string_mind_tricks_verbosity_ = "MS_AKG_MIND_TRICKS_VERBOSITY"; | |||
| static constexpr const char *const env_string_mind_tricks_templates_ = "MS_AKG_MIND_TRICKS_TEMPLATES"; | |||
| static constexpr const char *const env_string_mind_tricks_operator_blacklist_ = | |||
| "MS_AKG_MIND_TRICKS_OPERATOR_BLACKLIST"; | |||
| /////////////////////////////////////////////////////////////////////////// | |||
| // Logging | |||
| /////////////////////////////////////////////////////////////////////////// | |||
| int verbosity_{0}; | |||
| std::string LogPrefixText(const bool prefix = true) const; | |||
| // clang-format off | |||
| #define declare_constrain_schedule_log_wrappers(func) \ | |||
| void func(const std::string &message, const bool prefix = true) const; \ | |||
| void func(const std::stringstream &stream, const bool prefix = true) const; \ | |||
| void func(const int level, const std::string &message, const bool prefix = true) const; \ | |||
| void func(const int level, const std::stringstream &stream, const bool prefix = true) const; \ | |||
| void func(const akg::ir::poly::log::Verbosity level, const std::string &message, const bool prefix = true) const; \ | |||
| void func(const akg::ir::poly::log::Verbosity level, const std::stringstream &stream, const bool prefix = true) const; | |||
| declare_constrain_schedule_log_wrappers(Info) declare_constrain_schedule_log_wrappers(Warn) | |||
| declare_constrain_schedule_log_wrappers(Error) | |||
| #undef declare_constrain_schedule_log_wrappers | |||
| // clang-format on | |||
| }; | |||
| } // namespace poly | |||
| } // namespace ir | |||
| } // namespace akg | |||
| #endif // POLY_CONSTRAIN_SCHEDULE_H_ | |||
| @@ -0,0 +1,237 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "poly/schedule_pass/constrain_schedule.h" | |||
| #include "poly/isl_util.h" | |||
| #include "poly/log_util.h" | |||
| namespace akg { | |||
| namespace ir { | |||
| namespace poly { | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| // local declarations | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| #ifndef _AKG_DUMMY_DATE_STRING | |||
| #define _AKG_DUMMY_DATE_STRING "_akg_dummy_date" | |||
| #endif | |||
| static const char *const _akg_dummy_date_string = _AKG_DUMMY_DATE_STRING; | |||
| // This should be changed later. | |||
| static int _verbosity = 2; | |||
| __isl_give isl_schedule *isl_schedule_constraints_silently_compute_schedule( | |||
| __isl_take isl_schedule_constraints *constraints); | |||
| __isl_give isl_stat map_statement_to_dummy_date(__isl_take isl_map *map, void *user); | |||
| __isl_give isl_union_map *isl_map_domain_to_dummy_date(__isl_take isl_union_set *domain); | |||
| __isl_give isl_union_map *isl_schedule_get_temporal_accesses(__isl_keep isl_schedule *schedule); | |||
| __isl_give isl_union_map *isl_schedule_get_temporal_dependences(__isl_keep isl_schedule *schedule); | |||
| __isl_give isl_schedule *isl_schedule_compute_verifying_schedule( | |||
| __isl_keep isl_schedule *const schedule, __isl_keep isl_schedule_constraints *const initial_constraints); | |||
| __isl_give isl_stat isl_schedule_check(__isl_keep isl_schedule *const schedule, | |||
| __isl_keep isl_schedule_constraints *const initial_constraints); | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| // Schedule checking functions in ConstrainSchedule public API | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| bool ConstrainSchedule::CheckSchedule(const isl::schedule &sch) const { | |||
| isl_schedule_constraints *const constraints = pass_info_.constraints_.get(); | |||
| isl_schedule *const schedule = sch.get(); | |||
| _verbosity = verbosity_; | |||
| const isl_stat status = isl_schedule_check(schedule, constraints); | |||
| return status == isl_stat_ok; | |||
| } | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| // local definitions | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| /** | |||
| * \brief Silent wrapper for isl_schedule_constraints_compute_schedule(). | |||
| */ | |||
| __isl_give isl_schedule *isl_schedule_constraints_silently_compute_schedule( | |||
| __isl_take isl_schedule_constraints *const constraints) { | |||
| akg::ir::poly::log::Info(log::Verbosity::high, "silent constraints:\n" + to_block_string(constraints)); | |||
| isl_ctx *const ctx = isl_schedule_constraints_get_ctx(constraints); | |||
| const int previous_behaviour = isl_options_get_on_error(ctx); | |||
| isl_options_set_on_error(ctx, ISL_ON_ERROR_CONTINUE); | |||
| isl_schedule *const result = isl_schedule_constraints_compute_schedule(constraints); | |||
| isl_options_set_on_error(ctx, previous_behaviour); | |||
| return result; | |||
| } | |||
| /** | |||
| * \brief Callback for isl_map_list_foreach(); sets the out dim to _dummy_date. | |||
| */ | |||
| __isl_give isl_stat map_statement_to_dummy_date(__isl_take isl_map *map, void *const user) { | |||
| // The output dimension in the input map is empty. | |||
| // We want it to be '_akg_dummy_date[0]': | |||
| map = isl_map_insert_dims(map, isl_dim_out, 0, 1); | |||
| map = isl_map_fix_si(map, isl_dim_out, 0, 0); | |||
| map = isl_map_set_tuple_name(map, isl_dim_out, _akg_dummy_date_string); | |||
| isl_union_map **result = (isl_union_map **)user; | |||
| if (!*result) { | |||
| *result = isl_union_map_from_map(map); | |||
| } else { | |||
| *result = isl_union_map_add_map(*result, map); | |||
| } | |||
| return isl_stat_ok; | |||
| } | |||
| /** | |||
| * \brief Map statements to dummy dates | |||
| */ | |||
| __isl_give isl_union_map *isl_map_domain_to_dummy_date(__isl_take isl_union_set *const domain) { | |||
| isl_union_map *start = isl_union_map_from_domain(domain); | |||
| isl_map_list *list = isl_union_map_get_map_list(start); | |||
| isl_union_map *result = NULL; | |||
| isl_map_list_foreach(list, map_statement_to_dummy_date, &result); | |||
| isl_union_map_free(start); | |||
| isl_map_list_free(list); | |||
| return result; | |||
| } | |||
| __isl_give isl_union_map *isl_schedule_get_temporal_accesses(__isl_keep isl_schedule *const schedule) { | |||
| isl_union_set *const domain = isl_schedule_get_domain(schedule); | |||
| return isl_map_domain_to_dummy_date(domain); | |||
| } | |||
| __isl_give isl_union_map *isl_schedule_get_temporal_dependences(__isl_keep isl_schedule *const schedule) { | |||
| /* Prepare inputs for the union_access_info. */ | |||
| isl_union_map *const sink = isl_schedule_get_temporal_accesses(schedule); | |||
| isl_union_map *const source = isl_union_map_copy(sink); | |||
| isl_schedule *const schedule_copy = isl_schedule_copy(schedule); | |||
| /* Build the union_access_info. */ | |||
| isl_union_access_info *accesses = isl_union_access_info_from_sink(sink); | |||
| accesses = isl_union_access_info_set_schedule(accesses, schedule_copy); | |||
| accesses = isl_union_access_info_set_must_source(accesses, isl_union_map_copy(source)); | |||
| accesses = isl_union_access_info_set_kill(accesses, source); | |||
| /* Compute the flow. */ | |||
| isl_union_flow *const flow = isl_union_access_info_compute_flow(accesses); | |||
| /* Extract must dependences. */ | |||
| isl_union_map *const dependences = isl_union_flow_get_may_dependence(flow); | |||
| isl_union_flow_free(flow); | |||
| return dependences; | |||
| // isl_union_map* const dependences = isl_union_flow_get_full_may_dependence(flow); | |||
| // isl_union_flow_free(flow); | |||
| // /* | |||
| // * At this point, our maps (for instance, with S0 and S1) will be in form: | |||
| // * { S0[...] -> [S1[...] -> _akg_dummy_date[0]]: ...; ... } | |||
| // * We want maps in the form: | |||
| // * { S0[...] -> S1[...]: ...; ... } | |||
| // * We need to: | |||
| // * - uncurry the nested relation of the "range" of the map | |||
| // * (turning the "domain" of the map into a nested relation) | |||
| // * - unwrap the nested relation of the "domain" of the map | |||
| // */ | |||
| // /* Uncurry, ditch the "new" range and unwrap our dummy dependences. */ | |||
| // isl_union_map* const uncurried = isl_union_map_uncurry(dependences); | |||
| // isl_union_set* const domain = isl_union_map_domain(uncurried); | |||
| // isl_union_map* const unwrapped = isl_union_set_unwrap(domain); | |||
| // return unwrapped; | |||
| } | |||
| /** | |||
| * \brief Check a new schedule against previous constraints | |||
| * | |||
| * Bonus: can also be used to add permutable/coincident metadata and "reshape" | |||
| * a schedule tree (initial_constraints may be a null pointer). | |||
| */ | |||
| __isl_give isl_schedule *isl_schedule_compute_verifying_schedule( | |||
| __isl_keep isl_schedule *const schedule, __isl_keep isl_schedule_constraints *const initial_constraints) { | |||
| /* | |||
| * Principle: | |||
| * 1. extract "date" constraints from the proposed schedule | |||
| * 2. combine the date constraints with the initial constraints | |||
| * 3. attempt to schedule: | |||
| * - a schedule can be computed from the combined constraints: | |||
| * the proposed schedule is valid | |||
| * - no schedule can be computed from the combined constraints: | |||
| * the proposed schedule violates the initial constraints | |||
| */ | |||
| isl_union_map *const dates = isl_schedule_get_temporal_dependences(schedule); | |||
| /* Some logging. */ | |||
| if (initial_constraints) | |||
| akg::ir::poly::log::Info(log::Verbosity::high, "initial constraints:\n" + to_block_string(initial_constraints)); | |||
| else | |||
| akg::ir::poly::log::Info(log::Verbosity::high, "initial constraints: none"); | |||
| akg::ir::poly::log::Info(log::Verbosity::high, "dates:\n" + to_block_string(dates)); | |||
| isl_schedule_constraints *constraints = NULL; | |||
| if (!initial_constraints) { | |||
| /* | |||
| * Hack/Easter egg: the function will be used to "reshape" the schedule tree | |||
| */ | |||
| /* Build new schedule constraints. */ | |||
| isl_union_set *const domain = isl_schedule_get_domain(schedule); | |||
| constraints = isl_schedule_constraints_on_domain(domain); | |||
| constraints = isl_schedule_constraints_set_validity(constraints, dates); | |||
| } else { | |||
| /* Combine the schedule constraints. */ | |||
| constraints = isl_schedule_constraints_copy(initial_constraints); | |||
| isl_union_map *const validity = isl_schedule_constraints_get_validity(constraints); | |||
| isl_space *const space = isl_union_map_get_space(dates); | |||
| isl_union_map *const aligned = isl_union_map_align_params(validity, space); | |||
| isl_union_map *const restricted = isl_union_map_union(aligned, dates); | |||
| constraints = isl_schedule_constraints_set_validity(constraints, restricted); | |||
| } | |||
| /* Finally attempt to schedule. */ | |||
| isl_schedule *const result = isl_schedule_constraints_silently_compute_schedule(constraints); | |||
| return result; | |||
| } | |||
| __isl_give isl_stat isl_schedule_check(__isl_keep isl_schedule *const schedule, | |||
| __isl_keep isl_schedule_constraints *const initial_constraints) { | |||
| isl_schedule *const result = isl_schedule_compute_verifying_schedule(schedule, initial_constraints); | |||
| /* Check whether we managed to schedule something. */ | |||
| if (result) { | |||
| akg::ir::poly::log::Info(log::Verbosity::high, text_blue "schedule seems valid\n" + to_block_string(result)); | |||
| isl_schedule_free(result); | |||
| return isl_stat_ok; | |||
| } else { | |||
| akg::ir::poly::log::Warn(log::Verbosity::veryLow, "schedule is invalid"); | |||
| return isl_stat_error; | |||
| } | |||
| } | |||
| } // namespace poly | |||
| } // namespace ir | |||
| } // namespace akg | |||
| @@ -0,0 +1,303 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef POLY_SCHEDULING_MIND_TRICK_H_ | |||
| #define POLY_SCHEDULING_MIND_TRICK_H_ | |||
| // STL includes | |||
| #include <ostream> | |||
| #include <fstream> | |||
| #include <regex> | |||
| // External libraries | |||
| #include <picojson.h> | |||
| #include <isl/cpp.h> | |||
| // TVM | |||
| #include <tvm/node/container.h> | |||
| #include <tvm/node/node.h> | |||
| // Internal headers | |||
| #include "poly/isl.h" | |||
| #include "poly/log_util.h" | |||
| #include "poly/schedule_pass.h" | |||
| #include "poly/pass_info.h" | |||
| #include "poly/scop_info.h" | |||
| #include "poly/isl_influence.h" | |||
| namespace akg { | |||
| namespace ir { | |||
| namespace poly { | |||
| // Implementation notes | |||
| // | |||
| // 1. The class is non copyable. | |||
| // We use raw pointers to isl objects because the C++ wrapped objects are | |||
| // not practical in cases where the isl object is optional and std::optional | |||
| // curretly is not an option. | |||
| // Hence, we directly manage the isl objects. To avoid further complications | |||
| // with copy-constructors, move-constructors, isl_*_free(), isl_*_copy(), | |||
| // etc., the class is non copyable. | |||
| // data_value: tuple of <scheduling dim, coeff dim, coeff type, list of stmt name> | |||
| // using data_value = std::tuple<int, int, isl_influence_coeff_type, std::vector<std::string>>; | |||
| // single_data: tuple of <stmt name, scheduling dim, coeff dim, coeff type, value> | |||
| using single_data = std::tuple<std::string, int, int, isl_influence_coeff_type, int>; | |||
| // div_mod_data | |||
| // * if division: tuple of <stmt name, scheduling dim, divisor> | |||
| // * if modulo: tuple of <stmt name, scheduling dim, modulo value> | |||
| using div_mod_data = std::tuple<std::string, int, int>; | |||
| class GpuConfig { | |||
| public: | |||
| std::vector<int> block_sizes_; | |||
| std::vector<int> thread_sizes_; | |||
| std::vector<int> block_dimensions_; | |||
| std::vector<int> thread_dimensions_; | |||
| std::vector<std::string> compiler_flags_; | |||
| }; | |||
| enum class MindTrickType { | |||
| none = 0, | |||
| manual, | |||
| }; | |||
| std::string to_string(MindTrickType t); | |||
| MindTrickType MindTrickTypeFromString(const std::string &str); | |||
| class SchedulingMindTrick { | |||
| public: | |||
| /////////////////////////////////////////////////////////////////////////// | |||
| // Constructors and similar methods | |||
| /////////////////////////////////////////////////////////////////////////// | |||
| // Highly recommended to call the SchedulingMindTrick(PassInfo&, ScopInfo&) | |||
| // constructor in other constructors and in derived classes constructors! | |||
| SchedulingMindTrick(PassInfo &pass_info, ScopInfo &scop_info, int verbosity = -1); | |||
| ~SchedulingMindTrick(); | |||
| void Load(const std::string &filename); | |||
| // Parse JSON representation | |||
| void Parse(const picojson::value &json); | |||
| void Parse(const std::string &serialized_json); | |||
| std::istream &Parse(std::istream &streamed_json); | |||
| // Non copyable | |||
| SchedulingMindTrick(const SchedulingMindTrick &) = delete; | |||
| SchedulingMindTrick &operator=(const SchedulingMindTrick &) = delete; | |||
| /////////////////////////////////////////////////////////////////////////// | |||
| // MindTrick state | |||
| /////////////////////////////////////////////////////////////////////////// | |||
| explicit operator bool() const; | |||
| /////////////////////////////////////////////////////////////////////////// | |||
| // MindTrick use | |||
| /////////////////////////////////////////////////////////////////////////// | |||
| bool Matches(const isl::schedule &sch) const; | |||
| isl::schedule Apply(const isl::schedule &sch); | |||
| /////////////////////////////////////////////////////////////////////////// | |||
| // I/O | |||
| /////////////////////////////////////////////////////////////////////////// | |||
| std::string str(void) const; | |||
| std::ostream &Output(std::ostream &stream) const; | |||
| std::istream &Input(std::istream &stream); | |||
| friend std::ostream &operator<<(std::ostream &stream, const SchedulingMindTrick &mind_trick); | |||
| friend std::istream &operator>>(std::istream &stream, SchedulingMindTrick &mind_trick); | |||
| friend std::string to_string(const SchedulingMindTrick &mind_trick); | |||
| static std::string TemplateString(ScopInfo &scop_info, const isl::schedule &schedule, | |||
| MindTrickType type = MindTrickType::none); | |||
| /////////////////////////////////////////////////////////////////////////// | |||
| // MindTrick metadata | |||
| /////////////////////////////////////////////////////////////////////////// | |||
| void SetName(const std::string &name); | |||
| std::string GetName(void) const; | |||
| std::string GetTarget(void) const; | |||
| /////////////////////////////////////////////////////////////////////////// | |||
| // Misc. attributes | |||
| /////////////////////////////////////////////////////////////////////////// | |||
| bool HasSchedule(void) const; | |||
| bool NeedsScheduleCheck(void) const; | |||
| const air::Map<std::string, air::NodeRef> GetAttrs(void) const; | |||
| /////////////////////////////////////////////////////////////////////////// | |||
| // GPU Mapping | |||
| /////////////////////////////////////////////////////////////////////////// | |||
| void GuessGpuConfig(void); | |||
| std::string GetGpuBlocks(void) const; | |||
| std::string GetGpuThreads(void) const; | |||
| std::vector<std::string> GetGpuCompilerFlags(void) const; | |||
| /////////////////////////////////////////////////////////////////////////// | |||
| // Pass toggling | |||
| /////////////////////////////////////////////////////////////////////////// | |||
| const std::set<std::string> &GetDisabledPasses(void) const; | |||
| /////////////////////////////////////////////////////////////////////////// | |||
| // Mind Trick Type | |||
| /////////////////////////////////////////////////////////////////////////// | |||
| MindTrickType GetType(void) const; | |||
| void SetType(MindTrickType type); | |||
| protected: | |||
| /////////////////////////////////////////////////////////////////////////// | |||
| // JSON parsing | |||
| /////////////////////////////////////////////////////////////////////////// | |||
| static picojson::value maybe(const picojson::value &node, const std::string &key); | |||
| static picojson::value maybe(const picojson::object &node, const std::string &key); | |||
| static std::vector<int> to_int_vector(const picojson::value &node); | |||
| static std::vector<int> to_int_vector(const picojson::array &node); | |||
| static std::vector<std::string> to_string_vector(const picojson::value &node); | |||
| static std::vector<std::string> to_string_vector(const picojson::array &node); | |||
| void ParseName(const picojson::value &node); | |||
| void ParseDisabledPasses(const picojson::value &node); | |||
| void ParsePattern(const picojson::value &node); | |||
| void ParseOperatorName(const picojson::value &node); | |||
| void ParseDomain(const picojson::value &node); | |||
| void ParseSchedule(const picojson::value &node); | |||
| void ParseGpuInfo(const picojson::value &node); | |||
| void ParseExplicitTree(const picojson::value &node); | |||
| void ParseCheckSchedule(const picojson::value &node); | |||
| void ParseAttrs(const picojson::value &node); | |||
| void ParseVerbosity(const picojson::value &node); | |||
| void ParseSoftConstraints(const picojson::value &node); | |||
| /////////////////////////////////////////////////////////////////////////// | |||
| // Schedule | |||
| /////////////////////////////////////////////////////////////////////////// | |||
| isl::schedule GetSchedule(void) const; | |||
| /////////////////////////////////////////////////////////////////////////// | |||
| // Schedule suggestion | |||
| /////////////////////////////////////////////////////////////////////////// | |||
| void BuildSuggestedSchedule(void); | |||
| // Various helpers to build the suggested schedule | |||
| __isl_give isl_schedule *ComputeScheduleSuggestion(void); | |||
| __isl_give isl_schedule *PrepareMappingOuterBand(__isl_take isl_schedule *schedule, GpuConfig &info); | |||
| /////////////////////////////////////////////////////////////////////////// | |||
| // Soft constraints | |||
| /////////////////////////////////////////////////////////////////////////// | |||
| void CollectSoftConstraintsData(std::string stmt_name, unsigned int sched_dim, int coeff_dim, | |||
| isl_influence_coeff_type coeff_type, std::string coeff_vec_i); | |||
| void BuildSoftConstraints(); | |||
| void BuildInfluenceList(std::vector<single_data> singles); | |||
| void BuildInfluenceEqualList(std::map<std::string, std::vector<single_data>> linked); | |||
| void BuildInfluencedSchedule(void); | |||
| void IslInfluenceToggle(bool toggle); | |||
| __isl_give isl_schedule *AdjustSchedule(__isl_take isl_schedule *schedule, const std::vector<div_mod_data> &modulos, | |||
| const std::vector<div_mod_data> &divisions); | |||
| // Misc helpers | |||
| std::vector<std::string> split_string(std::string str, std::string delim); | |||
| /////////////////////////////////////////////////////////////////////////// | |||
| // Internal data | |||
| /////////////////////////////////////////////////////////////////////////// | |||
| // AKG info | |||
| PassInfo &pass_info_; | |||
| ScopInfo &scop_info_; | |||
| bool correctly_parsed_{false}; | |||
| std::string operator_{""}; | |||
| std::string pattern_{""}; | |||
| isl_schedule *explicit_tree_{nullptr}; | |||
| isl_union_set *domain_{nullptr}; | |||
| isl_schedule *suggested_schedule_{nullptr}; | |||
| std::string suggested_schedule_string_{""}; | |||
| std::vector<std::string> suggested_schedule_vector_; | |||
| std::vector<std::tuple<std::string, std::vector<int>>> post_transformations_; | |||
| std::vector<single_data> singles_; | |||
| std::map<std::string, std::vector<single_data>> linked_; | |||
| std::vector<div_mod_data> modulos_; | |||
| std::vector<div_mod_data> divisions_; | |||
| isl_influence_list *influence_list_{nullptr}; | |||
| isl_influence_equal_list *influence_equal_list_{nullptr}; | |||
| std::string parse_soft_constraints_log_str_{""}; | |||
| isl_schedule *influenced_schedule_{nullptr}; | |||
| bool check_schedule_{true}; | |||
| std::string target_{""}; | |||
| std::string name_{"unnamed mind_trick"}; | |||
| GpuConfig gpu_info_; | |||
| air::Map<std::string, air::NodeRef> attrs_; | |||
| std::set<std::string> disabled_passes_; | |||
| MindTrickType type_{MindTrickType::manual}; | |||
| int verbosity_{0}; | |||
| std::string LogPrefixText(const bool prefix = true) const; | |||
| // clang-format off | |||
| #define declare_scheduling_mind_trick_log_wrappers(func) \ | |||
| void func(const std::string &message, const bool prefix = true) const; \ | |||
| void func(const std::stringstream &stream, const bool prefix = true) const; \ | |||
| void func(const int level, const std::string &message, const bool prefix = true) const; \ | |||
| void func(const int level, const std::stringstream &stream, const bool prefix = true) const; \ | |||
| void func(const akg::ir::poly::log::Verbosity level, const std::string &message, const bool prefix = true) const; \ | |||
| void func(const akg::ir::poly::log::Verbosity level, const std::stringstream &stream, const bool prefix = true) const; | |||
| declare_scheduling_mind_trick_log_wrappers(Info) | |||
| declare_scheduling_mind_trick_log_wrappers(Warn) | |||
| declare_scheduling_mind_trick_log_wrappers(Error) | |||
| #undef declare_scheduling_mind_trick_log_wrappers | |||
| private : | |||
| // Non copyable | |||
| SchedulingMindTrick(); | |||
| // clang-format on | |||
| }; | |||
| } // namespace poly | |||
| } // namespace ir | |||
| } // namespace akg | |||
| #endif // POLY_SCHEDULING_MIND_TRICK_H_ | |||
| @@ -0,0 +1,342 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "poly/schedule_pass/scheduling_mind_trick.h" | |||
| #include "poly/log_util.h" | |||
| namespace akg { | |||
| namespace ir { | |||
| namespace poly { | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| // Local declarations | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| static inline isl_schedule_node_type GetScheduleNodeType(const isl::schedule_node &node); | |||
| static inline std::vector<isl::schedule_node> CollectScheduleNodes(const isl::schedule &sch, const int verbosity); | |||
| static inline bool DomainMatches(const isl::schedule &pattern, const isl::schedule &candidate, const std::string &name_, | |||
| const int verbosity); | |||
| static inline bool NodesMatch(const isl::schedule_node &pattern_node, const isl::schedule_node &candidate_node, | |||
| const std::string &name_, const int verbosity); | |||
| static inline bool PatternMatches(const isl::schedule &sch, const std::string &pattern_, const std::string &name_, | |||
| const int verbosity); | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| // Patter-matching related in SchedulingMindTrick public API | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| bool SchedulingMindTrick::Matches(const isl::schedule &sch) const { | |||
| if (operator_ != "" && pattern_ != "") { | |||
| Info(log::Verbosity::medium, "note that this mind trick contains both a pattern and an operator name..."); | |||
| } | |||
| if (PatternMatches(sch, pattern_, name_, verbosity_)) { | |||
| Info(log::Verbosity::low, text_green "schedule pattern matches!"); | |||
| return true; | |||
| } else if (operator_ == scop_info_.user_config_.GetKernelName()) { | |||
| Info(log::Verbosity::low, text_green "operator name matches!"); | |||
| return true; | |||
| } else if (pattern_ == "" && operator_ == "") { | |||
| Warn(log::Verbosity::veryLow, "pattern-free and operator-free mind_trick: matches! (are you sure?)"); | |||
| return true; | |||
| } else { | |||
| Info(log::Verbosity::medium, "schedule does not match!"); | |||
| return false; | |||
| } | |||
| } | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| // Local definitions | |||
| //////////////////////////////////////////////////////////////////////////////// | |||
| isl_schedule_node_type GetScheduleNodeType(const isl::schedule_node &node) { | |||
| isl_schedule_node_type type = isl_schedule_node_get_type(node.get()); | |||
| switch (type) { | |||
| case isl_schedule_node_band: | |||
| return isl_schedule_node_band; | |||
| case isl_schedule_node_context: | |||
| return isl_schedule_node_context; | |||
| case isl_schedule_node_domain: | |||
| return isl_schedule_node_domain; | |||
| case isl_schedule_node_extension: | |||
| return isl_schedule_node_extension; | |||
| case isl_schedule_node_filter: | |||
| return isl_schedule_node_filter; | |||
| case isl_schedule_node_guard: | |||
| return isl_schedule_node_guard; | |||
| case isl_schedule_node_mark: | |||
| return isl_schedule_node_mark; | |||
| case isl_schedule_node_leaf: | |||
| return isl_schedule_node_leaf; | |||
| case isl_schedule_node_sequence: | |||
| return isl_schedule_node_sequence; | |||
| case isl_schedule_node_set: | |||
| return isl_schedule_node_set; | |||
| default: | |||
| assert(false && "cannot convert the node type"); | |||
| return isl_schedule_node_leaf; | |||
| } | |||
| } | |||
| std::vector<isl::schedule_node> CollectScheduleNodes(const isl::schedule &sch, const int verbosity) { | |||
| std::vector<isl::schedule_node> nodes; | |||
| isl::schedule_node root = sch.get_root(); | |||
| if (GetScheduleNodeType(root.child(0)) == isl_schedule_node_leaf) { | |||
| log::Info(log::Verbosity::low, "Root node has no children; returning empty vector of schedule nodes"); | |||
| return nodes; | |||
| } else { | |||
| auto collect = [&nodes](isl::schedule_node node) { | |||
| nodes.push_back(node); | |||
| return true; | |||
| }; | |||
| root.foreach_descendant_top_down(collect); | |||
| } | |||
| return nodes; | |||
| } | |||
| bool DomainMatches(const isl::schedule &pattern, const isl::schedule &candidate, const std::string &name_, | |||
| const int verbosity) { | |||
| log::Info(log::Verbosity::high, name_ + ": checking if candidate schedule domain matches..."); | |||
| if (pattern.get_domain().is_empty()) { | |||
| log::Warn(log::Verbosity::high, | |||
| name_ + ": pattern's domain is completely unspecified therefore any domain matches"); | |||
| return true; | |||
| } | |||
| isl::set_list dom_patt_list = pattern.get_domain().get_set_list(); | |||
| isl::set_list dom_can_list = candidate.get_domain().get_set_list(); | |||
| if (dom_patt_list.size() != dom_can_list.size()) { | |||
| log::Warn(log::Verbosity::high, | |||
| name_ + ": candidate does not have the same number of statement as pattern. Does not match"); | |||
| return false; | |||
| } | |||
| for (unsigned i = 0; i < dom_patt_list.size(); ++i) { | |||
| isl::set patt_i = dom_patt_list.get_at(i); | |||
| isl::id name_i = patt_i.get_tuple_id(); | |||
| isl::space space_i = patt_i.get_space(); | |||
| isl::set can_j; | |||
| isl::id name_j; | |||
| isl::space space_j; | |||
| bool exists = false; | |||
| for (unsigned j = 0; j < dom_can_list.size(); ++j) { | |||
| can_j = dom_can_list.get_at(j); | |||
| name_j = can_j.get_tuple_id(); | |||
| space_j = can_j.get_space(); | |||
| if ((name_j.get_name() == name_i.get_name()) && (space_i.is_equal(space_j))) { | |||
| exists = true; | |||
| log::Info(log::Verbosity::high, | |||
| name_ + ": statement " + space_i.to_str() + " exists; proceeding with further checks..."); | |||
| isl::space space_i = patt_i.get_space(); | |||
| unsigned dim_num_i = space_i.dim(isl_dim_set); | |||
| for (unsigned k = 0; k < dim_num_i; ++k) { | |||
| // There may be cases where underspecified statements appears in the candidate tree | |||
| // ex: domain : "{S_0[]}" (variable initizalization or reductions) | |||
| // We must distinguish such cases from that of underspecified patterns (for which we bypass | |||
| // isl_exceptions in order to still capture their specification. | |||
| // The previous tests on spaces should ensure that, at this point of the code, | |||
| // only underspecified patterns can be encountered. | |||
| // But to be very sure about that, the following check eliminates any possibility of attempting | |||
| // pattern matching on a candidate that contains any S[] | |||
| isl_bool dim_id_check = isl_set_has_dim_id(can_j.get(), static_cast<enum isl_dim_type>(isl_dim_set), k); | |||
| if (!dim_id_check) { | |||
| log::Warn(log::Verbosity::high, | |||
| name_ + ": statement " + name_i.get_name() + | |||
| " has no dim id. Cannot perform pattern matching on an underspecified candidate"); | |||
| return false; | |||
| } | |||
| isl::id dim_name_j = | |||
| isl::manage(isl_set_get_dim_id(can_j.get(), static_cast<enum isl_dim_type>(isl_dim_set), k)); | |||
| isl::pw_aff min_j = can_j.dim_min(k); | |||
| isl::pw_aff max_j = can_j.dim_max(k); | |||
| std::stringstream log_prefix_stream; | |||
| log_prefix_stream << name_ << ": candidate bounds for dimension " << k << " of " << name_i.get_name(); | |||
| const std::string &log_prefix = log_prefix_stream.str(); | |||
| try { | |||
| isl::id dim_name_i = | |||
| isl::manage(isl_set_get_dim_id(patt_i.get(), static_cast<enum isl_dim_type>(isl_dim_set), k)); | |||
| isl::pw_aff min_i = patt_i.dim_min(k); | |||
| isl::pw_aff max_i = patt_i.dim_max(k); | |||
| if (dim_name_i.to_str() != "inf") { | |||
| if (!(min_i.is_equal(min_j) && max_i.is_equal(max_j))) { | |||
| log::Warn(log::Verbosity::high, | |||
| log_prefix + " are not equal to pattern bounds; domain does not match\n"); | |||
| return false; | |||
| } | |||
| } | |||
| log::Info(log::Verbosity::high, log_prefix + " match"); | |||
| } catch (isl::exception &) { | |||
| log::Info(log::Verbosity::high, log_prefix + " are undefined; any bound matches"); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| if (!exists) { | |||
| log::Warn(log::Verbosity::high, | |||
| name_ + ": statement " + space_i.to_str() + " does not exist; domain does not match"); | |||
| return false; | |||
| } | |||
| } | |||
| log::Info(log::Verbosity::high, name_ + ": candidate domain matches\n"); | |||
| return true; | |||
| } | |||
| bool NodesMatch(const isl::schedule_node &pattern_node, const isl::schedule_node &candidate_node, | |||
| const std::string &name_, const int verbosity) { | |||
| isl_schedule_node_type pnode_type = GetScheduleNodeType(pattern_node); | |||
| isl_schedule_node_type cnode_type = GetScheduleNodeType(candidate_node); | |||
| // At the first mismatch, return false | |||
| if (pnode_type != cnode_type) { | |||
| return false; | |||
| } | |||
| // If encountering schedule node bands or filters, | |||
| // ensure that their specification is equal | |||
| if (pnode_type == isl_schedule_node_band) { | |||
| isl::multi_union_pw_aff pnode_bsched = isl::manage(isl_schedule_node_band_get_partial_schedule(pattern_node.get())); | |||
| isl::multi_union_pw_aff cnode_bsched = | |||
| isl::manage(isl_schedule_node_band_get_partial_schedule(candidate_node.get())); | |||
| if (!cnode_bsched.plain_is_equal(pnode_bsched)) { | |||
| log::Warn(log::Verbosity::high, name_ + "\' : schedule bands are not equal. Schedule tree does not match"); | |||
| return false; | |||
| } | |||
| } | |||
| if (pnode_type == isl_schedule_node_filter) { | |||
| isl::union_set pfilter_info = isl::manage(isl_schedule_node_filter_get_filter(pattern_node.get())); | |||
| isl::union_set cfilter_info = isl::manage(isl_schedule_node_filter_get_filter(candidate_node.get())); | |||
| if (!pfilter_info.is_empty()) { | |||
| isl::set_list plist = pfilter_info.get_set_list(); | |||
| isl::set_list clist = cfilter_info.get_set_list(); | |||
| for (unsigned i = 0; i < plist.size(); ++i) { | |||
| try { | |||
| isl::id pid = plist.get_at(i).get_tuple_id(); | |||
| if (!pid.is_null()) { | |||
| bool exists = false; | |||
| for (unsigned j = 0; j < clist.size(); ++j) { | |||
| isl::id cid = clist.get_at(j).get_tuple_id(); | |||
| if (!cid.is_null() && (pid.get_name() == cid.get_name())) { | |||
| exists = true; | |||
| } | |||
| } | |||
| if (!exists) { | |||
| log::Warn( | |||
| log::Verbosity::high, | |||
| "\'" + name_ + "\' candidate filter tuple id does not match with pattern (" + pid.get_name() + ")"); | |||
| return false; | |||
| } | |||
| } | |||
| } catch (isl::exception_invalid &) { | |||
| log::Info(log::Verbosity::high, "\'" + name_ + "\' : no tuple id specified, therefore any tuple id matches"); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| // Return true if none of the previous tests failed | |||
| return true; | |||
| } | |||
| bool PatternMatches(const isl::schedule &sch, const std::string &pattern_, const std::string &name_, | |||
| const int verbosity) { | |||
| if (pattern_ != "") { | |||
| isl::schedule pattern = isl::schedule(sch.ctx(), pattern_); | |||
| if (!DomainMatches(pattern, sch, name_, verbosity)) { | |||
| return false; | |||
| } | |||
| log::Info(log::Verbosity::high, " \'" + name_ + "\' : collecting pattern nodes and candidate nodes..."); | |||
| std::vector<isl::schedule_node> pattern_nodes = CollectScheduleNodes(pattern, verbosity); | |||
| std::vector<isl::schedule_node> candidate_nodes = CollectScheduleNodes(sch, verbosity); | |||
| // Iterate according to the number of nodes in the pattern to ensure that | |||
| // the pattern matches at least part of the candidate schedule tree | |||
| if (pattern_nodes.size() == 0) { | |||
| log::Info( | |||
| log::Verbosity::high, | |||
| " \'" + name_ + "\' : no schedule tree pattern is specified, therefore any candidate schedule tree matches"); | |||
| return true; | |||
| } | |||
| // The pattern may not be found at the root of the candidate tree. | |||
| // We therefore first search for any nodes index that matches the first node | |||
| // of the pattern to capture starting points for checking if there is a match | |||
| std::vector<unsigned> entry_indexes; | |||
| for (unsigned i = 0; i < candidate_nodes.size(); ++i) { | |||
| if (NodesMatch(pattern_nodes[0], candidate_nodes[i], name_, verbosity)) { | |||
| entry_indexes.push_back(i); | |||
| } | |||
| } | |||
| if (entry_indexes.size() == 0) { | |||
| return false; | |||
| } | |||
| bool match = false; | |||
| for (unsigned k = 0; k < entry_indexes.size(); ++k) { | |||
| unsigned i; | |||
| for (i = 1; i < pattern_nodes.size() - 1; ++i) { | |||
| unsigned j = entry_indexes[k]; | |||
| bool check = NodesMatch(pattern_nodes[i], candidate_nodes[j + i], name_, verbosity); | |||
| if (!check) break; | |||
| } | |||
| // With the way we expect patterns to be specified, for now, there | |||
| // can be only one possible match if any. so as soon as we match | |||
| // something, we quit the loop over k. | |||
| if (i == pattern_nodes.size() - 1) { | |||
| match = true; | |||
| } | |||
| } | |||
| if (match == false) { | |||
| log::Warn(log::Verbosity::high, " \'" + name_ + "\' : the candidate schedule tree does not match"); | |||
| return false; | |||
| } | |||
| // If the code made it this far, then schedules match | |||
| return true; | |||
| } else { | |||
| return false; | |||
| } | |||
| } | |||
| } // namespace poly | |||
| } // namespace ir | |||
| } // namespace akg | |||
| @@ -202,6 +202,12 @@ isl::schedule MappingOuterBand::DoThreadMapping(const isl::schedule &sch) { | |||
| return node; | |||
| } | |||
| if (node.has_parent() && node.parent().isa<isl::schedule_node_mark>()) { | |||
| const std::string &marker = node.parent().as<isl::schedule_node_mark>().get_id().get_name(); | |||
| if (marker == "mind_trick_swizzle_marker") | |||
| return node; | |||
| } | |||
| size_t num_mapped_desc = NumMappedDescendant(thread_record, node); | |||
| if (CanBeMappedToThread(node, thread_record)) { | |||
| @@ -32,6 +32,13 @@ isl::schedule SharedMemoryManager::Run(isl::schedule sch) { | |||
| } | |||
| schedule_ = sch; | |||
| auto root = sch.get_root(); | |||
| // Update the variable/tensor to share | |||
| if (!scop_info_.user_config_.GetSharedTensors().empty()) { | |||
| configed_tensors_ = Split(scop_info_.user_config_.GetSharedTensors(), " "); | |||
| } | |||
| // Compute the depth where the shared memory have to be generated | |||
| UpdateDepth(root); | |||
| if (scop_info_.user_config_.GetSharedDepth() >= 0) { | |||
| depth_ = scop_info_.user_config_.GetSharedDepth(); | |||
| @@ -39,6 +46,8 @@ isl::schedule SharedMemoryManager::Run(isl::schedule sch) { | |||
| } | |||
| CHECK_GE(depth_, 0) << "shared depth should be greater than or equal with zero!"; | |||
| bank_conflict_ = scop_info_.user_config_.GetEnableBankConflict(); | |||
| shared_inversed_thread_map_ = scop_info_.user_config_.GetSharedInversedThreadMap(); | |||
| shared_vector_align_ = scop_info_.user_config_.GetSharedVectorAlign(); | |||
| // collect all bands at the given depth in the schedule tree | |||
| size_t remain_memory = share_memory_size_; | |||
| @@ -171,6 +180,31 @@ isl::schedule_node SharedMemoryManager::MapCopiesToThreads(isl::schedule_node &r | |||
| has_split = true; | |||
| } | |||
| if (shared_inversed_thread_map_) { | |||
| // Pretille - To make a vectorize loop more apparent with only the information of the mapping | |||
| const auto &domain = band_node.as<isl::schedule_node_band>().get_partial_schedule().domain(); | |||
| const isl::id ¤t_computing_id_shared = domain.unwrap().range().set_list().get_at(0).get_tuple_id(); | |||
| std::vector<size_t> tensor_size; | |||
| for (BufferDefInfo &buffer_info : scop_info_.analysis_result_.buffer_def_infos_) { | |||
| if (current_computing_id_shared == buffer_info.dst_tensor_id) { | |||
| tensor_size = buffer_info.sizes; | |||
| } | |||
| } | |||
| // Reverse because thread is innermost map | |||
| std::reverse(tensor_size.begin(),tensor_size.end()); | |||
| auto ctx = band_node.ctx(); | |||
| const auto &space = band_node.as<isl::schedule_node_band>().get_space(); | |||
| const auto n_member = band_node.as<isl::schedule_node_band>().n_member(); | |||
| isl::multi_val tile_size = isl::multi_val::zero(space); | |||
| for (size_t i = 0; i < n_member; ++i) { | |||
| const size_t size = tensor_size[i]/thread_cfg->GetAt(i).second; | |||
| tile_size = tile_size.set_val(n_member-1 - i, isl::val(ctx, size!=0?size:1)); | |||
| } | |||
| band_node = TileBand(band_node, tile_size); | |||
| } | |||
| auto mapping_cfg = thread_cfg; | |||
| if (scop_info_.user_config_.GetVectorLoadType() || scop_info_.user_config_.GetEnableTensorCoreUsePoly()) { | |||
| scop_info_.user_config_.SetEnableOneDimThread(true); | |||
| @@ -495,9 +529,7 @@ void SharedMemoryManager::GatherBufferFootprintDefInfo(const isl::schedule_node | |||
| sizes.back() += 16; | |||
| } | |||
| if (bank_conflict_) { | |||
| sizes = OptimizeBankConflict(sizes); | |||
| } | |||
| sizes = OptimizeSharedDimension(sizes); | |||
| isl::id tensor_id = tensor_info.tensor_id; | |||
| isl::id cluster_id = tensor_info.dst_tensor_id; | |||
| @@ -552,7 +584,7 @@ isl::schedule_node SharedMemoryManager::HoistClusters(const isl::schedule_node & | |||
| LOG(FATAL) << "Can not manage a scalar tensor"; | |||
| } | |||
| box_sizes = OptimizeBankConflict(box_sizes); | |||
| box_sizes = OptimizeSharedDimension(box_sizes); | |||
| auto approximation_size = std::accumulate(box_sizes.begin(), box_sizes.end(), 1, std::multiplies<size_t>()); | |||
| size_t byte = Bytes(id); | |||
| @@ -596,7 +628,7 @@ isl::schedule_node SharedMemoryManager::HoistToBlockThreadMemory(isl::schedule_n | |||
| isl::id dst_tensor_id = GpuDstId(type, tensor_id); | |||
| auto sizes = cluster.GetFixedBoxSizes(); | |||
| if (force_last_extension_odd) { | |||
| sizes = OptimizeBankConflict(sizes); | |||
| sizes = OptimizeSharedDimension(sizes); | |||
| } | |||
| auto res_node = PlaceOuterDataCopyBelow(scop_info_, tree, cluster, tensor_id, dst_tensor_id, out_schedule, | |||
| @@ -734,6 +766,13 @@ size_t SharedMemoryManager::Bytes(const isl::id tensor_id) { | |||
| return static_cast<size_t>(type.bytes()); | |||
| } | |||
| std::vector<size_t> SharedMemoryManager::OptimizeSharedDimension(std::vector<size_t> sizes) { | |||
| std::vector<size_t> res = sizes; | |||
| res = OptimizeBankConflict(res); | |||
| res = OptimizeVectorAlign(res); | |||
| return res; | |||
| } | |||
| std::vector<size_t> SharedMemoryManager::OptimizeBankConflict(std::vector<size_t> sizes) { | |||
| std::vector<size_t> res = sizes; | |||
| if (res.back() % 2 == 0) { | |||
| @@ -748,6 +787,15 @@ std::vector<size_t> SharedMemoryManager::OptimizeBankConflict(std::vector<size_t | |||
| return res; | |||
| } | |||
| std::vector<size_t> SharedMemoryManager::OptimizeVectorAlign(std::vector<size_t> sizes) { | |||
| std::vector<size_t> res = sizes; | |||
| if (shared_vector_align_ != 0) { | |||
| int padsize = res.back() % shared_vector_align_; | |||
| res.back() += padsize?(shared_vector_align_-padsize):0; | |||
| } | |||
| return res; | |||
| } | |||
| } // namespace poly | |||
| } // namespace ir | |||
| } // namespace akg | |||
| @@ -72,7 +72,9 @@ class SharedMemoryManager : public SchedulePass { | |||
| void UpdateDepth(const isl::schedule_node &root); | |||
| std::vector<size_t> OptimizeSharedDimension(std::vector<size_t> sizes); | |||
| std::vector<size_t> OptimizeBankConflict(std::vector<size_t> sizes); | |||
| std::vector<size_t> OptimizeVectorAlign(std::vector<size_t> sizes); | |||
| bool UnderThreadMarker(size_t depth); | |||
| std::string InAtomicTensors(isl::schedule_node &node); | |||
| @@ -102,10 +104,12 @@ class SharedMemoryManager : public SchedulePass { | |||
| std::string tensor_c_; | |||
| std::string tensor_a_; | |||
| std::string tensor_b_; | |||
| bool shared_inversed_thread_map_{false}; | |||
| int shared_vector_align_{0}; | |||
| }; | |||
| } // namespace poly | |||
| } // namespace ir | |||
| } // namespace akg | |||
| #endif | |||
| #endif | |||
| @@ -44,7 +44,17 @@ isl::schedule SchedulePassMgr::Run(const isl::schedule &sch, const std::vector<s | |||
| auto replace_sch = sch; | |||
| need_restart_ = false; | |||
| std::set<std::string> disabled; | |||
| for (auto &pass : passes) { | |||
| const std::string &name = pass->GetPassName(); | |||
| const bool disable = disabled.find(name) != disabled.end(); | |||
| if (disable) { | |||
| LOG(INFO) << "Disabling poly pass " << name; | |||
| continue; | |||
| } else { | |||
| LOG(INFO) << "Running poly pass " << name; | |||
| } | |||
| if (LoadScheduleTreeFromFile(scop_info_.AddDumpDir(pass->GetPassName() + ".txt"), replace_sch)) { | |||
| if (!replace_sch.plain_is_equal(final_sch)) { | |||
| final_sch = replace_sch; | |||
| @@ -67,6 +77,10 @@ isl::schedule SchedulePassMgr::Run(const isl::schedule &sch, const std::vector<s | |||
| need_restart_ = true; | |||
| break; | |||
| } | |||
| if (!pass->disabled_passes_.empty()) { | |||
| disabled.insert(pass->disabled_passes_.begin(), pass->disabled_passes_.end()); | |||
| LOG(INFO) << name << " requests to disable some subsequent passes"; | |||
| } | |||
| } | |||
| return final_sch; | |||
| } | |||
| @@ -179,6 +179,11 @@ class UserConfig { | |||
| ParseBoolAttr(attrs, "pragma_tilesize_is_var", &tile_size_is_var_); | |||
| ParseBoolAttr(attrs, "pragma_outerband_need_split", &outer_band_need_split_); | |||
| // Mind-trick pass | |||
| ParseIntAttr(attrs, "constrain_schedule_verbosity", &constrain_schedule_verbosity_); | |||
| ParseBoolAttr(attrs, "enable_mind_trick", &enable_mind_trick_); | |||
| ParseStringAttr(attrs, "mind_trick", &mind_trick_); | |||
| ParseStringAttr(attrs, "dim", &b_dim_); | |||
| ParseStringAttr(attrs, "bind_elem_per_thread", &elem_per_thread_); | |||
| ParseMappingCfgAttr(attrs, "bind_block", &block_cfg_); | |||
| @@ -241,6 +246,8 @@ class UserConfig { | |||
| ParseBoolAttr(attrs, "use_shared_memory", &use_shared_memory_); | |||
| ParseBoolAttr(attrs, "enable_bank_conflict_opt", &enable_bank_conflict_); | |||
| ParseBoolAttr(attrs, "enable_one_dim_thread", &enable_one_dim_thread_); | |||
| ParseBoolAttr(attrs, "shared_inversed_thread_map", &shared_inversed_thread_map_); | |||
| ParseIntAttr(attrs, "shared_vector_align", &shared_vector_align_); | |||
| ParseIntAttr(attrs, "register_memory_depth", ®ister_depth_); | |||
| ParseIntAttr(attrs, "shared_memory_depth", &shared_depth_); | |||
| ParseStringAttr(attrs, "shared_memory_tensors", &shared_tensors_); | |||
| @@ -305,6 +312,12 @@ class UserConfig { | |||
| void SetUnrollShared(const bool unroll_shared) { this->unroll_shared_ = unroll_shared; } | |||
| void SetDisableLoopFusion(const bool disable_loop_fusion) { this->disable_loop_fusion_ = disable_loop_fusion; } | |||
| // getter/setter for schedule_pass config | |||
| int GetConstrainScheduleVerbosity() const { return constrain_schedule_verbosity_; } | |||
| bool GetEnableMindTrick() const { return enable_mind_trick_; } | |||
| std::string GetMindTrick() const { return mind_trick_; } | |||
| void SetConstrainedSchedulingOutput(const std::string &output) { constrained_scheduling_output_ = output; } | |||
| // getter for schedule tree transform config | |||
| bool GetRemoveSelfDependence() const { return remove_self_dependence_; } | |||
| bool GetForceRemoveSelfDependence() const { return force_remove_self_dependence_; } | |||
| @@ -418,6 +431,12 @@ class UserConfig { | |||
| void SetEnableBankConflict(bool enable_bank_conflict) { enable_bank_conflict_ = enable_bank_conflict; } | |||
| bool GetEnableBankConflict() { return enable_bank_conflict_; } | |||
| int GetVectorLoadType() { return vector_load_type_; } | |||
| void SetSharedInversedThreadMap(bool shared_inversed_thread_map) { | |||
| shared_inversed_thread_map_ = shared_inversed_thread_map; | |||
| } | |||
| bool GetSharedInversedThreadMap() { return shared_inversed_thread_map_; } | |||
| void SetSharedVectorAlign(int shared_vector_align) { shared_vector_align_ = shared_vector_align; } | |||
| int GetSharedVectorAlign() { return shared_vector_align_; } | |||
| private: | |||
| // tools for parsing user config | |||
| @@ -576,6 +595,14 @@ class UserConfig { | |||
| int max_unroll_loop_{1}; | |||
| bool unroll_shared_{false}; | |||
| bool enable_bank_conflict_{false}; | |||
| bool shared_inversed_thread_map_{false}; | |||
| int shared_vector_align_{0}; | |||
| // schedule_pass/mind_trick config | |||
| int constrain_schedule_verbosity_{-1}; | |||
| bool enable_mind_trick_{true}; | |||
| std::string mind_trick_{""}; | |||
| std::string constrained_scheduling_output_{""}; | |||
| // schedule tree transform config | |||
| bool remove_self_dependence_{true}; | |||
| @@ -0,0 +1,3 @@ | |||
| opstest_*.log | |||
| *.cu | |||
| cuda_meta_* | |||
| @@ -64,7 +64,7 @@ from tests.operators.gpu.test_fused_bn_update_grad import test_fused_bn_update_g | |||
| from tests.operators.gpu.test_fused_mul_div_rsqrt_mul_isfinite_red import test_fused_mul_div_rsqrt_mul_isfinite_red | |||
| def add(poly_sch, fuzz_shape=None): | |||
| def add(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| if fuzz_shape: | |||
| test_ms_add(fuzz_shape, fuzz_shape, 'float32', poly_sch=poly_sch) | |||
| return | |||
| @@ -74,12 +74,12 @@ def add(poly_sch, fuzz_shape=None): | |||
| 'float32', poly_sch=poly_sch) | |||
| def addn(poly_sch, fuzz_shape=None): | |||
| def addn(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_ms_addn((1, 1024, 1024), "float32", 2, poly_sch=poly_sch) | |||
| test_ms_addn((1, 1024, 1024), "float16", 2, poly_sch=poly_sch) | |||
| def bmm(poly_sch, fuzz_shape=None): | |||
| def bmm(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_ms_bmm((768, 768), (768, 768), 'float16', 'float16', layout1='NHDT', layout2='NHDT', layout_out='NHDT', | |||
| shape_bias=(1, ), add_bias=False, tensor_core=True, poly_sch=poly_sch, | |||
| dim="0 0 128 128 0 1 128 128 0 2 64 4", bind_block="6 6", bind_thread="32 4") | |||
| @@ -158,19 +158,19 @@ def bmm(poly_sch, fuzz_shape=None): | |||
| dim="0 0 128 128 0 1 128 128 0 2 64 4", bind_block="24 32", bind_thread="32 4") | |||
| """ | |||
| def cast(poly_sch, fuzz_shape=None): | |||
| def cast(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_ms_cast((32, 32, 14, 14, 16), "float16", "float32", poly_sch=poly_sch) | |||
| test_ms_cast((32, 32, 14, 14, 16), "float32", "float16", poly_sch=poly_sch) | |||
| def exp(poly_sch, fuzz_shape=None): | |||
| def exp(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_ms_exp((1024, 4096), 'float32', poly_sch=poly_sch) | |||
| test_ms_exp((1024, 4096), 'float16', poly_sch=poly_sch) | |||
| test_ms_exp((1024, 4095), 'float16', poly_sch=poly_sch) | |||
| test_ms_exp((1024, 799), 'float16', poly_sch=poly_sch) | |||
| def maximum(poly_sch, fuzz_shape=None): | |||
| def maximum(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_ms_maximum((32, 1024, 1024), (32, 1024, 1024), | |||
| 'float32', poly_sch=poly_sch) | |||
| test_ms_maximum((32, 1024, 1024), (1, 1024, 1024), | |||
| @@ -179,7 +179,7 @@ def maximum(poly_sch, fuzz_shape=None): | |||
| 'float16', poly_sch=poly_sch) | |||
| def minimum(poly_sch, fuzz_shape=None): | |||
| def minimum(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_ms_minimum((32, 1024, 1024), (32, 1024, 1024), | |||
| 'float32', poly_sch=poly_sch) | |||
| test_ms_minimum((32, 1024, 1024), (1, 1024, 1024), | |||
| @@ -188,33 +188,33 @@ def minimum(poly_sch, fuzz_shape=None): | |||
| 'float16', poly_sch=poly_sch) | |||
| def mul(poly_sch, fuzz_shape=None): | |||
| def mul(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_ms_mul((1024, 4096), 'float32', poly_sch=poly_sch) | |||
| def divide(poly_sch, fuzz_shape=None): | |||
| def divide(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_ms_divide((1024, 1024), 'float32', poly_sch=poly_sch) | |||
| test_ms_divide((1024, 1024), 'float16', poly_sch=poly_sch) | |||
| def reshape(poly_sch, fuzz_shape=None): | |||
| def reshape(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_ms_reshape("float32", (64, 128, 1024), | |||
| (8192, 1024), poly_sch=poly_sch) | |||
| test_ms_reshape("float16", (64, 128, 1024), | |||
| (8192, 1024), poly_sch=poly_sch) | |||
| def rsqrt(poly_sch, fuzz_shape=None): | |||
| def rsqrt(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_ms_rsqrt((32, 1024, 1024), 'float32', poly_sch=poly_sch) | |||
| test_ms_rsqrt((32, 1024, 1024), 'float16', poly_sch=poly_sch) | |||
| def sqrt(poly_sch, fuzz_shape=None): | |||
| def sqrt(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_ms_sqrt((1024, 1024), "float32", poly_sch=poly_sch) | |||
| test_ms_sqrt((1024, 1024), "float16", poly_sch=poly_sch) | |||
| def sub(poly_sch, fuzz_shape=None): | |||
| def sub(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_ms_sub((32, 1024, 1024), (32, 1024, 1024), | |||
| 'float32', poly_sch=poly_sch) | |||
| test_ms_sub((32, 1024, 1024), (32, 1024, 1024), | |||
| @@ -224,36 +224,36 @@ def sub(poly_sch, fuzz_shape=None): | |||
| test_ms_sub((4, 4, 4), (1, 4, 4), 'float32', poly_sch=poly_sch) | |||
| def tile(poly_sch, fuzz_shape=None): | |||
| def tile(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_ms_tile((1024, 4096), (3,), 'float32', poly_sch=poly_sch) | |||
| test_ms_tile((1024, 4096), (3,), 'float16', poly_sch=poly_sch) | |||
| def one_hot(poly_sch, fuzz_shape=None): | |||
| def one_hot(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_ms_one_hot((1024,), 16, "int32", 1, 0, 0, poly_sch=poly_sch) | |||
| test_ms_one_hot((1024,), 16, "float32", 1, 0, 0, poly_sch=poly_sch) | |||
| test_ms_one_hot((32,), 16, "int32", 1, 0, 0, poly_sch=poly_sch) | |||
| test_ms_one_hot((32,), 16, "float32", 1, 0, 0, poly_sch=poly_sch) | |||
| def expand_dims(poly_sch, fuzz_shape=None): | |||
| def expand_dims(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_expand_dims((32, 1024, 1024), 1, 'float32', poly_sch=poly_sch) | |||
| test_expand_dims((32, 1024, 1024), 2, 'float16', poly_sch=poly_sch) | |||
| def trans_data(poly_sch, fuzz_shape=None): | |||
| def trans_data(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_ms_trans_data((8, 24, 38, 38), (0, 2, 1, 3), | |||
| 'float32', poly_sch=poly_sch) | |||
| test_ms_trans_data((8, 24, 38, 38), (0, 2, 1, 3), | |||
| 'float16', poly_sch=poly_sch) | |||
| def log(poly_sch, fuzz_shape=None): | |||
| def log(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_ms_log((9, 1024, 1024), 'float16', poly_sch=poly_sch) | |||
| test_ms_log((9, 1024, 1024), 'float32', poly_sch=poly_sch) | |||
| def pow(poly_sch, fuzz_shape=None): | |||
| def pow(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_ms_pow((9, 1024, 1024), (9, 1024, 1024), 'float32', poly_sch=poly_sch) | |||
| test_ms_pow((9, 1024, 1024), (9, 1024, 1), 'float32', poly_sch=poly_sch) | |||
| test_ms_pow((9, 1024, 1024), (9, 1, 1), 'float32', poly_sch=poly_sch) | |||
| @@ -264,7 +264,7 @@ def pow(poly_sch, fuzz_shape=None): | |||
| test_ms_pow((9, 1024, 1024), (1, 1, 1), 'float16', poly_sch=poly_sch) | |||
| def abs(poly_sch, fuzz_shape=None): | |||
| def abs(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_ms_abs((1024, 1024), "float32", poly_sch=poly_sch) | |||
| test_ms_abs((1024, 1024), "float16", poly_sch=poly_sch) | |||
| test_ms_abs((1, ), "float32", poly_sch=poly_sch) | |||
| @@ -273,7 +273,7 @@ def abs(poly_sch, fuzz_shape=None): | |||
| test_ms_abs((1, 1), "float16", poly_sch=poly_sch) | |||
| def neg(poly_sch, fuzz_shape=None): | |||
| def neg(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_ms_neg((1024, 1024), "float32", poly_sch=poly_sch) | |||
| test_ms_neg((1024, 1024), "float16", poly_sch=poly_sch) | |||
| test_ms_neg((1, ), "float32", poly_sch=poly_sch) | |||
| @@ -282,7 +282,7 @@ def neg(poly_sch, fuzz_shape=None): | |||
| test_ms_neg((1, 1), "float16", poly_sch=poly_sch) | |||
| def round(poly_sch, fuzz_shape=None): | |||
| def round(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_ms_round((1024, 1024), "float32", poly_sch=poly_sch) | |||
| test_ms_round((1024, 1024), "float16", poly_sch=poly_sch) | |||
| test_ms_round((1, ), "float32", poly_sch=poly_sch) | |||
| @@ -291,7 +291,7 @@ def round(poly_sch, fuzz_shape=None): | |||
| test_ms_round((1, 1), "float16", poly_sch=poly_sch) | |||
| def reduce_sum(poly_sch, fuzz_shape=None): | |||
| def reduce_sum(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_ms_reduce_sum((256, 256), 'float32', axis=(1,), | |||
| keepdims=True, poly_sch=poly_sch) | |||
| test_ms_reduce_sum((9, 1024, 1024), 'float32', axis=None, | |||
| @@ -304,31 +304,31 @@ def reduce_sum(poly_sch, fuzz_shape=None): | |||
| keepdims=True, poly_sch=poly_sch) | |||
| def select(poly_sch, fuzz_shape=None): | |||
| def select(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_ms_select((2, ), (2, 2, 2), "int8", "float16", poly_sch=poly_sch) | |||
| def equal(poly_sch, fuzz_shape=None): | |||
| def equal(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_ms_equal(((1, 1024), (1, 1024)), 'float16', poly_sch=poly_sch) | |||
| test_ms_equal(((1, 1024), (1, 1024)), 'float32', poly_sch=poly_sch) | |||
| def less_equal(poly_sch, fuzz_shape=None): | |||
| def less_equal(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_ms_less_equal((1, 1024), (1, 1024), 'float16', poly_sch=poly_sch) | |||
| test_ms_less_equal((1, 1024), (1, 1024), 'float32', poly_sch=poly_sch) | |||
| def greater_equal(poly_sch, fuzz_shape=None): | |||
| def greater_equal(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_ms_greater_equal((1, 1024), (1, 1024), 'float16', poly_sch=poly_sch) | |||
| test_ms_greater_equal((1, 1024), (1, 1024), 'float32', poly_sch=poly_sch) | |||
| def reciprocal(poly_sch, fuzz_shape=None): | |||
| def reciprocal(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_ms_reciprocal((1, 1024), 'float16', poly_sch=poly_sch) | |||
| test_ms_reciprocal((1, 1024), 'float32', poly_sch=poly_sch) | |||
| def reduce_min(poly_sch, fuzz_shape=None): | |||
| def reduce_min(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_ms_reduce_min((9, 1024, 1024), 'float32', axis=None, | |||
| keepdims=False, poly_sch=poly_sch) | |||
| test_ms_reduce_min((9, 1024, 1024), 'float16', axis=None, | |||
| @@ -339,7 +339,7 @@ def reduce_min(poly_sch, fuzz_shape=None): | |||
| keepdims=False, poly_sch=poly_sch) | |||
| def reduce_max(poly_sch, fuzz_shape=None): | |||
| def reduce_max(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_ms_reduce_max((9, 1024, 1024), 'float32', axis=None, | |||
| keepdims=False, poly_sch=poly_sch) | |||
| test_ms_reduce_max((9, 1024, 1024), 'float16', axis=None, | |||
| @@ -350,77 +350,77 @@ def reduce_max(poly_sch, fuzz_shape=None): | |||
| keepdims=False, poly_sch=poly_sch) | |||
| def fused_pad(poly_sch, fuzz_shape=None): | |||
| def fused_pad(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_fused_pad((7, 7, 3, 64), (0, 0, 0, 0), (0, 0, 1, 0), | |||
| layout='NHWC', pad_value=0.0, poly_sch=poly_sch) | |||
| def fused_bn_reduce(poly_sch, fuzz_shape=None): | |||
| def fused_bn_reduce(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_fused_bn_reduce((256, 7, 7, 2048), layout='NHWC', poly_sch=poly_sch) | |||
| def fused_bn_update(poly_sch, fuzz_shape=None): | |||
| test_fused_bn_update((2048,), poly_sch=poly_sch) | |||
| def fused_bn_update(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_fused_bn_update((2048,), poly_sch=poly_sch, mind_trick=mind_trick_str) | |||
| def fused_bn_follow_relu(poly_sch, fuzz_shape=None): | |||
| def fused_bn_follow_relu(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_fused_bn_follow_relu( | |||
| (256, 7, 7, 2048), layout='NHWC', poly_sch=poly_sch) | |||
| def fused_bn_follow_relu_avgpool(poly_sch, fuzz_shape=None): | |||
| def fused_bn_follow_relu_avgpool(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_fused_bn_follow_relu_avgpool( | |||
| (256, 7, 7, 2048), layout='NHWC', poly_sch=poly_sch) | |||
| def fused_bn_double_follow_relu(poly_sch, fuzz_shape=None): | |||
| def fused_bn_double_follow_relu(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_fused_bn_double_follow_relu( | |||
| (256, 7, 7, 2048), layout='NHWC', poly_sch=poly_sch) | |||
| def fused_bn_reduce_grad(poly_sch, fuzz_shape=None): | |||
| def fused_bn_reduce_grad(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_fused_bn_reduce_grad( | |||
| (256, 56, 56, 256), layout='NHWC', poly_sch=poly_sch) | |||
| def fused_relu_grad_bn_reduce_grad(poly_sch, fuzz_shape=None): | |||
| def fused_relu_grad_bn_reduce_grad(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_fused_relu_grad_bn_reduce_grad( | |||
| (64, ), (256, 112, 112, 64), layout='NHWC', poly_sch=poly_sch) | |||
| def fused_relu_grad_bn_double_reduce_grad(poly_sch, fuzz_shape=None): | |||
| def fused_relu_grad_bn_double_reduce_grad(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_fused_relu_grad_bn_double_reduce_grad( | |||
| (256,), (256, 56, 56, 256), layout="NHWC", poly_sch=poly_sch) | |||
| def fused_l2loss_grad(poly_sch, fuzz_shape=None): | |||
| test_fused_l2loss_grad((1, 1, 256, 1024), layout='NHWC', poly_sch=poly_sch) | |||
| def fused_l2loss_grad(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_fused_l2loss_grad((1, 1, 256, 1024), layout='NHWC', poly_sch=poly_sch, mind_trick=mind_trick_str) | |||
| def fused_is_finite(poly_sch, fuzz_shape=None): | |||
| def fused_is_finite(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_fused_is_finite((1, 1, 256, 1024), layout='NHWC', poly_sch=poly_sch) | |||
| def fused_relu_grad_bn_update_grad(poly_sch, fuzz_shape=None): | |||
| def fused_relu_grad_bn_update_grad(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_fused_relu_grad_bn_update_grad( | |||
| (256, 112, 112, 64), (64,), layout="NHWC", poly_sch=poly_sch) | |||
| def fused_relu_grad_bn_double_update_grad(poly_sch, fuzz_shape=None): | |||
| def fused_relu_grad_bn_double_update_grad(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_fused_relu_grad_bn_double_update_grad( | |||
| (256, 56, 56, 256), (256, ), layout='NHWC', poly_sch=poly_sch) | |||
| def fused_relu_grad(poly_sch, fuzz_shape=None): | |||
| test_fused_relu_grad((256, 56, 56, 256), poly_sch=poly_sch) | |||
| def fused_relu_grad(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_fused_relu_grad((256, 56, 56, 256), poly_sch=poly_sch, mind_trick=mind_trick_str) | |||
| def fused_bn_update_grad(poly_sch, fuzz_shape=None): | |||
| def fused_bn_update_grad(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_fused_bn_update_grad( | |||
| (256, 56, 56, 256), (256,), layout="NHWC", poly_sch=poly_sch) | |||
| def fused_mul_div_rsqrt_mul_isfinite_red(poly_sch, fuzz_shape=None): | |||
| def fused_mul_div_rsqrt_mul_isfinite_red(poly_sch, fuzz_shape=None, mind_trick_str=''): | |||
| test_fused_mul_div_rsqrt_mul_isfinite_red((64,), poly_sch=poly_sch) | |||
| @@ -490,11 +490,17 @@ if __name__ == '__main__': | |||
| sys.exit() | |||
| options, args = getopt.getopt( | |||
| sys.argv[1:], "f:", ["fuzz="]) | |||
| sys.argv[1:], "f:", ["fuzz=", "mind-trick-string=", "mind-trick-file="]) | |||
| mind_trick_str = '' | |||
| fuzz_dim = 0 | |||
| for name, value in options: | |||
| if name in ("-f", "--fuzz"): | |||
| fuzz_dim = int(value) | |||
| if name == "--mind-trick-string": | |||
| mind_trick_str = value | |||
| if name == "--mind-trick-file": | |||
| with open(value, 'r') as f: | |||
| mind_trick_str = f.read() | |||
| fail_op_list = dict() | |||
| run_op_list = list() | |||
| @@ -520,7 +526,7 @@ if __name__ == '__main__': | |||
| print("Fuzz shape: {}".format(fuzz_shape)) | |||
| try: | |||
| print("Time of auto schedule:") | |||
| op(poly_sch=True, fuzz_shape=fuzz_shape) | |||
| op(poly_sch=True, fuzz_shape=fuzz_shape, mind_trick_str=mind_trick_str) | |||
| except: | |||
| if op.__name__ in fail_op_list: | |||
| fail_op_list[op.__name__].extend( | |||
| @@ -535,4 +541,4 @@ if __name__ == '__main__': | |||
| for op, error_info in fail_op_list.items(): | |||
| print("Run op %s error" % op) | |||
| for e in error_info: | |||
| print(e) | |||
| print(e) | |||
| @@ -43,14 +43,17 @@ def compute_expect(input, c1, c2, c3, c4): | |||
| return [out1, out2, out3] | |||
| def test_fused_bn_update(shape, dtype="float32", c1=(1 / (256 * 7 * 7)), c2=1.001e-05, c3=1.00007975, c4=0.100000024, poly_sch=False): | |||
| def test_fused_bn_update(shape, dtype="float32", c1=(1 / (256 * 7 * 7)), c2=1.001e-05, c3=1.00007975, c4=0.100000024, poly_sch=False, mind_trick=''): | |||
| input = gen_data(shape, dtype) | |||
| expect = compute_expect(input, c1, c2, c3, c4) | |||
| attrs = [dtype, c1, c2, c3, c4] | |||
| shapes = [input[0].shape] * 4 | |||
| dtypes = [dtype] * 4 | |||
| if poly_sch: | |||
| mod = utils.op_build_test(fused_bn_update, shapes, dtypes, kernel_name="fused_bn_update", op_attrs=attrs, attrs={"target": "cuda"}) | |||
| if mind_trick: | |||
| mod = utils.op_build_test(fused_bn_update, shapes, dtypes, kernel_name="fused_bn_update", op_attrs=attrs, attrs={"target": "cuda", "mind_trick": mind_trick}) | |||
| else: | |||
| mod = utils.op_build_test(fused_bn_update, shapes, dtypes, kernel_name="fused_bn_update", op_attrs=attrs, attrs={"target": "cuda"}) | |||
| outputs = [np.full(shape, np.nan, dtype)] * 3 | |||
| attrs_list = input + outputs | |||
| @@ -64,4 +67,4 @@ def test_fused_bn_update(shape, dtype="float32", c1=(1 / (256 * 7 * 7)), c2=1.00 | |||
| data = to_tvm_nd_array(input) | |||
| expect = to_tvm_nd_array(expect) | |||
| gpu_profiling(mod, *data, *expect, 400) | |||
| gpu_profiling(mod, *data, *expect, 400) | |||
| @@ -36,7 +36,7 @@ def compute_py(data_f16, data_f32, layout, fill_data): | |||
| output = np.full(np.shape(expect), np.nan, 'float32') | |||
| return expect, output | |||
| def test_fused_l2loss_grad(shape, layout, fill_data=4e-05, poly_sch=False): | |||
| def test_fused_l2loss_grad(shape, layout, fill_data=4e-05, poly_sch=False, mind_trick=''): | |||
| data_1 = gen_data(shape, 'float16') | |||
| data_2 = gen_data(shape, 'float32') | |||
| @@ -45,7 +45,10 @@ def test_fused_l2loss_grad(shape, layout, fill_data=4e-05, poly_sch=False): | |||
| dtype_list = ['float16', 'float32'] | |||
| op_attrs = [layout, fill_data] | |||
| if poly_sch: | |||
| mod = utils.op_build_test(fused_l2loss_grad, input_list, dtype_list, kernel_name="fused_l2loss_grad", op_attrs=op_attrs, attrs={"target": "cuda"}) | |||
| if mind_trick: | |||
| mod = utils.op_build_test(fused_l2loss_grad, input_list, dtype_list, kernel_name="fused_l2loss_grad", op_attrs=op_attrs, attrs={"target": "cuda", "mind_trick": mind_trick}) | |||
| else: | |||
| mod = utils.op_build_test(fused_l2loss_grad, input_list, dtype_list, kernel_name="fused_l2loss_grad", op_attrs=op_attrs, attrs={"target": "cuda"}) | |||
| args = [data_1, data_2, output] | |||
| output = utils.mod_launch(mod, args, expect = expect) | |||
| @@ -58,4 +61,4 @@ def test_fused_l2loss_grad(shape, layout, fill_data=4e-05, poly_sch=False): | |||
| data = to_tvm_nd_array([data_1, data_2]) | |||
| expect = to_tvm_nd_array(expect) | |||
| gpu_profiling(mod, *data, expect, 400) | |||
| gpu_profiling(mod, *data, expect, 400) | |||
| @@ -38,7 +38,7 @@ def compute_expect(input, c1): | |||
| return np.where(cmp_zero, data_add, data_zero) | |||
| def test_fused_relu_grad(shape, c1=0, poly_sch=False): | |||
| def test_fused_relu_grad(shape, c1=0, poly_sch=False, mind_trick=''): | |||
| dtype='float16' | |||
| input = gen_data(shape, dtype) | |||
| expect = compute_expect(input, c1) | |||
| @@ -46,7 +46,10 @@ def test_fused_relu_grad(shape, c1=0, poly_sch=False): | |||
| dtypes = [dtype] * 3 | |||
| attrs = [c1] | |||
| if poly_sch: | |||
| mod = utils.op_build_test(fused_relu_grad, shapes, dtypes, op_attrs=attrs, kernel_name="fused_relu_grad", attrs={"target": "cuda"}) | |||
| if mind_trick: | |||
| mod = utils.op_build_test(fused_relu_grad, shapes, dtypes, op_attrs=attrs, kernel_name="fused_relu_grad", attrs={"target": "cuda", "mind_trick": mind_trick}) | |||
| else: | |||
| mod = utils.op_build_test(fused_relu_grad, shapes, dtypes, op_attrs=attrs, kernel_name="fused_relu_grad", attrs={"target": "cuda"}) | |||
| output = np.full(shape, np.nan, dtype) | |||
| output = utils.mod_launch(mod, (*input, output), expect=expect) | |||
| @@ -54,6 +54,12 @@ def print_usage(): | |||
| logging.info(template.format("", "Set attribute of 'bind_block' when use '-f' command.")) | |||
| logging.info(template.format("--bind_thread=<args>", "")) | |||
| logging.info(template.format("", "Set attribute of 'bind_thread' when use '-f' command.")) | |||
| logging.info(template.format("--mind-trick-enable=<0|1>", "")) | |||
| logging.info(template.format("", "explicitly enable (--mind-trick-enable=1) or disable (--mind-trick-enable=0) mind tricks")) | |||
| logging.info(template.format("--mind-trick-file", "")) | |||
| logging.info(template.format("", "json mind trick file")) | |||
| logging.info(template.format("--mind-trick-string", "")) | |||
| logging.info(template.format("", "json mind-trick string")) | |||
| logging.info("\n") | |||
| @@ -212,6 +218,7 @@ def main(argv): | |||
| try: | |||
| options, args = getopt.getopt(argv, "adcf:mh", ["auto", "manual", "ci", "profile", | |||
| "enable_atomic_add=", "dim=", "bind_block=", "bind_thread=", | |||
| "mind-trick-enable=", "mind-trick-file=", "mind-trick-string=", | |||
| "help"]) | |||
| poly = True | |||
| single_file = False | |||
| @@ -247,6 +254,13 @@ def main(argv): | |||
| attrs_list["bind_block"] = value | |||
| elif option == "--bind_thread": | |||
| attrs_list["bind_thread"] = value | |||
| elif option == "--mind-trick-enable": | |||
| attrs_list['enable_mind_trick'] = int(value) | |||
| elif option == "--mind-trick-file": | |||
| with open(value, 'r') as f: | |||
| attrs_list['mind_trick'] = f.read() | |||
| elif option == "--mind-trick-string": | |||
| attrs_list['mind_trick'] = value | |||
| except: | |||
| print_usage() | |||
| return | |||
| @@ -0,0 +1 @@ | |||
| {"composite":true,"composite_graph":"11380.22873","id":1054,"input_desc":[[{"data_type":"float16","format":"DefaultFormat","shape":[4096,768],"tensor_name":"input_0"}],[{"data_type":"float16","format":"DefaultFormat","shape":[4096,768],"tensor_name":"input_1"}],[{"data_type":"float16","format":"DefaultFormat","shape":[4096,768],"tensor_name":"input_2"}],[{"data_type":"float16","format":"DefaultFormat","shape":[4096,768],"tensor_name":"input_3"}]],"op":"Fused_AddN_fusion_9584919353229493170","op_desc":[{"attr":[{"data_type":"listInt","name":"dyn_input_sizes","value":[4]},{"data_type":"int","name":"n","value":4}],"impl_path":"","input_desc":[[{"data_type":"float16","format":"DefaultFormat","name":"input_0","shape":[4096,768],"tensor_name":"input_0"},{"data_type":"float16","format":"DefaultFormat","name":"input_0","shape":[4096,768],"tensor_name":"input_1"},{"data_type":"float16","format":"DefaultFormat","name":"input_0","shape":[4096,768],"tensor_name":"input_2"},{"data_type":"float16","format":"DefaultFormat","name":"input_0","shape":[4096,768],"tensor_name":"input_3"}]],"name":"AddN","output_desc":[{"data_type":"float16","format":"DefaultFormat","name":"output","shape":[4096,768],"tensor_name":"output_0_0"}]}],"output_desc":[{"data_type":"float16","format":"DefaultFormat","shape":[4096,768],"tensor_name":"output_0_0"}],"platform":"AKG","process":"cuda"} | |||
| @@ -0,0 +1 @@ | |||
| {"composite":true,"composite_graph":"11380.21701","id":929,"input_desc":[[{"data_type":"float32","format":"DefaultFormat","shape":[768],"tensor_name":"input_4"}],[{"data_type":"float16","format":"DefaultFormat","shape":[4096,768],"tensor_name":"input_5"}],[{"data_type":"float32","format":"DefaultFormat","shape":[4096,768],"tensor_name":"input_0"}],[{"data_type":"float16","format":"DefaultFormat","shape":[4096,768],"tensor_name":"input_11"}]],"op":"Fused_Cast_BiasAdd_GkDropout_tuple_getitem_TensorAdd_fusion_13282325956852925231","op_desc":[{"attr":[{"data_type":"str","name":"dst_type","value":"float16"},{"data_type":"bool","name":"is_backed_cast","value":false}],"impl_path":"","input_desc":[[{"data_type":"float32","format":"DefaultFormat","name":"input_0","shape":[4096,768],"tensor_name":"input_0"}]],"name":"Cast","output_desc":[{"data_type":"float16","format":"DefaultFormat","name":"output","shape":[4096,768],"tensor_name":"output_0_0"}]},{"attr":null,"impl_path":"","input_desc":[[{"data_type":"float16","format":"DefaultFormat","name":"input_0","shape":[4096,768],"tensor_name":"output_0_0"}],[{"data_type":"float16","format":"DefaultFormat","name":"input_1","shape":[1],"tensor_name":"input_2","value":0.89990234375}]],"name":"LessEqual","output_desc":[{"data_type":"float16","format":"DefaultFormat","name":"output","shape":[4096,768],"tensor_name":"output_0_1"}]},{"attr":[{"data_type":"str","name":"dst_type","value":"float16"},{"data_type":"bool","name":"is_backed_cast","value":false}],"impl_path":"","input_desc":[[{"data_type":"float16","format":"DefaultFormat","name":"input_0","shape":[4096,768],"tensor_name":"output_0_1"}]],"name":"Cast","output_desc":[{"data_type":"float16","format":"DefaultFormat","name":"output","shape":[4096,768],"tensor_name":"output_0_2"}]},{"attr":[{"data_type":"str","name":"dst_type","value":"float16"},{"data_type":"bool","name":"is_backed_cast","value":false}],"impl_path":"","input_desc":[[{"data_type":"float32","format":"DefaultFormat","name":"input_0","shape":[768],"tensor_name":"input_4"}]],"name":"Cast","output_desc":[{"data_type":"float16","format":"DefaultFormat","name":"output","shape":[768],"tensor_name":"output_0_3"}]},{"attr":null,"impl_path":"","input_desc":[[{"data_type":"float16","format":"DefaultFormat","name":"input_0","shape":[4096,768],"tensor_name":"input_5"}],[{"data_type":"float16","format":"DefaultFormat","name":"input_1","shape":[768],"tensor_name":"output_0_3"}]],"name":"TensorAdd","output_desc":[{"data_type":"float16","format":"DefaultFormat","name":"output","shape":[4096,768],"tensor_name":"output_0_4"}]},{"attr":null,"impl_path":"","input_desc":[[{"data_type":"float16","format":"DefaultFormat","name":"input_0","shape":[1],"tensor_name":"input_7","value":1.111328125}],[{"data_type":"float16","format":"DefaultFormat","name":"input_1","shape":[4096,768],"tensor_name":"output_0_4"}]],"name":"Mul","output_desc":[{"data_type":"float16","format":"DefaultFormat","name":"output","shape":[4096,768],"tensor_name":"output_0_5"}]},{"attr":null,"impl_path":"","input_desc":[[{"data_type":"float16","format":"DefaultFormat","name":"input_0","shape":[4096,768],"tensor_name":"output_0_5"}],[{"data_type":"float16","format":"DefaultFormat","name":"input_1","shape":[4096,768],"tensor_name":"output_0_2"}]],"name":"Mul","output_desc":[{"data_type":"float16","format":"DefaultFormat","name":"output","shape":[4096,768],"tensor_name":"output_0_6"}]},{"attr":null,"impl_path":"","input_desc":[[{"data_type":"float16","format":"DefaultFormat","name":"input_0","shape":[4096,768],"tensor_name":"input_11"}],[{"data_type":"float16","format":"DefaultFormat","name":"input_1","shape":[4096,768],"tensor_name":"output_0_6"}]],"name":"TensorAdd","output_desc":[{"data_type":"float16","format":"DefaultFormat","name":"output","shape":[4096,768],"tensor_name":"output_0_7"}]}],"output_desc":[{"data_type":"float16","format":"DefaultFormat","shape":[4096,768],"tensor_name":"output_0_2"},{"data_type":"float16","format":"DefaultFormat","shape":[4096,768],"tensor_name":"output_0_7"}],"platform":"AKG","process":"cuda"} | |||
| @@ -0,0 +1 @@ | |||
| {"composite":true,"composite_graph":"akg_decode_3954_20083.20131","id":925,"input_desc":[[{"data_type":"float16","format":"DefaultFormat","shape":[32,12,128,128],"tensor_name":"input_5"}],[{"data_type":"float32","format":"DefaultFormat","shape":[32,12,128,128],"tensor_name":"input_0"}]],"op":"Fused_GkDropout_2353362030752466006","op_desc":[{"attr":[{"data_type":"str","name":"dst_type","value":"float16"},{"data_type":"bool","name":"is_backed_cast","value":false}],"impl_path":"","input_desc":[[{"data_type":"float32","format":"DefaultFormat","name":"input_0","shape":[32,12,128,128],"tensor_name":"input_0"}]],"name":"Cast","output_desc":[{"data_type":"float16","format":"DefaultFormat","name":"output","shape":[32,12,128,128],"tensor_name":"output_0_0"}]},{"attr":null,"impl_path":"","input_desc":[[{"data_type":"float16","format":"DefaultFormat","name":"input_0","shape":[32,12,128,128],"tensor_name":"output_0_0"}],[{"data_type":"float16","format":"DefaultFormat","name":"input_1","shape":[1],"tensor_name":"input_2","value":0.89990234375}]],"name":"LessEqual","output_desc":[{"data_type":"float16","format":"DefaultFormat","name":"output","shape":[32,12,128,128],"tensor_name":"output_0_1"}]},{"attr":[{"data_type":"str","name":"dst_type","value":"float16"},{"data_type":"bool","name":"is_backed_cast","value":false}],"impl_path":"","input_desc":[[{"data_type":"float16","format":"DefaultFormat","name":"input_0","shape":[32,12,128,128],"tensor_name":"output_0_1"}]],"name":"Cast","output_desc":[{"data_type":"float16","format":"DefaultFormat","name":"output","shape":[32,12,128,128],"tensor_name":"output_0_2"}]},{"attr":null,"impl_path":"","input_desc":[[{"data_type":"float16","format":"DefaultFormat","name":"input_0","shape":[1],"tensor_name":"input_4","value":1.111328125}],[{"data_type":"float16","format":"DefaultFormat","name":"input_1","shape":[32,12,128,128],"tensor_name":"input_5"}]],"name":"Mul","output_desc":[{"data_type":"float16","format":"DefaultFormat","name":"output","shape":[32,12,128,128],"tensor_name":"output_0_3"}]},{"attr":null,"impl_path":"","input_desc":[[{"data_type":"float16","format":"DefaultFormat","name":"input_0","shape":[32,12,128,128],"tensor_name":"output_0_3"}],[{"data_type":"float16","format":"DefaultFormat","name":"input_1","shape":[32,12,128,128],"tensor_name":"output_0_2"}]],"name":"Mul","output_desc":[{"data_type":"float16","format":"DefaultFormat","name":"output","shape":[32,12,128,128],"tensor_name":"output_0_4"}]}],"output_desc":[{"data_type":"float16","format":"DefaultFormat","shape":[32,12,128,128],"tensor_name":"output_0_4"},{"data_type":"float16","format":"DefaultFormat","shape":[32,12,128,128],"tensor_name":"output_0_2"}],"platform":"AKG","process":"cuda"} | |||
| @@ -0,0 +1 @@ | |||
| {"composite":true,"composite_graph":"25320.25320","id":696,"input_desc":[[{"data_type":"float32","format":"DefaultFormat","shape":[32,12,128,1],"tensor_name":"input_1"}],[{"data_type":"float32","format":"DefaultFormat","shape":[32,12,128,128],"tensor_name":"input_2"}],[{"data_type":"float32","format":"DefaultFormat","shape":[32,12,128,128],"tensor_name":"input_0"}]],"op":"Fused_Sub_Mul_Mul_split_9258187064108063363","op_desc":[{"attr":null,"impl_path":"","input_desc":[[{"data_type":"float32","format":"DefaultFormat","name":"input_0","shape":[32,12,128,128],"tensor_name":"input_0"}],[{"data_type":"float32","format":"DefaultFormat","name":"input_1","shape":[32,12,128,1],"tensor_name":"input_1"}]],"name":"Sub","output_desc":[{"data_type":"float32","format":"DefaultFormat","name":"output","shape":[32,12,128,128],"tensor_name":"output_0_0"}]},{"attr":null,"impl_path":"","input_desc":[[{"data_type":"float32","format":"DefaultFormat","name":"input_0","shape":[32,12,128,128],"tensor_name":"input_2"}],[{"data_type":"float32","format":"DefaultFormat","name":"input_1","shape":[32,12,128,128],"tensor_name":"output_0_0"}]],"name":"Mul","output_desc":[{"data_type":"float32","format":"DefaultFormat","name":"output","shape":[32,12,128,128],"tensor_name":"output_0_1"}]},{"attr":null,"impl_path":"","input_desc":[[{"data_type":"float32","format":"DefaultFormat","name":"input_0","shape":[1],"tensor_name":"input_4","value":0.125}],[{"data_type":"float32","format":"DefaultFormat","name":"input_1","shape":[32,12,128,128],"tensor_name":"output_0_1"}]],"name":"Mul","output_desc":[{"data_type":"float32","format":"DefaultFormat","name":"output","shape":[32,12,128,128],"tensor_name":"output_0_2"}]}],"output_desc":[{"data_type":"float32","format":"DefaultFormat","shape":[32,12,128,128],"tensor_name":"output_0_2"}],"platform":"AKG","process":"cuda"} | |||
| @@ -0,0 +1 @@ | |||
| {"composite":true,"composite_graph":"27184.27184","id":926,"input_desc":[[{"data_type":"float16","format":"DefaultFormat","shape":[32,12,128,64],"tensor_name":"input_0"}]],"op":"Fused_Transpose_split_18185609042134105765","op_desc":[{"attr":[{"data_type":"listInt","name":"perm","value":[0,2,1,3]}],"impl_path":"","input_desc":[[{"data_type":"float16","format":"DefaultFormat","name":"input_0","shape":[32,12,128,64],"tensor_name":"input_0"}]],"name":"Transpose","output_desc":[{"data_type":"float16","format":"DefaultFormat","name":"output","shape":[32,128,12,64],"tensor_name":"output_0_0"}]}],"output_desc":[{"data_type":"float16","format":"DefaultFormat","shape":[32,128,12,64],"tensor_name":"output_0_0"}],"platform":"AKG","process":"cuda"} | |||
| @@ -0,0 +1,20 @@ | |||
| { | |||
| "name": "Fused_AddN_fusion_9584919353229493170_0", | |||
| "operator": "Fused_AddN_fusion_9584919353229493170_0", | |||
| "domain": "{ S_0[cc0] : 0 <= cc0 <= 3145727 }", | |||
| "schedule": [ | |||
| "{ S_0[cc0] -> [cc0/4096] }", | |||
| "{ S_0[cc0] -> [(cc0 mod 4096)/4] }", | |||
| "{ S_0[cc0] -> [(cc0 mod 4096) mod 4] }", | |||
| "{ S_0[cc0] -> [0] }" | |||
| ], | |||
| "gpu": { | |||
| "blocks": [0], | |||
| "threads": [1], | |||
| "compiler flags": ["--use_fast_math"] | |||
| }, | |||
| "disable": [ | |||
| "GpuDmaAnalysis", | |||
| "TileOuterBand" | |||
| ] | |||
| } | |||
| @@ -0,0 +1,22 @@ | |||
| { | |||
| "name": "Fused_Cast_BiasAdd_Gelu_fusion_7719078727474100806_0", | |||
| "operator": "Fused_Cast_BiasAdd_Gelu_fusion_7719078727474100806_0", | |||
| "domain": [ | |||
| "{ S_0[ax0_ax1_fused] : 0 <= ax0_ax1_fused <= 12582911 }", | |||
| "{ S_1[ax0_ax1_fused] : 0 <= ax0_ax1_fused <= 12582911 }" | |||
| ], | |||
| "schedule": [ | |||
| "{ S_0[ax0_ax1_fused] -> [ax0_ax1_fused/3072]; S_1[ax0_ax1_fused] -> [ax0_ax1_fused/3072] }", | |||
| "{ S_0[ax0_ax1_fused] -> [(ax0_ax1_fused mod 3072)/4]; S_1[ax0_ax1_fused] -> [(ax0_ax1_fused mod 3072)/4] }", | |||
| "{ S_0[ax0_ax1_fused] -> [(ax0_ax1_fused mod 3072) mod 4]; S_1[ax0_ax1_fused] -> [(ax0_ax1_fused mod 3072) mod 4] }", | |||
| "{ S_0[ax0_ax1_fused] -> [0]; S_1[ax0_ax1_fused] -> [1] }" | |||
| ], | |||
| "gpu": { | |||
| "blocks": [0], | |||
| "threads": [1] | |||
| }, | |||
| "disable": [ | |||
| "GpuDmaAnalysis", | |||
| "TileOuterBand" | |||
| ] | |||
| } | |||
| @@ -0,0 +1,22 @@ | |||
| { | |||
| "name": "Fused_Cast_BiasAdd_GkDropout_tuple_getitem_TensorAdd_fusion_13282325956852925231_0", | |||
| "operator": "Fused_Cast_BiasAdd_GkDropout_tuple_getitem_TensorAdd_fusion_13282325956852925231_0", | |||
| "domain": [ | |||
| "{ S_0[ax0_ax1_fused] : 0 <= ax0_ax1_fused <= 3145727 }", | |||
| "{ S_1[ax0_ax1_fused] : 0 <= ax0_ax1_fused <= 3145727 }" | |||
| ], | |||
| "schedule": [ | |||
| "{ S_0[ax0_ax1_fused] -> [ax0_ax1_fused/768]; S_1[ax0_ax1_fused] -> [ax0_ax1_fused/768] }", | |||
| "{ S_0[ax0_ax1_fused] -> [(ax0_ax1_fused mod 768)/4]; S_1[ax0_ax1_fused] -> [(ax0_ax1_fused mod 768)/4] }", | |||
| "{ S_0[ax0_ax1_fused] -> [(ax0_ax1_fused mod 768) mod 4]; S_1[ax0_ax1_fused] -> [(ax0_ax1_fused mod 768) mod 4] }", | |||
| "{ S_0[ax0_ax1_fused] -> [0]; S_1[ax0_ax1_fused] -> [1] }" | |||
| ], | |||
| "gpu": { | |||
| "blocks": [0], | |||
| "threads": [1] | |||
| }, | |||
| "disable": [ | |||
| "GpuDmaAnalysis", | |||
| "TileOuterBand" | |||
| ] | |||
| } | |||
| @@ -0,0 +1,57 @@ | |||
| { | |||
| "name": "Fused_Cast_RealDiv_Reshape_FusedAdamWeightDecay_fusion_1039082044534023692_0", | |||
| "operator": "Fused_Cast_RealDiv_Reshape_FusedAdamWeightDecay_fusion_1039082044534023692_0", | |||
| "domain": [ | |||
| "{ S_0[ax0, ax1] : 0 <= ax0 <= 767 and 0 <= ax1 <= 3071 }", | |||
| "{ S_1[ax0, ax1] : 0 <= ax0 <= 767 and 0 <= ax1 <= 3071 }", | |||
| "{ S_2[ax0, ax1] : 0 <= ax0 <= 767 and 0 <= ax1 <= 3071 }" | |||
| ], | |||
| "pattern": "{ domain: \"{ S_0[ax0, ax1] : 0 <= ax0 <= 767 and 0 <= ax1 <= 3071; S_2[ax0, ax1] : 0 <= ax0 <= 767 and 0 <= ax1 <= 3071; S_1[ax0, ax1] : 0 <= ax0 <= 767 and 0 <= ax1 <= 3071 }\", child: { sequence: [ { filter: \"{ S_0[ax0, ax1] }\", child: { schedule: \"[{ S_0[ax0, ax1] -> [(ax0)] }]\", child: { schedule: \"[{ S_0[ax0, ax1] -> [(ax1)] }]\" } } }, { filter: \"{ S_1[ax0, ax1] }\", child: { schedule: \"[{ S_1[ax0, ax1] -> [(ax0)] }]\", child: { schedule: \"[{ S_1[ax0, ax1] -> [(ax1)] }]\" } } }, { filter: \"{ S_2[ax0, ax1] }\", child: { schedule: \"[{ S_2[ax0, ax1] -> [(ax0)] }]\", child: { schedule: \"[{ S_2[ax0, ax1] -> [(ax1)] }]\" } } } ] } }", | |||
| "schedule": [ | |||
| "{ S_0[ax0, ax1] -> [ax0]; S_1[ax0, ax1] -> [ax0]; S_2[ax0, ax1] -> [ax0] }", | |||
| "{ S_0[ax0, ax1] -> [ax1/4]; S_1[ax0, ax1] -> [ax1/4]; S_2[ax0, ax1] -> [ax1/4] }", | |||
| "{ S_0[ax0, ax1] -> [ax1 mod 4]; S_1[ax0, ax1] -> [ax1 mod 4]; S_2[ax0, ax1] -> [ax1 mod 4] }", | |||
| "{ S_0[ax0, ax1] -> [0]; S_1[ax0, ax1] -> [1]; S_2[ax0, ax1] -> [2] }" | |||
| ], | |||
| "soft constraints": [ | |||
| { | |||
| "statement": "S_0", | |||
| "meta": [2, 0], | |||
| "coefficients": [ | |||
| "[?1, 0, ?3]", | |||
| "[?1, 1, ?3] (/4)", | |||
| "[?1, 1, ?3] (%4)", | |||
| "[0, 0, 0]" | |||
| ] | |||
| }, | |||
| { | |||
| "statement": "S_1", | |||
| "meta": [2, 0], | |||
| "coefficients": [ | |||
| "[?1, 0, ?3]", | |||
| "[?1, 1, ?3] (/4)", | |||
| "[?1, 1, ?3] (%4)", | |||
| "[0, 0, 1]" | |||
| ] | |||
| }, | |||
| { | |||
| "statement": "S_2", | |||
| "meta": [2, 0], | |||
| "coefficients": [ | |||
| "[?1, 0, ?3]", | |||
| "[?1, 1, ?3] (/4)", | |||
| "[?1, 1, ?3] (%4)", | |||
| "[0, 0, 2]" | |||
| ] | |||
| } | |||
| ], | |||
| "gpu": { | |||
| "blocks": [0], | |||
| "threads": [1], | |||
| "compiler flags": ["--use_fast_math"] | |||
| }, | |||
| "disable": [ | |||
| "GpuDmaAnalysis", | |||
| "TileOuterBand" | |||
| ] | |||
| } | |||
| @@ -0,0 +1,57 @@ | |||
| { | |||
| "name": "Fused_Cast_RealDiv_Reshape_FusedAdamWeightDecay_fusion_1545859458890067484_0", | |||
| "operator": "Fused_Cast_RealDiv_Reshape_FusedAdamWeightDecay_fusion_1545859458890067484_0", | |||
| "domain": [ | |||
| "{ S_0[ax0, ax1] : 0 <= ax0 <= 767 and 0 <= ax1 <= 767 }", | |||
| "{ S_1[ax0, ax1] : 0 <= ax0 <= 767 and 0 <= ax1 <= 767 }", | |||
| "{ S_2[ax0, ax1] : 0 <= ax0 <= 767 and 0 <= ax1 <= 767 }" | |||
| ], | |||
| "pattern": "{ domain: \"{ S_0[ax0, ax1] : 0 <= ax0 <= 767 and 0 <= ax1 <= 767; S_2[ax0, ax1] : 0 <= ax0 <= 767 and 0 <= ax1 <= 767; S_1[ax0, ax1] : 0 <= ax0 <= 767 and 0 <= ax1 <= 767 }\", child: { sequence: [ { filter: \"{ S_0[ax0, ax1] }\", child: { schedule: \"[{ S_0[ax0, ax1] -> [(ax0)] }]\", child: { schedule: \"[{ S_0[ax0, ax1] -> [(ax1)] }]\" } } }, { filter: \"{ S_1[ax0, ax1] }\", child: { schedule: \"[{ S_1[ax0, ax1] -> [(ax0)] }]\", child: { schedule: \"[{ S_1[ax0, ax1] -> [(ax1)] }]\" } } }, { filter: \"{ S_2[ax0, ax1] }\", child: { schedule: \"[{ S_2[ax0, ax1] -> [(ax0)] }]\", child: { schedule: \"[{ S_2[ax0, ax1] -> [(ax1)] }]\" } } } ] } }", | |||
| "schedule": [ | |||
| "{ S_0[ax0, ax1] -> [ax0]; S_1[ax0, ax1] -> [ax0]; S_2[ax0, ax1] -> [ax0] }", | |||
| "{ S_0[ax0, ax1] -> [ax1/4]; S_1[ax0, ax1] -> [ax1/4]; S_2[ax0, ax1] -> [ax1/4] }", | |||
| "{ S_0[ax0, ax1] -> [ax1 mod 4]; S_1[ax0, ax1] -> [ax1 mod 4]; S_2[ax0, ax1] -> [ax1 mod 4] }", | |||
| "{ S_0[ax0, ax1] -> [0]; S_1[ax0, ax1] -> [1]; S_2[ax0, ax1] -> [2] }" | |||
| ], | |||
| "soft constraints": [ | |||
| { | |||
| "statement": "S_0", | |||
| "meta": [2, 0], | |||
| "coefficients": [ | |||
| "[?1, 0, ?3]", | |||
| "[?1, 1, ?3] (/4)", | |||
| "[?1, 1, ?3] (%4)", | |||
| "[0, 0, 0]" | |||
| ] | |||
| }, | |||
| { | |||
| "statement": "S_1", | |||
| "meta": [2, 0], | |||
| "coefficients": [ | |||
| "[?1, 0, ?3]", | |||
| "[?1, 1, ?3] (/4)", | |||
| "[?1, 1, ?3] (%4)", | |||
| "[0, 0, 1]" | |||
| ] | |||
| }, | |||
| { | |||
| "statement": "S_2", | |||
| "meta": [2, 0], | |||
| "coefficients": [ | |||
| "[?1, 0, ?3]", | |||
| "[?1, 1, ?3] (/4)", | |||
| "[?1, 1, ?3] (%4)", | |||
| "[0, 0, 2]" | |||
| ] | |||
| } | |||
| ], | |||
| "gpu": { | |||
| "blocks": [0], | |||
| "threads": [1], | |||
| "compiler flags": ["--use_fast_math"] | |||
| }, | |||
| "disable": [ | |||
| "GpuDmaAnalysis", | |||
| "TileOuterBand" | |||
| ] | |||
| } | |||
| @@ -0,0 +1,57 @@ | |||
| { | |||
| "name": "Fused_Cast_RealDiv_Reshape_FusedAdamWeightDecay_fusion_1976850843332086880_0", | |||
| "operator": "Fused_Cast_RealDiv_Reshape_FusedAdamWeightDecay_fusion_1976850843332086880_0", | |||
| "domain": [ | |||
| "{ S_0[ax0, ax1] : 0 <= ax0 <= 3071 and 0 <= ax1 <= 767 }", | |||
| "{ S_1[ax0, ax1] : 0 <= ax0 <= 3071 and 0 <= ax1 <= 767 }", | |||
| "{ S_2[ax0, ax1] : 0 <= ax0 <= 3071 and 0 <= ax1 <= 767 }" | |||
| ], | |||
| "pattern": "{ domain: \"{ S_0[ax0, ax1] : 0 <= ax0 <= 3071 and 0 <= ax1 <= 767; S_2[ax0, ax1] : 0 <= ax0 <= 3071 and 0 <= ax1 <= 767; S_1[ax0, ax1] : 0 <= ax0 <= 3071 and 0 <= ax1 <= 767 }\", child: { sequence: [ { filter: \"{ S_0[ax0, ax1] }\", child: { schedule: \"[{ S_0[ax0, ax1] -> [(ax0)] }]\", child: { schedule: \"[{ S_0[ax0, ax1] -> [(ax1)] }]\" } } }, { filter: \"{ S_1[ax0, ax1] }\", child: { schedule: \"[{ S_1[ax0, ax1] -> [(ax0)] }]\", child: { schedule: \"[{ S_1[ax0, ax1] -> [(ax1)] }]\" } } }, { filter: \"{ S_2[ax0, ax1] }\", child: { schedule: \"[{ S_2[ax0, ax1] -> [(ax0)] }]\", child: { schedule: \"[{ S_2[ax0, ax1] -> [(ax1)] }]\" } } } ] } }", | |||
| "schedule": [ | |||
| "{ S_0[ax0, ax1] -> [ax0]; S_1[ax0, ax1] -> [ax0]; S_2[ax0, ax1] -> [ax0] }", | |||
| "{ S_0[ax0, ax1] -> [ax1/4]; S_1[ax0, ax1] -> [ax1/4]; S_2[ax0, ax1] -> [ax1/4] }", | |||
| "{ S_0[ax0, ax1] -> [ax1 mod 4]; S_1[ax0, ax1] -> [ax1 mod 4]; S_2[ax0, ax1] -> [ax1 mod 4] }", | |||
| "{ S_0[ax0, ax1] -> [0]; S_1[ax0, ax1] -> [1]; S_2[ax0, ax1] -> [2] }" | |||
| ], | |||
| "soft constraints": [ | |||
| { | |||
| "statement": "S_0", | |||
| "meta": [2, 0], | |||
| "coefficients": [ | |||
| "[?1, 0, ?3]", | |||
| "[?1, 1, ?3] (/4)", | |||
| "[?1, 1, ?3] (%4)", | |||
| "[0, 0, 0]" | |||
| ] | |||
| }, | |||
| { | |||
| "statement": "S_1", | |||
| "meta": [2, 0], | |||
| "coefficients": [ | |||
| "[?1, 0, ?3]", | |||
| "[?1, 1, ?3] (/4)", | |||
| "[?1, 1, ?3] (%4)", | |||
| "[0, 0, 1]" | |||
| ] | |||
| }, | |||
| { | |||
| "statement": "S_2", | |||
| "meta": [2, 0], | |||
| "coefficients": [ | |||
| "[?1, 0, ?3]", | |||
| "[?1, 1, ?3] (/4)", | |||
| "[?1, 1, ?3] (%4)", | |||
| "[0, 0, 2]" | |||
| ] | |||
| } | |||
| ], | |||
| "gpu": { | |||
| "blocks": [0], | |||
| "threads": [1], | |||
| "compiler flags": ["--use_fast_math"] | |||
| }, | |||
| "disable": [ | |||
| "GpuDmaAnalysis", | |||
| "TileOuterBand" | |||
| ] | |||
| } | |||
| @@ -0,0 +1,19 @@ | |||
| { | |||
| "name": "Fused_GkDropout_2353362030752466006_0", | |||
| "operator": "Fused_GkDropout_2353362030752466006_0", | |||
| "domain": "{ S_0[cc0] : 0 <= cc0 <= 6291455; S_1[cc0] : 0 <= cc0 <= 6291455 }", | |||
| "schedule": [ | |||
| "{ S_0[cc0] -> [cc0/3072]; S_1[cc0] -> [cc0/3072] }", | |||
| "{ S_0[cc0] -> [(cc0 mod 3072)/4]; S_1[cc0] -> [(cc0 mod 3072)/4] }", | |||
| "{ S_0[cc0] -> [(cc0 mod 3072) mod 4]; S_1[cc0] -> [(cc0 mod 3072) mod 4] }", | |||
| "{ S_0[cc0] -> [0]; S_1[cc0] -> [1] }" | |||
| ], | |||
| "gpu": { | |||
| "blocks": [0], | |||
| "threads": [1] | |||
| }, | |||
| "disable": [ | |||
| "GpuDmaAnalysis", | |||
| "TileOuterBand" | |||
| ] | |||
| } | |||
| @@ -0,0 +1,37 @@ | |||
| { | |||
| "name": "Fused_Transpose_split_18185609042134105765_0", | |||
| "operator": "Fused_Transpose_split_18185609042134105765_0", | |||
| "domain": "{ S_0[ax0, ax1, ax2, ax3] : 0 <= ax0 <= 31 and 0 <= ax1 <= 127 and 0 <= ax2 <= 11 and 0 <= ax3 <= 63 }", | |||
| "pattern": "{ domain: \"{ S_0[ax0, ax1, ax2, ax3] : 0 <= ax0 <= 31 and 0 <= ax1 <= 127 and 0 <= ax2 <= 11 and 0 <= ax3 <= 63 }\", child: { schedule: \"[{ S_0[ax0, ax1, ax2, ax3] -> [(ax0)] }]\", child: { schedule: \"[{ S_0[ax0, ax1, ax2, ax3] -> [(ax1)] }]\", child: { schedule: \"[{ S_0[ax0, ax1, ax2, ax3] -> [(ax2)] }]\", child: { schedule: \"[{ S_0[ax0, ax1, ax2, ax3] -> [(ax3)] }]\" } } } } }", | |||
| "schedule": [ | |||
| "{ S_0[ax0, ax1, ax2, ax3] -> [ax0] }", | |||
| "{ S_0[ax0, ax1, ax2, ax3] -> [ax1] }", | |||
| "{ S_0[ax0, ax1, ax2, ax3] -> [ax2] }", | |||
| "{ S_0[ax0, ax1, ax2, ax3] -> [ax3/4] }", | |||
| "{ S_0[ax0, ax1, ax2, ax3] -> [ax3 mod 4] }", | |||
| "{ S_0[ax0, ax1, ax2, ax3] -> [0] }" | |||
| ], | |||
| "soft constraints": [ | |||
| { | |||
| "statement": "S_0", | |||
| "meta": [4, 0], | |||
| "coefficients": [ | |||
| "[?, ?, ?, 0, ?]", | |||
| "[?, ?, ?, 0, ?]", | |||
| "[?, ?, ?, 0, ?]", | |||
| "[?, ?, ?, 1, ?] (/4)", | |||
| "[?, ?, ?, 1, ?] (%4)", | |||
| "[0, 0, 0, 0, 0]" | |||
| ] | |||
| } | |||
| ], | |||
| "gpu": { | |||
| "blocks": [0, 1], | |||
| "threads": [2, 3], | |||
| "compiler flags": ["--use_fast_math"] | |||
| }, | |||
| "disable": [ | |||
| "GpuDmaAnalysis", | |||
| "TileOuterBand" | |||
| ] | |||
| } | |||
| @@ -0,0 +1,7 @@ | |||
| { | |||
| "name": "fused_bn_update", | |||
| "operator": "fused_bn_update_auto_float32_2048_float32_2048_float32_2048_float32_2048_float32_7_971938775510203e_05_1_001e_05_1_00007975_0_100000024", | |||
| "pattern": "{ domain: \"{ S_2[ax0] : 0 <= ax0 <= 2047; S_0[i0] : 0 <= i0 <= 2047; S_1[ax0] : 0 <= ax0 <= 2047 }\", child: { sequence: [ { filter: \"{ S_0[i0] }\", child: { schedule: \"[{ S_0[i0] -> [(i0)] }]\" } }, { filter: \"{ S_1[ax0] }\", child: { schedule: \"[{ S_1[ax0] -> [(ax0)] }]\"} }, { filter: \"{ S_2[ax0] }\", child: { schedule: \"[{ S_2[ax0] -> [(ax0)] }]\" } } ] } }", | |||
| "tree": "{ \"domain\": \"{ S_2[ax0] : 0 <= ax0 <= 2047; S_0[i0] : 0 <= i0 <= 2047; S_1[ax0] : 0 <= ax0 <= 2047 }\", \"child\": { \"context\": \"[b0, t0] -> {[]: 0 <= b0 < 2 and 0 <= t0 < 256 }\", \"child\": { \"mark\": \"block_marker\", \"child\": { \"filter\": \"[b0] -> { S_0[i]: b0 = floor(i / 1024); S_1[i]: b0 = floor(i / 1024); S_2[i]: b0 = floor(i / 1024) }\", \"child\": { \"schedule\": \"[{ S_0[i] -> [i/1024]; S_1[i] -> [i/1024]; S_2[i] -> [i/1024] }]\", \"child\": { \"mark\": \"thread_marker\", \"child\": { \"filter\": \"[t0] -> { S_0[i]: t0 = floor((i / 4) mod 256); S_1[i]: t0 = floor((i / 4) mod 256); S_2[i]: t0 = floor((i / 4) mod 256) }\", \"child\": { \"schedule\": \"[{ S_0[i] -> [(i/4) mod 256]; S_1[i] -> [(i/4) mod 256]; S_2[i] -> [(i/4) mod 256] }]\", \"child\": { \"schedule\": \"[{S_0[i] -> [i mod 4]; S_1[i] -> [i mod 4]; S_2[i] -> [i mod 4] }]\", \"child\": { \"schedule\": \"[{ S_0[i] -> [2]; S_1[i] -> [1]; S_2[i] -> [0] }]\" } } } } } } } } } }", | |||
| "disable": [ "GpuDmaAnalysis", "TileOuterBand", "MappingOuterBand" ] | |||
| } | |||
| @@ -0,0 +1,34 @@ | |||
| { | |||
| "name": "fused_relu_grad", | |||
| "operator": "fused_relu_grad_auto_float16_256_56_56_256_float16_256_56_56_256_float16_256_56_56_256_0", | |||
| "pattern": "{ domain: \"{ S_0[ax0, ax1, ax2, ax3] : 0 <= ax0 <= 255 and 0 <= ax1 <= 55 and 0 <= ax2 <= 55 and 0 <= ax3 <= 255 }\", child: { schedule: \"[{ S_0[ax0, ax1, ax2, ax3] -> [(ax0)] }]\", child: { schedule: \"[{ S_0[ax0, ax1, ax2, ax3] -> [(ax1)] }]\", child: { schedule: \"[{ S_0[ax0, ax1, ax2, ax3] -> [(ax2)] }]\", child: { schedule: \"[{ S_0[ax0, ax1, ax2, ax3] -> [(ax3)] }]\" } } } } }", | |||
| "disable": [ "GpuDmaAnalysis", "TileOuterBand" ], | |||
| "gpu": { | |||
| "blocks": [0], | |||
| "threads": [1] | |||
| }, | |||
| "domain": "{ S_0[i0, i1, i2, i3] : 0 <= i0 <= 255 and 0 <= i1 <= 55 and 0 <= i2 <= 55 and 0 <= i3 <= 255 }", | |||
| "schedule": [ | |||
| "{ S_0[i0, i1, i2, i3] -> [(i0 * 56 + i1) * 14 + i2 / 4] }", | |||
| "{ S_0[i0, i1, i2, i3] -> [(i2 mod 4) * 64 + (i3 / 4)] }", | |||
| "{ S_0[i0, i1, i2, i3] -> [i3 mod 4] }" | |||
| ], | |||
| "soft constraints": [ | |||
| { | |||
| "statement": "S_0", | |||
| "meta": [4, 0], | |||
| "coefficients": [ | |||
| "[1, 0, 0, 0, 0]", | |||
| "[0, 1, 0, 0, 0]", | |||
| "[0, 0, 1, 0, 0] (/4)", | |||
| "[0, 0, 1, 0, 0] (%4)", | |||
| "[0, 0, 0, 1, 0] (/4)", | |||
| "[0, 0, 0, 1, 0] (%4)" | |||
| ] | |||
| } | |||
| ], | |||
| "post": [ | |||
| { "collapse": [3, 4] }, | |||
| { "collapse": [0, 2] } | |||
| ] | |||
| } | |||
| @@ -62,6 +62,7 @@ from tests.st.ops.gpu.test_fused_relu_grad import test_fused_relu_grad | |||
| from tests.st.ops.gpu.test_fused_bn_update_grad import test_fused_bn_update_grad | |||
| from tests.st.ops.gpu.test_fused_mul_div_rsqrt_mul_isfinite_red import test_fused_mul_div_rsqrt_mul_isfinite_red | |||
| from tests.st.ops.gpu.test_ms_composite_stitch import test_composite_stitch | |||
| from tests.st.ops.gpu.test_ms_mindtricks import test_mindtricks | |||
| @pytest.mark.level0 | |||
| @@ -505,6 +506,14 @@ def test_ms_composite_buffer_stitch(): | |||
| return True | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_ms_mindtricks(): | |||
| test_mindtricks() | |||
| return True | |||
| class Logger(object): | |||
| def __init__(self, filename, stream): | |||
| self.terminal = stream | |||
| @@ -44,15 +44,18 @@ def compute_expect(input, c1, c2, c3, c4): | |||
| def test_fused_bn_update(shape, dtype="float32", c1=(1 / (256 * 7 * 7)), c2=1.001e-05, c3=1.00007975, c4=0.100000024, | |||
| poly_sch=False): | |||
| poly_sch=False, mind_trick=''): | |||
| input = gen_data(shape, dtype) | |||
| expect = compute_expect(input, c1, c2, c3, c4) | |||
| attrs = [dtype, c1, c2, c3, c4] | |||
| shapes = [input[0].shape] * 4 | |||
| dtypes = [dtype] * 4 | |||
| if poly_sch: | |||
| mod_attrs = { "target": "cuda" } | |||
| if mind_trick: | |||
| mod_attrs["mind_trick"] = mind_trick | |||
| mod = utils.op_build_test(fused_bn_update, shapes, dtypes, kernel_name="fused_bn_update", op_attrs=attrs, | |||
| attrs={"target": "cuda"}) | |||
| attrs=mod_attrs) | |||
| outputs = [np.full(shape, np.nan, dtype)] * 3 | |||
| attrs_list = input + outputs | |||
| @@ -38,7 +38,7 @@ def compute_expect(input, c1): | |||
| return np.where(cmp_zero, data_add, data_zero) | |||
| def test_fused_relu_grad(shape, c1=0, poly_sch=False): | |||
| def test_fused_relu_grad(shape, c1=0, poly_sch=False, mind_trick=''): | |||
| dtype = 'float16' | |||
| input = gen_data(shape, dtype) | |||
| expect = compute_expect(input, c1) | |||
| @@ -46,8 +46,11 @@ def test_fused_relu_grad(shape, c1=0, poly_sch=False): | |||
| dtypes = [dtype] * 3 | |||
| attrs = [c1] | |||
| if poly_sch: | |||
| mod_attrs = { "target": "cuda" } | |||
| if mind_trick: | |||
| mod_attrs["mind_trick"] = mind_trick | |||
| mod = utils.op_build_test(fused_relu_grad, shapes, dtypes, op_attrs=attrs, kernel_name="fused_relu_grad", | |||
| attrs={"target": "cuda"}) | |||
| attrs=mod_attrs) | |||
| output = np.full(shape, np.nan, dtype) | |||
| output = utils.mod_launch(mod, (*input, output), expect=expect) | |||
| @@ -0,0 +1,159 @@ | |||
| #!/usr/bin/env python3 | |||
| # coding: utf-8 | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License | |||
| ######################################################################################################################## | |||
| # Basic libraries | |||
| import os | |||
| import sys | |||
| import json | |||
| import pytest | |||
| import logging | |||
| # For composite cases | |||
| from akg import composite | |||
| from akg.utils import kernel_exec as utils | |||
| from akg.utils.result_analysis import gpu_profiling | |||
| from akg.utils.format_transform import to_tvm_nd_array | |||
| from tests.common.gen_json_data import gen_json_data | |||
| from tests.common.base import get_rtol_atol | |||
| from tests.common.tensorio import compare_tensor | |||
| # Specific GPU tests | |||
| from tests.st.ops.gpu.test_fused_bn_update import test_fused_bn_update | |||
| from tests.st.ops.gpu.test_fused_relu_grad import test_fused_relu_grad | |||
| ######################################################################################################################## | |||
| mind_trick_cases_dir = "./mind-trick_cases" | |||
| # Note: no need to hardcode trick paths for composite_operators unless the trick's name differs from the operator name | |||
| tricks_dir = mind_trick_cases_dir + "/tricks" | |||
| tricks_filenames = { | |||
| "fused_bn_update": "fused_bn_update.json", | |||
| "fused_relu_grad": "fused_relu_grad.json", | |||
| } | |||
| composite_operators_dir = mind_trick_cases_dir + "/operators" | |||
| composite_operators = [ | |||
| "Fused_AddN_fusion_9584919353229493170", | |||
| "Fused_Cast_BiasAdd_Gelu_fusion_7719078727474100806", | |||
| "Fused_Cast_BiasAdd_GkDropout_tuple_getitem_TensorAdd_fusion_13282325956852925231", | |||
| "Fused_Cast_RealDiv_Reshape_FusedAdamWeightDecay_fusion_1039082044534023692", | |||
| "Fused_Cast_RealDiv_Reshape_FusedAdamWeightDecay_fusion_1545859458890067484", | |||
| "Fused_Cast_RealDiv_Reshape_FusedAdamWeightDecay_fusion_1976850843332086880", | |||
| "Fused_GkDropout_2353362030752466006", | |||
| "Fused_Transpose_split_18185609042134105765", | |||
| ] | |||
| ######################################################################################################################## | |||
| def _get_json_dict(desc): | |||
| return json.loads(desc) if isinstance(desc, str) else desc | |||
| def _get_backend(desc): | |||
| json_obj = _get_json_dict(desc) | |||
| if "process" not in json_obj.keys(): | |||
| logging.info("Can't identify the backend.") | |||
| return None | |||
| return json_obj["process"] | |||
| def _compare_func(output, expect): | |||
| rtol, atol = get_rtol_atol("FUSED", str(output.dtype)) | |||
| return compare_tensor(output, expect, rtol=rtol, atol=atol) | |||
| def get_result(desc, poly, attrs=None): | |||
| backend = _get_backend(desc) | |||
| if attrs is None: | |||
| attrs = {} | |||
| build_attrs = attrs if attrs else None | |||
| mod = composite.build(desc, build_attrs, poly=poly) | |||
| input_for_mod, expect, output_indexes = gen_json_data(desc) | |||
| output = utils.mod_launch(mod, input_for_mod, output_indexes) | |||
| if not all(map(_compare_func, output if isinstance(output, (list, tuple)) else [output], | |||
| expect if isinstance(expect, (list, tuple)) else [expect])): | |||
| logging.info(mod.imported_modules[0].get_source()) | |||
| return False | |||
| if backend == "cuda": | |||
| inputs = to_tvm_nd_array(input_for_mod) | |||
| expect = to_tvm_nd_array(expect) | |||
| gpu_profiling(mod, *inputs, *expect, repeat_time=400) | |||
| return True | |||
| def test_gpu_cases(): | |||
| # Check some GPU cases | |||
| with open(tricks_dir + "/" + tricks_filenames["fused_bn_update"], "r") as trick: | |||
| test_fused_bn_update((2048,), poly_sch=True, mind_trick=trick.read()) | |||
| with open(tricks_dir + "/" + tricks_filenames["fused_relu_grad"], "r") as trick: | |||
| test_fused_relu_grad((256, 56, 56, 256), poly_sch=True, mind_trick=trick.read()) | |||
| return True | |||
| def test_single_composite_file(input_file, attrs, poly): | |||
| with open(input_file, 'r') as f: | |||
| desc = f.read() | |||
| if get_result(desc, poly, attrs): | |||
| logging.info("Run Pass!") | |||
| else: | |||
| logging.info("Precision Error") | |||
| raise ValueError("Precision Error") | |||
| def test_composite_cases(operators=composite_operators): | |||
| for operator in operators: | |||
| operator_path = "" | |||
| if os.path.isfile(composite_operators_dir + "/" + operator + ".info"): | |||
| operator_path = composite_operators_dir + "/" + operator + ".info" | |||
| elif os.path.isfile(composite_operators_dir + "/" + operator + ".json"): | |||
| operator_path = composite_operators_dir + "/" + operator + ".json" | |||
| else: | |||
| logging.info("could not find desc for operator: " + operator) | |||
| trick_path = "" | |||
| if operator in tricks_filenames: | |||
| trick_path = tricks_dir + "/" + tricks_filenames[operator] | |||
| elif os.path.isfile(tricks_dir + "/" + operator + ".json"): | |||
| trick_path = tricks_dir + "/" + operator + ".json" | |||
| else: | |||
| logging.info("could not find trick for operator: " + operator) | |||
| if os.path.isfile(operator_path) and os.path.isfile(trick_path): | |||
| trick = open(trick_path, "r") | |||
| attrs = { | |||
| "target": "cuda", | |||
| "mind_trick": trick.read(), | |||
| } | |||
| test_single_composite_file(operator_path, attrs, poly=True) | |||
| return True | |||
| ######################################################################################################################## | |||
| def test_mindtricks(cases=["gpu", "composite"]): | |||
| if "gpu" in cases: | |||
| test_gpu_cases() | |||
| if "composite" in cases: | |||
| test_composite_cases(composite_operators) | |||
| return True | |||
| ######################################################################################################################## | |||
| if __name__ == '__main__': | |||
| test_mindtricks() | |||
| @@ -20,8 +20,8 @@ if [ $# -eq 1 ] && [ $1 = "gpu_ci" ]; then | |||
| echo "Argument gpu_ci is used in CI and will be deprecated." | |||
| else | |||
| CUR_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" | |||
| AKG_DIR="${CUR_DIR}/.." | |||
| AKG_BUILD_DIR="${AKG_DIR}/build" | |||
| AKG_DIR="${AKG_DIR:-${CUR_DIR}/..}" | |||
| AKG_BUILD_DIR="${AKG_BUILD_DIR:-${AKG_DIR}/build}" | |||
| TVM_ROOT="${AKG_DIR}/third_party/incubator-tvm" | |||
| export LD_LIBRARY_PATH=${AKG_BUILD_DIR}:${LD_LIBRARY_PATH} | |||
| @@ -583,6 +583,8 @@ class Call : public ExprNode { | |||
| static constexpr const char* glsl_texture_store = "glsl_texture_store"; | |||
| static constexpr const char* prefetch = "prefetch"; | |||
| static constexpr const char* isnan = "isnan"; | |||
| static constexpr const char* reinterpret_cast_op = "reinterpret_cast_op"; | |||
| static constexpr const char* ldg = "ldg"; | |||
| /*! \brief Vectorizable intrinsic list. */ | |||
| static const char* vectorizable_intrinsics[]; | |||
| @@ -1107,7 +1109,9 @@ enum class ForType : int { | |||
| /*! \brief Vector SIMD loop annotaion. */ | |||
| Vectorized = 2, | |||
| /*! \brief Unroll annotation. */ | |||
| Unrolled = 3 | |||
| Unrolled = 3, | |||
| /*! \brief Swizzle loop on GPU. */ | |||
| Swizzled = 4 | |||
| }; | |||
| // Kevice api of for loop | |||
| @@ -92,6 +92,21 @@ def compile_cuda(code, | |||
| file_target = path_target if path_target else temp_target | |||
| cmd = ["nvcc"] | |||
| cmd += ["--%s" % target, "-O3"] | |||
| mind_trick_gpu_flags = ".akg_gpu_compiler_flags_{}".format(os.getpid()) | |||
| if os.path.isfile(mind_trick_gpu_flags): | |||
| with open(mind_trick_gpu_flags, 'r') as f: | |||
| # This file should contain one flag per line | |||
| for line in f: | |||
| # Ensure we only use flags we explicitely allow and avoid arbitrary flags users may try to inject | |||
| flag = line.rstrip() | |||
| if flag == "-use_fast_math": | |||
| cmd += flag | |||
| try: | |||
| os.remove(mind_trick_gpu_flags) | |||
| except: | |||
| warnings.warn("Could not delete file {}".format(mind_trick_gpu_flags)) | |||
| if isinstance(arch, list): | |||
| cmd += arch | |||
| else: | |||
| @@ -60,6 +60,7 @@ | |||
| #include <cmath> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <tvm/ir_pass.h> | |||
| #include "literal/cuda_half_t.h" | |||
| #include "codegen_cuda.h" | |||
| @@ -130,6 +131,33 @@ std::string CodeGenCUDA::Finish() { | |||
| } | |||
| } | |||
| // TODO add condition | |||
| decl_stream << "// built-in for half swizzle\n" | |||
| "#include <cuda_fp16.h>\n" | |||
| "struct __device_builtin__ __align__(8) half4 { half x, y, z, w; };\n" | |||
| "\n" | |||
| "#if defined(__CUDACC_RTC__)\n" | |||
| "#define __CUDA_FP16_DECL__ __host__ __device__\n" | |||
| "#else\n" | |||
| "#define __CUDA_FP16_DECL__ static __device__ __inline__\n" | |||
| "#endif\n" | |||
| "\n" | |||
| "// half4 ldg function support\n" | |||
| "#if (defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__) || defined(__CUDACC_RTC__)\n" | |||
| "#define __LDG_PTR \"l\"\n" | |||
| "#else\n" | |||
| "// not sure about this one, it was copied from the half2 ldg() function\n" | |||
| "#define __LDG_PTR \"r\"\n" | |||
| "#endif /*(defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__) || defined(__CUDACC_RTC__)*/\n" | |||
| "\n" | |||
| "#define __HALF4_TO_UI(var) *(reinterpret_cast<unsigned long *>(&(var)))\n" | |||
| "__CUDA_FP16_DECL__ half4 __ldg(const half4 *ptr)\n" | |||
| "{\n" | |||
| " half4 ret;\n" | |||
| " asm (\"ld.global.nc.b64 %0, [%1];\" : \"=l\"(__HALF4_TO_UI(ret)) : __LDG_PTR(ptr));\n" | |||
| " return ret;\n" | |||
| "}\n\n"; | |||
| return CodeGenC::Finish(); | |||
| } | |||
| @@ -138,6 +166,11 @@ void CodeGenCUDA::VisitStmt_(const ir::For* op) { | |||
| PrintIndent(); | |||
| stream << "#pragma unroll\n"; | |||
| } | |||
| else if (op->for_type == ir::ForType::Swizzled) { | |||
| // remove this loop | |||
| PrintStmt(op->body); | |||
| return; | |||
| } | |||
| CodeGenC::VisitStmt_(op); | |||
| } | |||
| @@ -163,7 +196,11 @@ void CodeGenCUDA::PrintType(Type t, std::ostream& os) { // NOLINT(*) | |||
| os << "half"; | |||
| } else if (lanes <= 8) { | |||
| CHECK_EQ(lanes % 2, 0) << "only support even lane for half type"; | |||
| os << "float" << lanes / 2; | |||
| if (lanes <= 4) { // added for swizzle | |||
| os << "half" << lanes; | |||
| } else { | |||
| os << "float" << lanes / 2; | |||
| } | |||
| } else { | |||
| fail = true; | |||
| } | |||
| @@ -291,6 +328,7 @@ void CodeGenCUDA::PrintVecElemLoad( | |||
| } | |||
| } | |||
| void CodeGenCUDA::PrintVecElemStore( | |||
| const std::string& vec, Type t, int i, const std::string& value) { | |||
| this->PrintIndent(); | |||
| @@ -516,11 +554,119 @@ void CodeGenCUDA::VisitExpr_(const Call *op, std::ostream& os) { | |||
| return; | |||
| } | |||
| CodeGenC::VisitExpr_(op, os); | |||
| } else if (op->is_intrinsic(Call::reinterpret_cast_op)) { | |||
| os << "*(reinterpret_cast<"; | |||
| PrintType(op->args[1].type(), os); | |||
| if (op->args[0].as<IntImm>()->value > 0) os << op->args[0]; | |||
| os << "*>(&"; | |||
| auto ld = op->args[1].as<Load>(); | |||
| if (ld) { | |||
| auto var_name = ld->buffer_var->name_hint; | |||
| if (std::find_if(vec_loads.begin(), vec_loads.end(), | |||
| [var_name](const Variable *v) { return (v->name_hint == var_name); }) != vec_loads.end()) { | |||
| os << "sw_" + var_name; | |||
| } else this->PrintExpr(op->args[1], os); | |||
| } else this->PrintExpr(op->args[1], os); | |||
| os << "))"; | |||
| } else if (op->is_intrinsic(Call::ldg)) { | |||
| os << "__ldg(("; | |||
| PrintType(op->args[1].type(), os); | |||
| os << this->PrintExpr(op->args[0]) << " *)(&" << this->PrintExpr(op->args[1]) << "))"; | |||
| } else { | |||
| CodeGenC::VisitExpr_(op, os); | |||
| } | |||
| } | |||
| void CodeGenCUDA::VisitStmt_(const LetStmt* op) { | |||
| if (no_init_value){ | |||
| no_init_value = false; | |||
| PrintIndent(); | |||
| if (op->var.type() == Handle() && | |||
| handle_data_type_.count(op->var.get())) { | |||
| PrintType(handle_data_type_.at(op->var.get()), stream); | |||
| stream << "* " | |||
| << AllocVarID(op->var.get()) | |||
| << ";\n"; | |||
| } else { | |||
| PrintType(op->var.type(), this->stream); | |||
| this->stream << ' ' | |||
| << AllocVarID(op->var.get()) | |||
| << ";\n"; | |||
| } | |||
| PrintStmt(op->body); | |||
| } else { | |||
| CodeGenC::VisitStmt_(op); | |||
| } | |||
| } | |||
| void CodeGenCUDA::VisitExpr_(const Variable* op, std::ostream& os) { | |||
| // replace cce var with const index | |||
| if(replace_cce && loop_var == op){ | |||
| os << current_index; | |||
| } | |||
| else { | |||
| CodeGenC::VisitExpr_(op, os); | |||
| } | |||
| } | |||
| void CodeGenCUDA::VisitExpr_(const Load* op, std::ostream& os) { | |||
| int lanes = op->type.lanes(); | |||
| if (vec_store) { | |||
| static const char access[] = {'x', 'y', 'z', 'w', 'a', 'b', 'c', 'd'}; | |||
| if (lanes == 2 || lanes == 4) { | |||
| os << op->buffer_var->name_hint << "." << access[current_index]; | |||
| } else if(std::find_if(vec_loads.begin(), vec_loads.end(), | |||
| [op] (const Variable* v) { return (v->name_hint==op->buffer_var->name_hint); }) != vec_loads.end()){ | |||
| os << "sw_" << op->buffer_var->name_hint << "." << access[current_index]; | |||
| } else{ | |||
| // temp variable | |||
| os << op->buffer_var->name_hint << "["; | |||
| PrintExpr(op->index, os); | |||
| os << "]"; | |||
| } | |||
| } else { | |||
| CodeGenC::VisitExpr_(op, os); | |||
| } | |||
| } | |||
| void CodeGenCUDA::VisitStmt_(const Store* op) { | |||
| Type t = op->value.type(); | |||
| if (is_reinterpret && t.lanes() == 1) { | |||
| is_reinterpret = false; | |||
| std::string value = this->PrintExpr(op->value); | |||
| std::string ref = this->GetBufferRef(t, op->buffer_var.get(), op->index); | |||
| this->PrintIndent(); | |||
| Type elem_type = t.element_of(); | |||
| stream << "*(reinterpret_cast<"; | |||
| PrintType(elem_type, stream); | |||
| stream << loop_extent << "*>(&" << ref << ")) = " << value << ";\n"; | |||
| } else if (vec_store) { | |||
| replace_cce = true; | |||
| static const char access[] = {'x', 'y', 'z', 'w', 'a', 'b', 'c', 'd'}; | |||
| int lanes = op->buffer_var.type().lanes(); | |||
| loop_extent = lanes; | |||
| for (int i = 0; i < lanes; i++){ | |||
| this->PrintIndent(); | |||
| current_index = i; | |||
| stream << op->buffer_var->name_hint << "." << access[i] << " = "; | |||
| PrintExpr(op->value, stream); | |||
| stream << ";\n"; | |||
| } | |||
| replace_cce = false; | |||
| vec_store = false; | |||
| } else if (simple_store){ | |||
| simple_store = false; | |||
| std::string value = this->PrintExpr(op->value); | |||
| this->PrintIndent(); | |||
| stream << op->buffer_var->name_hint << " = " << value << ";\n"; | |||
| } else{ | |||
| CodeGenC::VisitStmt_(op); | |||
| } | |||
| } | |||
| void CodeGenCUDA::VisitStmt_(const AttrStmt* op) { | |||
| if (op->attr_key == attr::wmma_scope) { | |||
| const StringImm* scope_str = op->value.as<StringImm>(); | |||
| @@ -548,12 +694,30 @@ void CodeGenCUDA::VisitStmt_(const AttrStmt* op) { | |||
| const Variable* buffer = op->node.as<Variable>(); | |||
| int offset = op->value.as<IntImm>()->value; | |||
| sm_offsets[buffer] = offset; | |||
| } else if (op->attr_key == "reinterpret_store") { | |||
| loop_extent = op->value; | |||
| // mark next store statement to be a reinterpret cast | |||
| is_reinterpret = true; | |||
| } else if (op->attr_key == "vec_store") { | |||
| loop_var = op->value.as<Variable>(); | |||
| // mark next store statement to be a vector store | |||
| vec_store = true; | |||
| } else if (op->attr_key == "simple_store") { | |||
| // mark next store statement to be a basic store of type a = b | |||
| simple_store = true; | |||
| } else if (op->attr_key == "vec_load") { | |||
| vec_loads.insert(op->value.as<Variable>()); | |||
| } else if (op->attr_key == "no_init_value") { | |||
| // mark next let statement to be a simple, empty declaration | |||
| no_init_value = true; | |||
| } | |||
| CodeGenC::VisitStmt_(op); | |||
| } | |||
| void CodeGenCUDA::VisitStmt_(const Allocate* op) { | |||
| CHECK(!is_zero(op->condition)); | |||
| if (is_zero(op->condition)) { | |||
| stream << "// "; | |||
| } | |||
| std::string vid = AllocVarID(op->buffer_var.get()); | |||
| if (op->new_expr.defined()) { | |||
| // Prefer global static allocation for the program | |||
| @@ -637,6 +801,10 @@ void CodeGenCUDA::VisitExpr_(const Ramp* op, std::ostream& os) { | |||
| } | |||
| void CodeGenCUDA::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*) | |||
| if (vec_store) { | |||
| PrintExpr(op->value, os); | |||
| return; | |||
| } | |||
| if (op->type.is_int() && op->type.bits() == 8 && op->lanes == 4) { | |||
| // make_int8x4 | |||
| const int64_t *p = as_const_int(op->value); | |||
| @@ -94,6 +94,10 @@ class CodeGenCUDA final : public CodeGenC { | |||
| void VisitStmt_(const Evaluate *op) final; | |||
| void VisitStmt_(const Allocate *op) final; | |||
| void VisitStmt_(const AttrStmt *op) final; | |||
| void VisitStmt_(const LetStmt *op) final; | |||
| void VisitExpr_(const Variable *op, std::ostream &os) final; | |||
| void VisitExpr_(const Load *op, std::ostream &os) final; | |||
| void VisitStmt_(const Store *op) final; | |||
| private: | |||
| // Handle volatile loads. | |||
| @@ -120,6 +124,27 @@ class CodeGenCUDA final : public CodeGenC { | |||
| // whether need mma.h | |||
| bool need_mma_h_{false}; | |||
| // whether next store will be a reinterpret_cast | |||
| bool is_reinterpret{false}; | |||
| // the extent of the swizzled loop | |||
| Expr loop_extent{0}; | |||
| // whether next store is not an array | |||
| bool simple_store{false}; | |||
| // whether store uses vector format | |||
| bool vec_store{false}; | |||
| // var from Loads to modify with vector load | |||
| std::set<const Variable*> vec_loads; | |||
| // whether replace cce variable with constant in vec_store | |||
| bool replace_cce{false}; | |||
| // variable to replace | |||
| const Variable* loop_var; | |||
| // index to replace cce with | |||
| int current_index; | |||
| // do not set value to next LetStmt if true | |||
| bool no_init_value{false}; | |||
| // ignore next allocate stmt if true (trick to bypass some tests) | |||
| // bool ignore_next_allocate{false}; | |||
| // add for TensorCore | |||
| // warp tile size for TensorCore interface | |||
| Expr warp_tile_m = IntImm::make(Int(32), 1); | |||
| @@ -923,6 +923,9 @@ std::ostream &operator<<(std::ostream& out, ForType type) { // NOLINT(*) | |||
| case ForType::Vectorized: | |||
| out << "vectorized"; | |||
| break; | |||
| case ForType::Swizzled: | |||
| out << "swizzled"; | |||
| break; | |||
| } | |||
| return out; | |||
| } | |||
| @@ -0,0 +1,573 @@ | |||
| unchanged: | |||
| --- isl_0.22/include/isl/options.h 2021-01-26 13:29:02.345994411 +0100 | |||
| +++ isl/include/isl/options.h 2021-01-26 13:25:07.713990422 +0100 | |||
| @@ -49,6 +49,13 @@ int isl_options_get_coalesce_bounded_wra | |||
| isl_stat isl_options_set_coalesce_preserve_locals(isl_ctx *ctx, int val); | |||
| int isl_options_get_coalesce_preserve_locals(isl_ctx *ctx); | |||
| +/* ====================== AKG influence patch -- start ====================== */ | |||
| +isl_stat isl_options_set_akg_print_debug(isl_ctx *ctx, int val); | |||
| +int isl_options_get_akg_print_debug(isl_ctx *ctx); | |||
| +isl_stat isl_options_set_akg_influence_scheduler(isl_ctx *ctx, int val); | |||
| +int isl_options_get_akg_influence_scheduler(isl_ctx *ctx); | |||
| +/* ======================= AKG influence patch -- end ======================= */ | |||
| + | |||
| #if defined(__cplusplus) | |||
| } | |||
| #endif | |||
| unchanged: | |||
| --- isl_0.22/include/isl/schedule.h 2021-01-26 13:29:02.345994411 +0100 | |||
| +++ isl/include/isl/schedule.h 2021-01-26 13:25:07.713990422 +0100 | |||
| @@ -9,7 +9,7 @@ | |||
| #include <isl/set_type.h> | |||
| #include <isl/list.h> | |||
| #include <isl/printer_type.h> | |||
| - | |||
| +#include <isl/vec.h> | |||
| #if defined(__cplusplus) | |||
| extern "C" { | |||
| #endif | |||
| @@ -209,6 +209,57 @@ __isl_give isl_printer *isl_printer_prin | |||
| void isl_schedule_dump(__isl_keep isl_schedule *schedule); | |||
| __isl_give char *isl_schedule_to_str(__isl_keep isl_schedule *schedule); | |||
| +/* ====================== AKG influence patch -- start ====================== */ | |||
| + | |||
| +/* --- Exports --- */ | |||
| + | |||
| +/* export hidden types */ | |||
| +struct isl_sched_node; | |||
| +typedef struct isl_sched_node isl_sched_node; | |||
| +struct isl_sched_edge; | |||
| +typedef struct isl_sched_edge isl_sched_edge; | |||
| +struct isl_sched_graph; | |||
| +typedef struct isl_sched_graph isl_sched_graph; | |||
| + | |||
| +/* isl_sched_node functions */ | |||
| +int isl_sched_node_par_coef_offset(struct isl_sched_node *node); | |||
| +int isl_sched_node_cst_coef_offset(struct isl_sched_node *node); | |||
| +__isl_give isl_map *isl_sched_node_extract_schedule(struct isl_sched_node *node); | |||
| + | |||
| +/* isl vec functions */ | |||
| +int isl_inf_vec_get_size( isl_vec* vec); | |||
| +/* export isl_sched_edge functions */ | |||
| +/* export isl_sched_graph functions */ | |||
| +isl_stat isl_sched_graph_init(struct isl_sched_graph *graph, | |||
| + __isl_keep isl_schedule_constraints *sc); | |||
| +void isl_sched_graph_free(isl_ctx *ctx, struct isl_sched_graph *graph); | |||
| +__isl_give isl_schedule_node *isl_schedule_node_compute_schedule(isl_schedule_node *node, | |||
| + struct isl_sched_graph *graph); | |||
| + | |||
| +/* --- New functions --- */ | |||
| + | |||
| +/* isl_sched_node functions */ | |||
| +int isl_sched_node_get_nparam(const struct isl_sched_node *node); | |||
| +int isl_sched_node_get_nvar(const struct isl_sched_node *node); | |||
| + | |||
| +/* isl_vec functions */ | |||
| +int isl_influence_int_eq(isl_vec* v, int pos1, int pos2); | |||
| +isl_val* isl_influence_vec_get_elem(isl_vec*, int pos); | |||
| +/* isl_sched_graph functions */ | |||
| +struct isl_sched_node* isl_sched_graph_get_node(struct isl_sched_graph *graph, int i); | |||
| + | |||
| +/* --- AKG influence --- */ | |||
| +extern isl_basic_set* (*isl_influence_set_coef)( | |||
| + isl_ctx *ctx, struct isl_sched_graph *graph, isl_basic_set *bset); | |||
| +extern isl_basic_set* (*isl_influence_set_equal)( | |||
| + isl_ctx *ctx, struct isl_sched_graph *graph, isl_basic_set *bset); | |||
| +extern int (*isl_influence_maxvar)(struct isl_sched_graph* graph); | |||
| +extern int (*isl_influence_check_coincident)(struct isl_sched_graph* graph, isl_vec* sol); | |||
| +extern struct isl_sched_graph* (*isl_influence_sol_list_free)(struct isl_sched_graph* graph); | |||
| +extern struct isl_sched_graph* (*isl_influence_sol_add_elem)(isl_vec* sol, struct isl_sched_graph* graph); | |||
| +extern int(*isl_influence_sol_get_elem)(int sched, int pos, struct isl_sched_graph* graph); | |||
| +/* ======================= AKG influence patch -- end ======================= */ | |||
| + | |||
| #if defined(__cplusplus) | |||
| } | |||
| #endif | |||
| unchanged: | |||
| --- isl_0.22/isl_options.c 2021-01-26 13:29:02.313994410 +0100 | |||
| +++ isl/isl_options.c 2021-01-26 13:25:07.713990422 +0100 | |||
| @@ -228,6 +228,12 @@ ISL_ARG_BOOL(struct isl_options, print_s | |||
| "print statistics for every isl_ctx") | |||
| ISL_ARG_ULONG(struct isl_options, max_operations, 0, | |||
| "max-operations", 0, "default number of maximal operations per isl_ctx") | |||
| +/* ====================== AKG influence patch -- start ====================== */ | |||
| +ISL_ARG_BOOL(struct isl_options, akg_print_debug, 0, "print-debug", 0, | |||
| + "print debug info") | |||
| +ISL_ARG_BOOL(struct isl_options, akg_influence_scheduler, 0, "influence-schedule", 0, | |||
| + "update scheduler coefficients") | |||
| +/* ======================= AKG influence patch -- end ======================= */ | |||
| ISL_ARG_VERSION(print_version) | |||
| ISL_ARGS_END | |||
| @@ -402,3 +408,14 @@ ISL_CTX_SET_BOOL_DEF(isl_options, struct | |||
| ast_build_allow_or) | |||
| ISL_CTX_GET_BOOL_DEF(isl_options, struct isl_options, isl_options_args, | |||
| ast_build_allow_or) | |||
| + | |||
| +/* ====================== AKG influence patch -- start ====================== */ | |||
| +ISL_CTX_SET_BOOL_DEF(isl_options, struct isl_options, isl_options_args, | |||
| + akg_print_debug) | |||
| +ISL_CTX_GET_BOOL_DEF(isl_options, struct isl_options, isl_options_args, | |||
| + akg_print_debug) | |||
| +ISL_CTX_SET_BOOL_DEF(isl_options, struct isl_options, isl_options_args, | |||
| + akg_influence_scheduler) | |||
| +ISL_CTX_GET_BOOL_DEF(isl_options, struct isl_options, isl_options_args, | |||
| + akg_influence_scheduler) | |||
| +/* ======================= AKG influence patch -- end ======================= */ | |||
| unchanged: | |||
| --- isl_0.22/isl_options_private.h 2021-01-26 13:29:02.317994410 +0100 | |||
| +++ isl/isl_options_private.h 2021-01-26 13:25:07.713990422 +0100 | |||
| @@ -71,6 +71,10 @@ struct isl_options { | |||
| int print_stats; | |||
| unsigned long max_operations; | |||
| +/* ====================== AKG influence patch -- start ====================== */ | |||
| + int akg_print_debug; | |||
| + int akg_influence_scheduler; | |||
| +/* ======================= AKG influence patch -- end ======================= */ | |||
| }; | |||
| #endif | |||
| diff -u isl/isl_scheduler.c isl/isl_scheduler.c | |||
| --- isl/isl_scheduler.c 2021-02-05 09:56:15.186697283 +0100 | |||
| +++ isl/isl_scheduler.c 2021-03-31 14:49:38.568015437 +0200 | |||
| @@ -39,7 +39,6 @@ | |||
| #include <isl_morph.h> | |||
| #include <isl/ilp.h> | |||
| #include <isl_val_private.h> | |||
| - | |||
| /* | |||
| * The scheduling algorithm implemented in this file was inspired by | |||
| * Bondhugula et al., "Automatic Transformations for Communication-Minimized | |||
| @@ -303,6 +302,12 @@ | |||
| return is_condition(edge) || is_conditional_validity(edge); | |||
| } | |||
| +/* ====================== AKG influence patch -- start ====================== */ | |||
| +struct isl_influence_list; | |||
| +struct isl_influence_equal_list; | |||
| +struct isl_influence_sol_list; | |||
| +/* ======================= AKG influence patch -- end ======================= */ | |||
| + | |||
| /* Internal information about the dependence graph used during | |||
| * the construction of the schedule. | |||
| * | |||
| @@ -395,6 +400,11 @@ | |||
| int weak; | |||
| int max_weight; | |||
| +/* ====================== AKG influence patch -- start ====================== */ | |||
| + struct isl_influence_list *inf_list; | |||
| + struct isl_influence_equal_list *inf_equal_list; | |||
| + struct isl_influenec_sol_list* inf_sol_list; | |||
| +/* ======================= AKG influence patch -- end ======================= */ | |||
| }; | |||
| /* Initialize node_table based on the list of nodes. | |||
| @@ -757,6 +767,8 @@ | |||
| isl_hash_table_free(ctx, graph->edge_table[i]); | |||
| isl_hash_table_free(ctx, graph->node_table); | |||
| isl_basic_set_free(graph->lp); | |||
| + if(isl_influence_sol_list_free) | |||
| + graph=isl_influence_sol_list_free(graph); | |||
| } | |||
| /* For each "set" on which this function is called, increment | |||
| @@ -1233,7 +1245,6 @@ | |||
| return isl_stat_error; | |||
| if (compressed && (!hull || !compress || !decompress)) | |||
| return isl_stat_error; | |||
| - | |||
| return isl_stat_ok; | |||
| error: | |||
| isl_set_free(set); | |||
| @@ -3222,6 +3233,7 @@ | |||
| * In particular, the non-triviality region enforces that at least | |||
| * one of the linear combinations in the rows of node->indep is non-zero. | |||
| */ | |||
| + | |||
| static __isl_give isl_vec *solve_lp(isl_ctx *ctx, struct isl_sched_graph *graph) | |||
| { | |||
| int i; | |||
| @@ -3237,16 +3249,68 @@ | |||
| trivial = construct_trivial(node, node->indep); | |||
| else | |||
| trivial = isl_mat_zero(ctx, 0, 0); | |||
| + | |||
| graph->region[i].trivial = trivial; | |||
| } | |||
| lp = isl_basic_set_copy(graph->lp); | |||
| +/* ====================== AKG influence patch -- start ====================== */ | |||
| + isl_basic_set *lp_backup; | |||
| + isl_basic_set *lp_backup_inf; | |||
| + const int akg_influence = isl_options_get_akg_influence_scheduler(ctx); | |||
| + const int akg_debug = isl_options_get_akg_print_debug(ctx); | |||
| + if (akg_influence) | |||
| + { | |||
| + lp_backup = isl_basic_set_copy(graph->lp); | |||
| + lp = isl_influence_set_coef(ctx, graph, lp); | |||
| + lp = isl_influence_set_equal(ctx, graph, lp); | |||
| + lp_backup_inf = isl_basic_set_copy(lp); | |||
| + } | |||
| +/* ======================= AKG influence patch -- end ======================= */ | |||
| sol = isl_tab_basic_set_non_trivial_lexmin(lp, 2, graph->n, | |||
| graph->region, &check_conflict, graph); | |||
| +/* ====================== AKG influence patch -- start ====================== */ | |||
| + if (akg_influence) | |||
| + { | |||
| + if(sol->size == 0) | |||
| + { | |||
| + if (akg_debug >= 1) fprintf(stderr, "\ninfluence schedule did not find solution, relax region and try again::\n"); | |||
| + isl_vec_free(sol); | |||
| + | |||
| + for(int i=0; i< graph->n;i++) | |||
| + { | |||
| + isl_mat_free(graph->region[i].trivial); | |||
| + graph->region[i].trivial = isl_mat_zero(ctx,0,0); | |||
| + } | |||
| + | |||
| + sol = isl_tab_basic_set_non_trivial_lexmin(lp_backup_inf, 2, graph->n, graph->region, &check_conflict, graph); | |||
| + | |||
| + if (sol->size == 0) | |||
| + { | |||
| + if (akg_debug >= 1) fprintf(stderr, "\ninfluence schedule did not find solution, another try with use_coincidence=0:\n"); | |||
| + } else | |||
| + { | |||
| + isl_basic_set_free(lp_backup); | |||
| + } | |||
| + | |||
| + } | |||
| + else | |||
| + { | |||
| + isl_basic_set_free(lp_backup); | |||
| + isl_basic_set_free(lp_backup_inf); | |||
| + } | |||
| + } | |||
| + | |||
| + if (!sol->size) | |||
| + { | |||
| + if (akg_debug >= 1) | |||
| + fprintf(stderr, "solve_lp did not find solution for dimension: %d\n",graph->n_row); | |||
| + } | |||
| + | |||
| +/* ======================= AKG influence patch -- end ======================= */ | |||
| for (i = 0; i < graph->n; ++i) | |||
| isl_mat_free(graph->region[i].trivial); | |||
| return sol; | |||
| } | |||
| - | |||
| /* Extract the coefficients for the variables of "node" from "sol". | |||
| * | |||
| * Each schedule coefficient c_i_x is represented as the difference | |||
| @@ -3271,6 +3335,12 @@ | |||
| if (!csol) | |||
| return NULL; | |||
| +/* ====================== AKG influence patch -- start ====================== */ | |||
| + isl_ctx* const ctx = isl_vec_get_ctx(sol); | |||
| + const int akg_influence = isl_options_get_akg_influence_scheduler(ctx); | |||
| + const int akg_debug = isl_options_get_akg_print_debug(ctx); | |||
| + | |||
| +/* ======================= AKG influence patch -- end ======================= */ | |||
| for (i = 0; i < node->nvar; ++i) { | |||
| pos = 1 + node_var_coef_pos(node, i); | |||
| if (node->nonneg) | |||
| @@ -3304,10 +3374,15 @@ | |||
| if (sol->size == 0) | |||
| isl_die(sol->ctx, isl_error_internal, | |||
| "no solution found", goto error); | |||
| - if (graph->n_total_row >= graph->max_row) | |||
| +/* ====================== AKG influence patch -- start ====================== */ | |||
| + isl_ctx* const ctx = isl_vec_get_ctx(sol); | |||
| + const int akg_influence = isl_options_get_akg_influence_scheduler(ctx); | |||
| + const int akg_debug = isl_options_get_akg_print_debug(ctx); | |||
| + | |||
| + if (!akg_influence && graph->n_total_row >= graph->max_row) | |||
| isl_die(sol->ctx, isl_error_internal, | |||
| "too many schedule rows", goto error); | |||
| - | |||
| + /* ======================= AKG influence patch -- end ======================= */ | |||
| for (i = 0; i < graph->n; ++i) { | |||
| struct isl_sched_node *node = &graph->node[i]; | |||
| int pos; | |||
| @@ -3325,11 +3400,13 @@ | |||
| goto error; | |||
| pos = node_cst_coef_offset(node); | |||
| node->sched = isl_mat_set_element(node->sched, | |||
| - row, 0, sol->el[1 + pos]); | |||
| + row, 0, sol->el[1 + pos]); | |||
| pos = node_par_coef_offset(node); | |||
| for (j = 0; j < node->nparam; ++j) | |||
| + { | |||
| node->sched = isl_mat_set_element(node->sched, | |||
| row, 1 + j, sol->el[1 + pos + j]); | |||
| + } | |||
| for (j = 0; j < node->nvar; ++j) | |||
| node->sched = isl_mat_set_element(node->sched, | |||
| row, 1 + node->nparam + j, csol->el[j]); | |||
| @@ -3340,6 +3417,14 @@ | |||
| graph->n_row++; | |||
| graph->n_total_row++; | |||
| +/* ====================== AKG influence patch -- start ====================== */ | |||
| + if (akg_influence && graph->n_row == graph->maxvar) | |||
| + { | |||
| + if(graph->n_row < isl_influence_maxvar(graph)){ | |||
| + graph->maxvar++; | |||
| + } | |||
| + } | |||
| +/* ======================= AKG influence patch -- end ======================= */ | |||
| return 0; | |||
| error: | |||
| @@ -3917,14 +4002,16 @@ | |||
| static int compute_maxvar(struct isl_sched_graph *graph) | |||
| { | |||
| int i; | |||
| - | |||
| graph->maxvar = 0; | |||
| for (i = 0; i < graph->n; ++i) { | |||
| struct isl_sched_node *node = &graph->node[i]; | |||
| int nvar; | |||
| - if (node_update_vmap(node) < 0) | |||
| +/* ====================== AKG influence patch -- start ====================== */ | |||
| + if (node_update_vmap(node) < 0) { | |||
| return -1; | |||
| + } | |||
| +/* ======================= AKG influence patch -- end ======================= */ | |||
| nvar = node->nvar + graph->n_row - node->rank; | |||
| if (nvar > graph->maxvar) | |||
| graph->maxvar = nvar; | |||
| @@ -3968,6 +4055,10 @@ | |||
| sub->max_row = graph->max_row; | |||
| sub->n_total_row = graph->n_total_row; | |||
| sub->band_start = graph->band_start; | |||
| +/* ====================== AKG influence patch -- start ====================== */ | |||
| + sub->inf_list = graph->inf_list; | |||
| + sub->inf_equal_list = graph->inf_equal_list; | |||
| +/* ======================= AKG influence patch -- end ======================= */ | |||
| return isl_stat_ok; | |||
| } | |||
| @@ -3997,6 +4088,11 @@ | |||
| { | |||
| struct isl_sched_graph split = { 0 }; | |||
| +/* ====================== AKG influence patch -- start ====================== */ | |||
| + const int akg_influence = isl_options_get_akg_influence_scheduler(ctx); | |||
| + const int akg_debug = isl_options_get_akg_print_debug(ctx); | |||
| +/* ======================= AKG influence patch -- end ======================= */ | |||
| + | |||
| if (extract_sub_graph(ctx, graph, node_pred, edge_pred, data, | |||
| &split) < 0) | |||
| goto error; | |||
| @@ -5396,6 +5492,30 @@ | |||
| return NULL; | |||
| lp = isl_basic_set_copy(graph->lp); | |||
| +/* ====================== AKG influence patch -- start ====================== */ | |||
| + const int akg_influence = isl_options_get_akg_influence_scheduler(ctx); | |||
| + const int akg_debug = isl_options_get_akg_print_debug(ctx); | |||
| + if (akg_influence) | |||
| + { | |||
| + isl_basic_set *lp_backup = isl_basic_set_copy(graph->lp); | |||
| + lp = isl_influence_set_coef(ctx, graph, lp); | |||
| + lp = isl_influence_set_equal(ctx, graph, lp); | |||
| + isl_vec *sol = non_neg_lexmin(graph, lp, n_edge, want_integral); | |||
| + if (sol->size == 0) | |||
| + { | |||
| + if (akg_debug >= 1) { | |||
| + fprintf(stderr, "\ncarry_lp hack failed, restoring previous lp problem\n"); | |||
| + } | |||
| + sol = non_neg_lexmin(graph, lp_backup, n_edge, want_integral); | |||
| + } | |||
| + else | |||
| + { | |||
| + isl_basic_set_free(lp_backup); | |||
| + } | |||
| + | |||
| + return sol; | |||
| + } | |||
| +/* ======================= AKG influence patch -- end ======================= */ | |||
| return non_neg_lexmin(graph, lp, n_edge, want_integral); | |||
| } | |||
| @@ -5991,6 +6111,11 @@ | |||
| int use_coincidence; | |||
| int force_coincidence = 0; | |||
| int check_conditional; | |||
| + int coincidence_relaxed=0; | |||
| +/* ====================== AKG influence patch -- start ====================== */ | |||
| + const int akg_influence = isl_options_get_akg_influence_scheduler(ctx); | |||
| + const int akg_debug = isl_options_get_akg_print_debug(ctx); | |||
| +/* ======================= AKG influence patch -- end ======================= */ | |||
| if (sort_sccs(graph) < 0) | |||
| return isl_stat_error; | |||
| @@ -5998,10 +6123,17 @@ | |||
| clear_local_edges(graph); | |||
| check_conditional = need_condition_check(graph); | |||
| has_coincidence = has_any_coincidence(graph); | |||
| - | |||
| if (ctx->opt->schedule_outer_coincidence) | |||
| force_coincidence = 1; | |||
| + | |||
| + if(akg_influence){ | |||
| + int previous_coincidence = has_coincidence; | |||
| + int isl_maxvar = isl_influence_maxvar(graph); | |||
| + if(graph->maxvar > isl_maxvar) | |||
| + graph->maxvar = isl_maxvar; | |||
| + } | |||
| + | |||
| use_coincidence = has_coincidence; | |||
| while (graph->n_row < graph->maxvar) { | |||
| isl_vec *sol; | |||
| @@ -6014,19 +6146,34 @@ | |||
| if (setup_lp(ctx, graph, use_coincidence) < 0) | |||
| return isl_stat_error; | |||
| sol = solve_lp(ctx, graph); | |||
| - if (!sol) | |||
| + if(akg_influence && sol->size) | |||
| + graph=isl_influence_sol_add_elem(sol,graph); | |||
| + if(!sol) | |||
| + { | |||
| return isl_stat_error; | |||
| + } | |||
| if (sol->size == 0) { | |||
| + /* ====================== AKG influence patch -- start ====================== */ | |||
| + if (akg_influence && akg_debug >= 1) { | |||
| + fprintf(stderr, "No solution!\n"); | |||
| + } | |||
| + /* ======================= AKG influence patch -- end ======================= */ | |||
| int empty = graph->n_total_row == graph->band_start; | |||
| isl_vec_free(sol); | |||
| if (use_coincidence && (!force_coincidence || !empty)) { | |||
| use_coincidence = 0; | |||
| + coincidence_relaxed=1; | |||
| continue; | |||
| } | |||
| return isl_stat_ok; | |||
| } | |||
| coincident = !has_coincidence || use_coincidence; | |||
| + | |||
| + if(akg_influence && graph->n > 1){ | |||
| + coincident=isl_influence_check_coincident(graph,sol); | |||
| + } | |||
| + | |||
| if (update_schedule(graph, sol, coincident) < 0) | |||
| return isl_stat_error; | |||
| @@ -7688,6 +7835,10 @@ | |||
| if (graph->scc <= 1 || isl_options_get_schedule_whole_component(ctx)) | |||
| return compute_schedule_wcc_whole(node, graph); | |||
| +/* ====================== AKG influence patch -- start ====================== */ | |||
| + else if (isl_options_get_akg_influence_scheduler(ctx)) | |||
| + return compute_schedule_wcc_whole(node, graph); | |||
| +/* ======================= AKG influence patch -- end ======================= */ | |||
| else | |||
| return compute_schedule_wcc_clustering(node, graph); | |||
| } | |||
| @@ -7735,6 +7886,10 @@ | |||
| else | |||
| node = isl_schedule_node_insert_sequence(node, filters); | |||
| +/* ====================== AKG influence patch -- start ====================== */ | |||
| + const int akg_influence = isl_options_get_akg_influence_scheduler(ctx); | |||
| + const int akg_debug = isl_options_get_akg_print_debug(ctx); | |||
| +/* ======================= AKG influence patch -- end ======================= */ | |||
| for (component = 0; component < graph->scc; ++component) { | |||
| node = isl_schedule_node_child(node, component); | |||
| node = isl_schedule_node_child(node, 0); | |||
| @@ -7775,7 +7930,10 @@ | |||
| return isl_schedule_node_free(node); | |||
| } | |||
| - if (graph->scc > 1) | |||
| +/* ====================== AKG influence patch -- start ====================== */ | |||
| + const int akg_influence = isl_options_get_akg_influence_scheduler(ctx); | |||
| + if (graph->scc > 1 && akg_influence == 0) | |||
| +/* ======================= AKG influence patch -- end ======================= */ | |||
| return compute_component_schedule(node, graph, 1); | |||
| return compute_schedule_wcc(node, graph); | |||
| @@ -7806,6 +7964,11 @@ | |||
| isl_union_set *domain; | |||
| isl_size n; | |||
| +/* ====================== AKG influence patch -- start ====================== */ | |||
| + const int akg_influence = isl_options_get_akg_influence_scheduler(ctx); | |||
| + const int akg_debug = isl_options_get_akg_print_debug(ctx); | |||
| +/* ======================= AKG influence patch -- end ======================= */ | |||
| + | |||
| sc = isl_schedule_constraints_align_params(sc); | |||
| domain = isl_schedule_constraints_get_domain(sc); | |||
| @@ -7852,0 +8016,62 @@ | |||
| + | |||
| +/* ====================== AKG influence patch -- start ====================== */ | |||
| +isl_basic_set* (*isl_influence_set_coef)( | |||
| + isl_ctx *ctx, struct isl_sched_graph *graph, isl_basic_set *bset) = { 0 }; | |||
| +isl_basic_set* (*isl_influence_set_equal)( | |||
| + isl_ctx *ctx, struct isl_sched_graph *graph, isl_basic_set *bset) = { 0 }; | |||
| +int (*isl_influence_maxvar)(struct isl_sched_graph* graph) = { 0 }; | |||
| +int (*isl_influence_check_coincident)(struct isl_sched_graph *graph,isl_vec* sol) = { 0 }; | |||
| +struct isl_sched_graph* (*isl_influence_sol_list_free)(struct isl_sched_graph* graph) = { 0 } ; | |||
| +struct isl_sched_graph* (*isl_influence_sol_add_elem)(isl_vec* sol, struct isl_sched_graph* graph) = { 0 }; | |||
| +int (*isl_influence_sol_get_elem)(int sched, int pos, struct isl_sched_graph* graph) = { 0 }; | |||
| +int isl_sched_node_par_coef_offset(struct isl_sched_node *node) { | |||
| + return node_par_coef_offset(node); | |||
| +} | |||
| + | |||
| +int isl_sched_node_cst_coef_offset(struct isl_sched_node *node) { | |||
| + return node_cst_coef_offset(node); | |||
| +} | |||
| + | |||
| +__isl_give isl_map *isl_sched_node_extract_schedule(struct isl_sched_node *node) { | |||
| + return node_extract_schedule(node); | |||
| +} | |||
| + | |||
| +int isl_sched_node_get_nparam(const struct isl_sched_node *node) { | |||
| + return node->nparam; | |||
| +} | |||
| + | |||
| +int isl_sched_node_get_nvar(const struct isl_sched_node *node) { | |||
| + return node->nvar; | |||
| +} | |||
| + | |||
| +int isl_influence_int_eq(isl_vec* v, int pos1, int pos2) | |||
| +{ | |||
| + return isl_int_eq(v->el[pos1],v->el[pos2]); | |||
| +} | |||
| + | |||
| +isl_val* isl_influence_vec_get_elem(isl_vec* v, int pos) | |||
| +{ | |||
| + return isl_vec_get_element_val(v,pos); | |||
| +} | |||
| +isl_stat isl_sched_graph_init(struct isl_sched_graph *graph, | |||
| + __isl_keep isl_schedule_constraints *sc) { | |||
| + return graph_init(graph, sc); | |||
| +} | |||
| + | |||
| +void isl_sched_graph_free(isl_ctx *ctx, struct isl_sched_graph *graph) { | |||
| + graph_free(ctx, graph); | |||
| +} | |||
| + | |||
| +__isl_give isl_schedule_node *isl_schedule_node_compute_schedule(isl_schedule_node *node, | |||
| + struct isl_sched_graph *graph) { | |||
| + return compute_schedule(node, graph); | |||
| +} | |||
| + | |||
| +struct isl_sched_node* isl_sched_graph_get_node(struct isl_sched_graph *graph, int i) { | |||
| + return &graph->node[i]; | |||
| +} | |||
| + | |||
| +int isl_inf_vec_get_size(isl_vec* vec){ | |||
| + return(int) vec->size; | |||
| +} | |||
| +/* ======================= AKG influence patch -- end ======================= */ | |||