| @@ -29,6 +29,18 @@ | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| const std::set<PrimitivePtr> END_NODE_BLACK_LIST = {prim::kPrimDepend, prim::kPrimTupleGetItem, | |||
| prim::kPrimSoftmaxCrossEntropyWithLogits}; | |||
| static bool IsInEndNodeBlackList(const CNodePtr &cnode) { | |||
| for (auto &prim : END_NODE_BLACK_LIST) { | |||
| if (IsPrimitiveCNode(cnode, prim)) { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| AnfNodePtr FindAccuGrad(const CNodePtr &cnode) { | |||
| auto pre_node = cnode->input(1); | |||
| while (true) { | |||
| @@ -392,7 +404,7 @@ void BroadCastMicroBatch(const CNodePtr &node, NodeUsersMap *node_users_map, con | |||
| AnfNodePtr GetPreNode(const AnfNodePtr &node) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (IsPrimitiveCNode(node, prim::kPrimDepend)) { | |||
| if (IsInEndNodeBlackList(cnode)) { | |||
| return GetPreNode(cnode->input(1)); | |||
| } | |||
| return cnode; | |||
| @@ -134,6 +134,11 @@ void PipelineTransformer::LabelMicroBatch() { | |||
| for (auto &node_user : node_users) { | |||
| if (IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) { | |||
| auto data_users = manager_->node_users()[node_user.first]; | |||
| auto node_first = data_users.front().first; | |||
| if (!IsPrimitiveCNode(node_first, prim::kPrimStridedSlice)) { | |||
| data_users.clear(); | |||
| data_users = node_user_map[node_first]; | |||
| } | |||
| auto micro_size = int64_t(data_users.size()); | |||
| micro_size_ = micro_size; | |||
| MS_LOG(INFO) << "Micro Size is: " << micro_size; | |||
| @@ -690,7 +695,10 @@ std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> PipelineTransformer: | |||
| auto depend_op = CreatOpInstance(depend_attrs, DEPEND, DEPEND); | |||
| std::vector<AnfNodePtr> receive_ops; | |||
| std::vector<AnfNodePtr> send_ops; | |||
| auto all_nodes = graph->nodes(); | |||
| auto ret = graph->get_return(); | |||
| MS_EXCEPTION_IF_NULL(ret); | |||
| std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret); | |||
| std::reverse(all_nodes.begin(), all_nodes.end()); | |||
| auto stage_num = g_device_manager->stage_num(); | |||
| if (root_->has_flag(TRAINING) && (stage_num > micro_size_)) { | |||
| MS_LOG(EXCEPTION) << "MicroBatch size: " << micro_size_ << " can't less than stage num: " << stage_num; | |||
| @@ -3045,6 +3045,9 @@ void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes, const FuncGrap | |||
| } | |||
| std::string cloned_param_name = cloned_parameter_node->cast<ParameterPtr>()->name(); | |||
| auto cloned_param_layout = cloned_parameter_node->user_data<TensorLayout>(); | |||
| if (cloned_param_layout == nullptr) { | |||
| continue; | |||
| } | |||
| tensor_info_map[cloned_param_name] = cloned_param_layout; | |||
| } | |||
| if (StrategyCheckpoint::GetInstance().Save(stra_map, tensor_info_map, &manual_shape_map) != SUCCESS) { | |||
| @@ -161,6 +161,7 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInf | |||
| } | |||
| for (auto &node_tensor_info : tensor_info_map) { | |||
| TensorLayoutPtr tensor_layout = node_tensor_info.second; | |||
| MS_EXCEPTION_IF_NULL(tensor_layout); | |||
| straspb::ParallelLayoutItem *parallel_layout_item = parallel_strategy_map.add_parallel_layout_item(); | |||
| MS_EXCEPTION_IF_NULL(parallel_layout_item); | |||
| parallel_layout_item->set_param_name(node_tensor_info.first); | |||
| @@ -18,7 +18,7 @@ Wrap cells for networks. | |||
| Use the Wrapper to combine the loss or build the training steps. | |||
| """ | |||
| from .cell_wrapper import ForwardValueAndGrad, TrainOneStepCell, WithLossCell, WithGradCell, WithEvalCell, \ | |||
| ParameterUpdate, GetNextSingleOp, VirtualDatasetCellTriple | |||
| ParameterUpdate, GetNextSingleOp, VirtualDatasetCellTriple, PipelineCell | |||
| from .loss_scale import TrainOneStepWithLossScaleCell, DynamicLossScaleUpdateCell, FixedLossScaleUpdateCell | |||
| from .grad_reducer import DistributedGradReducer | |||
| from ..layer.timedistributed import TimeDistributed | |||
| @@ -29,6 +29,7 @@ __all__ = [ | |||
| "TrainOneStepCell", | |||
| "WithLossCell", | |||
| "WithGradCell", | |||
| "PipelineCell", | |||
| "WithEvalCell", | |||
| "GetNextSingleOp", | |||
| "TrainOneStepWithLossScaleCell", | |||
| @@ -15,6 +15,7 @@ | |||
| """Loss scale cell for loss scale training.""" | |||
| import mindspore.context as context | |||
| from mindspore.context import ParallelMode | |||
| from mindspore.parallel._utils import _get_enable_parallel_optimizer | |||
| from .cell_wrapper import TrainOneStepCell | |||
| from ..cell import Cell | |||
| from ...common import Tensor, RowTensor | |||
| @@ -430,3 +431,100 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell): | |||
| if self.loss_scaling_manager is not None: | |||
| return self.loss_scaling_manager(self.scale_sense, overflow) | |||
| return overflow | |||
| grad_scale = C.MultitypeFuncGraph("grad_scale") | |||
| shard_grad_scale = C.MultitypeFuncGraph("shard_grad_scale") | |||
| reciprocal = P.Reciprocal() | |||
| @grad_scale.register("Tensor", "Tensor", "Tensor") | |||
| def tensor_grad_scale_pipeline(scale, grad, accu_grad): | |||
| accu_grad = F.depend(accu_grad, grad) | |||
| new_grad = accu_grad * reciprocal(scale) | |||
| accu_grad = F.depend(accu_grad, new_grad) | |||
| zeros = F.tensor_mul(accu_grad, 0.0) | |||
| new_grad = F.depend(new_grad, F.assign(accu_grad, zeros)) | |||
| return new_grad | |||
| @shard_grad_scale.register("Tensor", "Tensor", "Tensor") | |||
| def tensor_shard_grad_scale_pipeline(scale, grad, accu_grad): | |||
| new_grad = grad * reciprocal(scale) | |||
| accu_grad = F.depend(accu_grad, new_grad) | |||
| new_grad = F.depend(new_grad, F.assign(accu_grad, F.zeros_like(accu_grad))) | |||
| return new_grad | |||
| class _TrainPipelineWithLossScaleCell(TrainOneStepCell): | |||
| """ | |||
| Append an optimizer to the training network after that the construct | |||
| function can be called to create the backward graph. | |||
| Args: | |||
| network (Cell): The training network. Note that loss function should have been added. | |||
| optimizer (Optimizer): Optimizer for updating the weights. | |||
| scale_sense (Cell): Cell to do the loss scale. | |||
| """ | |||
| def __init__(self, network, optimizer, scale_sense): | |||
| super(_TrainPipelineWithLossScaleCell, self).__init__(network, optimizer, sens=None) | |||
| self.network = network | |||
| self.network.add_flags(defer_inline=True) | |||
| self.weights = optimizer.parameters | |||
| self.accu_grads = self.weights.clone(prefix="accu_grads", init="zeros") | |||
| self.optimizer = optimizer | |||
| self.grad = C.GradOperation(get_by_list=True, sens_param=True) | |||
| self.grad_reducer = F.identity | |||
| self.degree = 1 | |||
| self.cast = P.Cast() | |||
| self.alloc_status = P.NPUAllocFloatStatus() | |||
| self.get_status = P.NPUGetFloatStatus() | |||
| self.clear_before_grad = P.NPUClearFloatStatus() | |||
| self.reduce_sum = P.ReduceSum(keep_dims=False) | |||
| self.base = Tensor(1, mstype.float32) | |||
| self.less_equal = P.LessEqual() | |||
| self.hyper_map = C.HyperMap() | |||
| self.reshape = P.Reshape() | |||
| self.loss_scaling_manager = None | |||
| if isinstance(scale_sense, Cell): | |||
| self.loss_scaling_manager = scale_sense | |||
| self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), dtype=mstype.float32), | |||
| name="scale_sense") | |||
| elif isinstance(scale_sense, Tensor): | |||
| if scale_sense.shape == (1,) or scale_sense.shape == (): | |||
| self.scale_sense = Parameter(scale_sense, name='scale_sense') | |||
| else: | |||
| raise ValueError("The shape of scale_sense must be (1,) or (), but got {}".format(scale_sense.shape)) | |||
| else: | |||
| raise TypeError("The scale_sense must be Cell or Tensor, but got {}".format(type(scale_sense))) | |||
| self.opt_shard = _get_enable_parallel_optimizer() | |||
| def construct(self, *inputs): | |||
| weights = self.weights | |||
| loss = self.network(*inputs) | |||
| scaling_sens = self.scale_sense | |||
| init = self.alloc_status() | |||
| status_clear = self.clear_before_grad(init) | |||
| scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss)) | |||
| grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled) | |||
| init = F.depend(init, grads) | |||
| get_status = self.get_status(init) | |||
| init = F.depend(init, get_status) | |||
| flag_sum = self.reduce_sum(init, (0,)) | |||
| loss = F.depend(loss, status_clear) | |||
| if self.opt_shard: | |||
| grads = self.grad_reducer(grads) | |||
| grads = self.hyper_map(F.partial(shard_grad_scale, scaling_sens * self.degree), grads, self.accu_grads) | |||
| else: | |||
| accu_grads = self.grad_reducer(self.accu_grads) | |||
| grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads, accu_grads) | |||
| cond = self.less_equal(self.base, flag_sum) | |||
| overflow = cond | |||
| if self.loss_scaling_manager is not None: | |||
| overflow = self.loss_scaling_manager(self.scale_sense, cond) | |||
| if overflow: | |||
| succ = False | |||
| else: | |||
| succ = self.optimizer(grads) | |||
| ret = (loss, overflow, scaling_sens) | |||
| return F.depend(ret, succ) | |||
| @@ -19,6 +19,7 @@ from .._checkparam import Rel | |||
| from ..common import dtype as mstype | |||
| from ..nn import acc | |||
| from ..nn.wrap.cell_wrapper import _VirtualDatasetCell, _TrainPipelineAccuStepCell | |||
| from ..nn.wrap.loss_scale import _TrainPipelineWithLossScaleCell | |||
| from ..ops import functional as F | |||
| from ..parallel._utils import _get_parallel_mode, _get_pipeline_stages | |||
| from .loss_scale_manager import DynamicLossScaleManager, LossScaleManager | |||
| @@ -184,8 +185,12 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs): | |||
| raise ValueError("Only `loss_scale_manager=None` or " | |||
| "`loss_scale_manager=FixedLossScaleManager(drop_overflow_update=False)`" | |||
| "are supported on device `CPU`. ") | |||
| network = nn.TrainOneStepWithLossScaleCell(network, optimizer, | |||
| scale_sense=update_cell).set_train() | |||
| if _get_pipeline_stages() > 1: | |||
| network = _TrainPipelineWithLossScaleCell(network, optimizer, | |||
| scale_sense=update_cell).set_train() | |||
| else: | |||
| network = nn.TrainOneStepWithLossScaleCell(network, optimizer, | |||
| scale_sense=update_cell).set_train() | |||
| return network | |||
| if _get_pipeline_stages() > 1: | |||
| network = _TrainPipelineAccuStepCell(network, optimizer).set_train() | |||