|
- diff --git a/include/tvm/runtime/registry.h b/include/tvm/runtime/registry.h
- index d668984..3676a61 100644
- --- a/include/tvm/runtime/registry.h
- +++ b/include/tvm/runtime/registry.h
- @@ -319,6 +319,19 @@ class Registry {
- #define TVM_REGISTER_EXT_TYPE(T) \
- TVM_STR_CONCAT(TVM_TYPE_REG_VAR_DEF, __COUNTER__) = \
- ::tvm::runtime::ExtTypeVTable::Register_<T>()
- +/*
- + * Macro transfer TVM runtime API to custom runtime API
- + */
- +#define TVM_RT_FUNC_TRANS(OrigFuncStr) ({ \
- + const runtime::PackedFunc* trans_func = runtime::Registry::Get("codegen.GetTransRTFunc");\
- + const char* dst_func_str = nullptr; \
- + if( trans_func != nullptr){ \
- + dst_func_str = ((*trans_func)(OrigFuncStr)).ptr<const char>(); \
- + }else{ \
- + dst_func_str = OrigFuncStr; \
- + } \
- + dst_func_str; \
- +})
-
- } // namespace runtime
- } // namespace tvm
- diff --git a/src/codegen/llvm/codegen_cpu.cc b/src/codegen/llvm/codegen_cpu.cc
- index 0ba0c58..2850ad4 100644
- --- a/src/codegen/llvm/codegen_cpu.cc
- +++ b/src/codegen/llvm/codegen_cpu.cc
- @@ -99,26 +99,26 @@ void CodeGenCPU::Init(const std::string& module_name,
- // We will need this in environment for backward registration.
- f_tvm_register_system_symbol_ = llvm::Function::Create(
- llvm::FunctionType::get(t_int_, {t_char_->getPointerTo(), t_void_p_}, false),
- - llvm::Function::ExternalLinkage, "TVMBackendRegisterSystemLibSymbol", module_.get());
- + llvm::Function::ExternalLinkage, TVM_RT_FUNC_TRANS("TVMBackendRegisterSystemLibSymbol"), module_.get());
- } else {
- f_tvm_register_system_symbol_ = nullptr;
- }
- if (dynamic_lookup || system_lib) {
- f_tvm_func_call_ = llvm::Function::Create(
- ftype_tvm_func_call_,
- - llvm::Function::ExternalLinkage, "TVMFuncCall", module_.get());
- + llvm::Function::ExternalLinkage, TVM_RT_FUNC_TRANS("TVMFuncCall"), module_.get());
- f_tvm_get_func_from_env_ = llvm::Function::Create(
- ftype_tvm_get_func_from_env_,
- llvm::Function::ExternalLinkage, "TVMBackendGetFuncFromEnv", module_.get());
- f_tvm_api_set_last_error_ = llvm::Function::Create(
- ftype_tvm_api_set_last_error_,
- - llvm::Function::ExternalLinkage, "TVMAPISetLastError", module_.get());
- + llvm::Function::ExternalLinkage, TVM_RT_FUNC_TRANS("TVMAPISetLastError"), module_.get());
- f_tvm_parallel_launch_ = llvm::Function::Create(
- ftype_tvm_parallel_launch_,
- - llvm::Function::ExternalLinkage, "TVMBackendParallelLaunch", module_.get());
- + llvm::Function::ExternalLinkage, TVM_RT_FUNC_TRANS("TVMBackendParallelLaunch"), module_.get());
- f_tvm_parallel_barrier_ = llvm::Function::Create(
- ftype_tvm_parallel_barrier_,
- - llvm::Function::ExternalLinkage, "TVMBackendParallelBarrier", module_.get());
- + llvm::Function::ExternalLinkage, TVM_RT_FUNC_TRANS("TVMBackendParallelBarrier"), module_.get());
- }
- this->InitGlobalContext(dynamic_lookup);
- }
- @@ -461,11 +461,14 @@ void CodeGenCPU::CreateComputeScope(const AttrStmt* op) {
- }
- std::swap(function_, fcompute);
- std::swap(new_vmap, var_map_);
- + std::stack<bool*> br_ret_flg;
- + std::swap(br_ret_flg, br_ret_flg_);
- BasicBlock *compute_entry = BasicBlock::Create(*ctx_, "entry", function_);
- builder_->SetInsertPoint(compute_entry);
- this->VisitStmt(op->body);
- builder_->CreateRet(ConstInt32(0));
- // swap the var map back, now we are back on track.
- + std::swap(br_ret_flg, br_ret_flg_);
- std::swap(new_vmap, var_map_);
- std::swap(function_, fcompute);
- builder_->SetInsertPoint(compute_call_end);
- @@ -542,9 +545,12 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) {
- std::swap(function_, f);
- std::swap(parallel_env_, par_env);
- std::swap(var_map_, new_vmap);
- + std::stack<bool*> br_ret_flg;
- + std::swap(br_ret_flg, br_ret_flg_);
- this->VisitStmt(body);
- builder_->CreateRet(ConstInt32(0));
- // swap the var map back, now we are back on track.
- + std::swap(br_ret_flg, br_ret_flg_);
- std::swap(var_map_, new_vmap);
- std::swap(parallel_env_, par_env);
- std::swap(function_, f);
- @@ -794,7 +800,9 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const Call* op) {
- } else if (op->is_intrinsic(intrinsic::tvm_static_handle)) {
- return CreateStaticHandle();
- } else if (op->is_intrinsic(intrinsic::tvm_throw_last_error)) {
- - builder_->CreateRet(ConstInt32(-1));
- + llvm::Value* pRetCode = (op->args.size() == 0) ? ConstInt32(-1) : MakeValue(op->args[0]);
- + builder_->CreateRet(pRetCode);
- + CodeGenLLVM::SetRetTrFlg(true);
- return ConstInt32(-1);
- } else if (op->is_intrinsic(intrinsic::tvm_struct_get)) {
- CHECK_EQ(op->args.size(), 3U);
- diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc
- index 2cff88b..e26812d 100644
- --- a/src/codegen/llvm/codegen_llvm.cc
- +++ b/src/codegen/llvm/codegen_llvm.cc
- @@ -1110,23 +1110,37 @@ void CodeGenLLVM::VisitStmt_(const IfThenElse* op) {
- *ctx_, "if_then", function_);
- BasicBlock* end_block = BasicBlock::Create(
- *ctx_, "if_end", function_);
- + // define ret terminitor exist flg for this Stmt
- + bool cur_br_ret_flg = false;
- + br_ret_flg_.push(&cur_br_ret_flg);
- if (op->else_case.defined()) {
- BasicBlock* else_block = BasicBlock::Create(
- *ctx_, "if_else", function_);
- builder_->CreateCondBr(cond, then_block, else_block);
- builder_->SetInsertPoint(then_block);
- + cur_br_ret_flg = false;
- this->VisitStmt(op->then_case);
- builder_->CreateBr(end_block);
- + if ( !cur_br_ret_flg ){
- + builder_->CreateBr(end_block);
- + }
- builder_->SetInsertPoint(else_block);
- + cur_br_ret_flg = false;
- this->VisitStmt(op->else_case);
- - builder_->CreateBr(end_block);
- + if ( !cur_br_ret_flg ){
- + builder_->CreateBr(end_block);
- + }
- } else {
- builder_->CreateCondBr(cond, then_block, end_block, md_very_likely_branch_);
- builder_->SetInsertPoint(then_block);
- + cur_br_ret_flg = false;
- this->VisitStmt(op->then_case);
- - builder_->CreateBr(end_block);
- + if ( !cur_br_ret_flg ){
- + builder_->CreateBr(end_block);
- + }
- }
- builder_->SetInsertPoint(end_block);
- + br_ret_flg_.pop();
- }
-
-
- diff --git a/src/codegen/llvm/codegen_llvm.h b/src/codegen/llvm/codegen_llvm.h
- index b7d091b..6fba863 100644
- --- a/src/codegen/llvm/codegen_llvm.h
- +++ b/src/codegen/llvm/codegen_llvm.h
- @@ -143,6 +143,11 @@ class CodeGenLLVM :
- void VisitStmt_(const Block* op) override;
- void VisitStmt_(const Evaluate* op) override;
- void VisitStmt_(const ProducerConsumer* op) override;
- + //Set IfThelElse branch exist Return flg
- + void SetRetTrFlg(bool RetFlg){
- + if( !br_ret_flg_.empty() )
- + *(br_ret_flg_.top()) = RetFlg;
- + }
-
- protected:
- /*! \brief The storage information */
- @@ -304,6 +309,12 @@ class CodeGenLLVM :
- * initializes file and compilation_unit_ to TVM defaults.
- */
- static std::unique_ptr<DebugInfo> CreateDebugInfo(llvm::Module* module);
- +
- + /*
- + * IfThenElse stmt branch return flg store stack
- + * if a branch already return, can't add br terminator again
- + */
- + std::stack<bool*> br_ret_flg_;
- };
- } // namespace codegen
- } // namespace tvm
- diff --git a/src/pass/lower_tvm_builtin.cc b/src/pass/lower_tvm_builtin.cc
- index e73956c..3a7b46c 100644
- --- a/src/pass/lower_tvm_builtin.cc
- +++ b/src/pass/lower_tvm_builtin.cc
- @@ -104,7 +104,7 @@ class BuiltinLower : public IRMutator {
- CHECK(device_type_.defined()) << "Unknown device type in current IR";
- CHECK(device_id_.defined()) << "Unknown device id in current IR";
- Stmt throw_last_error = Evaluate::make(Call::make(Int(32),
- - intrinsic::tvm_throw_last_error, {},
- + intrinsic::tvm_throw_last_error, {(Int(32), 1001)},
- Call::Intrinsic));
-
- Stmt body = Block::make(
- @@ -117,7 +117,7 @@ class BuiltinLower : public IRMutator {
- Stmt alloca = LetStmt::make(
- op->buffer_var,
- Call::make(op->buffer_var.type(),
- - "TVMBackendAllocWorkspace",
- + TVM_RT_FUNC_TRANS("TVMBackendAllocWorkspace"),
- {cast(Int(32), device_type_),
- cast(Int(32), device_id_),
- cast(UInt(64), total_bytes),
- @@ -127,7 +127,7 @@ class BuiltinLower : public IRMutator {
- body);
-
- Expr free_op = Call::make(Int(32),
- - "TVMBackendFreeWorkspace",
- + TVM_RT_FUNC_TRANS("TVMBackendFreeWorkspace"),
- {cast(Int(32), device_type_),
- cast(Int(32), device_id_),
- op->buffer_var},
|