|
|
|
@@ -90,20 +90,38 @@ AreaPtr SplitModel::NewArea(const PrimOpPtr &op, bool is_output) { |
|
|
|
return new_area; |
|
|
|
} |
|
|
|
|
|
|
|
void SplitModel::InitGraph(const LiteGraphPtr &litegraph) { |
|
|
|
// Push "1" to empty shape to facilitate pattern matching, |
|
|
|
// the shapes should be changed before initializing areas. |
|
|
|
void SplitModel::AlignShape(const LiteGraphPtr &litegraph) { |
|
|
|
for (auto &inp : litegraph->inputs()) { |
|
|
|
if (inp->shape.empty()) { |
|
|
|
inp->shape.push_back(1LL); |
|
|
|
} |
|
|
|
} |
|
|
|
auto check_pattern = [](const NodePtr &op) { |
|
|
|
auto pn = op->As<PrimOp>()->compute_type(); |
|
|
|
return pn == NodePattern::ELEMWISE || pn == NodePattern::BROADCAST || pn == NodePattern::REDUCE; |
|
|
|
}; |
|
|
|
for (auto &op : litegraph->ops()) { |
|
|
|
if (op->shape.empty()) { |
|
|
|
op->shape.push_back(1LL); |
|
|
|
if (!check_pattern(op)) { |
|
|
|
if (op->shape.empty()) { |
|
|
|
op->shape.push_back(1LL); |
|
|
|
} |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto cur_shape_size = op->shape.size(); |
|
|
|
for (auto &inp : op->inputs()) { |
|
|
|
if (inp->shape.size() > cur_shape_size) { |
|
|
|
cur_shape_size = inp->shape.size(); |
|
|
|
} |
|
|
|
} |
|
|
|
if (cur_shape_size > op->shape.size()) { |
|
|
|
auto num = cur_shape_size - op->shape.size(); |
|
|
|
op->shape.insert(op->shape.begin(), num, 1LL); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void SplitModel::InitGraph(const LiteGraphPtr &litegraph) { |
|
|
|
AlignShape(litegraph); |
|
|
|
auto &outputs = litegraph->GetOutputs(); |
|
|
|
HashSet<NodePtr> outputs_set(outputs.begin(), outputs.end()); |
|
|
|
for (const auto &op : litegraph->ops()) { |
|
|
|
@@ -178,11 +196,13 @@ bool SplitModel::RunOnePattern(const FusePatternPtr &pattern) { |
|
|
|
if (pattern->Run(area)) { |
|
|
|
MS_LOG(DEBUG) << "Area " << area->ToString() << " matches " << pattern->ToString(); |
|
|
|
LimitAreaSize(area, &pattern->fused_areas()); |
|
|
|
FuseAreas(area, pattern->fused_areas(), pattern->direction()); |
|
|
|
changed = true; |
|
|
|
} else { |
|
|
|
++iter; |
|
|
|
if (!pattern->fused_areas().empty()) { |
|
|
|
FuseAreas(area, pattern->fused_areas(), pattern->direction()); |
|
|
|
changed = true; |
|
|
|
continue; |
|
|
|
} |
|
|
|
} |
|
|
|
++iter; |
|
|
|
} |
|
|
|
return changed; |
|
|
|
} |
|
|
|
|