| @@ -1269,15 +1269,17 @@ std::pair<AnfNodePtr, int> FindSubGraph(const FuncGraphPtr &graph, const AnfNode | |||
| } else { | |||
| AnfNodeIndexSet param_sub_set = manager->node_users()[parameter]; | |||
| for (auto ¶m_pair : param_sub_set) { | |||
| CNodePtr graph_cnode = param_pair.first->cast<CNodePtr>(); | |||
| if ((graph_cnode == nullptr) || !graph_cnode->input(0)->isa<CNode>()) { | |||
| continue; | |||
| CNodePtr param_cnode = param_pair.first->cast<CNodePtr>(); | |||
| AnfNodePtr graph_value_node; | |||
| if (param_cnode->input(0)->isa<CNode>()) { | |||
| graph_value_node = param_cnode->input(0)->cast<CNodePtr>()->input(1); | |||
| } else { | |||
| graph_value_node = param_cnode->input(0); | |||
| } | |||
| CNodePtr graph_cnode_inp0 = graph_cnode->input(0)->cast<CNodePtr>(); | |||
| if (!IsValueNode<FuncGraph>(graph_cnode_inp0->input(1))) { | |||
| if (!IsValueNode<FuncGraph>(graph_value_node)) { | |||
| continue; | |||
| } | |||
| FuncGraphPtr graph_sub = GetValueNode<FuncGraphPtr>(graph_cnode_inp0->input(1)); | |||
| FuncGraphPtr graph_sub = GetValueNode<FuncGraphPtr>(graph_value_node); | |||
| auto parameters = graph_sub->parameters(); | |||
| if (IntToSize(param_pair.second - 1) >= parameters.size()) { | |||
| MS_LOG(EXCEPTION) << "The index is out of range, index is " << param_pair.second - 1 << ", vector size is " | |||
| @@ -1864,7 +1866,8 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) { | |||
| // return -> make_tuple | |||
| if (current_prim->name() == MAKE_TUPLE) { | |||
| MS_LOG(EXCEPTION) << "The loss have make_tuple, it is not supported"; | |||
| MS_LOG(WARNING) << "The loss have make_tuple, it is not supported"; | |||
| return nullptr; | |||
| } | |||
| // return -> loss | |||
| @@ -2069,6 +2072,12 @@ std::set<FuncGraphPtr> FindForwardGraphByRootNodes(const AnfNodeSet &root_all_no | |||
| auto graph = GetValueNode<FuncGraphPtr>(cnode->input(1)); | |||
| MS_LOG(DEBUG) << "Find the forward graph success"; | |||
| graph_set.insert(graph); | |||
| auto manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| auto graph_used = manager->func_graphs_used_total(graph); | |||
| for (auto &sub_graph : graph_used) { | |||
| graph_set.insert(sub_graph); | |||
| } | |||
| } | |||
| } | |||
| return graph_set; | |||
| @@ -2423,7 +2432,7 @@ void HandleRootReshape(const std::vector<AnfNodePtr> &all_nodes) { | |||
| void MarkForwardCNode(const FuncGraphPtr &root) { | |||
| MS_EXCEPTION_IF_NULL(root); | |||
| auto all_nodes = root->nodes(); | |||
| std::set<FuncGraphPtr> graph_set = FindForwardGraphByRootNodes(all_nodes); | |||
| auto graph_set = FindForwardGraphByRootNodes(all_nodes); | |||
| if (graph_set.empty()) { | |||
| MS_LOG(INFO) << "Can not find the forward graph, so mark the ops in root graph"; | |||
| @@ -145,10 +145,10 @@ int32_t GetTupleGetItemIndex(const CNodePtr &cnode); | |||
| Status ParallelInit(); | |||
| std::vector<std::string> ExtractInputsTensorName(const CNodePtr &node); | |||
| std::set<FuncGraphPtr> ForwardGraph(const FuncGraphPtr &root); | |||
| std::vector<std::string> ExtractInputsTensorName(const CNodePtr &node); | |||
| bool AnfNodeIsPrimitive(const AnfNodePtr &anf_node, const std::string &prim_name); | |||
| using RefKeyPair = std::pair<AnfNodePtr, std::vector<AnfNodePtr>>; | |||
| @@ -0,0 +1,69 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import numpy as np | |||
| import mindspore as ms | |||
| import mindspore.context as context | |||
| from mindspore.common.api import _executor | |||
| from mindspore import Tensor, Parameter | |||
| from mindspore.nn import Cell, TrainOneStepCell, Momentum | |||
| from mindspore.ops import operations as P | |||
| class TwoInputBprop(Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.op = P.Mul() | |||
| def construct(self, x, y): | |||
| return self.op(x, y) | |||
| def bprop(self, x, y, out, dout): | |||
| return x * 5, y * 8 | |||
| class ParallelFloorDivBpropNet(Cell): | |||
| def __init__(self, mul_size, test_size, strategy=None, strategy2=None): | |||
| super().__init__() | |||
| mul_np = np.full(mul_size, 0.5, dtype=np.float32) | |||
| floordiv_np = np.full(test_size, 0.1, dtype=np.float32) | |||
| self.mul_weight = Parameter(Tensor(mul_np), name="mul_weight") | |||
| self.floordiv_weight = Parameter(Tensor(floordiv_np), name="floordiv_weight") | |||
| self.mul = TwoInputBprop() | |||
| self.floor_div = P.FloorDiv() | |||
| if strategy is not None: | |||
| self.mul.op.shard(strategy2) | |||
| self.floor_div.shard(strategy) | |||
| def construct(self, inputs, label): | |||
| x = self.mul(inputs, self.mul_weight) | |||
| x = self.floor_div(x, self.floordiv_weight) | |||
| return x | |||
| inputs_ = Tensor(np.random.randn(128, 96).astype(np.float32), dtype=ms.float32) | |||
| label_ = Tensor(np.random.randn(128, 96).astype(np.float32), dtype=ms.float32) | |||
| def compile_net(net): | |||
| optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
| train_net = TrainOneStepCell(net, optimizer) | |||
| train_net.set_auto_parallel() | |||
| _executor.compile(train_net, inputs_, label_) | |||
| context.reset_auto_parallel_context() | |||
| def test_net(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=4, global_rank=0) | |||
| strategy = ((4, 1), (4, 1)) | |||
| net = ParallelFloorDivBpropNet(mul_size=(128, 96), test_size=(128, 96), strategy=strategy, strategy2=strategy) | |||
| compile_net(net) | |||