Merge pull request !2222 from wangcong/mastertags/v0.5.0-beta
| @@ -28,8 +28,8 @@ build_in_impl_path = get_build_in_impl_path() | |||||
| # op function list | # op function list | ||||
| op_build = "compile" | op_build = "compile" | ||||
| op_pre_build = "pre_build" | op_pre_build = "pre_build" | ||||
| fusion_type_map = {'Convolution': 0, 'ElemWise': 1, 'CommReduce': 2, | |||||
| 'Segment': 3, 'Opaque': 4} | |||||
| fusion_pattern_start_flag = "fusion_pattern_start" | |||||
| fusion_pattern_end_flag = "fusion_pattern_end" | |||||
| def _initialize(impl_path): | def _initialize(impl_path): | ||||
| """Initialize""" | """Initialize""" | ||||
| @@ -43,7 +43,6 @@ def _initialize(impl_path): | |||||
| sys.path.insert(0, op_module_name) | sys.path.insert(0, op_module_name) | ||||
| def build_op(build_type, json_str): | def build_op(build_type, json_str): | ||||
| """ | """ | ||||
| call op functions with function name and input args json_str | call op functions with function name and input args json_str | ||||
| @@ -169,7 +168,5 @@ def compile_with_json(json_str): | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| in_args = sys.stdin.readline() | in_args = sys.stdin.readline() | ||||
| result = compile_with_json(in_args) | result = compile_with_json(in_args) | ||||
| if result in fusion_type_map: | |||||
| exit(fusion_type_map[result]) | |||||
| else: | |||||
| exit(100) | |||||
| sys.stdout.write(fusion_pattern_start_flag + str(result) + fusion_pattern_end_flag) | |||||
| sys.stdout.flush() | |||||
| @@ -88,10 +88,10 @@ def run_compiler(op_json): | |||||
| try: | try: | ||||
| tbe_compiler = os.path.join(os.path.split(os.path.realpath(__file__))[0], "compiler.py") | tbe_compiler = os.path.join(os.path.split(os.path.realpath(__file__))[0], "compiler.py") | ||||
| completed_object = subprocess.run([sys.executable, tbe_compiler], input=op_json, timeout=300, | completed_object = subprocess.run([sys.executable, tbe_compiler], input=op_json, timeout=300, | ||||
| text=True, capture_output=True, check=False) | |||||
| text=True, capture_output=True, check=True) | |||||
| if completed_object: | if completed_object: | ||||
| code = completed_object.returncode | |||||
| return "Success", str(code) | |||||
| out = completed_object.stdout | |||||
| return "Success", out | |||||
| except subprocess.TimeoutExpired: | except subprocess.TimeoutExpired: | ||||
| tb = traceback.format_exc() | tb = traceback.format_exc() | ||||
| return "TBEException", "PreCompileTimeOut: " + tb + "\ninput_args: " + op_json | return "TBEException", "PreCompileTimeOut: " + tb + "\ninput_args: " + op_json | ||||
| @@ -73,7 +73,8 @@ static bool KernelPreBuildParallelCompile(const mindspore::session::KernelGraph | |||||
| KernelType kernel_type = AnfAlgo::GetKernelType(anf_node); | KernelType kernel_type = AnfAlgo::GetKernelType(anf_node); | ||||
| switch (kernel_type) { | switch (kernel_type) { | ||||
| case KernelType::TBE_KERNEL: { | case KernelType::TBE_KERNEL: { | ||||
| if (AnfAlgo::GetKernelMod(anf_node) == nullptr) { | |||||
| if (AnfAlgo::GetKernelMod(anf_node) == nullptr && | |||||
| AnfAlgo::GetFusionType(anf_node) == kernel::FusionType::DYNAMIC) { | |||||
| tbe_nodes.push_back(anf_node); | tbe_nodes.push_back(anf_node); | ||||
| } | } | ||||
| break; | break; | ||||
| @@ -45,6 +45,7 @@ enum FusionType { | |||||
| COMMREDUCE, | COMMREDUCE, | ||||
| SEGMENT, | SEGMENT, | ||||
| OPAQUE, | OPAQUE, | ||||
| DYNAMIC, | |||||
| UNKNOWN_FUSION_TYPE = -1, | UNKNOWN_FUSION_TYPE = -1, | ||||
| }; | }; | ||||
| enum OpPattern { | enum OpPattern { | ||||
| @@ -63,7 +63,7 @@ const std::unordered_map<std::string, size_t> type_nbyte_maps = { | |||||
| const std::unordered_map<std::string, FusionType> fusion_type_maps = { | const std::unordered_map<std::string, FusionType> fusion_type_maps = { | ||||
| {"CONVLUTION", FusionType::CONVLUTION}, {"ELEMWISE", FusionType::ELEMWISE}, {"COMMREDUCE", FusionType::COMMREDUCE}, | {"CONVLUTION", FusionType::CONVLUTION}, {"ELEMWISE", FusionType::ELEMWISE}, {"COMMREDUCE", FusionType::COMMREDUCE}, | ||||
| {"SEGMENT", FusionType::SEGMENT}, {"OPAQUE", FusionType::OPAQUE}, | |||||
| {"SEGMENT", FusionType::SEGMENT}, {"DYNAMIC", FusionType::DYNAMIC}, {"OPAQUE", FusionType::OPAQUE}, | |||||
| }; | }; | ||||
| TypeId DtypeToTypeId(const std::string &dtypes) { | TypeId DtypeToTypeId(const std::string &dtypes) { | ||||
| @@ -205,6 +205,20 @@ void ParallelBuildManager::PreTaskFinishProcess(int32_t task_id, const std::stri | |||||
| if (task_iter == pre_task_map_.end()) { | if (task_iter == pre_task_map_.end()) { | ||||
| MS_EXCEPTION(ArgumentError) << "can find pre task_id:" << task_id; | MS_EXCEPTION(ArgumentError) << "can find pre task_id:" << task_id; | ||||
| } | } | ||||
| auto node = task_iter->second; | |||||
| auto builder = | |||||
| std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node)); | |||||
| std::string start_flag = "fusion_pattern_start"; | |||||
| std::string end_flag = "fusion_pattern_end"; | |||||
| int start = pre_build_result.find(start_flag); | |||||
| int end = pre_build_result.find(end_flag); | |||||
| if (start != -1 && end != -1) { | |||||
| std::string result = pre_build_result.substr(start + start_flag.size(), end - start - start_flag.size()); | |||||
| transform(result.begin(), result.end(), result.begin(), ::toupper); | |||||
| FusionType fusion_type = tbe::GetFusionType(result); | |||||
| builder->SetFusionType(fusion_type); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get()); | |||||
| } | |||||
| (void)pre_task_map_.erase(task_iter); | (void)pre_task_map_.erase(task_iter); | ||||
| } | } | ||||
| @@ -535,6 +535,7 @@ void AscendSession::InitRuntimeResource() { | |||||
| } | } | ||||
| void AscendSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const { | void AscendSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const { | ||||
| device::ascend::KernelPreBuild(kernel_graph.get()); | |||||
| MS_LOG(INFO) << "HardwareOptimize start!"; | MS_LOG(INFO) << "HardwareOptimize start!"; | ||||
| opt::AscendBackendOptimization(kernel_graph); | opt::AscendBackendOptimization(kernel_graph); | ||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | MS_EXCEPTION_IF_NULL(kernel_graph); | ||||
| @@ -17,7 +17,7 @@ | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | ||||
| matmul_op_info = TBERegOp("MatMul") \ | matmul_op_info = TBERegOp("MatMul") \ | ||||
| .fusion_type("OPAQUE") \ | |||||
| .fusion_type("DYNAMIC") \ | |||||
| .async_flag(False) \ | .async_flag(False) \ | ||||
| .binfile_name("matmul.so") \ | .binfile_name("matmul.so") \ | ||||
| .compute_cost(10) \ | .compute_cost(10) \ | ||||