You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

0001-RetBugFix-CustomRuntime_v06.patch 9.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. diff --git a/include/tvm/runtime/registry.h b/include/tvm/runtime/registry.h
  2. index d668984..3676a61 100644
  3. --- a/include/tvm/runtime/registry.h
  4. +++ b/include/tvm/runtime/registry.h
  5. @@ -319,6 +319,19 @@ class Registry {
  6. #define TVM_REGISTER_EXT_TYPE(T) \
  7. TVM_STR_CONCAT(TVM_TYPE_REG_VAR_DEF, __COUNTER__) = \
  8. ::tvm::runtime::ExtTypeVTable::Register_<T>()
  9. +/*
  10. + * Macro transfer TVM runtime API to custom runtime API
  11. + */
  12. +#define TVM_RT_FUNC_TRANS(OrigFuncStr) ({ \
  13. + const runtime::PackedFunc* trans_func = runtime::Registry::Get("codegen.GetTransRTFunc");\
  14. + const char* dst_func_str = nullptr; \
  15. + if( trans_func != nullptr){ \
  16. + dst_func_str = ((*trans_func)(OrigFuncStr)).ptr<const char>(); \
  17. + }else{ \
  18. + dst_func_str = OrigFuncStr; \
  19. + } \
  20. + dst_func_str; \
  21. +})
  22. } // namespace runtime
  23. } // namespace tvm
  24. diff --git a/src/codegen/llvm/codegen_cpu.cc b/src/codegen/llvm/codegen_cpu.cc
  25. index 0ba0c58..2850ad4 100644
  26. --- a/src/codegen/llvm/codegen_cpu.cc
  27. +++ b/src/codegen/llvm/codegen_cpu.cc
  28. @@ -99,26 +99,26 @@ void CodeGenCPU::Init(const std::string& module_name,
  29. // We will need this in environment for backward registration.
  30. f_tvm_register_system_symbol_ = llvm::Function::Create(
  31. llvm::FunctionType::get(t_int_, {t_char_->getPointerTo(), t_void_p_}, false),
  32. - llvm::Function::ExternalLinkage, "TVMBackendRegisterSystemLibSymbol", module_.get());
  33. + llvm::Function::ExternalLinkage, TVM_RT_FUNC_TRANS("TVMBackendRegisterSystemLibSymbol"), module_.get());
  34. } else {
  35. f_tvm_register_system_symbol_ = nullptr;
  36. }
  37. if (dynamic_lookup || system_lib) {
  38. f_tvm_func_call_ = llvm::Function::Create(
  39. ftype_tvm_func_call_,
  40. - llvm::Function::ExternalLinkage, "TVMFuncCall", module_.get());
  41. + llvm::Function::ExternalLinkage, TVM_RT_FUNC_TRANS("TVMFuncCall"), module_.get());
  42. f_tvm_get_func_from_env_ = llvm::Function::Create(
  43. ftype_tvm_get_func_from_env_,
  44. llvm::Function::ExternalLinkage, "TVMBackendGetFuncFromEnv", module_.get());
  45. f_tvm_api_set_last_error_ = llvm::Function::Create(
  46. ftype_tvm_api_set_last_error_,
  47. - llvm::Function::ExternalLinkage, "TVMAPISetLastError", module_.get());
  48. + llvm::Function::ExternalLinkage, TVM_RT_FUNC_TRANS("TVMAPISetLastError"), module_.get());
  49. f_tvm_parallel_launch_ = llvm::Function::Create(
  50. ftype_tvm_parallel_launch_,
  51. - llvm::Function::ExternalLinkage, "TVMBackendParallelLaunch", module_.get());
  52. + llvm::Function::ExternalLinkage, TVM_RT_FUNC_TRANS("TVMBackendParallelLaunch"), module_.get());
  53. f_tvm_parallel_barrier_ = llvm::Function::Create(
  54. ftype_tvm_parallel_barrier_,
  55. - llvm::Function::ExternalLinkage, "TVMBackendParallelBarrier", module_.get());
  56. + llvm::Function::ExternalLinkage, TVM_RT_FUNC_TRANS("TVMBackendParallelBarrier"), module_.get());
  57. }
  58. this->InitGlobalContext(dynamic_lookup);
  59. }
  60. @@ -461,11 +461,14 @@ void CodeGenCPU::CreateComputeScope(const AttrStmt* op) {
  61. }
  62. std::swap(function_, fcompute);
  63. std::swap(new_vmap, var_map_);
  64. + std::stack<bool*> br_ret_flg;
  65. + std::swap(br_ret_flg, br_ret_flg_);
  66. BasicBlock *compute_entry = BasicBlock::Create(*ctx_, "entry", function_);
  67. builder_->SetInsertPoint(compute_entry);
  68. this->VisitStmt(op->body);
  69. builder_->CreateRet(ConstInt32(0));
  70. // swap the var map back, now we are back on track.
  71. + std::swap(br_ret_flg, br_ret_flg_);
  72. std::swap(new_vmap, var_map_);
  73. std::swap(function_, fcompute);
  74. builder_->SetInsertPoint(compute_call_end);
  75. @@ -542,9 +545,12 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) {
  76. std::swap(function_, f);
  77. std::swap(parallel_env_, par_env);
  78. std::swap(var_map_, new_vmap);
  79. + std::stack<bool*> br_ret_flg;
  80. + std::swap(br_ret_flg, br_ret_flg_);
  81. this->VisitStmt(body);
  82. builder_->CreateRet(ConstInt32(0));
  83. // swap the var map back, now we are back on track.
  84. + std::swap(br_ret_flg, br_ret_flg_);
  85. std::swap(var_map_, new_vmap);
  86. std::swap(parallel_env_, par_env);
  87. std::swap(function_, f);
  88. @@ -794,7 +800,9 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const Call* op) {
  89. } else if (op->is_intrinsic(intrinsic::tvm_static_handle)) {
  90. return CreateStaticHandle();
  91. } else if (op->is_intrinsic(intrinsic::tvm_throw_last_error)) {
  92. - builder_->CreateRet(ConstInt32(-1));
  93. + llvm::Value* pRetCode = (op->args.size() == 0) ? ConstInt32(-1) : MakeValue(op->args[0]);
  94. + builder_->CreateRet(pRetCode);
  95. + CodeGenLLVM::SetRetTrFlg(true);
  96. return ConstInt32(-1);
  97. } else if (op->is_intrinsic(intrinsic::tvm_struct_get)) {
  98. CHECK_EQ(op->args.size(), 3U);
  99. diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc
  100. index 2cff88b..e26812d 100644
  101. --- a/src/codegen/llvm/codegen_llvm.cc
  102. +++ b/src/codegen/llvm/codegen_llvm.cc
  103. @@ -1110,23 +1110,37 @@ void CodeGenLLVM::VisitStmt_(const IfThenElse* op) {
  104. *ctx_, "if_then", function_);
  105. BasicBlock* end_block = BasicBlock::Create(
  106. *ctx_, "if_end", function_);
  107. + // define ret terminitor exist flg for this Stmt
  108. + bool cur_br_ret_flg = false;
  109. + br_ret_flg_.push(&cur_br_ret_flg);
  110. if (op->else_case.defined()) {
  111. BasicBlock* else_block = BasicBlock::Create(
  112. *ctx_, "if_else", function_);
  113. builder_->CreateCondBr(cond, then_block, else_block);
  114. builder_->SetInsertPoint(then_block);
  115. + cur_br_ret_flg = false;
  116. this->VisitStmt(op->then_case);
  117. builder_->CreateBr(end_block);
  118. + if ( !cur_br_ret_flg ){
  119. + builder_->CreateBr(end_block);
  120. + }
  121. builder_->SetInsertPoint(else_block);
  122. + cur_br_ret_flg = false;
  123. this->VisitStmt(op->else_case);
  124. - builder_->CreateBr(end_block);
  125. + if ( !cur_br_ret_flg ){
  126. + builder_->CreateBr(end_block);
  127. + }
  128. } else {
  129. builder_->CreateCondBr(cond, then_block, end_block, md_very_likely_branch_);
  130. builder_->SetInsertPoint(then_block);
  131. + cur_br_ret_flg = false;
  132. this->VisitStmt(op->then_case);
  133. - builder_->CreateBr(end_block);
  134. + if ( !cur_br_ret_flg ){
  135. + builder_->CreateBr(end_block);
  136. + }
  137. }
  138. builder_->SetInsertPoint(end_block);
  139. + br_ret_flg_.pop();
  140. }
  141. diff --git a/src/codegen/llvm/codegen_llvm.h b/src/codegen/llvm/codegen_llvm.h
  142. index b7d091b..6fba863 100644
  143. --- a/src/codegen/llvm/codegen_llvm.h
  144. +++ b/src/codegen/llvm/codegen_llvm.h
  145. @@ -143,6 +143,11 @@ class CodeGenLLVM :
  146. void VisitStmt_(const Block* op) override;
  147. void VisitStmt_(const Evaluate* op) override;
  148. void VisitStmt_(const ProducerConsumer* op) override;
  149. + //Set IfThelElse branch exist Return flg
  150. + void SetRetTrFlg(bool RetFlg){
  151. + if( !br_ret_flg_.empty() )
  152. + *(br_ret_flg_.top()) = RetFlg;
  153. + }
  154. protected:
  155. /*! \brief The storage information */
  156. @@ -304,6 +309,12 @@ class CodeGenLLVM :
  157. * initializes file and compilation_unit_ to TVM defaults.
  158. */
  159. static std::unique_ptr<DebugInfo> CreateDebugInfo(llvm::Module* module);
  160. +
  161. + /*
  162. + * IfThenElse stmt branch return flg store stack
  163. + * if a branch already return, can't add br terminator again
  164. + */
  165. + std::stack<bool*> br_ret_flg_;
  166. };
  167. } // namespace codegen
  168. } // namespace tvm
  169. diff --git a/src/pass/lower_tvm_builtin.cc b/src/pass/lower_tvm_builtin.cc
  170. index e73956c..3a7b46c 100644
  171. --- a/src/pass/lower_tvm_builtin.cc
  172. +++ b/src/pass/lower_tvm_builtin.cc
  173. @@ -104,7 +104,7 @@ class BuiltinLower : public IRMutator {
  174. CHECK(device_type_.defined()) << "Unknown device type in current IR";
  175. CHECK(device_id_.defined()) << "Unknown device id in current IR";
  176. Stmt throw_last_error = Evaluate::make(Call::make(Int(32),
  177. - intrinsic::tvm_throw_last_error, {},
  178. + intrinsic::tvm_throw_last_error, {(Int(32), 1001)},
  179. Call::Intrinsic));
  180. Stmt body = Block::make(
  181. @@ -117,7 +117,7 @@ class BuiltinLower : public IRMutator {
  182. Stmt alloca = LetStmt::make(
  183. op->buffer_var,
  184. Call::make(op->buffer_var.type(),
  185. - "TVMBackendAllocWorkspace",
  186. + TVM_RT_FUNC_TRANS("TVMBackendAllocWorkspace"),
  187. {cast(Int(32), device_type_),
  188. cast(Int(32), device_id_),
  189. cast(UInt(64), total_bytes),
  190. @@ -127,7 +127,7 @@ class BuiltinLower : public IRMutator {
  191. body);
  192. Expr free_op = Call::make(Int(32),
  193. - "TVMBackendFreeWorkspace",
  194. + TVM_RT_FUNC_TRANS("TVMBackendFreeWorkspace"),
  195. {cast(Int(32), device_type_),
  196. cast(Int(32), device_id_),
  197. op->buffer_var},