Browse Source

!29844 Fix bugs in graphkernel split model

Merge pull request !29844 from DeshiChen/0209_costmodel
feature/build-system-rewrite
i-robot Gitee 4 years ago
parent
commit
d3f4ad23a0
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 42 additions and 14 deletions
  1. +29
    -9
      mindspore/ccsrc/backend/optimizer/graph_kernel/split_model/split_model.cc
  2. +2
    -0
      mindspore/ccsrc/backend/optimizer/graph_kernel/split_model/split_model.h
  3. +11
    -5
      mindspore/ccsrc/backend/optimizer/graph_kernel/split_model/split_model_cpu.cc

+ 29
- 9
mindspore/ccsrc/backend/optimizer/graph_kernel/split_model/split_model.cc View File

@@ -90,20 +90,38 @@ AreaPtr SplitModel::NewArea(const PrimOpPtr &op, bool is_output) {
return new_area;
}

void SplitModel::InitGraph(const LiteGraphPtr &litegraph) {
// Push "1" to empty shape to facilitate pattern matching,
// the shapes should be changed before initializing areas.
void SplitModel::AlignShape(const LiteGraphPtr &litegraph) {
for (auto &inp : litegraph->inputs()) {
if (inp->shape.empty()) {
inp->shape.push_back(1LL);
}
}
auto check_pattern = [](const NodePtr &op) {
auto pn = op->As<PrimOp>()->compute_type();
return pn == NodePattern::ELEMWISE || pn == NodePattern::BROADCAST || pn == NodePattern::REDUCE;
};
for (auto &op : litegraph->ops()) {
if (op->shape.empty()) {
op->shape.push_back(1LL);
if (!check_pattern(op)) {
if (op->shape.empty()) {
op->shape.push_back(1LL);
}
continue;
}
auto cur_shape_size = op->shape.size();
for (auto &inp : op->inputs()) {
if (inp->shape.size() > cur_shape_size) {
cur_shape_size = inp->shape.size();
}
}
if (cur_shape_size > op->shape.size()) {
auto num = cur_shape_size - op->shape.size();
op->shape.insert(op->shape.begin(), num, 1LL);
}
}
}

void SplitModel::InitGraph(const LiteGraphPtr &litegraph) {
AlignShape(litegraph);
auto &outputs = litegraph->GetOutputs();
HashSet<NodePtr> outputs_set(outputs.begin(), outputs.end());
for (const auto &op : litegraph->ops()) {
@@ -178,11 +196,13 @@ bool SplitModel::RunOnePattern(const FusePatternPtr &pattern) {
if (pattern->Run(area)) {
MS_LOG(DEBUG) << "Area " << area->ToString() << " matches " << pattern->ToString();
LimitAreaSize(area, &pattern->fused_areas());
FuseAreas(area, pattern->fused_areas(), pattern->direction());
changed = true;
} else {
++iter;
if (!pattern->fused_areas().empty()) {
FuseAreas(area, pattern->fused_areas(), pattern->direction());
changed = true;
continue;
}
}
++iter;
}
return changed;
}


+ 2
- 0
mindspore/ccsrc/backend/optimizer/graph_kernel/split_model/split_model.h View File

@@ -55,6 +55,8 @@ class SplitModel {
protected:
// transform the litegraph to areas, and initialize inner tables.
void InitGraph(const LiteGraphPtr &litegraph);
// Push leading "1" to shapes to facilitate pattern match.
void AlignShape(const LiteGraphPtr &litegraph);
// initialize fusion pattern list.
virtual void InitFusePatterns() = 0;
bool RunOnePattern(const FusePatternPtr &pattern);


+ 11
- 5
mindspore/ccsrc/backend/optimizer/graph_kernel/split_model/split_model_cpu.cc View File

@@ -169,29 +169,35 @@ class FuseElemwiseBroadcastBwd : public FusePattern {
return dom->size() <= size_limit_;
}
bool Match(const AreaPtr &dom) override {
// this pattern is to fuse ALL users of dom area,
// since the broadcast node should not be an output when it fuse nodes in backward.
for (auto &[a, r] : dom->users_with_relation()) {
if (fuse_type_ == FuseType::kDepth && a->input_num() != 1) {
continue;
return false;
}
if (a->pattern() > NodePattern::REDUCE) {
continue;
return false;
}
if (fuse_type_ == FuseType::kWidth) {
if (!fused_areas_.empty() && fused_areas_[0]->dom()->shape != a->dom()->shape) {
continue;
return false;
}
if (HasCircle(dom, a)) {
return false;
}
if (HasCircle(dom, a)) continue;
}
if (a->pattern() == NodePattern::REDUCE) {
// elemwise + reduce
if (dom->pattern() == NodePattern::ELEMWISE && r == EdgeRelation::INJECTIVE) {
(void)fused_areas_.emplace_back(a);
} else {
return false;
}
} else { // a->pattern() < NodePattern::REDUCE
(void)fused_areas_.emplace_back(a);
}
}
return !fused_areas_.empty();
return fused_areas_.size() == dom->user_num();
}

FuseType fuse_type_;


Loading…
Cancel
Save