| @@ -13,11 +13,12 @@ | |||
| # limitations under the License. | |||
| # =========================================================================== | |||
| """Cost model splitter""" | |||
| from functools import reduce | |||
| from .model import PrimLib, Graph, Tensor | |||
| use_poly_reduce = True | |||
| class GraphSplitByPattern: | |||
| """Graph splitter""" | |||
| class Area: | |||
| @@ -33,6 +34,8 @@ class GraphSplitByPattern: | |||
| self.mode = self.MODE_BASIC | |||
| if self.pattern == PrimLib.TRANSFORM or (use_poly_reduce and self.pattern == PrimLib.REDUCE): | |||
| self.mode = self.MODE_COMPOSITE | |||
| if init_op.prim == "AddN": | |||
| self.mode = self.MODE_COMPOSITE | |||
| self.is_output = is_output | |||
| self.output_excluded = set() | |||
| if self.pattern == PrimLib.REDUCE: | |||
| @@ -196,7 +199,7 @@ class GraphSplitByPattern: | |||
| min_area, forward_fuse = None, False | |||
| for a, _ in dom.out_relations.items(): | |||
| if a.pattern <= PrimLib.BROADCAST and dom.check_circle(a) and \ | |||
| (min_area is None or a.pattern < min_area.pattern): | |||
| (min_area is None or a.pattern < min_area.pattern): | |||
| min_area = a | |||
| for a, _ in dom.in_relations.items(): | |||
| if a.pattern <= PrimLib.BROADCAST and a.check_circle(dom) and \ | |||
| @@ -210,7 +213,7 @@ class GraphSplitByPattern: | |||
| return None | |||
| a, r = list(dom.in_relations.items())[0] | |||
| if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or r != PrimLib.ELEMWISE or \ | |||
| a.dom_op().output.shape != dom.dom_op().output.shape: | |||
| a.dom_op().output.shape != dom.dom_op().output.shape: | |||
| return None | |||
| return [a], True | |||
| @@ -220,7 +223,7 @@ class GraphSplitByPattern: | |||
| fused = [] | |||
| for a, r in dom.in_relations.items(): | |||
| if a.pattern <= PrimLib.BROADCAST and r == PrimLib.ELEMWISE and a.check_circle(dom) and \ | |||
| a.dom_op().output.shape == dom.dom_op().output.shape: | |||
| a.dom_op().output.shape == dom.dom_op().output.shape: | |||
| fused.append(a) | |||
| return fused, True | |||
| @@ -231,7 +234,7 @@ class GraphSplitByPattern: | |||
| def _broadcast_depth(dom): | |||
| if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.out_relations) != 1 or \ | |||
| dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH: | |||
| dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH: | |||
| return None | |||
| a, r = list(dom.out_relations.items())[0] | |||
| if _broadcast_pat_exclude(dom, a, r) or len(a.in_relations) != 1: | |||
| @@ -240,12 +243,12 @@ class GraphSplitByPattern: | |||
| def _broadcast_width(dom): | |||
| if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or \ | |||
| dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH: | |||
| dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH: | |||
| return None | |||
| fused = [] | |||
| for a, r in dom.out_relations.items(): | |||
| if _broadcast_pat_exclude(dom, a, r) or not dom.check_circle(a) or \ | |||
| (fused and fused[0].dom_op().output.shape != a.dom_op().output.shape): | |||
| (fused and fused[0].dom_op().output.shape != a.dom_op().output.shape): | |||
| return None | |||
| fused.append(a) | |||
| return fused, False | |||
| @@ -301,8 +304,19 @@ class GraphSplitByPattern: | |||
| return size | |||
| def _reduce_output(dom): | |||
| def _is_atomic_add_available(dom): | |||
| if any(["Reduce" in x.prim for x in dom.ops[1:]]): | |||
| return False | |||
| op = dom.ops[0] | |||
| reduce_axis = op.attrs["reduce_axis"] | |||
| if len(op.inputs[0].shape) - 1 in reduce_axis: | |||
| reduce_size = reduce(lambda x, y: x * y, [op.inputs[0].shape[i] for i in reduce_axis]) | |||
| return reduce_size >= 1024 | |||
| return True | |||
| if dom.pattern != PrimLib.REDUCE: | |||
| return None | |||
| if _is_atomic_add_available(dom): | |||
| return None | |||
| is_all_reduce = _tensor_size(dom.ops[0].output) == 1 | |||
| # excluded large size all reduce | |||
| if is_all_reduce and _tensor_size(dom.ops[0].inputs[0]) > 1024 * 12: | |||
| @@ -310,7 +324,7 @@ class GraphSplitByPattern: | |||
| fused = [] | |||
| for a, r in dom.out_relations.items(): | |||
| if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.BROADCAST and \ | |||
| dom.check_circle(a) and not dom.reduce_out_exclude(a): | |||
| dom.check_circle(a) and not dom.reduce_out_exclude(a): | |||
| fused.append(a) | |||
| return fused, False | |||
| @@ -207,7 +207,7 @@ class CompositeGraph: | |||
| def _get_axis_while_none(input_shape, output_shape): | |||
| red_axis = [] | |||
| if len(output_shape) == len(input_shape): | |||
| for s, i in enumerate(output_shape): | |||
| for i, s in enumerate(output_shape): | |||
| if s == 1 and input_shape[i] > 1: | |||
| red_axis.append(i) | |||
| else: | |||
| @@ -158,7 +158,7 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr | |||
| } | |||
| auto fuse_nodes = FindFuseCNodes(node, depend_prior); | |||
| if (fuse_nodes.size() <= 1) { | |||
| if (fuse_nodes.empty() || (fuse_nodes.size() == 1 && AnfAlgo::IsGraphKernel(fuse_nodes[0]))) { | |||
| continue; | |||
| } | |||
| changed = true; | |||
| @@ -173,17 +173,11 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr | |||
| } | |||
| } // namespace | |||
| bool FuseBasicOps(const FuncGraphPtr &kernel_graph) { | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| auto mng = kernel_graph->manager(); | |||
| if (mng == nullptr) { | |||
| mng = Manage(kernel_graph, true); | |||
| kernel_graph->set_manager(mng); | |||
| } | |||
| bool FuseBasicOps(const FuncGraphPtr &func_graph) { | |||
| std::unordered_set<AnfNodePtr> fused_ops; | |||
| auto todos = TopoSort(kernel_graph->get_return()); | |||
| auto todos = TopoSort(func_graph->get_return()); | |||
| std::reverse(todos.begin(), todos.end()); | |||
| return FuseBasicOps(kernel_graph, todos, &fused_ops); | |||
| return FuseBasicOps(func_graph, todos, &fused_ops); | |||
| } | |||
| void EliminateGetitem(const FuncGraphPtr &func_graph) { | |||
| @@ -197,9 +191,16 @@ void EliminateGetitem(const FuncGraphPtr &func_graph) { | |||
| } | |||
| bool BasicOpsFusion::Run(const FuncGraphPtr &func_graph) { | |||
| auto mng = func_graph->manager(); | |||
| if (mng == nullptr) { | |||
| mng = Manage(func_graph, true); | |||
| func_graph->set_manager(mng); | |||
| } | |||
| bool changed = FuseBasicOps(func_graph); | |||
| if (changed) { | |||
| EliminateGetitem(func_graph); | |||
| mng->RemoveRoots(); | |||
| mng->KeepRoots({func_graph}); | |||
| } | |||
| return changed; | |||
| } | |||
| @@ -192,7 +192,7 @@ class EliminateGetitemForControlDepend : public Pass { | |||
| MS_EXCEPTION_IF_NULL(maketuple); | |||
| std::vector<size_t> result; | |||
| for (auto i : indexes_) { | |||
| auto real_output = maketuple->input(i); | |||
| auto real_output = maketuple->input(i + 1); | |||
| if (users[real_output].size() > 1) { | |||
| result.push_back(i); | |||
| } | |||
| @@ -708,11 +708,11 @@ std::unordered_set<PrimitivePtr> GetExpandOps() { | |||
| prim::kPrimGeluGrad, | |||
| prim::kPrimFusedAdam, | |||
| prim::kPrimFusedAdamWeightDecay, | |||
| prim::kPrimTanhGrad, | |||
| prim::kPrimReduceMean, | |||
| prim::kPrimMaximumGrad, | |||
| prim::kPrimMinimumGrad, | |||
| prim::kPrimGkDropout | |||
| prim::kPrimGkDropout, | |||
| prim::kPrimDropoutGrad, | |||
| #endif | |||
| }; | |||
| return expand_ops; | |||
| @@ -544,7 +544,8 @@ class CostModelSplitSchemer : public Splitter::SplitSchemer { | |||
| } | |||
| func_graph_ = func_graph; | |||
| this->Run(); | |||
| return split_plan_.size() > 1; | |||
| if (split_plan_.empty()) return false; | |||
| return split_plan_.size() > 1 || NeedInline(0); | |||
| } | |||
| bool NeedInline(size_t group_id) const override { | |||
| @@ -629,7 +630,7 @@ class CostModelSplitSchemer : public Splitter::SplitSchemer { | |||
| } | |||
| GetValidKernelNodes(); | |||
| // call CostModel to get a split plan. | |||
| if (!SplitByCostModel() || split_plan_.size() <= 1) { | |||
| if (!SplitByCostModel()) { | |||
| split_plan_.clear(); | |||
| need_inline_.clear(); | |||
| return; | |||
| @@ -77,8 +77,8 @@ const AnfNodePtr SubstituteDropout::Process(const FuncGraphPtr &func_graph, cons | |||
| ShapeVector shape_i64; | |||
| std::transform(shape.begin(), shape.end(), std::back_inserter(shape_i64), [](size_t x) { return SizeToLong(x); }); | |||
| // Create new tensor | |||
| AnfNodePtrList uniform_input = {NewValueNode(prim::kPrimCudnnUniformReal)}; | |||
| // The primitive should use a clone, otherwise the attr seed will be overrided. | |||
| AnfNodePtrList uniform_input = {NewValueNode(prim::kPrimCudnnUniformReal->Clone())}; | |||
| auto tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt64, ShapeVector(1, SizeToLong(shape.size())), | |||
| static_cast<void *>(&shape[0]), kNumberTypeInt64); | |||
| uniform_input.push_back(NewValueNode(tensor)); | |||
| @@ -98,8 +98,8 @@ const AnfNodePtr SubstituteDropout::Process(const FuncGraphPtr &func_graph, cons | |||
| // create new uniform_real_node | |||
| auto uniform_real_node = func_graph->NewCNode(uniform_input); | |||
| AnfAlgo::GetCNodePrimitive(uniform_real_node)->set_attr("seed", MakeValue(SizeToLong(rand_r(&seed_)))); | |||
| AnfAlgo::GetCNodePrimitive(uniform_real_node)->set_attr("seed2", MakeValue(SizeToLong(rand_r(&seed_)))); | |||
| AnfAlgo::GetCNodePrimitive(uniform_real_node)->set_attr("seed", MakeValue(SizeToLong(seed_++))); | |||
| AnfAlgo::GetCNodePrimitive(uniform_real_node)->set_attr("seed2", MakeValue(SizeToLong(seed_++))); | |||
| auto uniform_abstract = std::make_shared<abstract::AbstractTensor>(std::make_shared<Float>(32), shape_i64); | |||
| uniform_real_node->set_abstract(uniform_abstract); | |||
| uniform_real_node->set_kernel_info(std::make_shared<device::KernelInfo>()); | |||