| @@ -54,31 +54,6 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) { | |||
| return dict; | |||
| } | |||
| py::dict GetCNodeStrategy(const FuncGraphPtr &graph) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| py::dict dict; | |||
| auto ret = graph->get_return(); | |||
| MS_EXCEPTION_IF_NULL(ret); | |||
| auto nodes = DeepScopedGraphSearch(ret); | |||
| for (auto node : nodes) { | |||
| if (node->isa<CNode>()) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto distributed_operation_info = cnode->user_data<OperatorInfo>(); | |||
| if (distributed_operation_info != nullptr) { | |||
| auto strategyPtr = distributed_operation_info->strategy(); | |||
| if (strategyPtr != nullptr) { | |||
| auto strategy = strategyPtr->GetInputDim(); | |||
| auto name = cnode->fullname_with_scope(); | |||
| dict[py::str(name)] = strategy; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return dict; | |||
| } | |||
| py::dict GetAllreduceFusion(const FuncGraphPtr &graph) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| py::dict dict; | |||
| @@ -25,7 +25,6 @@ namespace py = pybind11; | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| py::dict GetParameterLayout(const FuncGraphPtr &graph); | |||
| py::dict GetCNodeStrategy(const FuncGraphPtr &graph); | |||
| py::dict GetAllreduceFusion(const FuncGraphPtr &graph); | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -1524,7 +1524,6 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) { | |||
| // Get global rank after the checkpoint? | |||
| int32_t global_rank = ParallelContext::GetInstance()->global_rank(); | |||
| std::vector<int32_t> stages = ParallelContext::GetInstance()->stage(); | |||
| for (auto &node : all_nodes) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { | |||
| @@ -2478,17 +2477,32 @@ void InsertShapeOp(const CNodePtr &node, const AnfNodePtr &pre_node, const FuncG | |||
| InsertNode(op, node, 2, pre_node, root, "shape"); | |||
| } | |||
| void HandleRootReshape(const std::vector<AnfNodePtr> &all_nodes) { | |||
| void HandleRootReshapeAndSaveStrategy(const std::vector<AnfNodePtr> &all_nodes) { | |||
| // If root graph has reshape op. Find the corresponding parameter. | |||
| // Reshape's shape is the shape of the parameter. | |||
| auto executor = pipeline::ExecutorPy::GetInstance(); | |||
| for (auto &node : all_nodes) { | |||
| if (!node->isa<CNode>()) { | |||
| continue; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (!IsValueNode<Primitive>(cnode->input(0)) || cnode->in_forward_flag()) { | |||
| if (!IsValueNode<Primitive>(cnode->input(0)) || cnode == nullptr) { | |||
| continue; | |||
| } | |||
| if (cnode->in_forward_flag()) { | |||
| // Save strategy in executor | |||
| OperatorInfoPtr op_info = cnode->user_data<OperatorInfo>(); | |||
| if (op_info) { | |||
| auto stra_ptr = op_info->strategy(); | |||
| if (stra_ptr) { | |||
| auto strategy = stra_ptr->GetInputDim(); | |||
| // fullname with scope should be found in step parallel end ir | |||
| executor->SetCNodeStrategy(cnode->fullname_with_scope(), strategy); | |||
| } | |||
| } | |||
| continue; | |||
| } | |||
| auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| if (prim->name() != RESHAPE) { | |||
| continue; | |||
| @@ -2844,7 +2858,7 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) | |||
| ReshapeInit(all_nodes); | |||
| } | |||
| HandleRootReshape(all_nodes); | |||
| HandleRootReshapeAndSaveStrategy(all_nodes); | |||
| HandleForwardMakeTupleAndMakeList(all_nodes); | |||
| @@ -29,6 +29,7 @@ | |||
| #include "frontend/optimizer/opt.h" | |||
| #include "frontend/parallel/strategy.h" | |||
| #include "frontend/parallel/tensor_layout/tensor_redistribution.h" | |||
| #include "pipeline/jit/pipeline.h" | |||
| using OperatorInfoPtr = std::shared_ptr<mindspore::parallel::OperatorInfo>; | |||
| @@ -243,9 +243,12 @@ py::dict ExecutorPy::GetParameterLayout(const std::string &phase) { | |||
| py::dict ExecutorPy::GetCNodeStrategy(const std::string &phase) { | |||
| MS_LOG(DEBUG) << "GetCNodeStrategy!"; | |||
| std::string layout_graph = phase + kStepParallelGraph; | |||
| auto graph = GetFuncGraph(layout_graph); | |||
| return mindspore::parallel::GetCNodeStrategy(graph); | |||
| return stra_dict_[phase]; | |||
| } | |||
| void ExecutorPy::SetCNodeStrategy(const std::string &name, const parallel::Strategys &strategy) { | |||
| MS_LOG(DEBUG) << "SetCNodeStrategy!"; | |||
| stra_dict_[phase_][py::str(name)] = strategy; | |||
| } | |||
| py::dict ExecutorPy::GetAllreduceFusion(const std::string &phase) { | |||
| @@ -449,6 +452,7 @@ bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, cons | |||
| #endif | |||
| ExecutorInfoPtr executor_info = std::make_shared<ExecutorInfo>(); | |||
| auto phase_s = py::cast<std::string>(phase); | |||
| phase_ = phase_s; | |||
| MS_LOG(INFO) << "ExecutorPy compile phase:" << phase_s << "!"; | |||
| ResourcePtr resource = std::make_shared<Resource>(obj); | |||
| @@ -92,6 +92,7 @@ class ExecutorPy : public std::enable_shared_from_this<ExecutorPy> { | |||
| void RunInitGraph(const py::dict &init_params, const std::string &phase); | |||
| py::dict GetParameterLayout(const std::string &phase); | |||
| py::dict GetCNodeStrategy(const std::string &phase); | |||
| void SetCNodeStrategy(const std::string &name, const parallel::Strategys &strategy); | |||
| py::dict GetAllreduceFusion(const std::string &phase); | |||
| void DelNetRes(const std::string &id); | |||
| void ReleaseResource(const py::object &phase); | |||
| @@ -114,6 +115,8 @@ class ExecutorPy : public std::enable_shared_from_this<ExecutorPy> { | |||
| static std::shared_ptr<ExecutorPy> executor_; | |||
| static std::mutex instance_lock_; | |||
| static bool debugger_terminate_; | |||
| std::map<std::string, py::dict> stra_dict_; | |||
| std::string phase_ = ""; | |||
| }; | |||
| using ExecutorPyPtr = std::shared_ptr<ExecutorPy>; | |||
| @@ -12,6 +12,7 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import re | |||
| import numpy as np | |||
| import mindspore as ms | |||
| @@ -96,15 +97,15 @@ def test_all_to_all(): | |||
| _reset_op_id() | |||
| strategys = all_to_all_common(strategy1) | |||
| print(strategys) | |||
| expect_dict = {'Default/network-_VirtualDatasetCell/_backbone-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits' | |||
| '/SoftmaxCrossEntropyWithLogits-op3': [[8, 1], [8, 1]], | |||
| 'Default/network-_VirtualDatasetCell/_backbone-WithLossCell/_loss_fn-SoftmaxCrossEntropyWithLogits/' | |||
| 'OneHot-op4': [[8, 1], [], []], | |||
| 'Default/network-_VirtualDatasetCell/_backbone-WithLossCell/_backbone-AllToAllNet/Transpose-op1': [ | |||
| [8, 1]], | |||
| 'Default/network-_VirtualDatasetCell/_backbone-WithLossCell/_backbone-AllToAllNet/MatMul-op0': [ | |||
| [1, 1], [1, 8]]} | |||
| assert strategys == expect_dict | |||
| for (k, v) in strategys.items(): | |||
| if re.search('SoftmaxCrossEntropyWithLogits-op', k) is not None: | |||
| assert v == [[8, 1], [8, 1]] | |||
| elif re.search('OneHot-op', k) is not None: | |||
| assert v == [[8, 1], [], []] | |||
| elif re.search('Transpose-op', k) is not None: | |||
| assert v == [[8, 1]] | |||
| elif re.search('MatMul-op', k) is not None: | |||
| assert v == [[1, 1], [1, 8]] | |||
| context.set_context(save_graphs=False) | |||
| @@ -77,8 +77,8 @@ def test_auto_parallel_arithmetic(): | |||
| b = Tensor(np.ones([64, 128]), dtype=ms.float32) | |||
| compile_net(net, x, y, b, phase='train') | |||
| strategies = _executor._get_shard_strategy(net) | |||
| expected_strategies = {'Default/network-Net/FloorDiv-op0': [[2, 4], [2, 4]], | |||
| 'Default/network-Net/MatMul-op1': [[2, 1], [1, 4]]} | |||
| expected_strategies = {'Default/network-Net/FloorDiv-op1': [[2, 4], [2, 4]], | |||
| 'Default/network-Net/MatMul-op0': [[2, 1], [1, 4]]} | |||
| assert strategies == expected_strategies | |||
| @@ -104,8 +104,8 @@ def test_auto_parallel_arithmetic_broadcast_both(): | |||
| b = Tensor(np.ones([1, 64]), dtype=ms.float32) | |||
| compile_net(net, x, y, b, phase='train') | |||
| strategies = _executor._get_shard_strategy(net) | |||
| expected_strategies = {'Default/network-Net/FloorDiv-op0': [[8, 1], [1, 1]], | |||
| 'Default/network-Net/MatMul-op1': [[8, 1], [1, 1]]} | |||
| expected_strategies = {'Default/network-Net/FloorDiv-op1': [[8, 1], [1, 1]], | |||
| 'Default/network-Net/MatMul-op0': [[8, 1], [1, 1]]} | |||
| assert strategies == expected_strategies | |||
| @@ -131,8 +131,8 @@ def test_auto_parallel_arithmetic_broadcast_right(): | |||
| b = Tensor(np.ones([32]), dtype=ms.float32) | |||
| compile_net(net, x, y, b, phase='train') | |||
| strategies = _executor._get_shard_strategy(net) | |||
| expected_strategies = {'Default/network-Net/FloorDiv-op0': [[4, 2], [2]], | |||
| 'Default/network-Net/MatMul-op1': [[4, 1], [1, 2]]} | |||
| expected_strategies = {'Default/network-Net/FloorDiv-op1': [[4, 2], [2]], | |||
| 'Default/network-Net/MatMul-op0': [[4, 1], [1, 2]]} | |||
| assert strategies == expected_strategies | |||
| @@ -158,6 +158,6 @@ def test_auto_parallel_arithmetic_broadcast_left(): | |||
| b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) | |||
| compile_net(net, x, y, b, phase="train") | |||
| strategies = _executor._get_shard_strategy(net) | |||
| expected_strategies = {'Default/network-Net/FloorDiv-op0': [[4, 2], [1, 4, 2]], | |||
| 'Default/network-Net/MatMul-op1': [[4, 1], [1, 2]]} | |||
| expected_strategies = {'Default/network-Net/FloorDiv-op1': [[4, 2], [1, 4, 2]], | |||
| 'Default/network-Net/MatMul-op0': [[4, 1], [1, 2]]} | |||
| assert strategies == expected_strategies | |||
| @@ -86,6 +86,6 @@ def test_double_star_graph(): | |||
| expected_strategies = {'Default/network-Net/Cast-op0': [[8, 1]], | |||
| 'Default/network-Net/Cast-op1': [[1, 8]], | |||
| 'Default/network-Net/MatMul-op3': [[8, 1], [1, 1]], | |||
| 'Default/network-Net/MatMul-op4': [[1, 1], [1, 8]], | |||
| 'Default/network-Net/MatMul-op2': [[1, 8], [8, 1]]} | |||
| 'Default/network-Net/MatMul-op2': [[1, 1], [1, 8]], | |||
| 'Default/network-Net/MatMul-op4': [[1, 8], [8, 1]]} | |||
| assert strategies == expected_strategies | |||
| @@ -12,6 +12,7 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import re | |||
| import numpy as np | |||
| import mindspore as ms | |||
| @@ -114,12 +115,16 @@ def test_double_subgraphs(): | |||
| reset_op_id() | |||
| _executor.compile(net, x, phase='train') | |||
| strategies = _executor._get_shard_strategy(net) | |||
| expected_strategies = {'Default/network-NetWithLoss/ReduceMean-op0': [[8, 1, 1, 1]], | |||
| 'Default/network-NetWithLoss/net-Net/ReLU-op1': [[8, 1, 1, 1]], | |||
| 'Default/network-NetWithLoss/net-Net/Mul-op2': [[8, 1, 1, 1], [8, 1, 1, 1]], | |||
| 'Default/network-NetWithLoss/net-Net/Mul-op3': [[8, 1, 1, 1], [8, 1, 1, 1]], | |||
| 'Default/network-NetWithLoss/ReduceSum-op4': [[8, 1, 1, 1]]} | |||
| assert strategies == expected_strategies | |||
| for (k, v) in strategies.items(): | |||
| if re.search('ReduceMean-op', k) is not None: | |||
| assert v == [[8, 1, 1, 1]] | |||
| elif re.search('ReLU-op', k) is not None: | |||
| assert v == [[8, 1, 1, 1]] | |||
| elif re.search('Mul-op', k) is not None: | |||
| assert v == [[8, 1, 1, 1], [8, 1, 1, 1]] | |||
| elif re.search('ReduceSum-op', k) is not None: | |||
| assert v == [[8, 1, 1, 1]] | |||
| class DatasetLenet(): | |||
| def __init__(self, predict, label, length=3): | |||
| @@ -160,10 +165,14 @@ def test_double_subgraphs_train(): | |||
| model = Model(net) | |||
| model.train(1, ds_train, dataset_sink_mode=False) | |||
| strategies = _executor._get_shard_strategy(net) | |||
| expected_strategies = {'Default/network-NetWithLoss/ReduceMean-op3': [[1, 1, 1, 1]], | |||
| 'Default/network-NetWithLoss/net-Net/ReLU-op4': [[1, 1, 1, 1]], | |||
| 'Default/network-NetWithLoss/net-Net/Mul-op5': [[1, 1, 1, 1], [1, 1, 1, 1]], | |||
| 'Default/network-NetWithLoss/net-Net/Mul-op6': [[1, 1, 1, 1], [1, 1, 1, 1]], | |||
| 'Default/network-NetWithLoss/net-Net/Cast-op1': [[1, 1, 1, 1]], | |||
| 'Default/network-NetWithLoss/ReduceSum-op7': [[1, 1, 1, 1]]} | |||
| assert strategies == expected_strategies | |||
| for (k, v) in strategies.items(): | |||
| if re.search('ReduceMean-op', k) is not None: | |||
| assert v == [[1, 1, 1, 1]] | |||
| elif re.search('ReLU-op', k) is not None: | |||
| assert v == [[1, 1, 1, 1]] | |||
| elif re.search('Mul-op', k) is not None: | |||
| assert v == [[1, 1, 1, 1], [1, 1, 1, 1]] | |||
| elif re.search('Cast-op', k) is not None: | |||
| assert v == [[1, 1, 1, 1]] | |||
| elif re.search('ReduceSum-op', k) is not None: | |||
| assert v == [[1, 1, 1, 1]] | |||
| @@ -78,8 +78,8 @@ def test_two_matmul_transpose(): | |||
| _executor.compile(net, x, y, b, phase='train') | |||
| strategies = _executor._get_shard_strategy(net) | |||
| expected_strategies = {'Default/network-Net/Transpose-op0': [[1, 16]], | |||
| 'Default/network-Net/Transpose-op1': [[16, 1]], | |||
| 'Default/network-Net/MatMul-op2': [[16, 1], [1, 1]], | |||
| 'Default/network-Net/MatMul-op3': [[16, 1], [1, 1]]} | |||
| expected_strategies = {'Default/network-Net/Transpose-op3': [[1, 16]], | |||
| 'Default/network-Net/Transpose-op2': [[16, 1]], | |||
| 'Default/network-Net/MatMul-op0': [[16, 1], [1, 1]], | |||
| 'Default/network-Net/MatMul-op1': [[16, 1], [1, 1]]} | |||
| assert strategies == expected_strategies | |||