Browse Source

dtype bug modify

pull/70/head
xiexingxian 5 years ago
parent
commit
b73802e47a
2 changed files with 3 additions and 20 deletions
  1. +2
    -19
      src/poly/gpu_isl_emitter.cc
  2. +1
    -1
      src/poly/scop_info.cc

+ 2
- 19
src/poly/gpu_isl_emitter.cc View File

@@ -159,28 +159,11 @@ Stmt GpuIslEmitter::EmitReadCore(const isl::ast_node_user &node) {

Expr GpuIslEmitter::MakeLeftCallFromProvide(const Provide *op) {
std::string name = op->func->func_name();
Type type = GetTypeOfTensor(name);
Type type = info_.GetDtypeOf(name);
Expr dst = Call::make(type, name, op->args, Call::Halide, op->func, 0);
return dst;
}

Type GpuIslEmitter::GetTypeOfTensor(std::string name) {
auto binds = info_.user_config_.GetBind();

for (auto &i : binds) {
if (!i.first.defined()) continue;
if (!i.second.defined()) continue;

if (name == i.first->op->name) {
auto b = i.second;
return b->dtype;
}
}

CHECK(false) << "Can not find type of tensor " << name;
return Type();
}

Stmt GpuIslEmitter::EmitWrite(const isl::ast_node_user &node) {
auto node_id = node.get_annotation();
CHECK_GT(node_info_map_.count(node_id), 0);
@@ -624,7 +607,7 @@ Stmt GpuIslEmitter::EmitUserStmtCoreSync(const isl::ast_node_user &node) {
serial_number = MMA_FILL_STMT_SERIAL;
auto op = s.as<Provide>();
auto left_expr = MakeLeftCallFromProvide(op);
Type type = GetTypeOfTensor(op->func->func_name());
Type type = info_.GetDtypeOf(op->func->func_name());
auto *add = op->value.as<Add>();
CHECK(add) << "format error of bmm";
auto mul = akg::common::SplitCast(add->b, type).as<Mul>();


+ 1
- 1
src/poly/scop_info.cc View File

@@ -407,7 +407,7 @@ Type ScopInfo::GetDtypeOf(const std::string &tensor_name) const {
return i.second->dtype;
}
}
LOG(INFO) << " no such tensor in binds: " << tensor_name;
CHECK(false) << " no such tensor in binds: " << tensor_name;
return Int(32);
}



Loading…
Cancel
Save