Browse Source

!14893 enable stitch fusion on bert

From: @r1chardf1d0
Reviewed-by: @gaoxiong1,@ckey_dou
Signed-off-by: @ckey_dou
pull/14893/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
cd002cb7f7
9 changed files with 27 additions and 11 deletions
  1. +8
    -5
      mindspore/_extends/graph_kernel/model/graph_split.py
  2. +2
    -2
      mindspore/_extends/graph_kernel/splitter.py
  3. +2
    -0
      mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.cc
  4. +4
    -2
      mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_mod.cc
  5. +1
    -1
      mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc
  6. +1
    -0
      mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc
  7. +3
    -0
      mindspore/ccsrc/utils/context/graph_kernel_flags.cc
  8. +5
    -0
      mindspore/ccsrc/utils/context/graph_kernel_flags.h
  9. +1
    -1
      model_zoo/official/nlp/bert/run_pretrain.py

+ 8
- 5
mindspore/_extends/graph_kernel/model/graph_split.py View File

@@ -145,9 +145,10 @@ class GraphSplitByPattern:
return True
return False

def __init__(self, graph):
def __init__(self, graph, flags):
self.graph = graph
self.areas = []
self.flags = flags
area_map = {}
_, outputs = graph.deduce_parameters()
for op in graph.ops:
@@ -450,6 +451,7 @@ class GraphSplitGpu(GraphSplitByPattern):
fused.append(a)
return fused, True

enable_stitch_fusion = self.flags.get("enable_stitch_fusion", False)
changed = True
while changed:
changed = self.fuse(_reshape)
@@ -461,7 +463,8 @@ class GraphSplitGpu(GraphSplitByPattern):
changed = self.fuse(_broadcast_width) or changed
if use_poly_reduce:
changed = self.fuse(_reduce_output) or changed
changed = self.fuse(_reduce_stitch) or changed
if enable_stitch_fusion:
changed = self.fuse(_reduce_stitch) or changed
self.fuse(_transpose)

class GraphSplitAscend(GraphSplitByPattern):
@@ -582,11 +585,11 @@ class GraphSplitAscend(GraphSplitByPattern):
changed = self.fuse(_broadcast_depth) or changed
changed = self.fuse(_broadcast_width) or changed

def split(graph, target):
def split(graph, target, flags):
"""Split graph"""
result = None
if target == "cuda":
result = GraphSplitGpu(graph).split()
result = GraphSplitGpu(graph, flags).split()
else:
result = GraphSplitAscend(graph).split()
result = GraphSplitAscend(graph, flags).split()
return result

+ 2
- 2
mindspore/_extends/graph_kernel/splitter.py View File

@@ -30,7 +30,7 @@ def split_with_json(json_str, flags_str):
flags = json.loads(flags_str)
target = graph_desc['process']
comp = model.load_composite(graph_desc)
graph_split, graph_mode = model.split(comp.graph, target)
graph_split, graph_mode = model.split(comp.graph, target, flags)
is_multi_graph = len(graph_split) > 1
graph_list = list(map(comp.dump, graph_split))
_reset_graphmode_for_inplaceassign(graph_list, graph_mode)
@@ -61,7 +61,7 @@ def _dump_split_info(flags, graph_json, graph_desc, subgraphs, graph_mode):
f.write("********** main graph: {} **********\n".format(graph_desc.name))
f.write("input json:\n{}\n".format(graph_json))
f.write("graph desc:\n{}\n".format(str(graph_desc)))
if len(subgraphs) > 1:
if len(subgraphs) > 1 or subgraphs[0].stitch_info is not None:
for i, g in enumerate(subgraphs):
f.write("-------- subgraph {}, mode: {} --------\n".format(i, graph_mode[i]))
f.write("{}\n".format(str(g)))


+ 2
- 0
mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_json_decoder.cc View File

@@ -457,10 +457,12 @@ void AkgKernelJsonDecoder::SetStitchAttr(const nlohmann::json &op_desc, const St
std::string tensor_name = output_descs[0][kJsonKeyTensorName];
if (std::find(info.stitch_ops.begin(), info.stitch_ops.end(), tensor_name) != info.stitch_ops.end()) {
AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("common"), node);
MS_LOG(INFO) << "Enable common stitch fusion by " << node->fullname_with_scope();
}
if (std::find(info.stitch_atomic_ops.begin(), info.stitch_atomic_ops.end(), tensor_name) !=
info.stitch_atomic_ops.end()) {
AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("atomic"), node);
MS_LOG(INFO) << "Enable atomic add stitch fusion by " << node->fullname_with_scope();
}
}



