| @@ -1269,15 +1269,17 @@ std::pair<AnfNodePtr, int> FindSubGraph(const FuncGraphPtr &graph, const AnfNode | |||||
| } else { | } else { | ||||
| AnfNodeIndexSet param_sub_set = manager->node_users()[parameter]; | AnfNodeIndexSet param_sub_set = manager->node_users()[parameter]; | ||||
| for (auto ¶m_pair : param_sub_set) { | for (auto ¶m_pair : param_sub_set) { | ||||
| CNodePtr graph_cnode = param_pair.first->cast<CNodePtr>(); | |||||
| if ((graph_cnode == nullptr) || !graph_cnode->input(0)->isa<CNode>()) { | |||||
| continue; | |||||
| CNodePtr param_cnode = param_pair.first->cast<CNodePtr>(); | |||||
| AnfNodePtr graph_value_node; | |||||
| if (param_cnode->input(0)->isa<CNode>()) { | |||||
| graph_value_node = param_cnode->input(0)->cast<CNodePtr>()->input(1); | |||||
| } else { | |||||
| graph_value_node = param_cnode->input(0); | |||||
| } | } | ||||
| CNodePtr graph_cnode_inp0 = graph_cnode->input(0)->cast<CNodePtr>(); | |||||
| if (!IsValueNode<FuncGraph>(graph_cnode_inp0->input(1))) { | |||||
| if (!IsValueNode<FuncGraph>(graph_value_node)) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| FuncGraphPtr graph_sub = GetValueNode<FuncGraphPtr>(graph_cnode_inp0->input(1)); | |||||
| FuncGraphPtr graph_sub = GetValueNode<FuncGraphPtr>(graph_value_node); | |||||
| auto parameters = graph_sub->parameters(); | auto parameters = graph_sub->parameters(); | ||||
| if (IntToSize(param_pair.second - 1) >= parameters.size()) { | if (IntToSize(param_pair.second - 1) >= parameters.size()) { | ||||
| MS_LOG(EXCEPTION) << "The index is out of range, index is " << param_pair.second - 1 << ", vector size is " | MS_LOG(EXCEPTION) << "The index is out of range, index is " << param_pair.second - 1 << ", vector size is " | ||||
| @@ -1864,7 +1866,8 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) { | |||||
| // return -> make_tuple | // return -> make_tuple | ||||
| if (current_prim->name() == MAKE_TUPLE) { | if (current_prim->name() == MAKE_TUPLE) { | ||||
| MS_LOG(EXCEPTION) << "The loss have make_tuple, it is not supported"; | |||||
| MS_LOG(WARNING) << "The loss have make_tuple, it is not supported"; | |||||
| return nullptr; | |||||
| } | } | ||||
| // return -> loss | // return -> loss | ||||
| @@ -2069,6 +2072,12 @@ std::set<FuncGraphPtr> FindForwardGraphByRootNodes(const AnfNodeSet &root_all_no | |||||
| auto graph = GetValueNode<FuncGraphPtr>(cnode->input(1)); | auto graph = GetValueNode<FuncGraphPtr>(cnode->input(1)); | ||||
| MS_LOG(DEBUG) << "Find the forward graph success"; | MS_LOG(DEBUG) << "Find the forward graph success"; | ||||
| graph_set.insert(graph); | graph_set.insert(graph); | ||||
| auto manager = graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| auto graph_used = manager->func_graphs_used_total(graph); | |||||
| for (auto &sub_graph : graph_used) { | |||||
| graph_set.insert(sub_graph); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| return graph_set; | return graph_set; | ||||
| @@ -2423,7 +2432,7 @@ void HandleRootReshape(const std::vector<AnfNodePtr> &all_nodes) { | |||||
| void MarkForwardCNode(const FuncGraphPtr &root) { | void MarkForwardCNode(const FuncGraphPtr &root) { | ||||
| MS_EXCEPTION_IF_NULL(root); | MS_EXCEPTION_IF_NULL(root); | ||||
| auto all_nodes = root->nodes(); | auto all_nodes = root->nodes(); | ||||
| std::set<FuncGraphPtr> graph_set = FindForwardGraphByRootNodes(all_nodes); | |||||
| auto graph_set = FindForwardGraphByRootNodes(all_nodes); | |||||
| if (graph_set.empty()) { | if (graph_set.empty()) { | ||||
| MS_LOG(INFO) << "Can not find the forward graph, so mark the ops in root graph"; | MS_LOG(INFO) << "Can not find the forward graph, so mark the ops in root graph"; | ||||
| @@ -145,10 +145,10 @@ int32_t GetTupleGetItemIndex(const CNodePtr &cnode); | |||||
| Status ParallelInit(); | Status ParallelInit(); | ||||
| std::vector<std::string> ExtractInputsTensorName(const CNodePtr &node); | |||||
| std::set<FuncGraphPtr> ForwardGraph(const FuncGraphPtr &root); | std::set<FuncGraphPtr> ForwardGraph(const FuncGraphPtr &root); | ||||
| std::vector<std::string> ExtractInputsTensorName(const CNodePtr &node); | |||||
| bool AnfNodeIsPrimitive(const AnfNodePtr &anf_node, const std::string &prim_name); | bool AnfNodeIsPrimitive(const AnfNodePtr &anf_node, const std::string &prim_name); | ||||
| using RefKeyPair = std::pair<AnfNodePtr, std::vector<AnfNodePtr>>; | using RefKeyPair = std::pair<AnfNodePtr, std::vector<AnfNodePtr>>; | ||||
| @@ -0,0 +1,69 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| import numpy as np | |||||
| import mindspore as ms | |||||
| import mindspore.context as context | |||||
| from mindspore.common.api import _executor | |||||
| from mindspore import Tensor, Parameter | |||||
| from mindspore.nn import Cell, TrainOneStepCell, Momentum | |||||
| from mindspore.ops import operations as P | |||||
| class TwoInputBprop(Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.op = P.Mul() | |||||
| def construct(self, x, y): | |||||
| return self.op(x, y) | |||||
| def bprop(self, x, y, out, dout): | |||||
| return x * 5, y * 8 | |||||
| class ParallelFloorDivBpropNet(Cell): | |||||
| def __init__(self, mul_size, test_size, strategy=None, strategy2=None): | |||||
| super().__init__() | |||||
| mul_np = np.full(mul_size, 0.5, dtype=np.float32) | |||||
| floordiv_np = np.full(test_size, 0.1, dtype=np.float32) | |||||
| self.mul_weight = Parameter(Tensor(mul_np), name="mul_weight") | |||||
| self.floordiv_weight = Parameter(Tensor(floordiv_np), name="floordiv_weight") | |||||
| self.mul = TwoInputBprop() | |||||
| self.floor_div = P.FloorDiv() | |||||
| if strategy is not None: | |||||
| self.mul.op.shard(strategy2) | |||||
| self.floor_div.shard(strategy) | |||||
| def construct(self, inputs, label): | |||||
| x = self.mul(inputs, self.mul_weight) | |||||
| x = self.floor_div(x, self.floordiv_weight) | |||||
| return x | |||||
| inputs_ = Tensor(np.random.randn(128, 96).astype(np.float32), dtype=ms.float32) | |||||
| label_ = Tensor(np.random.randn(128, 96).astype(np.float32), dtype=ms.float32) | |||||
| def compile_net(net): | |||||
| optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||||
| train_net = TrainOneStepCell(net, optimizer) | |||||
| train_net.set_auto_parallel() | |||||
| _executor.compile(train_net, inputs_, label_) | |||||
| context.reset_auto_parallel_context() | |||||
| def test_net(): | |||||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=4, global_rank=0) | |||||
| strategy = ((4, 1), (4, 1)) | |||||
| net = ParallelFloorDivBpropNet(mul_size=(128, 96), test_size=(128, 96), strategy=strategy, strategy2=strategy) | |||||
| compile_net(net) | |||||