Merge pull request !4744 from lichen/support_berttags/v0.7.0-beta
| @@ -610,6 +610,15 @@ Status MatMulBase::CheckForTensorSliceValid() const { | |||
| return SUCCESS; | |||
| } | |||
| std::shared_ptr<Strategys> BatchMatMulInfo::GenerateBatchStrategies() { | |||
| CheckGlobalDeviceManager(); | |||
| size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); | |||
| Dimensions batch_strategy(inputs_shape_[1].size() - 1, 1); | |||
| batch_strategy.insert(batch_strategy.begin(), SizeToLong(dev_num)); | |||
| Strategys strategy_v = {batch_strategy, batch_strategy}; | |||
| return std::make_shared<Strategys>(strategy_v); | |||
| } | |||
| Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) { | |||
| if (InitForCostModel(strategy) == FAILED) { | |||
| if (is_auto_parallel_) { | |||
| @@ -91,6 +91,8 @@ class BatchMatMulInfo : public MatMul { | |||
| const PrimitiveAttrs &attrs) | |||
| : MatMul(name, inputs_shape, outputs_shape, attrs) {} | |||
| ~BatchMatMulInfo() override = default; | |||
| std::shared_ptr<Strategys> GenerateBatchStrategies() override; | |||
| }; | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -162,6 +162,7 @@ constexpr char SIGMOID_CROSS_ENTROPY_WITH_LOGITS[] = "SigmoidCrossEntropyWithLog | |||
| constexpr char MATMUL[] = "MatMul"; | |||
| constexpr char GELU[] = "Gelu"; | |||
| constexpr char TANH[] = "Tanh"; | |||
| constexpr char SHAPE_OP[] = "Shape"; | |||
| constexpr char SOFTMAX[] = "Softmax"; | |||
| constexpr char LOG_SOFTMAX[] = "LogSoftmax"; | |||
| constexpr char ACTIVATION[] = "Activation"; | |||
| @@ -1673,6 +1673,41 @@ std::shared_ptr<TensorLayout> CreateParameterLayout(const AnfNodePtr &node) { | |||
| return std::make_shared<TensorLayout>(input_tensor_layout); | |||
| } | |||
| RedistributionOpListPtr InferSensRedistribution(const AnfNodePtr &node, const TensorLayout &loss_layout) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| TensorRedistribution tensor_redistribution; | |||
| // create stand alone layout:TensorMap:[all -1],dev_matrix:[dev_num]. | |||
| CheckGlobalDeviceManager(); | |||
| int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(0).size()); | |||
| TensorLayout stand_alone_layout; | |||
| Shapes inputs_shape = GetNodeShape(node); | |||
| if (inputs_shape.empty()) { | |||
| MS_LOG(EXCEPTION) << "InferSensRedistribution failed cause inputs shape is empty."; | |||
| } | |||
| Shape input_shape_array = inputs_shape[0]; | |||
| if (input_shape_array.empty()) { | |||
| MS_LOG(INFO) << "No need to redistribution for sens."; | |||
| return nullptr; | |||
| } | |||
| // TensorMap | |||
| TensorMap stand_alone_tensor_map_array(SizeToInt(input_shape_array.size()), -1); | |||
| // Dev_matrix | |||
| Shape dev_matrix_array = {dev_num}; | |||
| if (stand_alone_layout.InitFromVector(dev_matrix_array, stand_alone_tensor_map_array, input_shape_array) == FAILED) { | |||
| MS_LOG(EXCEPTION) << "Create tensor layout for Sens failed."; | |||
| } | |||
| // Infer Redistribution op list for stand alone and loss layout. | |||
| RankList dev_list = g_device_manager->GetDeviceListByStageId(0); | |||
| if (tensor_redistribution.Init(stand_alone_layout, loss_layout, dev_list) == FAILED) { | |||
| MS_LOG(EXCEPTION) << "Redistribution for Sens init failed."; | |||
| } | |||
| RedistributionOpListPtr sens_redistribution_list = tensor_redistribution.InferTensorRedistributionOperatorList(); | |||
| MS_EXCEPTION_IF_NULL(sens_redistribution_list); | |||
| return sens_redistribution_list; | |||
| } | |||
| std::shared_ptr<TensorLayout> FindPrevLayout(const AnfNodePtr &node) { | |||
| if (node->isa<Parameter>()) { | |||
| return CreateParameterLayout(node); | |||
| @@ -1897,7 +1932,18 @@ void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_lay | |||
| sens_tensor_param->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout)); | |||
| return; | |||
| } | |||
| MS_LOG(EXCEPTION) << "The type of sens node is not Tensor or Parameter, it is unsupported now."; | |||
| if (sens_tensor_node->isa<CNode>()) { | |||
| auto op_list_ptr = InferSensRedistribution(sens_tensor_node, loss_grad_layout); | |||
| if (op_list_ptr == nullptr) { | |||
| return; | |||
| } | |||
| auto sens_tensor_cnode = sens_tensor_node->cast<CNodePtr>(); | |||
| auto func_graph = grad_sens_node->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| InsertRedistribution(op_list_ptr, grad_sens_node, func_graph, 1, sens_tensor_cnode); | |||
| return; | |||
| } | |||
| MS_LOG(EXCEPTION) << "The type of sens node is not Tensor or Parameter or CNode, it is unsupported now."; | |||
| } | |||
| // Use _GetTensorSlice operator to split the sens tensor | |||
| @@ -2305,6 +2351,41 @@ std::vector<AnfNodePtr> FindRootForwardCNode(const FuncGraphPtr &graph, const An | |||
| return root_forward_nodes; | |||
| } | |||
| void InsertShapeOp(const CNodePtr &node, const AnfNodePtr &pre_node, const FuncGraphPtr &root) { | |||
| // shape op doesn't have params and attrs. | |||
| OperatorParams params; | |||
| OperatorAttrs attrs; | |||
| OperatorArgs args = std::make_pair(attrs, params); | |||
| Operator op = std::make_pair(SHAPE_OP, args); | |||
| InsertNode(op, node, 2, pre_node, root, "shape"); | |||
| } | |||
| void HandleRootReshape(const std::vector<AnfNodePtr> &all_nodes) { | |||
| // If root graph has reshape op. Find the corresponding parameter. | |||
| // Reshape's shape is the shape of the parameter. | |||
| for (auto &node : all_nodes) { | |||
| if (!node->isa<CNode>()) { | |||
| continue; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (!IsValueNode<Primitive>(cnode->input(0)) || cnode->in_forward_flag()) { | |||
| continue; | |||
| } | |||
| auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| if (prim->name() != RESHAPE) { | |||
| continue; | |||
| } | |||
| auto root = node->func_graph(); | |||
| auto all_dfs_nodes = DeepLinkedGraphSearch(node); | |||
| for (auto r_iter = all_dfs_nodes.rbegin(); r_iter != all_dfs_nodes.rend(); ++r_iter) { | |||
| if ((*r_iter)->isa<Parameter>()) { | |||
| InsertShapeOp(cnode, *r_iter, root); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void MarkForwardCNode(const FuncGraphPtr &root) { | |||
| MS_EXCEPTION_IF_NULL(root); | |||
| auto all_nodes = root->nodes(); | |||
| @@ -2456,6 +2537,7 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) | |||
| // mark the forward cnodes, parallel only care these nodes | |||
| MarkForwardCNode(root); | |||
| HandleRootReshape(all_nodes); | |||
| if (FindCommunicationOp(all_nodes)) { | |||
| MS_LOG(EXCEPTION) << "The graph contain communication op"; | |||
| @@ -0,0 +1,177 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import numpy as np | |||
| import mindspore as ms | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import functional as F | |||
| from mindspore.nn.optim.momentum import Momentum | |||
| from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell | |||
| import mindspore.nn as nn | |||
| from mindspore.train import Model, ParallelMode | |||
| from tests.dataset_mock import MindData | |||
| GRADIENT_CLIP_TYPE = 1 | |||
| GRADIENT_CLIP_VALUE = 1.0 | |||
| clip_grad = C.MultitypeFuncGraph("clip_grad") | |||
| grad_scale = C.MultitypeFuncGraph("grad_scale") | |||
| reciprocal = P.Reciprocal() | |||
| @grad_scale.register("Tensor", "Tensor") | |||
| def tensor_grad_scale(scale, grad): | |||
| return grad * reciprocal(scale) | |||
| update_cell = DynamicLossScaleUpdateCell(loss_scale_value=65536, scale_factor=2, scale_window=1000) | |||
| @clip_grad.register("Number", "Number", "Tensor") | |||
| def _clip_grad(clip_type, clip_value, grad): | |||
| dt = F.dtype(grad) | |||
| if clip_type == 0: | |||
| new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt), | |||
| F.cast(F.tuple_to_array((clip_value,)), dt)) | |||
| else: | |||
| new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt)) | |||
| return new_grad | |||
| class TrainOneStepWithLossScaleCell(nn.Cell): | |||
| def __init__(self, network, optimizer, scale_update_cell=None): | |||
| super(TrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) | |||
| self.network = network | |||
| self.weights = optimizer.parameters | |||
| self.optimizer = optimizer | |||
| self.grad = C.GradOperation('grad', | |||
| get_by_list=True, | |||
| sens_param=True) | |||
| self.reducer_flag = False | |||
| self.grad_reducer = F.identity | |||
| self.cast = P.Cast() | |||
| self.alloc_status = P.NPUAllocFloatStatus() | |||
| self.get_status = P.NPUGetFloatStatus() | |||
| self.clear_before_grad = P.NPUClearFloatStatus() | |||
| self.reduce_sum = P.ReduceSum(keep_dims=False) | |||
| self.depend_parameter_use = P.ControlDepend(depend_mode=1) | |||
| self.base = Tensor(1, mstype.float32) | |||
| self.less_equal = P.LessEqual() | |||
| self.hyper_map = C.HyperMap() | |||
| self.loss_scale = None | |||
| self.loss_scaling_manager = scale_update_cell | |||
| if scale_update_cell: | |||
| self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), | |||
| name="loss_scale") | |||
| @C.add_flags(has_effect=True) | |||
| def construct(self, x, sens=None): | |||
| """Defines the computation performed.""" | |||
| weights = self.weights | |||
| loss = self.network(x) | |||
| if sens is None: | |||
| scaling_sens = self.loss_scale | |||
| else: | |||
| scaling_sens = sens | |||
| # alloc status and clear should be right before gradoperation | |||
| init = self.alloc_status() | |||
| self.clear_before_grad(init) | |||
| grads = self.grad(self.network, weights)(x, self.cast(scaling_sens, mstype.float32)) | |||
| # apply grad reducer on grads | |||
| grads = self.grad_reducer(grads) | |||
| grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) | |||
| self.get_status(init) | |||
| flag_sum = self.reduce_sum(init, (0,)) | |||
| cond = self.less_equal(self.base, flag_sum) | |||
| overflow = cond | |||
| if sens is None: | |||
| overflow = self.loss_scaling_manager(self.loss_scale, cond) | |||
| if overflow: | |||
| succ = False | |||
| else: | |||
| succ = self.optimizer(grads) | |||
| ret = (loss, cond, scaling_sens) | |||
| return F.depend(ret, succ) | |||
| class DatasetLenet(MindData): | |||
| def __init__(self, predict, label, length=3): | |||
| super(DatasetLenet, self).__init__() | |||
| self.predict = predict | |||
| self.label = label | |||
| self.index = 0 | |||
| self.length = length | |||
| def __iter__(self): | |||
| return self | |||
| def __next__(self): | |||
| if self.index >= self.length: | |||
| raise StopIteration | |||
| self.index += 1 | |||
| return self.predict, self.label | |||
| def reset(self): | |||
| self.index = 0 | |||
| class LoopLayer(nn.Cell): | |||
| def __init__(self): | |||
| super(LoopLayer, self).__init__() | |||
| self.matmul = P.MatMul() | |||
| self.relu = P.ReLU() | |||
| self.matmul_weight = Parameter(Tensor(np.ones([64, 64]), dtype=ms.float32), name="weight") | |||
| def construct(self, x): | |||
| out = self.matmul(x, self.matmul_weight) | |||
| out = self.relu(out) | |||
| return out | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.exp = P.Exp() | |||
| self.mean = P.ReduceMean() | |||
| layers = [] | |||
| for _ in range(3): | |||
| layer = LoopLayer() | |||
| layers.append(layer) | |||
| self.layers = nn.CellList(layers) | |||
| def construct(self, x): | |||
| out = self.exp(x) | |||
| for layer in self.layers: | |||
| layer_out = layer(out) | |||
| out = layer_out | |||
| out = self.mean(out, -1) | |||
| return out | |||
| def test_loss_scale(): | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=8) | |||
| predict = Tensor(np.ones([64, 64]), dtype=ms.float32) | |||
| label = Tensor(np.ones([64,]), dtype=ms.int32) | |||
| dataset = DatasetLenet(predict, label) | |||
| net = Net() | |||
| opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9) | |||
| net = TrainOneStepWithLossScaleCell(net, opt, update_cell) | |||
| model = Model(network=net) | |||
| model.train(2, dataset, dataset_sink_mode=False) | |||