+ 4
- 2
mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_mod.cc View File

@@ -95,7 +95,9 @@ bool GpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vect
CUfunction kernel_addr;
CUresult result = kernelmanager_->GetFunction(kernel_pack_, false, &thread_info, &kernel_addr);
if (result != CUDA_SUCCESS) {
MS_LOG(ERROR) << "GetFunction failed.";
const char *msg = nullptr;
cuGetErrorName(result, &msg);
MS_LOG(ERROR) << "Get function failed, error: " << msg;
return false;
}
std::vector<void *> runtimeargs;
@@ -109,7 +111,7 @@ bool GpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vect
if (result != CUDA_SUCCESS) {
const char *msg = nullptr;
cuGetErrorName(result, &msg);
MS_LOG(ERROR) << "Launch Kernel failed. error: " << msg;
MS_LOG(ERROR) << "Launch kernel failed, error: " << msg;
return false;
}
return true;


+ 1
- 1
mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc View File

@@ -122,7 +122,7 @@ bool GenJson(const AnfNodePtrList &op_nodes, const AnfNodePtrList &inputs, const
std::for_each(op_nodes.begin(), op_nodes.end(), [&fused_name](const AnfNodePtr &node) {
(void)fused_name.append(AnfAlgo::GetCNodeName(node)).append("_");
});
MS_LOG(INFO) << "Collect fusion json: " << fused_name;
MS_LOG(DEBUG) << "Collect fusion json: " << fused_name;
return true;
}



+ 1
- 0
mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc View File

@@ -728,6 +728,7 @@ class CostModelSplitSchemer : public Splitter::SplitSchemer {
const auto &flags = context::GraphKernelFlags::GetInstance();
nlohmann::json flag_json;
flag_json["dump_as_text"] = flags.dump_as_text;
flag_json["enable_stitch_fusion"] = flags.enable_stitch_fusion;
return flag_json.dump();
}



+ 3
- 0
mindspore/ccsrc/utils/context/graph_kernel_flags.cc View File

@@ -159,6 +159,8 @@ void GraphKernelFlags::RegisterFlags(std::map<std::string, std::string> *flag_ma

reg.AddFlag("dump_as_text", &dump_as_text);

reg.AddFlag("enable_stitch_fusion", &enable_stitch_fusion);

reg.AddFlag("opt_level", &opt_level);
reg.AddFlag("auto_tune", &auto_tune);
reg.AddFlag("cluster_limit", &cluster_limit);
@@ -176,6 +178,7 @@ void GraphKernelFlags::RegisterFlags(std::map<std::string, std::string> *flag_ma
std::string GraphKernelFlags::DumpAllFlags() const {
nlohmann::json json;
json["dump_as_text"] = dump_as_text;
json["enable_stitch_fusion"] = enable_stitch_fusion;

json["opt_level"] = opt_level;
json["auto_tune"] = auto_tune;


+ 5
- 0
mindspore/ccsrc/utils/context/graph_kernel_flags.h View File

@@ -54,6 +54,11 @@ class GraphKernelFlags {
*/
bool dump_as_text{false};

/**
* Enable stitch fusion in graph kernel fusion strategy.
*/
bool enable_stitch_fusion{false};

/**
* Optimization level, value from 0 to 3.
* 0: GraphKernel disabled


+ 1
- 1
model_zoo/official/nlp/bert/run_pretrain.py View File

@@ -135,7 +135,7 @@ def _auto_enable_graph_kernel(device_target, graph_kernel_mode):
def _set_graph_kernel_context(device_target, enable_graph_kernel, is_auto_enable_graph_kernel):
if enable_graph_kernel == "true" or is_auto_enable_graph_kernel:
if device_target == 'GPU':
context.set_context(enable_graph_kernel=True)
context.set_context(enable_graph_kernel=True, graph_kernel_flags="--enable_stitch_fusion=true")
else:
logger.warning('Graph kernel only supports GPU back-end now, run with graph kernel off.')



Loading…
Cancel
Save