| @@ -27,23 +27,86 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| constexpr auto kSingleInputIndex = 1; | |||
| constexpr auto kIsolatedDependRealInputIndex = 0; | |||
| constexpr auto kIsolatedDependVirtualInputIndex = 1; | |||
| namespace { | |||
| CNodePtr CreateNewDependNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| const std::vector<AnfNodePtr> &new_depend_inputs) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>(); | |||
| CNodePtr new_depend = nullptr; | |||
| if (kernel_graph == nullptr) { | |||
| new_depend = func_graph->NewCNode(new_depend_inputs); | |||
| MS_EXCEPTION_IF_NULL(new_depend); | |||
| new_depend->set_abstract(cnode->abstract()); | |||
| new_depend->set_scope(cnode->scope()); | |||
| } else { | |||
| new_depend = kernel_graph->NewCNode(cnode); | |||
| MS_EXCEPTION_IF_NULL(new_depend); | |||
| new_depend->set_inputs(new_depend_inputs); | |||
| } | |||
| func_graph->manager()->Replace(cnode, new_depend); | |||
| return new_depend; | |||
| } | |||
| CNodePtr CheckIsolatedVirtualNode(const CNodePtr &cnode) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDepend->name()) { | |||
| return nullptr; | |||
| } | |||
| auto virtual_input_op = AnfAlgo::GetInputNode(cnode, kIsolatedDependVirtualInputIndex); | |||
| if (!AnfAlgo::CheckPrimitiveType(virtual_input_op, prim::kPrimUpdateState)) { | |||
| return nullptr; | |||
| } | |||
| auto real_input_op = AnfAlgo::GetInputNode(cnode, kIsolatedDependRealInputIndex); | |||
| if (!real_input_op->isa<CNode>()) { | |||
| return nullptr; | |||
| } | |||
| auto real_input_cnode = real_input_op->cast<CNodePtr>(); | |||
| return real_input_cnode; | |||
| } | |||
| AnfNodePtr EliminateIsolatedVirtualNodeInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| const CNodePtr &eliminate_node) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| MS_EXCEPTION_IF_NULL(eliminate_node); | |||
| auto replace_node = eliminate_node->input(kSingleInputIndex); | |||
| std::vector<AnfNodePtr> new_depend_inputs = cnode->inputs(); | |||
| new_depend_inputs[kIsolatedDependRealInputIndex + 1] = replace_node; | |||
| auto new_cnode = CreateNewDependNode(func_graph, cnode, new_depend_inputs); | |||
| auto new_node = new_cnode->cast<AnfNodePtr>(); | |||
| return new_node; | |||
| } | |||
| AnfNodePtr GetReplaceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (!node->isa<CNode>()) { | |||
| return nullptr; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| string op_name = AnfAlgo::GetCNodeName(cnode); | |||
| auto replace_cnode = cnode; | |||
| // Process updatestate and depend as isolated node env. | |||
| auto isolated_cnode = CheckIsolatedVirtualNode(replace_cnode); | |||
| if (isolated_cnode != nullptr) { | |||
| replace_cnode = isolated_cnode; | |||
| } | |||
| string op_name = AnfAlgo::GetCNodeName(replace_cnode); | |||
| // Currently we only eliminate transdata or cast nodes. | |||
| if (op_name != kTransDataOpName && op_name != prim::kPrimCast->name()) { | |||
| return nullptr; | |||
| } | |||
| if (!IsNotRealUsedByOthers(func_graph, cnode)) { | |||
| if (!IsNotRealUsedByOthers(func_graph, replace_cnode)) { | |||
| return nullptr; | |||
| } | |||
| CheckCNodeInputSize(cnode, kSingleInputIndex); | |||
| CheckCNodeInputSize(replace_cnode, kSingleInputIndex); | |||
| if (isolated_cnode != nullptr) { | |||
| auto new_depend_node = EliminateIsolatedVirtualNodeInput(func_graph, cnode, replace_cnode); | |||
| return new_depend_node; | |||
| } | |||
| return cnode->input(kSingleInputIndex); | |||
| } | |||
| @@ -137,20 +200,10 @@ const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, con | |||
| return nullptr; | |||
| } | |||
| new_depend_inputs[replace_index] = replace_node; | |||
| // Because depend's input has been changed, so a new depend(UpdateState) node will be created to replaced the old one. | |||
| auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>(); | |||
| CNodePtr new_depend = nullptr; | |||
| if (kernel_graph == nullptr) { | |||
| new_depend = func_graph->NewCNode(new_depend_inputs); | |||
| MS_EXCEPTION_IF_NULL(new_depend); | |||
| new_depend->set_abstract(depend_cnode->abstract()); | |||
| new_depend->set_scope(depend_cnode->scope()); | |||
| } else { | |||
| new_depend = kernel_graph->NewCNode(depend_cnode); | |||
| MS_EXCEPTION_IF_NULL(new_depend); | |||
| new_depend->set_inputs(new_depend_inputs); | |||
| auto new_depend = CreateNewDependNode(func_graph, depend_cnode, new_depend_inputs); | |||
| if (new_depend == nullptr) { | |||
| return nullptr; | |||
| } | |||
| func_graph->manager()->Replace(depend_cnode, new_depend); | |||
| return nullptr; | |||
| } | |||
| @@ -107,6 +107,20 @@ std::vector<std::vector<size_t>> SplitGroup(const std::vector<AnfNodePtr> &topos | |||
| // a = Load(para1, u1) | |||
| // ... | |||
| // b = Load(para1, u2) | |||
| // u3 = UpdateState(u2, b) | |||
| //==> | |||
| // delete the UpdateState | |||
| void DeleteLoadUserUpdateState(const FuncGraphManagerPtr &manager, const AnfNodePtr &load_user, | |||
| const AnfNodePtr &load) { | |||
| const auto &load_cnode = load->cast<CNodePtr>(); | |||
| const auto &u = load_cnode->input(2); | |||
| manager->Replace(load_user, u); | |||
| } | |||
| // Pattern2====================================== | |||
| // a = Load(para1, u1) | |||
| // ... | |||
| // b = Load(para1, u2) | |||
| // t = make_tuple(x, b) | |||
| // u3 = UpdateState(u2, t) | |||
| //==> | |||
| @@ -127,7 +141,7 @@ void DeleteLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const CNodePtr | |||
| manager->Replace(make_tuple, other_input); | |||
| } | |||
| // Pattern2====================================== | |||
| // Pattern3====================================== | |||
| // a = Load(para1, u1) | |||
| // ... | |||
| // b = Load(para1, u2) | |||
| @@ -153,6 +167,11 @@ void ReplaceLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const FuncGrap | |||
| void ReplaceLoadUser(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, const AnfNodePtr &load) { | |||
| auto load_users = manager->node_users()[load]; | |||
| for (const auto &load_user : load_users) { | |||
| // Pattern1 | |||
| if (IsPrimitiveCNode(load_user.first, prim::kPrimUpdateState)) { | |||
| DeleteLoadUserUpdateState(manager, load_user.first, load); | |||
| continue; | |||
| } | |||
| if (IsPrimitiveCNode(load_user.first, prim::kPrimMakeTuple)) { | |||
| const auto &make_tuple = load_user.first->cast<CNodePtr>(); | |||
| auto &maketuple_users = manager->node_users()[make_tuple]; | |||
| @@ -161,12 +180,12 @@ void ReplaceLoadUser(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, | |||
| if (!maketuple_as_input_of_update) { | |||
| continue; | |||
| } | |||
| // Pattern1 | |||
| // Pattern2 | |||
| if (make_tuple->size() == 3) { | |||
| DeleteLoadUserMakeTuple(manager, make_tuple, load); | |||
| continue; | |||
| } | |||
| // Pattern2 | |||
| // Pattern3 | |||
| if (make_tuple->size() > 3) { | |||
| ReplaceLoadUserMakeTuple(manager, fg, make_tuple, load); | |||
| } | |||
| @@ -17,10 +17,11 @@ from types import FunctionType, MethodType | |||
| from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean, | |||
| _get_parallel_mode) | |||
| from mindspore.context import ParallelMode | |||
| from mindspore.context import ParallelMode, get_auto_parallel_context | |||
| from mindspore._checkparam import Validator as validator | |||
| from ...common import dtype as mstype | |||
| from ...common.parameter import Parameter, ParameterTuple | |||
| from ...common.tensor import Tensor | |||
| from ...ops import composite as C | |||
| from ...ops import functional as F | |||
| from ...ops import operations as P | |||
| @@ -62,6 +63,19 @@ def _tensors_cast_datatype(datatype, param): | |||
| """ | |||
| return F.cast(param, datatype) | |||
| _gradient_accumulation_op = C.MultitypeFuncGraph("gradient_accumulation_op") | |||
| @_gradient_accumulation_op.register("Int64", "Tensor", "Tensor") | |||
| def _cumulative_grad(accumulation_step, cumulative_grad, grad): | |||
| """Apply gradient accumulation to cumulative grad.""" | |||
| return P.AssignAdd()(cumulative_grad, grad / accumulation_step) | |||
| _gradient_clear_op = C.MultitypeFuncGraph("gradient_clear_op") | |||
| @_gradient_clear_op.register("Tensor") | |||
| def _clear_grad(cumulative_grad): | |||
| zero_grad = P.ZerosLike()(cumulative_grad) | |||
| return F.assign(cumulative_grad, zero_grad) | |||
| class WithLossCell(Cell): | |||
| r""" | |||
| @@ -347,15 +361,28 @@ class TrainOneStepCell(Cell): | |||
| self.mean = _get_gradients_mean() | |||
| self.degree = _get_device_num() | |||
| self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree) | |||
| self.use_grad_accumulation = False | |||
| if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.STAND_ALONE): | |||
| self.use_grad_accumulation = True | |||
| if self.use_grad_accumulation: | |||
| self.max_accumulation_step = get_auto_parallel_context("grad_accumulation_step") | |||
| if self.max_accumulation_step <= 1: | |||
| self.max_accumulation_step = 1 | |||
| self.use_grad_accumulation = False | |||
| if self.use_grad_accumulation: | |||
| self.grad_accumulation = GradientAccumulation(self.max_accumulation_step, self.optimizer) | |||
| def construct(self, *inputs): | |||
| weights = self.weights | |||
| loss = self.network(*inputs) | |||
| sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | |||
| sens = F.depend(sens, loss) | |||
| grads = self.grad(self.network, weights)(*inputs, sens) | |||
| grads = self.grad_reducer(grads) | |||
| loss = F.depend(loss, self.optimizer(grads)) | |||
| if self.use_grad_accumulation: | |||
| loss = self.grad_accumulation(loss, grads) | |||
| else: | |||
| loss = F.depend(loss, self.optimizer(grads)) | |||
| return loss | |||
| @@ -557,3 +584,34 @@ class _BroadCastCell(Cell): | |||
| params = self.broadcast(params) | |||
| new_params = self.map_(F.partial(_cast_datatype), datatypes, params) | |||
| return new_params | |||
| class GradientAccumulation(Cell): | |||
| """ | |||
| After accumulating the gradients of multiple steps, call to optimize its update. | |||
| Args: | |||
| max_accumulation_step (int): Steps to accumulate gradients. | |||
| optimizer(Cell):Optimizer used. | |||
| """ | |||
| def __init__(self, max_accumulation_step, optimizer): | |||
| super(GradientAccumulation, self).__init__() | |||
| self._max_accumulation_step = max_accumulation_step | |||
| self.optimizer = optimizer | |||
| self.weights = optimizer.parameters | |||
| self.hyper_map = C.HyperMap() | |||
| self._grad_accumulation = self.weights.clone(prefix="grad_accumulation", init='zeros') | |||
| self._accumulation_step = Parameter(Tensor(0, dtype=mstype.int32), name="accumulation_step") | |||
| def construct(self, loss, grads): | |||
| loss = F.depend(loss, self.hyper_map(F.partial(_gradient_accumulation_op, self._max_accumulation_step), | |||
| self._grad_accumulation, grads)) | |||
| self._accumulation_step += 1 | |||
| if self._accumulation_step >= self._max_accumulation_step: | |||
| loss = F.depend(loss, self.optimizer(self._grad_accumulation)) | |||
| self._accumulation_step = 0 | |||
| if self._accumulation_step == 0: | |||
| loss = F.depend(loss, self.hyper_map(F.partial(_gradient_clear_op), self._grad_accumulation)) | |||
| return loss | |||
| @@ -319,7 +319,10 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell): | |||
| overflow = self.process_loss_scale(cond) | |||
| # if there is no overflow, do optimize | |||
| if not overflow: | |||
| loss = F.depend(loss, self.optimizer(grads)) | |||
| if self.use_grad_accumulation: | |||
| loss = self.grad_accumulation(loss, grads) | |||
| else: | |||
| loss = F.depend(loss, self.optimizer(grads)) | |||
| return loss, cond, scaling_sens | |||
| def set_sense_scale(self, sens): | |||