From 3b329959366d881a8c49ceeec673af31a7fb5be3 Mon Sep 17 00:00:00 2001 From: r1chardf1d0 Date: Fri, 9 Apr 2021 18:25:31 +0800 Subject: [PATCH] enable stitch fusion on bert --- .../_extends/graph_kernel/model/graph_split.py | 13 ++++++++----- mindspore/_extends/graph_kernel/splitter.py | 4 ++-- .../kernel_compiler/akg/akg_kernel_json_decoder.cc | 2 ++ .../kernel_compiler/akg/gpu/akg_gpu_kernel_mod.cc | 6 ++++-- .../optimizer/graph_kernel/graph_kernel_helper.cc | 2 +- .../optimizer/graph_kernel/graph_kernel_splitter.cc | 1 + mindspore/ccsrc/utils/context/graph_kernel_flags.cc | 3 +++ mindspore/ccsrc/utils/context/graph_kernel_flags.h | 5 +++++ model_zoo/official/nlp/bert/run_pretrain.py | 2 +- 9 files changed, 27 insertions(+), 11 deletions(-) diff --git a/mindspore/_extends/graph_kernel/model/graph_split.py b/mindspore/_extends/graph_kernel/model/graph_split.py index ee54838eea..05654b376f 100644 --- a/mindspore/_extends/graph_kernel/model/graph_split.py +++ b/mindspore/_extends/graph_kernel/model/graph_split.py @@ -145,9 +145,10 @@ class GraphSplitByPattern: return True return False - def __init__(self, graph): + def __init__(self, graph, flags): self.graph = graph self.areas = [] + self.flags = flags area_map = {} _, outputs = graph.deduce_parameters() for op in graph.ops: @@ -450,6 +451,7 @@ class GraphSplitGpu(GraphSplitByPattern): fused.append(a) return fused, True + enable_stitch_fusion = self.flags.get("enable_stitch_fusion", False) changed = True while changed: changed = self.fuse(_reshape) @@ -461,7 +463,8 @@ class GraphSplitGpu(GraphSplitByPattern): changed = self.fuse(_broadcast_width) or changed if use_poly_reduce: changed = self.fuse(_reduce_output) or changed - changed = self.fuse(_reduce_stitch) or changed + if enable_stitch_fusion: + changed = self.fuse(_reduce_stitch) or changed self.fuse(_transpose) class GraphSplitAscend(GraphSplitByPattern): @@ -582,11 +585,11 @@ class GraphSplitAscend(GraphSplitByPattern): changed = self.fuse(_broadcast_depth) or changed changed = self.fuse(_broadcast_width) or changed -def split(graph, target): +def split(graph, target, flags): """Split graph""" result = None if target == "cuda": - result = GraphSplitGpu(graph).split() + result = GraphSplitGpu(graph, flags).split() else: - result = GraphSplitAscend(graph).split() + result = GraphSplitAscend(graph, flags).split() return result diff --git a/mindspore/_extends/graph_kernel/splitter.py b/mindspore/_extends/graph_kernel/splitter.py index ad612320c8..b2d2253cc7 100644 --- a/mindspore/_extends/graph_kernel/splitter.py +++ b/mindspore/_extends/graph_kernel/splitter.py @@ -30,7 +30,7 @@ def split_with_json(json_str, flags_str): flags = json.loads(flags_str) target = graph_desc['process'] comp = model.load_composite(graph_desc) - graph_split, graph_mode = model.split(comp.graph, target) + graph_split, graph_mode = model.split(comp.graph, target, flags) is_multi_graph = len(graph_split) > 1 graph_list = list(map(comp.dump, graph_split)) _reset_graphmode_for_inplaceassign(graph_list, graph_mode) @@ -61,7 +61,7 @@ def _dump_split_info(flags, graph_json, graph_desc, subgraphs, graph_mode): f.write("********** main graph: {} **********\n".format(graph_desc.name)) f.write("input json:\n{}\n".format(graph_json)) f.write("graph desc:\n{}\n".format(str(graph_desc))) - if len(subgraphs) > 1: + if len(subgraphs) > 1 or subgraphs[0].stitch_info is not None: for i, g in enumerate(subgraphs): f.write("-------- subgraph {}, mode: {} --------\n".format(i, graph_mode[i])) f.write("{}\n".format(str(g))) diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.cc index f470ca8a1a..7fc2d71535 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.cc @@ -457,10 +457,12 @@ void AkgKernelJsonDecoder::SetStitchAttr(const nlohmann::json &op_desc, const St std::string tensor_name = output_descs[0][kJsonKeyTensorName]; if (std::find(info.stitch_ops.begin(), info.stitch_ops.end(), tensor_name) != info.stitch_ops.end()) { AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("common"), node); + MS_LOG(INFO) << "Enable common stitch fusion by " << node->fullname_with_scope(); } if (std::find(info.stitch_atomic_ops.begin(), info.stitch_atomic_ops.end(), tensor_name) != info.stitch_atomic_ops.end()) { AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("atomic"), node); + MS_LOG(INFO) << "Enable atomic add stitch fusion by " << node->fullname_with_scope(); } } diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_mod.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_mod.cc index fb5ac024a0..0af8ee6c82 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_mod.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_mod.cc @@ -95,7 +95,9 @@ bool GpuKernelMod::Launch(const std::vector &inputs, const std::vect CUfunction kernel_addr; CUresult result = kernelmanager_->GetFunction(kernel_pack_, false, &thread_info, &kernel_addr); if (result != CUDA_SUCCESS) { - MS_LOG(ERROR) << "GetFunction failed."; + const char *msg = nullptr; + cuGetErrorName(result, &msg); + MS_LOG(ERROR) << "Get function failed, error: " << msg; return false; } std::vector runtimeargs; @@ -109,7 +111,7 @@ bool GpuKernelMod::Launch(const std::vector &inputs, const std::vect if (result != CUDA_SUCCESS) { const char *msg = nullptr; cuGetErrorName(result, &msg); - MS_LOG(ERROR) << "Launch Kernel failed. error: " << msg; + MS_LOG(ERROR) << "Launch kernel failed, error: " << msg; return false; } return true; diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc index ffabc7364d..f363a631bf 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc @@ -122,7 +122,7 @@ bool GenJson(const AnfNodePtrList &op_nodes, const AnfNodePtrList &inputs, const std::for_each(op_nodes.begin(), op_nodes.end(), [&fused_name](const AnfNodePtr &node) { (void)fused_name.append(AnfAlgo::GetCNodeName(node)).append("_"); }); - MS_LOG(INFO) << "Collect fusion json: " << fused_name; + MS_LOG(DEBUG) << "Collect fusion json: " << fused_name; return true; } diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc index a50aebc5f3..c34cb1c79b 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc @@ -728,6 +728,7 @@ class CostModelSplitSchemer : public Splitter::SplitSchemer { const auto &flags = context::GraphKernelFlags::GetInstance(); nlohmann::json flag_json; flag_json["dump_as_text"] = flags.dump_as_text; + flag_json["enable_stitch_fusion"] = flags.enable_stitch_fusion; return flag_json.dump(); } diff --git a/mindspore/ccsrc/utils/context/graph_kernel_flags.cc b/mindspore/ccsrc/utils/context/graph_kernel_flags.cc index 25e73002f1..b8a7ae72d2 100644 --- a/mindspore/ccsrc/utils/context/graph_kernel_flags.cc +++ b/mindspore/ccsrc/utils/context/graph_kernel_flags.cc @@ -159,6 +159,8 @@ void GraphKernelFlags::RegisterFlags(std::map *flag_ma reg.AddFlag("dump_as_text", &dump_as_text); + reg.AddFlag("enable_stitch_fusion", &enable_stitch_fusion); + reg.AddFlag("opt_level", &opt_level); reg.AddFlag("auto_tune", &auto_tune); reg.AddFlag("cluster_limit", &cluster_limit); @@ -176,6 +178,7 @@ void GraphKernelFlags::RegisterFlags(std::map *flag_ma std::string GraphKernelFlags::DumpAllFlags() const { nlohmann::json json; json["dump_as_text"] = dump_as_text; + json["enable_stitch_fusion"] = enable_stitch_fusion; json["opt_level"] = opt_level; json["auto_tune"] = auto_tune; diff --git a/mindspore/ccsrc/utils/context/graph_kernel_flags.h b/mindspore/ccsrc/utils/context/graph_kernel_flags.h index fd989e5bfb..4b1c037a51 100644 --- a/mindspore/ccsrc/utils/context/graph_kernel_flags.h +++ b/mindspore/ccsrc/utils/context/graph_kernel_flags.h @@ -54,6 +54,11 @@ class GraphKernelFlags { */ bool dump_as_text{false}; + /** + * Enable stitch fusion in graph kernel fusion strategy. + */ + bool enable_stitch_fusion{false}; + /** * Optimization level, value from 0 to 3. * 0: GraphKernel disabled diff --git a/model_zoo/official/nlp/bert/run_pretrain.py b/model_zoo/official/nlp/bert/run_pretrain.py index a4d7d17ff6..bb5264d127 100644 --- a/model_zoo/official/nlp/bert/run_pretrain.py +++ b/model_zoo/official/nlp/bert/run_pretrain.py @@ -135,7 +135,7 @@ def _auto_enable_graph_kernel(device_target, graph_kernel_mode): def _set_graph_kernel_context(device_target, enable_graph_kernel, is_auto_enable_graph_kernel): if enable_graph_kernel == "true" or is_auto_enable_graph_kernel: if device_target == 'GPU': - context.set_context(enable_graph_kernel=True) + context.set_context(enable_graph_kernel=True, graph_kernel_flags="--enable_stitch_fusion=true") else: logger.warning('Graph kernel only supports GPU back-end now, run with graph kernel off.')