From: @tronzhang Reviewed-by: @gaoxiong1,@dylangeng Signed-off-by: @dylangengpull/15182/MERGE
| @@ -21,6 +21,7 @@ | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "utils/ms_context.h" | #include "utils/ms_context.h" | ||||
| #include "utils/context/graph_kernel_flags.h" | |||||
| #include "backend/optimizer/graph_kernel/add_atomic_clean.h" | #include "backend/optimizer/graph_kernel/add_atomic_clean.h" | ||||
| #include "backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.h" | #include "backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.h" | ||||
| #include "backend/optimizer/graph_kernel/arithmetic_simplify.h" | #include "backend/optimizer/graph_kernel/arithmetic_simplify.h" | ||||
| @@ -138,7 +139,7 @@ PassManagerPtr GraphKernelOptimizer::HighLevelOpt2() { | |||||
| PassManagerPtr GraphKernelOptimizer::Combine() { | PassManagerPtr GraphKernelOptimizer::Combine() { | ||||
| auto pm = std::make_shared<PassManager>("graphkernel_stage6_combine"); | auto pm = std::make_shared<PassManager>("graphkernel_stage6_combine"); | ||||
| // Enable parallel fusion | // Enable parallel fusion | ||||
| if (is_gpu) { | |||||
| if (is_gpu && context::GraphKernelFlags::GetInstance().enable_parallel_fusion) { | |||||
| // Do parallel fusion for gpu device | // Do parallel fusion for gpu device | ||||
| pm->AddPass(std::make_shared<ParallelOpFusion>(kGPUDevice, ParallelConfig(7))); | pm->AddPass(std::make_shared<ParallelOpFusion>(kGPUDevice, ParallelConfig(7))); | ||||
| } | } | ||||
| @@ -157,14 +157,17 @@ void GraphKernelFlags::Refresh() { | |||||
| void GraphKernelFlags::RegisterFlags(std::map<std::string, std::string> *flag_map) { | void GraphKernelFlags::RegisterFlags(std::map<std::string, std::string> *flag_map) { | ||||
| FlagRegister reg(flag_map); | FlagRegister reg(flag_map); | ||||
| // Boolean flags | |||||
| reg.AddFlag("dump_as_text", &dump_as_text); | reg.AddFlag("dump_as_text", &dump_as_text); | ||||
| reg.AddFlag("enable_stitch_fusion", &enable_stitch_fusion); | reg.AddFlag("enable_stitch_fusion", &enable_stitch_fusion); | ||||
| reg.AddFlag("enable_parallel_fusion", &enable_parallel_fusion); | |||||
| // Integer flags | |||||
| reg.AddFlag("opt_level", &opt_level); | reg.AddFlag("opt_level", &opt_level); | ||||
| reg.AddFlag("auto_tune", &auto_tune); | reg.AddFlag("auto_tune", &auto_tune); | ||||
| reg.AddFlag("cluster_limit", &cluster_limit); | reg.AddFlag("cluster_limit", &cluster_limit); | ||||
| // String list flags | |||||
| reg.AddFlag("enable_expand_ops", &enable_expand_ops); | reg.AddFlag("enable_expand_ops", &enable_expand_ops); | ||||
| reg.AddFlag("enable_expand_ops_only", &enable_expand_ops_only); | reg.AddFlag("enable_expand_ops_only", &enable_expand_ops_only); | ||||
| reg.AddFlag("disable_expand_ops", &disable_expand_ops); | reg.AddFlag("disable_expand_ops", &disable_expand_ops); | ||||
| @@ -177,8 +180,10 @@ void GraphKernelFlags::RegisterFlags(std::map<std::string, std::string> *flag_ma | |||||
| std::string GraphKernelFlags::DumpAllFlags() const { | std::string GraphKernelFlags::DumpAllFlags() const { | ||||
| nlohmann::json json; | nlohmann::json json; | ||||
| json["dump_as_text"] = dump_as_text; | json["dump_as_text"] = dump_as_text; | ||||
| json["enable_stitch_fusion"] = enable_stitch_fusion; | json["enable_stitch_fusion"] = enable_stitch_fusion; | ||||
| json["enable_parallel_fusion"] = enable_parallel_fusion; | |||||
| json["opt_level"] = opt_level; | json["opt_level"] = opt_level; | ||||
| json["auto_tune"] = auto_tune; | json["auto_tune"] = auto_tune; | ||||
| @@ -59,6 +59,11 @@ class GraphKernelFlags { | |||||
| */ | */ | ||||
| bool enable_stitch_fusion{false}; | bool enable_stitch_fusion{false}; | ||||
| /** | |||||
| * Enable parallel fusion in graph kernel fusion strategy. | |||||
| */ | |||||
| bool enable_parallel_fusion{false}; | |||||
| /** | /** | ||||
| * Optimization level, value from 0 to 3. | * Optimization level, value from 0 to 3. | ||||
| * 0: GraphKernel disabled | * 0: GraphKernel disabled | ||||
| @@ -135,7 +135,8 @@ def _auto_enable_graph_kernel(device_target, graph_kernel_mode): | |||||
| def _set_graph_kernel_context(device_target, enable_graph_kernel, is_auto_enable_graph_kernel): | def _set_graph_kernel_context(device_target, enable_graph_kernel, is_auto_enable_graph_kernel): | ||||
| if enable_graph_kernel == "true" or is_auto_enable_graph_kernel: | if enable_graph_kernel == "true" or is_auto_enable_graph_kernel: | ||||
| if device_target == 'GPU': | if device_target == 'GPU': | ||||
| context.set_context(enable_graph_kernel=True, graph_kernel_flags="--enable_stitch_fusion=true") | |||||
| context.set_context(enable_graph_kernel=True, | |||||
| graph_kernel_flags="--enable_stitch_fusion=true --enable_parallel_fusion=true") | |||||
| else: | else: | ||||
| logger.warning('Graph kernel only supports GPU back-end now, run with graph kernel off.') | logger.warning('Graph kernel only supports GPU back-end now, run with graph kernel off.') | ||||