Merge pull request !4744 from lichen/support_berttags/v0.7.0-beta
| @@ -610,6 +610,15 @@ Status MatMulBase::CheckForTensorSliceValid() const { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| std::shared_ptr<Strategys> BatchMatMulInfo::GenerateBatchStrategies() { | |||||
| CheckGlobalDeviceManager(); | |||||
| size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); | |||||
| Dimensions batch_strategy(inputs_shape_[1].size() - 1, 1); | |||||
| batch_strategy.insert(batch_strategy.begin(), SizeToLong(dev_num)); | |||||
| Strategys strategy_v = {batch_strategy, batch_strategy}; | |||||
| return std::make_shared<Strategys>(strategy_v); | |||||
| } | |||||
| Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) { | Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) { | ||||
| if (InitForCostModel(strategy) == FAILED) { | if (InitForCostModel(strategy) == FAILED) { | ||||
| if (is_auto_parallel_) { | if (is_auto_parallel_) { | ||||
| @@ -91,6 +91,8 @@ class BatchMatMulInfo : public MatMul { | |||||
| const PrimitiveAttrs &attrs) | const PrimitiveAttrs &attrs) | ||||
| : MatMul(name, inputs_shape, outputs_shape, attrs) {} | : MatMul(name, inputs_shape, outputs_shape, attrs) {} | ||||
| ~BatchMatMulInfo() override = default; | ~BatchMatMulInfo() override = default; | ||||
| std::shared_ptr<Strategys> GenerateBatchStrategies() override; | |||||
| }; | }; | ||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -162,6 +162,7 @@ constexpr char SIGMOID_CROSS_ENTROPY_WITH_LOGITS[] = "SigmoidCrossEntropyWithLog | |||||
| constexpr char MATMUL[] = "MatMul"; | constexpr char MATMUL[] = "MatMul"; | ||||
| constexpr char GELU[] = "Gelu"; | constexpr char GELU[] = "Gelu"; | ||||
| constexpr char TANH[] = "Tanh"; | constexpr char TANH[] = "Tanh"; | ||||
| constexpr char SHAPE_OP[] = "Shape"; | |||||
| constexpr char SOFTMAX[] = "Softmax"; | constexpr char SOFTMAX[] = "Softmax"; | ||||
| constexpr char LOG_SOFTMAX[] = "LogSoftmax"; | constexpr char LOG_SOFTMAX[] = "LogSoftmax"; | ||||
| constexpr char ACTIVATION[] = "Activation"; | constexpr char ACTIVATION[] = "Activation"; | ||||
| @@ -1673,6 +1673,41 @@ std::shared_ptr<TensorLayout> CreateParameterLayout(const AnfNodePtr &node) { | |||||
| return std::make_shared<TensorLayout>(input_tensor_layout); | return std::make_shared<TensorLayout>(input_tensor_layout); | ||||
| } | } | ||||
| RedistributionOpListPtr InferSensRedistribution(const AnfNodePtr &node, const TensorLayout &loss_layout) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| TensorRedistribution tensor_redistribution; | |||||
| // create stand alone layout:TensorMap:[all -1],dev_matrix:[dev_num]. | |||||
| CheckGlobalDeviceManager(); | |||||
| int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(0).size()); | |||||
| TensorLayout stand_alone_layout; | |||||
| Shapes inputs_shape = GetNodeShape(node); | |||||
| if (inputs_shape.empty()) { | |||||
| MS_LOG(EXCEPTION) << "InferSensRedistribution failed cause inputs shape is empty."; | |||||
| } | |||||
| Shape input_shape_array = inputs_shape[0]; | |||||
| if (input_shape_array.empty()) { | |||||
| MS_LOG(INFO) << "No need to redistribution for sens."; | |||||
| return nullptr; | |||||
| } | |||||
| // TensorMap | |||||
| TensorMap stand_alone_tensor_map_array(SizeToInt(input_shape_array.size()), -1); | |||||
| // Dev_matrix | |||||
| Shape dev_matrix_array = {dev_num}; | |||||
| if (stand_alone_layout.InitFromVector(dev_matrix_array, stand_alone_tensor_map_array, input_shape_array) == FAILED) { | |||||
| MS_LOG(EXCEPTION) << "Create tensor layout for Sens failed."; | |||||
| } | |||||
| // Infer Redistribution op list for stand alone and loss layout. | |||||
| RankList dev_list = g_device_manager->GetDeviceListByStageId(0); | |||||
| if (tensor_redistribution.Init(stand_alone_layout, loss_layout, dev_list) == FAILED) { | |||||
| MS_LOG(EXCEPTION) << "Redistribution for Sens init failed."; | |||||
| } | |||||
| RedistributionOpListPtr sens_redistribution_list = tensor_redistribution.InferTensorRedistributionOperatorList(); | |||||
| MS_EXCEPTION_IF_NULL(sens_redistribution_list); | |||||
| return sens_redistribution_list; | |||||
| } | |||||
| std::shared_ptr<TensorLayout> FindPrevLayout(const AnfNodePtr &node) { | std::shared_ptr<TensorLayout> FindPrevLayout(const AnfNodePtr &node) { | ||||
| if (node->isa<Parameter>()) { | if (node->isa<Parameter>()) { | ||||
| return CreateParameterLayout(node); | return CreateParameterLayout(node); | ||||
| @@ -1897,7 +1932,18 @@ void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_lay | |||||
| sens_tensor_param->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout)); | sens_tensor_param->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout)); | ||||
| return; | return; | ||||
| } | } | ||||
| MS_LOG(EXCEPTION) << "The type of sens node is not Tensor or Parameter, it is unsupported now."; | |||||
| if (sens_tensor_node->isa<CNode>()) { | |||||
| auto op_list_ptr = InferSensRedistribution(sens_tensor_node, loss_grad_layout); | |||||
| if (op_list_ptr == nullptr) { | |||||
| return; | |||||
| } | |||||
| auto sens_tensor_cnode = sens_tensor_node->cast<CNodePtr>(); | |||||
| auto func_graph = grad_sens_node->func_graph(); | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| InsertRedistribution(op_list_ptr, grad_sens_node, func_graph, 1, sens_tensor_cnode); | |||||
| return; | |||||
| } | |||||
| MS_LOG(EXCEPTION) << "The type of sens node is not Tensor or Parameter or CNode, it is unsupported now."; | |||||
| } | } | ||||
| // Use _GetTensorSlice operator to split the sens tensor | // Use _GetTensorSlice operator to split the sens tensor | ||||
| @@ -2305,6 +2351,41 @@ std::vector<AnfNodePtr> FindRootForwardCNode(const FuncGraphPtr &graph, const An | |||||
| return root_forward_nodes; | return root_forward_nodes; | ||||
| } | } | ||||
| void InsertShapeOp(const CNodePtr &node, const AnfNodePtr &pre_node, const FuncGraphPtr &root) { | |||||
| // shape op doesn't have params and attrs. | |||||
| OperatorParams params; | |||||
| OperatorAttrs attrs; | |||||
| OperatorArgs args = std::make_pair(attrs, params); | |||||
| Operator op = std::make_pair(SHAPE_OP, args); | |||||
| InsertNode(op, node, 2, pre_node, root, "shape"); | |||||
| } | |||||
| void HandleRootReshape(const std::vector<AnfNodePtr> &all_nodes) { | |||||
| // If root graph has reshape op. Find the corresponding parameter. | |||||
| // Reshape's shape is the shape of the parameter. | |||||
| for (auto &node : all_nodes) { | |||||
| if (!node->isa<CNode>()) { | |||||
| continue; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| if (!IsValueNode<Primitive>(cnode->input(0)) || cnode->in_forward_flag()) { | |||||
| continue; | |||||
| } | |||||
| auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||||
| if (prim->name() != RESHAPE) { | |||||
| continue; | |||||
| } | |||||
| auto root = node->func_graph(); | |||||
| auto all_dfs_nodes = DeepLinkedGraphSearch(node); | |||||
| for (auto r_iter = all_dfs_nodes.rbegin(); r_iter != all_dfs_nodes.rend(); ++r_iter) { | |||||
| if ((*r_iter)->isa<Parameter>()) { | |||||
| InsertShapeOp(cnode, *r_iter, root); | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| void MarkForwardCNode(const FuncGraphPtr &root) { | void MarkForwardCNode(const FuncGraphPtr &root) { | ||||
| MS_EXCEPTION_IF_NULL(root); | MS_EXCEPTION_IF_NULL(root); | ||||
| auto all_nodes = root->nodes(); | auto all_nodes = root->nodes(); | ||||
| @@ -2456,6 +2537,7 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) | |||||
| // mark the forward cnodes, parallel only care these nodes | // mark the forward cnodes, parallel only care these nodes | ||||
| MarkForwardCNode(root); | MarkForwardCNode(root); | ||||
| HandleRootReshape(all_nodes); | |||||
| if (FindCommunicationOp(all_nodes)) { | if (FindCommunicationOp(all_nodes)) { | ||||
| MS_LOG(EXCEPTION) << "The graph contain communication op"; | MS_LOG(EXCEPTION) << "The graph contain communication op"; | ||||
| @@ -0,0 +1,177 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| import numpy as np | |||||
| import mindspore as ms | |||||
| from mindspore import Tensor | |||||
| from mindspore import context | |||||
| from mindspore.common.parameter import Parameter | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops import functional as F | |||||
| from mindspore.nn.optim.momentum import Momentum | |||||
| from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell | |||||
| import mindspore.nn as nn | |||||
| from mindspore.train import Model, ParallelMode | |||||
| from tests.dataset_mock import MindData | |||||
| GRADIENT_CLIP_TYPE = 1 | |||||
| GRADIENT_CLIP_VALUE = 1.0 | |||||
| clip_grad = C.MultitypeFuncGraph("clip_grad") | |||||
| grad_scale = C.MultitypeFuncGraph("grad_scale") | |||||
| reciprocal = P.Reciprocal() | |||||
| @grad_scale.register("Tensor", "Tensor") | |||||
| def tensor_grad_scale(scale, grad): | |||||
| return grad * reciprocal(scale) | |||||
| update_cell = DynamicLossScaleUpdateCell(loss_scale_value=65536, scale_factor=2, scale_window=1000) | |||||
| @clip_grad.register("Number", "Number", "Tensor") | |||||
| def _clip_grad(clip_type, clip_value, grad): | |||||
| dt = F.dtype(grad) | |||||
| if clip_type == 0: | |||||
| new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt), | |||||
| F.cast(F.tuple_to_array((clip_value,)), dt)) | |||||
| else: | |||||
| new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt)) | |||||
| return new_grad | |||||
| class TrainOneStepWithLossScaleCell(nn.Cell): | |||||
| def __init__(self, network, optimizer, scale_update_cell=None): | |||||
| super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) | |||||
| self.network = network | |||||
| self.weights = optimizer.parameters | |||||
| self.optimizer = optimizer | |||||
| self.grad = C.GradOperation('grad', | |||||
| get_by_list=True, | |||||
| sens_param=True) | |||||
| self.reducer_flag = False | |||||
| self.grad_reducer = F.identity | |||||
| self.cast = P.Cast() | |||||
| self.alloc_status = P.NPUAllocFloatStatus() | |||||
| self.get_status = P.NPUGetFloatStatus() | |||||
| self.clear_before_grad = P.NPUClearFloatStatus() | |||||
| self.reduce_sum = P.ReduceSum(keep_dims=False) | |||||
| self.depend_parameter_use = P.ControlDepend(depend_mode=1) | |||||
| self.base = Tensor(1, mstype.float32) | |||||
| self.less_equal = P.LessEqual() | |||||
| self.hyper_map = C.HyperMap() | |||||
| self.loss_scale = None | |||||
| self.loss_scaling_manager = scale_update_cell | |||||
| if scale_update_cell: | |||||
| self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), | |||||
| name="loss_scale") | |||||
| @C.add_flags(has_effect=True) | |||||
| def construct(self, x, sens=None): | |||||
| """Defines the computation performed.""" | |||||
| weights = self.weights | |||||
| loss = self.network(x) | |||||
| if sens is None: | |||||
| scaling_sens = self.loss_scale | |||||
| else: | |||||
| scaling_sens = sens | |||||
| # alloc status and clear should be right before gradoperation | |||||
| init = self.alloc_status() | |||||
| self.clear_before_grad(init) | |||||
| grads = self.grad(self.network, weights)(x, self.cast(scaling_sens, mstype.float32)) | |||||
| # apply grad reducer on grads | |||||
| grads = self.grad_reducer(grads) | |||||
| grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) | |||||
| self.get_status(init) | |||||
| flag_sum = self.reduce_sum(init, (0,)) | |||||
| cond = self.less_equal(self.base, flag_sum) | |||||
| overflow = cond | |||||
| if sens is None: | |||||
| overflow = self.loss_scaling_manager(self.loss_scale, cond) | |||||
| if overflow: | |||||
| succ = False | |||||
| else: | |||||
| succ = self.optimizer(grads) | |||||
| ret = (loss, cond, scaling_sens) | |||||
| return F.depend(ret, succ) | |||||
| class DatasetLenet(MindData): | |||||
| def __init__(self, predict, label, length=3): | |||||
| super(DatasetLenet, self).__init__() | |||||
| self.predict = predict | |||||
| self.label = label | |||||
| self.index = 0 | |||||
| self.length = length | |||||
| def __iter__(self): | |||||
| return self | |||||
| def __next__(self): | |||||
| if self.index >= self.length: | |||||
| raise StopIteration | |||||
| self.index += 1 | |||||
| return self.predict, self.label | |||||
| def reset(self): | |||||
| self.index = 0 | |||||
| class LoopLayer(nn.Cell): | |||||
| def __init__(self): | |||||
| super(LoopLayer, self).__init__() | |||||
| self.matmul = P.MatMul() | |||||
| self.relu = P.ReLU() | |||||
| self.matmul_weight = Parameter(Tensor(np.ones([64, 64]), dtype=ms.float32), name="weight") | |||||
| def construct(self, x): | |||||
| out = self.matmul(x, self.matmul_weight) | |||||
| out = self.relu(out) | |||||
| return out | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.exp = P.Exp() | |||||
| self.mean = P.ReduceMean() | |||||
| layers = [] | |||||
| for _ in range(3): | |||||
| layer = LoopLayer() | |||||
| layers.append(layer) | |||||
| self.layers = nn.CellList(layers) | |||||
| def construct(self, x): | |||||
| out = self.exp(x) | |||||
| for layer in self.layers: | |||||
| layer_out = layer(out) | |||||
| out = layer_out | |||||
| out = self.mean(out, -1) | |||||
| return out | |||||
| def test_loss_scale(): | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=8) | |||||
| predict = Tensor(np.ones([64, 64]), dtype=ms.float32) | |||||
| label = Tensor(np.ones([64,]), dtype=ms.int32) | |||||
| dataset = DatasetLenet(predict, label) | |||||
| net = Net() | |||||
| opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9) | |||||
| net = TrainOneStepWithLossScaleCell(net, opt, update_cell) | |||||
| model = Model(network=net) | |||||
| model.train(2, dataset, dataset_sink_mode=False) | |||||