Browse Source

!919 fix pytest failed when one case compile error

Merge pull request !919 from jjfeing/master
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
afe048474d
3 changed files with 23 additions and 1 deletions
  1. +7
    -0
      mindspore/_extends/parallel_compile/tbe_compiler/tbe_process.py
  2. +14
    -0
      mindspore/ccsrc/kernel/tbe/tbe_kernel_parallel_build.cc
  3. +2
    -1
      mindspore/ccsrc/kernel/tbe/tbe_kernel_parallel_build.h

+ 7
- 0
mindspore/_extends/parallel_compile/tbe_compiler/tbe_process.py View File

@@ -161,5 +161,12 @@ class CompilerPool:
ret = task_id, "Exception: Not support return type:" + str(ret_type)
return ret

def reset_task_info(self):
"""
reset task info when task compile error
"""
if self.__running_tasks:
self.__running_tasks.clear()


compile_pool = CompilerPool()

+ 14
- 0
mindspore/ccsrc/kernel/tbe/tbe_kernel_parallel_build.cc View File

@@ -40,6 +40,7 @@ constexpr auto kParallelCompileModule = "mindspore._extends.parallel_compile.tbe
constexpr auto kCreateParallelCompiler = "create_tbe_parallel_compiler";
constexpr auto kStartCompileOp = "start_compile_op";
constexpr auto kWaitOne = "wait_one";
constexpr auto kResetTaskInfo = "reset_task_info";

bool TbeOpParallelBuild(std::vector<AnfNodePtr> anf_nodes) {
auto build_manger = std::make_shared<ParallelBuildManager>();
@@ -96,6 +97,8 @@ bool TbeOpParallelBuild(std::vector<AnfNodePtr> anf_nodes) {

ParallelBuildManager::ParallelBuildManager() { tbe_parallel_compiler_ = TbePythonFuncs::TbeParallelCompiler(); }

ParallelBuildManager::~ParallelBuildManager() { ResetTaskInfo(); }

int32_t ParallelBuildManager::StartCompileOp(const nlohmann::json &kernel_json) const {
PyObject *pRes = nullptr;
PyObject *pArgs = PyTuple_New(1);
@@ -234,5 +237,16 @@ KernelModPtr ParallelBuildManager::GenKernelMod(const string &json_name, const s
kernel_mod_ptr->SetWorkspaceSizeList(kernel_json_info.workspaces);
return kernel_mod_ptr;
}

void ParallelBuildManager::ResetTaskInfo() {
if (task_map_.empty()) {
MS_LOG(INFO) << "All tasks are compiled success.";
return;
}
task_map_.clear();
same_op_list_.clear();
PyObject *pArg = Py_BuildValue("()");
(void)PyObject_CallMethod(tbe_parallel_compiler_, kResetTaskInfo, "O", pArg);
}
} // namespace kernel
} // namespace mindspore

+ 2
- 1
mindspore/ccsrc/kernel/tbe/tbe_kernel_parallel_build.h View File

@@ -40,7 +40,7 @@ struct KernelBuildTaskInfo {
class ParallelBuildManager {
public:
ParallelBuildManager();
~ParallelBuildManager() = default;
~ParallelBuildManager();
int32_t StartCompileOp(const nlohmann::json &kernel_json) const;
void SaveTaskInfo(int32_t task_id, const AnfNodePtr &anf_node, const std::string &json_name,
const std::vector<size_t> &input_size_list, const std::vector<size_t> &output_size_list,
@@ -58,6 +58,7 @@ class ParallelBuildManager {
KernelModPtr GenKernelMod(const string &json_name, const string &processor,
const std::vector<size_t> &input_size_list, const std::vector<size_t> &output_size_list,
const KernelPackPtr &kernel_pack) const;
void ResetTaskInfo();

private:
PyObject *tbe_parallel_compiler_;


Loading…
Cancel
Save