| @@ -13,11 +13,12 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # =========================================================================== | # =========================================================================== | ||||
| """Cost model splitter""" | """Cost model splitter""" | ||||
| from functools import reduce | |||||
| from .model import PrimLib, Graph, Tensor | from .model import PrimLib, Graph, Tensor | ||||
| use_poly_reduce = True | use_poly_reduce = True | ||||
| class GraphSplitByPattern: | class GraphSplitByPattern: | ||||
| """Graph splitter""" | """Graph splitter""" | ||||
| class Area: | class Area: | ||||
| @@ -33,6 +34,8 @@ class GraphSplitByPattern: | |||||
| self.mode = self.MODE_BASIC | self.mode = self.MODE_BASIC | ||||
| if self.pattern == PrimLib.TRANSFORM or (use_poly_reduce and self.pattern == PrimLib.REDUCE): | if self.pattern == PrimLib.TRANSFORM or (use_poly_reduce and self.pattern == PrimLib.REDUCE): | ||||
| self.mode = self.MODE_COMPOSITE | self.mode = self.MODE_COMPOSITE | ||||
| if init_op.prim == "AddN": | |||||
| self.mode = self.MODE_COMPOSITE | |||||
| self.is_output = is_output | self.is_output = is_output | ||||
| self.output_excluded = set() | self.output_excluded = set() | ||||
| if self.pattern == PrimLib.REDUCE: | if self.pattern == PrimLib.REDUCE: | ||||
| @@ -196,7 +199,7 @@ class GraphSplitByPattern: | |||||
| min_area, forward_fuse = None, False | min_area, forward_fuse = None, False | ||||
| for a, _ in dom.out_relations.items(): | for a, _ in dom.out_relations.items(): | ||||
| if a.pattern <= PrimLib.BROADCAST and dom.check_circle(a) and \ | if a.pattern <= PrimLib.BROADCAST and dom.check_circle(a) and \ | ||||
| (min_area is None or a.pattern < min_area.pattern): | |||||
| (min_area is None or a.pattern < min_area.pattern): | |||||
| min_area = a | min_area = a | ||||
| for a, _ in dom.in_relations.items(): | for a, _ in dom.in_relations.items(): | ||||
| if a.pattern <= PrimLib.BROADCAST and a.check_circle(dom) and \ | if a.pattern <= PrimLib.BROADCAST and a.check_circle(dom) and \ | ||||
| @@ -210,7 +213,7 @@ class GraphSplitByPattern: | |||||
| return None | return None | ||||
| a, r = list(dom.in_relations.items())[0] | a, r = list(dom.in_relations.items())[0] | ||||
| if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or r != PrimLib.ELEMWISE or \ | if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or r != PrimLib.ELEMWISE or \ | ||||
| a.dom_op().output.shape != dom.dom_op().output.shape: | |||||
| a.dom_op().output.shape != dom.dom_op().output.shape: | |||||
| return None | return None | ||||
| return [a], True | return [a], True | ||||
| @@ -220,7 +223,7 @@ class GraphSplitByPattern: | |||||
| fused = [] | fused = [] | ||||
| for a, r in dom.in_relations.items(): | for a, r in dom.in_relations.items(): | ||||
| if a.pattern <= PrimLib.BROADCAST and r == PrimLib.ELEMWISE and a.check_circle(dom) and \ | if a.pattern <= PrimLib.BROADCAST and r == PrimLib.ELEMWISE and a.check_circle(dom) and \ | ||||
| a.dom_op().output.shape == dom.dom_op().output.shape: | |||||
| a.dom_op().output.shape == dom.dom_op().output.shape: | |||||
| fused.append(a) | fused.append(a) | ||||
| return fused, True | return fused, True | ||||
| @@ -231,7 +234,7 @@ class GraphSplitByPattern: | |||||
| def _broadcast_depth(dom): | def _broadcast_depth(dom): | ||||
| if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.out_relations) != 1 or \ | if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.out_relations) != 1 or \ | ||||
| dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH: | |||||
| dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH: | |||||
| return None | return None | ||||
| a, r = list(dom.out_relations.items())[0] | a, r = list(dom.out_relations.items())[0] | ||||
| if _broadcast_pat_exclude(dom, a, r) or len(a.in_relations) != 1: | if _broadcast_pat_exclude(dom, a, r) or len(a.in_relations) != 1: | ||||
| @@ -240,12 +243,12 @@ class GraphSplitByPattern: | |||||
| def _broadcast_width(dom): | def _broadcast_width(dom): | ||||
| if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or \ | if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or \ | ||||
| dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH: | |||||
| dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH: | |||||
| return None | return None | ||||
| fused = [] | fused = [] | ||||
| for a, r in dom.out_relations.items(): | for a, r in dom.out_relations.items(): | ||||
| if _broadcast_pat_exclude(dom, a, r) or not dom.check_circle(a) or \ | if _broadcast_pat_exclude(dom, a, r) or not dom.check_circle(a) or \ | ||||
| (fused and fused[0].dom_op().output.shape != a.dom_op().output.shape): | |||||
| (fused and fused[0].dom_op().output.shape != a.dom_op().output.shape): | |||||
| return None | return None | ||||
| fused.append(a) | fused.append(a) | ||||
| return fused, False | return fused, False | ||||
| @@ -301,8 +304,19 @@ class GraphSplitByPattern: | |||||
| return size | return size | ||||
| def _reduce_output(dom): | def _reduce_output(dom): | ||||
| def _is_atomic_add_available(dom): | |||||
| if any(["Reduce" in x.prim for x in dom.ops[1:]]): | |||||
| return False | |||||
| op = dom.ops[0] | |||||
| reduce_axis = op.attrs["reduce_axis"] | |||||
| if len(op.inputs[0].shape) - 1 in reduce_axis: | |||||
| reduce_size = reduce(lambda x, y: x * y, [op.inputs[0].shape[i] for i in reduce_axis]) | |||||
| return reduce_size >= 1024 | |||||
| return True | |||||
| if dom.pattern != PrimLib.REDUCE: | if dom.pattern != PrimLib.REDUCE: | ||||
| return None | return None | ||||
| if _is_atomic_add_available(dom): | |||||
| return None | |||||
| is_all_reduce = _tensor_size(dom.ops[0].output) == 1 | is_all_reduce = _tensor_size(dom.ops[0].output) == 1 | ||||
| # excluded large size all reduce | # excluded large size all reduce | ||||
| if is_all_reduce and _tensor_size(dom.ops[0].inputs[0]) > 1024 * 12: | if is_all_reduce and _tensor_size(dom.ops[0].inputs[0]) > 1024 * 12: | ||||
| @@ -310,7 +324,7 @@ class GraphSplitByPattern: | |||||
| fused = [] | fused = [] | ||||
| for a, r in dom.out_relations.items(): | for a, r in dom.out_relations.items(): | ||||
| if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.BROADCAST and \ | if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.BROADCAST and \ | ||||
| dom.check_circle(a) and not dom.reduce_out_exclude(a): | |||||
| dom.check_circle(a) and not dom.reduce_out_exclude(a): | |||||
| fused.append(a) | fused.append(a) | ||||
| return fused, False | return fused, False | ||||
| @@ -207,7 +207,7 @@ class CompositeGraph: | |||||
| def _get_axis_while_none(input_shape, output_shape): | def _get_axis_while_none(input_shape, output_shape): | ||||
| red_axis = [] | red_axis = [] | ||||
| if len(output_shape) == len(input_shape): | if len(output_shape) == len(input_shape): | ||||
| for s, i in enumerate(output_shape): | |||||
| for i, s in enumerate(output_shape): | |||||
| if s == 1 and input_shape[i] > 1: | if s == 1 and input_shape[i] > 1: | ||||
| red_axis.append(i) | red_axis.append(i) | ||||
| else: | else: | ||||
| @@ -158,7 +158,7 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr | |||||
| } | } | ||||
| auto fuse_nodes = FindFuseCNodes(node, depend_prior); | auto fuse_nodes = FindFuseCNodes(node, depend_prior); | ||||
| if (fuse_nodes.size() <= 1) { | |||||
| if (fuse_nodes.empty() || (fuse_nodes.size() == 1 && AnfAlgo::IsGraphKernel(fuse_nodes[0]))) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| changed = true; | changed = true; | ||||
| @@ -173,17 +173,11 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| bool FuseBasicOps(const FuncGraphPtr &kernel_graph) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||||
| auto mng = kernel_graph->manager(); | |||||
| if (mng == nullptr) { | |||||
| mng = Manage(kernel_graph, true); | |||||
| kernel_graph->set_manager(mng); | |||||
| } | |||||
| bool FuseBasicOps(const FuncGraphPtr &func_graph) { | |||||
| std::unordered_set<AnfNodePtr> fused_ops; | std::unordered_set<AnfNodePtr> fused_ops; | ||||
| auto todos = TopoSort(kernel_graph->get_return()); | |||||
| auto todos = TopoSort(func_graph->get_return()); | |||||
| std::reverse(todos.begin(), todos.end()); | std::reverse(todos.begin(), todos.end()); | ||||
| return FuseBasicOps(kernel_graph, todos, &fused_ops); | |||||
| return FuseBasicOps(func_graph, todos, &fused_ops); | |||||
| } | } | ||||
| void EliminateGetitem(const FuncGraphPtr &func_graph) { | void EliminateGetitem(const FuncGraphPtr &func_graph) { | ||||
| @@ -197,9 +191,16 @@ void EliminateGetitem(const FuncGraphPtr &func_graph) { | |||||
| } | } | ||||
| bool BasicOpsFusion::Run(const FuncGraphPtr &func_graph) { | bool BasicOpsFusion::Run(const FuncGraphPtr &func_graph) { | ||||
| auto mng = func_graph->manager(); | |||||
| if (mng == nullptr) { | |||||
| mng = Manage(func_graph, true); | |||||
| func_graph->set_manager(mng); | |||||
| } | |||||
| bool changed = FuseBasicOps(func_graph); | bool changed = FuseBasicOps(func_graph); | ||||
| if (changed) { | if (changed) { | ||||
| EliminateGetitem(func_graph); | EliminateGetitem(func_graph); | ||||
| mng->RemoveRoots(); | |||||
| mng->KeepRoots({func_graph}); | |||||
| } | } | ||||
| return changed; | return changed; | ||||
| } | } | ||||
| @@ -192,7 +192,7 @@ class EliminateGetitemForControlDepend : public Pass { | |||||
| MS_EXCEPTION_IF_NULL(maketuple); | MS_EXCEPTION_IF_NULL(maketuple); | ||||
| std::vector<size_t> result; | std::vector<size_t> result; | ||||
| for (auto i : indexes_) { | for (auto i : indexes_) { | ||||
| auto real_output = maketuple->input(i); | |||||
| auto real_output = maketuple->input(i + 1); | |||||
| if (users[real_output].size() > 1) { | if (users[real_output].size() > 1) { | ||||
| result.push_back(i); | result.push_back(i); | ||||
| } | } | ||||
| @@ -708,11 +708,11 @@ std::unordered_set<PrimitivePtr> GetExpandOps() { | |||||
| prim::kPrimGeluGrad, | prim::kPrimGeluGrad, | ||||
| prim::kPrimFusedAdam, | prim::kPrimFusedAdam, | ||||
| prim::kPrimFusedAdamWeightDecay, | prim::kPrimFusedAdamWeightDecay, | ||||
| prim::kPrimTanhGrad, | |||||
| prim::kPrimReduceMean, | prim::kPrimReduceMean, | ||||
| prim::kPrimMaximumGrad, | prim::kPrimMaximumGrad, | ||||
| prim::kPrimMinimumGrad, | prim::kPrimMinimumGrad, | ||||
| prim::kPrimGkDropout | |||||
| prim::kPrimGkDropout, | |||||
| prim::kPrimDropoutGrad, | |||||
| #endif | #endif | ||||
| }; | }; | ||||
| return expand_ops; | return expand_ops; | ||||
| @@ -544,7 +544,8 @@ class CostModelSplitSchemer : public Splitter::SplitSchemer { | |||||
| } | } | ||||
| func_graph_ = func_graph; | func_graph_ = func_graph; | ||||
| this->Run(); | this->Run(); | ||||
| return split_plan_.size() > 1; | |||||
| if (split_plan_.empty()) return false; | |||||
| return split_plan_.size() > 1 || NeedInline(0); | |||||
| } | } | ||||
| bool NeedInline(size_t group_id) const override { | bool NeedInline(size_t group_id) const override { | ||||
| @@ -629,7 +630,7 @@ class CostModelSplitSchemer : public Splitter::SplitSchemer { | |||||
| } | } | ||||
| GetValidKernelNodes(); | GetValidKernelNodes(); | ||||
| // call CostModel to get a split plan. | // call CostModel to get a split plan. | ||||
| if (!SplitByCostModel() || split_plan_.size() <= 1) { | |||||
| if (!SplitByCostModel()) { | |||||
| split_plan_.clear(); | split_plan_.clear(); | ||||
| need_inline_.clear(); | need_inline_.clear(); | ||||
| return; | return; | ||||
| @@ -77,8 +77,8 @@ const AnfNodePtr SubstituteDropout::Process(const FuncGraphPtr &func_graph, cons | |||||
| ShapeVector shape_i64; | ShapeVector shape_i64; | ||||
| std::transform(shape.begin(), shape.end(), std::back_inserter(shape_i64), [](size_t x) { return SizeToLong(x); }); | std::transform(shape.begin(), shape.end(), std::back_inserter(shape_i64), [](size_t x) { return SizeToLong(x); }); | ||||
| // Create new tensor | |||||
| AnfNodePtrList uniform_input = {NewValueNode(prim::kPrimCudnnUniformReal)}; | |||||
| // The primitive should use a clone, otherwise the attr seed will be overrided. | |||||
| AnfNodePtrList uniform_input = {NewValueNode(prim::kPrimCudnnUniformReal->Clone())}; | |||||
| auto tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt64, ShapeVector(1, SizeToLong(shape.size())), | auto tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt64, ShapeVector(1, SizeToLong(shape.size())), | ||||
| static_cast<void *>(&shape[0]), kNumberTypeInt64); | static_cast<void *>(&shape[0]), kNumberTypeInt64); | ||||
| uniform_input.push_back(NewValueNode(tensor)); | uniform_input.push_back(NewValueNode(tensor)); | ||||
| @@ -98,8 +98,8 @@ const AnfNodePtr SubstituteDropout::Process(const FuncGraphPtr &func_graph, cons | |||||
| // create new uniform_real_node | // create new uniform_real_node | ||||
| auto uniform_real_node = func_graph->NewCNode(uniform_input); | auto uniform_real_node = func_graph->NewCNode(uniform_input); | ||||
| AnfAlgo::GetCNodePrimitive(uniform_real_node)->set_attr("seed", MakeValue(SizeToLong(rand_r(&seed_)))); | |||||
| AnfAlgo::GetCNodePrimitive(uniform_real_node)->set_attr("seed2", MakeValue(SizeToLong(rand_r(&seed_)))); | |||||
| AnfAlgo::GetCNodePrimitive(uniform_real_node)->set_attr("seed", MakeValue(SizeToLong(seed_++))); | |||||
| AnfAlgo::GetCNodePrimitive(uniform_real_node)->set_attr("seed2", MakeValue(SizeToLong(seed_++))); | |||||
| auto uniform_abstract = std::make_shared<abstract::AbstractTensor>(std::make_shared<Float>(32), shape_i64); | auto uniform_abstract = std::make_shared<abstract::AbstractTensor>(std::make_shared<Float>(32), shape_i64); | ||||
| uniform_real_node->set_abstract(uniform_abstract); | uniform_real_node->set_abstract(uniform_abstract); | ||||
| uniform_real_node->set_kernel_info(std::make_shared<device::KernelInfo>()); | uniform_real_node->set_kernel_info(std::make_shared<device::KernelInfo>()); | ||||