| @@ -649,108 +649,13 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) { | |||
| MS_LOG(INFO) << "Constructing edges for cost graph ends."; | |||
| } | |||
| std::pair<AnfNodePtr, std::vector<AnfNodePtr>> CNodeWithRefKeys(const AnfNodePtr &cnode) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| std::vector<AnfNodePtr> refkeys; | |||
| if (cnode->isa<CNode>()) { | |||
| auto cnode_ptr = cnode->cast<CNodePtr>(); | |||
| auto inputs = cnode_ptr->inputs(); | |||
| for (auto &one_input : inputs) { | |||
| if (IsValueNode<RefKey>(one_input)) { | |||
| refkeys.push_back(one_input); | |||
| } | |||
| } | |||
| if (refkeys.size() >= 1) { | |||
| return std::make_pair(cnode, refkeys); | |||
| } | |||
| } | |||
| return {nullptr, refkeys}; | |||
| } | |||
| void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) { | |||
| // Step 3 | |||
| for (auto &node : all_nodes) { | |||
| auto cnode_with_refkeys = CNodeWithRefKeys(node); | |||
| if ((!node->isa<Parameter>()) && (cnode_with_refkeys.first == nullptr)) { | |||
| continue; | |||
| } | |||
| std::string parameter_name; | |||
| AnfNodePtr target_parameter = nullptr; | |||
| AnfNodeIndexSet target_set; | |||
| if (cnode_with_refkeys.first != nullptr) { | |||
| // Dealing with the RefKey case | |||
| auto refkeys = cnode_with_refkeys.second; | |||
| auto cnode = cnode_with_refkeys.first; | |||
| auto cnode_ptr = cnode->cast<CNodePtr>(); | |||
| if (cnode_ptr == nullptr || !IsValueNode<Primitive>(cnode_ptr->input(0))) { | |||
| continue; | |||
| } | |||
| if (!IsAutoParallelCareNode(cnode_ptr)) { | |||
| continue; | |||
| } | |||
| if (refkeys.size() > 1) { | |||
| MS_LOG(EXCEPTION) << "CNode: " << cnode->fullname_with_scope() << " 's inputs have more than 1 RefKeys."; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(cnode->func_graph()); | |||
| auto cnode_func_graph = cnode->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(cnode->func_graph()->manager()); | |||
| // Find the RefKey being used | |||
| auto candidate_set_by_refkey = cnode_func_graph->manager()->node_users()[refkeys[0]]; | |||
| for (auto &candidate : candidate_set_by_refkey) { | |||
| auto candidate_node = candidate.first; | |||
| auto c = candidate_node->cast<CNodePtr>(); | |||
| if (c == nullptr || !IsValueNode<Primitive>(c->input(0))) { | |||
| continue; | |||
| } | |||
| if (!IsAutoParallelCareNode(c)) { | |||
| continue; | |||
| } | |||
| target_set.add(candidate); | |||
| } | |||
| // Find the corresponding Parameter being used | |||
| std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(refkeys[0], cnode_func_graph); | |||
| if (parameters.size() != 1) { | |||
| MS_LOG(EXCEPTION) << "Find parameter by ref key node failed"; | |||
| } | |||
| parameter_name = parameters[0]->cast<ParameterPtr>()->name(); | |||
| target_parameter = parameters[0]; | |||
| auto candidate_set_by_para = cnode_func_graph->manager()->node_users()[parameters[0]]; | |||
| for (auto &candidate : candidate_set_by_para) { | |||
| auto candidate_node = candidate.first; | |||
| auto c = candidate_node->cast<CNodePtr>(); | |||
| if (c == nullptr || !IsValueNode<Primitive>(c->input(0))) { | |||
| continue; | |||
| } | |||
| if (!IsAutoParallelCareNode(c)) { | |||
| continue; | |||
| } | |||
| (void)target_set.insert(candidate); | |||
| } | |||
| } else if (node->isa<Parameter>()) { | |||
| // Dealing with the Parameter case | |||
| MS_EXCEPTION_IF_NULL(node->func_graph()); | |||
| MS_EXCEPTION_IF_NULL(node->func_graph()->manager()); | |||
| auto candidate_set = node->func_graph()->manager()->node_users()[node]; | |||
| for (auto &candidate : candidate_set) { | |||
| auto candidate_node = candidate.first; | |||
| auto c = candidate_node->cast<CNodePtr>(); | |||
| if (c == nullptr || !IsValueNode<Primitive>(c->input(0))) { | |||
| continue; | |||
| } | |||
| if (!IsAutoParallelCareNode(c)) { | |||
| continue; | |||
| } | |||
| (void)target_set.insert(candidate); | |||
| } | |||
| // In this case, node is a Parameter | |||
| parameter_name = node->cast<ParameterPtr>()->name(); | |||
| target_parameter = node; | |||
| } | |||
| ParameterUsersInfo parameter_users_info = FindParameterUsers(node, IsAutoParallelCareNode); | |||
| auto parameter_name = parameter_users_info.first; | |||
| auto target_parameter = parameter_users_info.second.first; | |||
| auto target_set = parameter_users_info.second.second; | |||
| if (target_set.size() <= 1) { | |||
| continue; | |||
| } | |||
| @@ -2499,6 +2499,149 @@ void HandleForwardMakeTuple(const std::vector<AnfNodePtr> &all_nodes) { | |||
| } | |||
| } | |||
| RefKeyPair CNodeWithRefKeys(const AnfNodePtr &cnode) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| std::vector<AnfNodePtr> refkeys; | |||
| if (cnode->isa<CNode>()) { | |||
| auto cnode_ptr = cnode->cast<CNodePtr>(); | |||
| auto inputs = cnode_ptr->inputs(); | |||
| for (auto &one_input : inputs) { | |||
| if (IsValueNode<RefKey>(one_input)) { | |||
| refkeys.push_back(one_input); | |||
| } | |||
| } | |||
| if (refkeys.size() >= 1) { | |||
| return std::make_pair(cnode, refkeys); | |||
| } | |||
| } | |||
| return {nullptr, refkeys}; | |||
| } | |||
| ParameterUsersInfo FindParameterNodeUsers(const AnfNodePtr &node, bool (*IsCareNode)(const CNodePtr &)) { | |||
| // In this case, node is a Parameter | |||
| ParameterUsersInfo parameter_user_info; | |||
| MS_EXCEPTION_IF_NULL(node->func_graph()); | |||
| MS_EXCEPTION_IF_NULL(node->func_graph()->manager()); | |||
| auto candidate_set = node->func_graph()->manager()->node_users()[node]; | |||
| for (auto &candidate : candidate_set) { | |||
| auto candidate_node = candidate.first; | |||
| auto c = candidate_node->cast<CNodePtr>(); | |||
| if (c == nullptr || !IsValueNode<Primitive>(c->input(0)) || !IsCareNode(c)) { | |||
| continue; | |||
| } | |||
| (void)parameter_user_info.second.second.insert(candidate); | |||
| } | |||
| parameter_user_info.first = node->cast<ParameterPtr>()->name(); | |||
| parameter_user_info.second.first = node; | |||
| return parameter_user_info; | |||
| } | |||
| ParameterUsersInfo FindRefKeyNodeUsers(const RefKeyPair &ref_key_pair, bool (*IsCareNode)(const CNodePtr &)) { | |||
| // Dealing with the RefKey case | |||
| ParameterUsersInfo parameter_user_info; | |||
| auto refkeys = ref_key_pair.second; | |||
| auto cnode = ref_key_pair.first; | |||
| auto cnode_ptr = cnode->cast<CNodePtr>(); | |||
| if ((cnode_ptr == nullptr) || !IsValueNode<Primitive>(cnode_ptr->input(0)) || !IsCareNode(cnode_ptr)) { | |||
| return parameter_user_info; | |||
| } | |||
| if (refkeys.size() > 1) { | |||
| MS_LOG(EXCEPTION) << "CNode: " << cnode->fullname_with_scope() << "'s inputs have more than 1 RefKeys"; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(cnode->func_graph()); | |||
| auto cnode_func_graph = cnode->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(cnode->func_graph()->manager()); | |||
| // Find the RefKey being used | |||
| auto candidate_set_by_refkey = cnode_func_graph->manager()->node_users()[refkeys[0]]; | |||
| for (auto &candidate : candidate_set_by_refkey) { | |||
| auto candidate_node = candidate.first; | |||
| auto c = candidate_node->cast<CNodePtr>(); | |||
| if ((c == nullptr) || !IsValueNode<Primitive>(c->input(0)) || !IsCareNode(c)) { | |||
| continue; | |||
| } | |||
| parameter_user_info.second.second.add(candidate); | |||
| } | |||
| // Find the corresponding Parameter being used | |||
| std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(refkeys[0], cnode_func_graph); | |||
| if (parameters.size() != 1) { | |||
| MS_LOG(EXCEPTION) << "Find parameter by ref key node failed"; | |||
| } | |||
| parameter_user_info.first = parameters[0]->cast<ParameterPtr>()->name(); | |||
| parameter_user_info.second.first = parameters[0]; | |||
| auto candidate_set_by_para = cnode_func_graph->manager()->node_users()[parameters[0]]; | |||
| for (auto &candidate : candidate_set_by_para) { | |||
| auto candidate_node = candidate.first; | |||
| auto c = candidate_node->cast<CNodePtr>(); | |||
| if ((c == nullptr) || !IsValueNode<Primitive>(c->input(0)) || !IsCareNode(c)) { | |||
| continue; | |||
| } | |||
| (void)parameter_user_info.second.second.insert(candidate); | |||
| } | |||
| return parameter_user_info; | |||
| } | |||
| ParameterUsersInfo FindParameterUsers(const AnfNodePtr &node, bool (*IsCareNode)(const CNodePtr &)) { | |||
| ParameterUsersInfo parameter_users_info; | |||
| auto cnode_with_refkeys = CNodeWithRefKeys(node); | |||
| if (cnode_with_refkeys.first != nullptr) { | |||
| // the node is a ref key node | |||
| return FindRefKeyNodeUsers(cnode_with_refkeys, IsCareNode); | |||
| } else if (node->isa<Parameter>()) { | |||
| // the node is a parameter node | |||
| return FindParameterNodeUsers(node, IsCareNode); | |||
| } | |||
| return parameter_users_info; | |||
| } | |||
| Shape ParameterSliceShape(const std::pair<AnfNodePtr, int> ¶m_info) { | |||
| auto user_cnode = param_info.first->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(user_cnode); | |||
| auto user_input_index = param_info.second; | |||
| OperatorInfoPtr op_info = user_cnode->user_data<OperatorInfo>(); | |||
| MS_EXCEPTION_IF_NULL(op_info); | |||
| size_t input_tensor_info_size = op_info->inputs_tensor_info().size(); | |||
| if (SizeToInt(input_tensor_info_size) <= user_input_index - 1) { | |||
| MS_LOG(EXCEPTION) << op_info->name() << ": the size of inputs tensor info is " << input_tensor_info_size | |||
| << ", but the index is " << user_input_index - 1; | |||
| } | |||
| TensorInfo tensor_info = op_info->inputs_tensor_info()[user_input_index - 1]; | |||
| MS_LOG(DEBUG) << "The op name is " << op_info->name() << ", the parameter index is " << user_input_index - 1 | |||
| << ", the slice shape is " << ShapeToString(tensor_info.slice_shape()) << ", the origin shape is " | |||
| << ShapeToString(tensor_info.shape()); | |||
| return tensor_info.slice_shape(); | |||
| } | |||
| void CheckParameterSplit(const std::vector<AnfNodePtr> &all_nodes) { | |||
| for (auto &node : all_nodes) { | |||
| ParameterUsersInfo parameter_users_info = FindParameterUsers(node, IsParallelCareNode); | |||
| auto users_set = parameter_users_info.second.second; | |||
| if (users_set.size() <= 1) { | |||
| continue; | |||
| } | |||
| auto parameter_name = parameter_users_info.first; | |||
| MS_LOG(INFO) << "The parameter: " << parameter_name << " has " << users_set.size() << " users"; | |||
| auto first_user = users_set.pop(); | |||
| Shape first_user_slice_shape = ParameterSliceShape(first_user); | |||
| for (auto &user : users_set) { | |||
| Shape user_slice_shape = ParameterSliceShape(user); | |||
| if (first_user_slice_shape != user_slice_shape) { | |||
| MS_LOG(EXCEPTION) << "The parameter: " << parameter_name | |||
| << " has multiple users, but the split strategies are different"; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) { | |||
| MS_EXCEPTION_IF_NULL(root); | |||
| MS_EXCEPTION_IF_NULL(optimizer); | |||
| @@ -2556,6 +2699,9 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) | |||
| HandleForwardMakeTuple(all_nodes); | |||
| // if the input or parameter has multiple users, check whether its split strategies are consistent. | |||
| CheckParameterSplit(all_nodes); | |||
| // save strategy as checkpoint for multi-train | |||
| if (StrategyCheckpoint::GetInstance().SaveCheckPointOn()) { | |||
| CheckpointStrategy(all_nodes); | |||
| @@ -150,6 +150,13 @@ std::vector<std::string> ExtractInputsTensorName(const CNodePtr &node); | |||
| std::set<FuncGraphPtr> ForwardGraph(const FuncGraphPtr &root); | |||
| bool AnfNodeIsPrimitive(const AnfNodePtr &anf_node, const std::string &prim_name); | |||
| using RefKeyPair = std::pair<AnfNodePtr, std::vector<AnfNodePtr>>; | |||
| using ParameterUsersInfo = std::pair<std::string, std::pair<AnfNodePtr, AnfNodeIndexSet>>; | |||
| RefKeyPair CNodeWithRefKeys(const AnfNodePtr &cnode); | |||
| ParameterUsersInfo FindParameterUsers(const AnfNodePtr &node, bool (*IsCareNode)(const CNodePtr &)); | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -245,51 +245,3 @@ def test_reshape_auto_5(): | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||
| net.set_auto_parallel() | |||
| _executor.compile(net, x, y) | |||
| def test_reshape_auto_6(): | |||
| class NetWithLoss6(nn.Cell): | |||
| def __init__(self, network): | |||
| super(NetWithLoss6, self).__init__() | |||
| self.loss = VirtualLoss() | |||
| self.network = network | |||
| def construct(self, x, y): | |||
| predict = self.network(x, y) | |||
| return self.loss(predict) | |||
| class GradWrap6(nn.Cell): | |||
| def __init__(self, network): | |||
| super(GradWrap6, self).__init__() | |||
| self.network = network | |||
| def construct(self, x, y): | |||
| return grad_all(self.network)(x, y) | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.relu = P.ReLU() | |||
| self.mul = P.Mul() | |||
| self.reshape = P.Reshape() | |||
| self.reduce_mean = P.ReduceMean() | |||
| self.wide_w = Parameter(Tensor(np.ones([4, 1024, 1]), dtype=ms.float32), name="weight") | |||
| def construct(self, x, y): | |||
| out1 = x + self.wide_w | |||
| w = self.reshape(self.wide_w, (4, 1024)) | |||
| out1 = self.reduce_mean(out1, 1) | |||
| out1 = out1 - w | |||
| out2 = self.mul(y, w) | |||
| out = out1 + out2 | |||
| return out | |||
| size = 8 | |||
| context.set_auto_parallel_context(device_num=size, global_rank=0) | |||
| x = Tensor(np.ones([4, 1024, 1]), dtype=ms.float32) | |||
| y = Tensor(np.ones([4, 1024,]), dtype=ms.float32) | |||
| net = GradWrap6(NetWithLoss6(Net())) | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||
| net.set_auto_parallel() | |||
| _executor.compile(net, x, y) | |||
| @@ -0,0 +1,94 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore as ms | |||
| from mindspore import context, Tensor, Parameter | |||
| from mindspore.common.api import _executor | |||
| from mindspore.nn import Cell, TrainOneStepCell, Momentum | |||
| from mindspore.ops import operations as P | |||
| class Net(Cell): | |||
| def __init__(self, mul_weight, strategy1=None, strategy2=None): | |||
| super().__init__() | |||
| self.mul = P.Mul().set_strategy(strategy1) | |||
| self.mul2 = P.Mul().set_strategy(strategy2) | |||
| self.mul_weight = Parameter(mul_weight, "w1") | |||
| def construct(self, x, b): | |||
| out = self.mul(x, self.mul_weight) | |||
| out = self.mul2(out, self.mul_weight) | |||
| return out | |||
| class Net2(Cell): | |||
| def __init__(self, mul_weight, strategy1=None, strategy2=None): | |||
| super().__init__() | |||
| self.mul = P.Mul().set_strategy(strategy1) | |||
| self.mul2 = P.Mul().set_strategy(strategy2) | |||
| self.mul_weight = Parameter(mul_weight, "w1") | |||
| def construct(self, x, b): | |||
| out = self.mul(x, self.mul_weight) | |||
| out = self.mul2(x, out) | |||
| return out | |||
| _x = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) | |||
| _w = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) | |||
| _b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) | |||
| def compile_net(net): | |||
| optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
| train_net = TrainOneStepCell(net, optimizer) | |||
| train_net.set_auto_parallel() | |||
| _executor.compile(train_net, _x, _b) | |||
| context.reset_auto_parallel_context() | |||
| def test_parameter_same_split(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) | |||
| strategy1 = ((16, 1, 1), (16, 1, 1)) | |||
| strategy2 = ((16, 1, 1), (16, 1, 1)) | |||
| net = Net(_w, strategy1, strategy2) | |||
| compile_net(net) | |||
| def test_parameter_different_split(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) | |||
| strategy1 = ((16, 1, 1), (16, 1, 1)) | |||
| strategy2 = ((4, 4, 1), (4, 4, 1)) | |||
| net = Net(_w, strategy1, strategy2) | |||
| with pytest.raises(RuntimeError): | |||
| compile_net(net) | |||
| def test_input_same_split(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) | |||
| strategy1 = ((16, 1, 1), (16, 1, 1)) | |||
| strategy2 = ((16, 1, 1), (16, 1, 1)) | |||
| net = Net(_w, strategy1, strategy2) | |||
| compile_net(net) | |||
| def test_input_different_split(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) | |||
| strategy1 = ((16, 1, 1), (16, 1, 1)) | |||
| strategy2 = ((4, 4, 1), (4, 4, 1)) | |||
| net = Net2(_w, strategy1, strategy2) | |||
| with pytest.raises(RuntimeError): | |||
| compile_net(net) | |||