|
|
|
@@ -159,28 +159,11 @@ Stmt GpuIslEmitter::EmitReadCore(const isl::ast_node_user &node) { |
|
|
|
|
|
|
|
Expr GpuIslEmitter::MakeLeftCallFromProvide(const Provide *op) { |
|
|
|
std::string name = op->func->func_name(); |
|
|
|
Type type = GetTypeOfTensor(name); |
|
|
|
Type type = info_.GetDtypeOf(name); |
|
|
|
Expr dst = Call::make(type, name, op->args, Call::Halide, op->func, 0); |
|
|
|
return dst; |
|
|
|
} |
|
|
|
|
|
|
|
Type GpuIslEmitter::GetTypeOfTensor(std::string name) { |
|
|
|
auto binds = info_.user_config_.GetBind(); |
|
|
|
|
|
|
|
for (auto &i : binds) { |
|
|
|
if (!i.first.defined()) continue; |
|
|
|
if (!i.second.defined()) continue; |
|
|
|
|
|
|
|
if (name == i.first->op->name) { |
|
|
|
auto b = i.second; |
|
|
|
return b->dtype; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
CHECK(false) << "Can not find type of tensor " << name; |
|
|
|
return Type(); |
|
|
|
} |
|
|
|
|
|
|
|
Stmt GpuIslEmitter::EmitWrite(const isl::ast_node_user &node) { |
|
|
|
auto node_id = node.get_annotation(); |
|
|
|
CHECK_GT(node_info_map_.count(node_id), 0); |
|
|
|
@@ -624,7 +607,7 @@ Stmt GpuIslEmitter::EmitUserStmtCoreSync(const isl::ast_node_user &node) { |
|
|
|
serial_number = MMA_FILL_STMT_SERIAL; |
|
|
|
auto op = s.as<Provide>(); |
|
|
|
auto left_expr = MakeLeftCallFromProvide(op); |
|
|
|
Type type = GetTypeOfTensor(op->func->func_name()); |
|
|
|
Type type = info_.GetDtypeOf(op->func->func_name()); |
|
|
|
auto *add = op->value.as<Add>(); |
|
|
|
CHECK(add) << "format error of bmm"; |
|
|
|
auto mul = akg::common::SplitCast(add->b, type).as<Mul>(); |
|
|
|
|