Merge pull request !7495 from Xiaoda/27-star-elimination-overloadtags/v1.1.0
| @@ -201,7 +201,7 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) { | |||
| right_edge->set_selected_cost(decision->right_edge_cost_); | |||
| // 'left_node' recovers the strategy. | |||
| left_node->SetSelectedStrategyAndCost(decision->left_node_strategy_, decision->left_node_cost_); | |||
| if (TRIANGLE_STRATEGY_OVERWRITE) { | |||
| if (TRIANGLE_STAR_STRATEGY_OVERWRITE) { | |||
| // 'right_node' recovers the strategy. | |||
| MS_LOG(INFO) << "Overwrite the right-node: " << right_node->name() << " in recovering triangle elimination."; | |||
| right_node->SetSelectedStrategyAndCost(decision->right_node_strategy_, decision->right_node_cost_); | |||
| @@ -225,10 +225,16 @@ Status RecoverStrategy(std::vector<EliminationPtr> eliminations) { | |||
| MS_EXCEPTION_IF_NULL(succ_nodes[0]); | |||
| MS_EXCEPTION_IF_NULL(decision->succ_ops_stra_list_[0]); | |||
| MS_EXCEPTION_IF_NULL(decision->succ_ops_cost_list_[0]); | |||
| // Since Star is eliminated into 'succ_nodes[0]', only 'succ_nodes[0]' is needed to recover the strategy. | |||
| // Star is eliminated into 'succ_nodes[0]' | |||
| succ_nodes[0]->SetSelectedStrategyAndCost(decision->succ_ops_stra_list_[0], decision->succ_ops_cost_list_[0]); | |||
| for (size_t k = 1; k < succ_nodes.size(); ++k) { | |||
| succ_nodes[k]->CheckSelectedStrategy(decision->succ_ops_stra_list_[k]); | |||
| if (TRIANGLE_STAR_STRATEGY_OVERWRITE) { | |||
| // 'succ_nodes[k]' is overwritten strategy and cost. | |||
| succ_nodes[k]->SetSelectedStrategyAndCost(decision->succ_ops_stra_list_[k], decision->succ_ops_cost_list_[k]); | |||
| } else { | |||
| // In this case, 'succ_nodes[k]' is NOT overwritten strategy and cost, however, it checks the strategy. | |||
| succ_nodes[k]->CheckSelectedStrategy(decision->succ_ops_stra_list_[k]); | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "Recover starElimination succeeded."; | |||
| } else { | |||
| @@ -20,9 +20,9 @@ | |||
| #include <memory> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "ir/value.h" | |||
| #include "frontend/parallel/auto_parallel/edge_costmodel.h" | |||
| #include "frontend/parallel/auto_parallel/graph_costmodel.h" | |||
| #include "ir/value.h" | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| @@ -22,11 +22,11 @@ | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "utils/ms_utils.h" | |||
| #include "frontend/parallel/auto_parallel/costmodel.h" | |||
| #include "frontend/parallel/ops_info/operator_info.h" | |||
| #include "frontend/parallel/tensor_layout/tensor_info.h" | |||
| #include "frontend/parallel/tensor_layout/tensor_layout.h" | |||
| #include "utils/ms_utils.h" | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| @@ -40,7 +40,7 @@ bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES; | |||
| bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; | |||
| bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS; | |||
| int32_t RUN_PHASE = DEFAULT_RUN_PHASE; | |||
| bool TRIANGLE_STRATEGY_OVERWRITE = DEFAULT_TRIANGLE_STRATEGY_OVERWRITE; | |||
| bool TRIANGLE_STAR_STRATEGY_OVERWRITE = DEFAULT_TRIANGLE_STAR_STRATEGY_OVERWRITE; | |||
| void CostGraph::SetDeviceMemoryAndCostParameter() { | |||
| MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance()); | |||
| @@ -155,12 +155,12 @@ void CostGraph::SetDeviceMemoryAndCostParameter() { | |||
| MS_LOG(INFO) << "multi_subgraphs: false."; | |||
| } | |||
| auto overwrite = CostModelContext::GetInstance()->triangle_strategy_overwrite(); | |||
| TRIANGLE_STRATEGY_OVERWRITE = overwrite; | |||
| if (TRIANGLE_STRATEGY_OVERWRITE) { | |||
| MS_LOG(INFO) << "triangle_strategy_overwrite: true."; | |||
| auto overwrite = CostModelContext::GetInstance()->triangle_star_strategy_overwrite(); | |||
| TRIANGLE_STAR_STRATEGY_OVERWRITE = overwrite; | |||
| if (TRIANGLE_STAR_STRATEGY_OVERWRITE) { | |||
| MS_LOG(INFO) << "triangle_star_strategy_overwrite: true."; | |||
| } else { | |||
| MS_LOG(INFO) << "triangle_strategy_overwrite: false."; | |||
| MS_LOG(INFO) << "triangle_star_strategy_overwrite: false."; | |||
| } | |||
| // RUN_PHASE | |||
| @@ -1303,7 +1303,7 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, | |||
| elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ + | |||
| left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_; | |||
| if (TRIANGLE_STRATEGY_OVERWRITE) { | |||
| if (TRIANGLE_STAR_STRATEGY_OVERWRITE) { | |||
| new_computation += right_op_cost->computation_cost_; | |||
| new_memory += right_op_cost->memory_with_reuse_; | |||
| new_commu_cost += right_op_cost->communication_cost_; | |||
| @@ -1399,7 +1399,9 @@ OperatorInfoPtr CostGraph::EliminationTriangle(const OperatorInfoPtr &elimi_op, | |||
| } | |||
| if (!valid) { | |||
| MS_LOG(EXCEPTION) << "Eliminating triangle: " << elimi_op->name() << " failed."; | |||
| MS_LOG(EXCEPTION) << "Eliminating triangle: " << elimi_op->name() | |||
| << " failed. It may be caused by " | |||
| "configuring inconsistent strategies for operators."; | |||
| } | |||
| elimi_op->SetNotAlive(); | |||
| MS_LOG(INFO) << "Eliminating triangle: " << elimi_op->name() << " succeeded."; | |||
| @@ -1440,6 +1442,13 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_n | |||
| commu_cost += succ_edges_costs[i]->communication_cost_; | |||
| commu_forward += succ_edges_costs[i]->communication_forward_; | |||
| commu_without += succ_edges_costs[i]->communication_without_parameter_; | |||
| if (TRIANGLE_STAR_STRATEGY_OVERWRITE) { | |||
| computation_cost += succ_nodes_costs[i]->computation_cost_; | |||
| memory_cost += succ_nodes_costs[i]->memory_with_reuse_; | |||
| commu_cost += succ_nodes_costs[i]->communication_cost_; | |||
| commu_forward += succ_nodes_costs[i]->communication_forward_; | |||
| commu_without += succ_nodes_costs[i]->communication_without_parameter_; | |||
| } | |||
| } | |||
| } | |||
| @@ -1544,7 +1553,9 @@ std::vector<std::shared_ptr<Edge>> CostGraph::EliminationStar(const OperatorInfo | |||
| } | |||
| if (!valid) { | |||
| MS_LOG(EXCEPTION) << "Eliminating star centered at: " << merged_op->name() << " failed."; | |||
| MS_LOG(EXCEPTION) << "Eliminating star centered at: " << merged_op->name() | |||
| << " failed. It may be caused by " | |||
| "configuring inconsistent strategies for operators."; | |||
| } | |||
| merged_op->SetNotAlive(); | |||
| @@ -22,11 +22,11 @@ | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "utils/ms_utils.h" | |||
| #include "frontend/parallel/auto_parallel/edge_costmodel.h" | |||
| #include "frontend/parallel/costmodel_context.h" | |||
| #include "frontend/parallel/ops_info/operator_info.h" | |||
| #include "frontend/parallel/ops_info/tmp_identity_info.h" | |||
| #include "utils/ms_utils.h" | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| @@ -46,7 +46,7 @@ extern bool FULLY_USE_DEVICES; | |||
| extern bool ELEMENTWISE_OP_STRA_FOLLOW; | |||
| extern bool MULTI_SUBGRAPHS; | |||
| extern int32_t RUN_PHASE; | |||
| extern bool TRIANGLE_STRATEGY_OVERWRITE; | |||
| extern bool TRIANGLE_STAR_STRATEGY_OVERWRITE; | |||
| class CostGraph { | |||
| // 'CostGraph' consists of Operators and edges between them. An edge is created between two Operators if they have | |||
| @@ -16,8 +16,8 @@ | |||
| #include "frontend/parallel/auto_parallel/operator_costmodel.h" | |||
| #include <random> | |||
| #include <algorithm> | |||
| #include <random> | |||
| #include "frontend/parallel/device_matrix.h" | |||
| #include "frontend/parallel/tensor_layout/tensor_redistribution.h" | |||
| @@ -19,9 +19,9 @@ | |||
| #include <algorithm> | |||
| #include <cstdint> | |||
| #include <functional> | |||
| #include <map> | |||
| #include <memory> | |||
| #include <utility> | |||
| #include <map> | |||
| #include "frontend/parallel/device_manager.h" | |||
| @@ -18,18 +18,18 @@ | |||
| #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_CONTEXT_H_ | |||
| #include <cstdint> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "abstract/abstract_value.h" | |||
| #include "frontend/parallel/ops_info/ops_utils.h" | |||
| #include "frontend/parallel/status.h" | |||
| #include "utils/convert_utils.h" | |||
| #include "ir/anf.h" | |||
| #include "ir/func_graph.h" | |||
| #include "utils/convert_utils.h" | |||
| #include "utils/info.h" | |||
| #include "abstract/abstract_value.h" | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| @@ -64,7 +64,7 @@ void CostModelContext::ResetAlgoParameters() { | |||
| tensor_slice_alignment_size_ = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE; | |||
| fully_use_device_ = DEFAULT_FULLY_USE_DEVICES; | |||
| elementwise_stra_follow_ = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; | |||
| triangle_strategy_overwrite_ = DEFAULT_TRIANGLE_STRATEGY_OVERWRITE; | |||
| triangle_star_strategy_overwrite_ = DEFAULT_TRIANGLE_STAR_STRATEGY_OVERWRITE; | |||
| } | |||
| void CostModelContext::set_costmodel_context_for_device(const std::string &device_target) { | |||
| @@ -134,7 +134,9 @@ void CostModelContext::set_elementwise_stra_follow(bool elementwise_follow) { | |||
| elementwise_stra_follow_ = elementwise_follow; | |||
| } | |||
| void CostModelContext::set_triangle_strategy_overwrite(bool overwrite) { triangle_strategy_overwrite_ = overwrite; } | |||
| void CostModelContext::set_triangle_star_strategy_overwrite(bool overwrite) { | |||
| triangle_star_strategy_overwrite_ = overwrite; | |||
| } | |||
| void CostModelContext::set_run_phase(int32_t phase) { run_phase_ = phase; } | |||
| @@ -44,7 +44,7 @@ namespace parallel { | |||
| #define DEFAULT_RUN_PHASE 0 | |||
| #define TRAINING_PHASE 0 | |||
| #define INFERENCE_PHASE 1 | |||
| #define DEFAULT_TRIANGLE_STRATEGY_OVERWRITE true; | |||
| #define DEFAULT_TRIANGLE_STAR_STRATEGY_OVERWRITE true; | |||
| class CostModelContext { | |||
| public: | |||
| @@ -135,8 +135,8 @@ class CostModelContext { | |||
| void set_elementwise_stra_follow(bool); | |||
| bool elementwise_stra_follow() const { return elementwise_stra_follow_; } | |||
| void set_triangle_strategy_overwrite(bool); | |||
| bool triangle_strategy_overwrite() const { return triangle_strategy_overwrite_; } | |||
| void set_triangle_star_strategy_overwrite(bool); | |||
| bool triangle_star_strategy_overwrite() const { return triangle_star_strategy_overwrite_; } | |||
| void set_run_phase(int32_t); | |||
| int32_t run_phase() const { return run_phase_; } | |||
| @@ -172,9 +172,9 @@ class CostModelContext { | |||
| // MULTI_SUBGRAPHS | |||
| bool is_multi_subgraphs_; | |||
| // In the recovery phase of DP algorithm, when encountering triangle structure, | |||
| // In the recovery phase of DP algorithm, when encountering triangle structure and star structure, | |||
| // whether overwrite the right-node strategy | |||
| bool triangle_strategy_overwrite_; | |||
| bool triangle_star_strategy_overwrite_; | |||
| int32_t run_phase_; // 0: 'training', 1: 'inference' | |||
| @@ -25,13 +25,13 @@ | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "utils/ms_utils.h" | |||
| #include "frontend/parallel/device.h" | |||
| #include "frontend/parallel/device_matrix.h" | |||
| #include "frontend/parallel/group_manager.h" | |||
| #include "frontend/parallel/status.h" | |||
| #include "frontend/parallel/strategy.h" | |||
| #include "utils/convert_utils.h" | |||
| #include "utils/ms_utils.h" | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| @@ -17,8 +17,8 @@ | |||
| #include "frontend/parallel/group_manager.h" | |||
| #include <algorithm> | |||
| #include <vector> | |||
| #include "frontend/parallel/device_manager.h" | |||
| #include "backend/session/executor_manager.h" | |||
| #include "frontend/parallel/device_manager.h" | |||
| #include "utils/comm_manager.h" | |||
| #include "utils/ms_context.h" | |||
| @@ -26,10 +26,8 @@ | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include <unordered_set> | |||
| #include "ir/anf.h" | |||
| #include "ir/param_info.h" | |||
| #include "ir/tensor.h" | |||
| #include "frontend/optimizer/opt.h" | |||
| #include "frontend/optimizer/optimizer.h" | |||
| #include "frontend/parallel/auto_parallel/dp_algo_costmodel.h" | |||
| @@ -39,11 +37,14 @@ | |||
| #include "frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h" | |||
| #include "frontend/parallel/auto_parallel/rec_core/rec_partition.h" | |||
| #include "frontend/parallel/context.h" | |||
| #include "frontend/parallel/ops_info/tmp_identity_info.h" | |||
| #include "frontend/parallel/ops_info/reshape_info.h" | |||
| #include "frontend/parallel/graph_util/node_info.h" | |||
| #include "frontend/parallel/ops_info/reshape_info.h" | |||
| #include "frontend/parallel/ops_info/tmp_identity_info.h" | |||
| #include "frontend/parallel/step_parallel.h" | |||
| #include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" | |||
| #include "ir/anf.h" | |||
| #include "ir/param_info.h" | |||
| #include "ir/tensor.h" | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| @@ -21,9 +21,9 @@ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "ir/anf.h" | |||
| #include "frontend/optimizer/opt.h" | |||
| #include "frontend/parallel/status.h" | |||
| #include "ir/anf.h" | |||
| #include "pipeline/jit/pipeline.h" | |||
| namespace mindspore { | |||
| @@ -27,8 +27,6 @@ | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include "ir/tensor.h" | |||
| #include "ir/param_info.h" | |||
| #include "frontend/operator/ops.h" | |||
| #include "frontend/optimizer/optimizer.h" | |||
| #include "frontend/parallel/auto_parallel/graph_costmodel.h" | |||
| @@ -41,9 +39,11 @@ | |||
| #include "frontend/parallel/node_check.h" | |||
| #include "frontend/parallel/ops_info/matmul_info.h" | |||
| #include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" | |||
| #include "ir/param_info.h" | |||
| #include "ir/tensor.h" | |||
| #include "utils/comm_manager.h" | |||
| #include "utils/symbolic.h" | |||
| #include "utils/ms_context.h" | |||
| #include "utils/symbolic.h" | |||
| using mindspore::tensor::Tensor; | |||
| @@ -21,10 +21,10 @@ | |||
| #include <map> | |||
| #include <memory> | |||
| #include <set> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include <set> | |||
| #include "frontend/optimizer/opt.h" | |||
| #include "frontend/parallel/strategy.h" | |||
| @@ -23,8 +23,8 @@ | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "frontend/parallel/status.h" | |||
| #include "frontend/parallel/device_matrix.h" | |||
| #include "frontend/parallel/status.h" | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| @@ -0,0 +1,134 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| from mindspore.common.api import _executor | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import operations as P | |||
| from mindspore.parallel._utils import _reset_op_id as reset_op_id | |||
| from mindspore import context, Tensor, Parameter | |||
| from mindspore.parallel import set_algo_parameters | |||
| from tests.ut.python.ops.test_math_ops import VirtualLoss | |||
| grad_all = C.GradOperation(get_all=True) | |||
| class NetWithLoss(nn.Cell): | |||
| def __init__(self, network): | |||
| super(NetWithLoss, self).__init__() | |||
| self.loss = VirtualLoss() | |||
| self.network = network | |||
| def construct(self, x): | |||
| predict = self.network(x) | |||
| return self.loss(predict) | |||
| class GradWarp(nn.Cell): | |||
| def __init__(self, network): | |||
| super(GradWarp, self).__init__() | |||
| self.network = network | |||
| def construct(self, x): | |||
| return grad_all(self.network)(x) | |||
| class Net(nn.Cell): | |||
| def __init__(self, strategy_dict=None): | |||
| super(Net, self).__init__() | |||
| self.mul1 = P.Mul() | |||
| self.mul2 = P.Mul() | |||
| self.mul3 = P.Mul() | |||
| self.mul4 = P.Mul() | |||
| self.relu1 = P.ReLU() | |||
| self.relu2 = P.ReLU() | |||
| self.ba1 = P.BiasAdd() | |||
| self.add = P.TensorAdd() | |||
| self.weight = Parameter(Tensor(np.ones([128, 1000]), dtype=ms.float32), name="weight") | |||
| self.bias = Parameter(Tensor(np.ones([1000]), dtype=ms.float32), name="bias") | |||
| if strategy_dict is not None: | |||
| self.mul1.shard(strategy_dict["mul1"]) | |||
| self.mul2.shard(strategy_dict["mul2"]) | |||
| self.relu1.shard(strategy_dict["relu1"]) | |||
| self.relu2.shard(strategy_dict["relu2"]) | |||
| self.ba1.shard(strategy_dict["bias_add"]) | |||
| self.add.shard(strategy_dict["add"]) | |||
| def construct(self, inputs): | |||
| x = self.mul1(inputs, self.weight) | |||
| y = self.relu1(x) | |||
| y = self.mul2(y, self.weight) | |||
| z = self.mul3(x, self.weight) | |||
| z = self.ba1(z, self.bias) | |||
| x = self.add(y, z) | |||
| x = self.mul4(x, self.weight) | |||
| x = self.relu2(x) | |||
| return x | |||
| def test_star_strategy_consistency1(): | |||
| size = 8 | |||
| context.set_auto_parallel_context(device_num=size, global_rank=0) | |||
| set_algo_parameters(fully_use_devices=False) | |||
| x = Tensor(np.ones([128, 1000]), dtype=ms.float32) | |||
| strategy_dict = {"mul1": ((2, 4), (2, 4)), "mul2": None, "relu1": ((4, 1),), "bias_add": ((8, 1), (1,)), | |||
| "relu2": ((2, 2),), "add": ((1, 8), (1, 8))} | |||
| net = NetWithLoss(Net(strategy_dict)) | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||
| net.set_auto_parallel() | |||
| reset_op_id() | |||
| _executor.compile(net, x, phase='train') | |||
| def test_star_strategy_consistency2(): | |||
| size = 8 | |||
| context.set_auto_parallel_context(device_num=size, global_rank=0) | |||
| set_algo_parameters(fully_use_devices=False) | |||
| x = Tensor(np.ones([128, 1000]), dtype=ms.float32) | |||
| strategy_dict = {"mul1": None, "mul2": ((1, 4), (1, 4)), "relu1": ((2, 1),), "bias_add": ((4, 2), (2,)), | |||
| "relu2": ((2, 2),), "add": ((8, 1), (8, 1))} | |||
| net = NetWithLoss(Net(strategy_dict)) | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||
| net.set_auto_parallel() | |||
| reset_op_id() | |||
| _executor.compile(net, x, phase='train') | |||
| def test_star_strategy_consistency3(): | |||
| size = 8 | |||
| context.set_auto_parallel_context(device_num=size, global_rank=0) | |||
| set_algo_parameters(fully_use_devices=False) | |||
| x = Tensor(np.ones([128, 1000]), dtype=ms.float32) | |||
| strategy_dict = {"mul1": None, "mul2": None, "relu1": ((8, 1),), "bias_add": ((1, 4), (4,)), | |||
| "relu2": ((4, 1),), "add": ((2, 2), (2, 2))} | |||
| net = NetWithLoss(Net(strategy_dict)) | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||
| net.set_auto_parallel() | |||
| reset_op_id() | |||
| _executor.compile(net, x, phase='train') | |||
| def test_star_strategy_consistency4(): | |||
| size = 8 | |||
| context.set_auto_parallel_context(device_num=size, global_rank=0) | |||
| set_algo_parameters(fully_use_devices=False) | |||
| x = Tensor(np.ones([128, 1000]), dtype=ms.float32) | |||
| strategy_dict = {"mul1": ((1, 8), (1, 8)), "mul2": ((1, 4), (1, 4)), "relu1": None, "bias_add": None, | |||
| "relu2": None, "add": None} | |||
| net = NetWithLoss(Net(strategy_dict)) | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||
| net.set_auto_parallel() | |||
| reset_op_id() | |||
| with pytest.raises(RuntimeError): | |||
| _executor.compile(net, x, phase='train') | |||