| @@ -44,6 +44,7 @@ namespace parallel { | |||
| #define DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE 16 | |||
| #define DEFAULT_FULLY_USE_DEVICES true | |||
| #define DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW false | |||
| #define DEFAULT_IS_MULTI_SUBGRAPHS false | |||
| class CostGraph; | |||
| using CostGraphPtr = std::shared_ptr<CostGraph>; | |||
| @@ -46,6 +46,7 @@ void CostModelContext::ResetCostModel() { | |||
| costmodel_communi_threshold_ = DEFAULT_COST_MODEL_COMMUNI_THRESHOLD; | |||
| costmodel_communi_const_ = DEFAULT_COST_MODEL_COMMUNI_CONST; | |||
| costmodel_communi_bias_ = DEFAULT_COST_MODEL_COMMUNI_BIAS; | |||
| is_multi_subgraphs_ = DEFAULT_IS_MULTI_SUBGRAPHS; | |||
| costmodel_allreduce_fusion_algorithm_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALGORITHM; | |||
| costmodel_allreduce_fusion_times_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TIMES; | |||
| costmodel_allreduce_fusion_tail_percent_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_PERCENT; | |||
| @@ -84,6 +85,7 @@ void CostModelContext::set_costmodel_communi_const(double cm_communi_const) { | |||
| void CostModelContext::set_costmodel_communi_bias(double cm_communi_bias) { costmodel_communi_bias_ = cm_communi_bias; } | |||
| void CostModelContext::set_multi_subgraphs(bool multi_graphs) { is_multi_subgraphs_ = multi_graphs; } | |||
| void CostModelContext::set_costmodel_allreduce_fusion_algorithm(int32_t algorithm) { | |||
| costmodel_allreduce_fusion_algorithm_ = algorithm; | |||
| } | |||
| @@ -67,6 +67,9 @@ class CostModelContext { | |||
| void set_costmodel_communi_bias(double); | |||
| double costmodel_communi_bias() const { return costmodel_communi_bias_; } | |||
| void set_multi_subgraphs(bool); | |||
| bool is_multi_subgraphs() const { return is_multi_subgraphs_; } | |||
| void set_costmodel_allreduce_fusion_algorithm(int32_t); | |||
| int32_t costmodel_allreduce_fusion_algorithm() const { return costmodel_allreduce_fusion_algorithm_; } | |||
| @@ -138,6 +141,8 @@ class CostModelContext { | |||
| // COST_MODEL_COMMUNI_BIAS | |||
| double costmodel_communi_bias_; | |||
| bool is_multi_subgraphs_; | |||
| int32_t costmodel_allreduce_fusion_algorithm_; | |||
| int32_t costmodel_allreduce_fusion_times_; | |||
| @@ -426,13 +426,13 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & | |||
| return operator_info; | |||
| } | |||
| Status ConstructCostGraphNodes(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) { | |||
| // Using CNode's UniqueIds to construct nodes | |||
| Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) { | |||
| MS_LOG(INFO) << "Constructing nodes for cost graph begins."; | |||
| entire_costgraph = std::make_shared<CostGraph>(); | |||
| entire_costgraph->SetDeviceMemoryAndCostParameter(); | |||
| bool new_operator = true, first_operator = true; | |||
| std::string first_operator_cnode; | |||
| size_t current_op_index = 0; | |||
| // The map from CNode's UniqueId to its operatorInfo | |||
| std::map<std::string, OperatorInfoPtr> from_cnode_to_info; | |||
| // Step 1 | |||
| for (auto &node : all_nodes) { | |||
| @@ -449,12 +449,8 @@ Status ConstructCostGraphNodes(const std::vector<AnfNodePtr> &all_nodes, const F | |||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| // When visiting the second subgraph, use the corresponding operatorInfo which already created | |||
| bool modify_new_operator = (new_operator) && (!first_operator) && (cnode->UniqueId() == first_operator_cnode); | |||
| if (modify_new_operator) { | |||
| new_operator = false; | |||
| } | |||
| if (new_operator) { | |||
| auto search_cnode = from_cnode_to_info.find(cnode->UniqueId()); | |||
| if (search_cnode == from_cnode_to_info.end()) { | |||
| auto operator_info = CreateTheOperatorInfo(prim, cnode); | |||
| if (operator_info == nullptr) { | |||
| return FAILED; | |||
| @@ -465,14 +461,67 @@ Status ConstructCostGraphNodes(const std::vector<AnfNodePtr> &all_nodes, const F | |||
| entire_costgraph->AddOperator(operator_info); | |||
| (void)cnode->set_operator_info(operator_info); | |||
| if (first_operator) { | |||
| first_operator_cnode = cnode->UniqueId(); | |||
| first_operator = false; | |||
| MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() | |||
| << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() | |||
| << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); | |||
| (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy(), operator_info)); | |||
| // Needed by rec_parser | |||
| entire_costgraph->add_inputs_tensor_name(inputs_tensor_name); | |||
| } else { | |||
| // Two CNODEs' UniqueIds should not be equal | |||
| MS_LOG(EXCEPTION) << "The CNode with UniqueId: " << cnode->UniqueId() | |||
| << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() | |||
| << " is set OperatorInfo: " << search_cnode->second->name() << ", Primitive: " << prim->name(); | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "Constructing nodes for cost graph ends."; | |||
| return SUCCESS; | |||
| } | |||
| // Using CNode's UniqueIdThroughCopys to construct nodes | |||
| Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) { | |||
| MS_LOG(INFO) << "Constructing nodes for cost graph begins."; | |||
| entire_costgraph = std::make_shared<CostGraph>(); | |||
| entire_costgraph->SetDeviceMemoryAndCostParameter(); | |||
| // The map from CNode's UniqueIdThroughCopy to its operatorInfo | |||
| std::map<std::string, OperatorInfoPtr> from_cnode_to_info; | |||
| for (auto &node : all_nodes) { | |||
| // NOTE: we only care about splittable Primitive operators | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| bool bool_result = (cnode == nullptr) || (!IsValueNode<Primitive>(cnode->input(0))); | |||
| if (bool_result) { | |||
| continue; | |||
| } | |||
| ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | |||
| if (!IsAutoParallelCareNode(cnode)) { | |||
| continue; | |||
| } | |||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); | |||
| // Find the operatorInfo if it exists | |||
| auto search_cnode = from_cnode_to_info.find(cnode->UniqueIdThroughCopy()); | |||
| if (search_cnode == from_cnode_to_info.end()) { | |||
| // In this case, the corresponding OperatorInfo is not created, create the new one. | |||
| auto operator_info = CreateTheOperatorInfo(prim, cnode); | |||
| if (operator_info == nullptr) { | |||
| return FAILED; | |||
| } | |||
| // Needed by rec_parser | |||
| operator_info->set_type(prim->name()); | |||
| std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode); | |||
| entire_costgraph->AddOperator(operator_info); | |||
| (void)cnode->set_operator_info(operator_info); | |||
| MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() | |||
| << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() | |||
| << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); | |||
| (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy(), operator_info)); | |||
| // Needed by rec_parser | |||
| entire_costgraph->add_inputs_tensor_name(inputs_tensor_name); | |||
| } else { | |||
| auto current_op_ptr = entire_costgraph->FindOperatorByIndex(current_op_index); | |||
| auto current_op_ptr = search_cnode->second; | |||
| if (current_op_ptr == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Find " << prim->name() << " from CostGraph failed."; | |||
| } else { | |||
| @@ -484,14 +533,12 @@ Status ConstructCostGraphNodes(const std::vector<AnfNodePtr> &all_nodes, const F | |||
| << " does not match the Prim: " << prim->name(); | |||
| } | |||
| (void)cnode->set_operator_info(current_op_ptr); | |||
| current_op_index++; | |||
| MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() | |||
| << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() | |||
| << " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name(); | |||
| } | |||
| } | |||
| } | |||
| if ((!new_operator) && (current_op_index != entire_costgraph->GetOperators().size())) { | |||
| MS_LOG(EXCEPTION) << "The second subgraph's operator number: " << current_op_index | |||
| << " does not match the first ones: " << entire_costgraph->GetOperators().size(); | |||
| } | |||
| MS_LOG(INFO) << "Constructing nodes for cost graph ends."; | |||
| return SUCCESS; | |||
| @@ -844,11 +891,20 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu | |||
| // OUTPUT: the determined strategy for each operator. | |||
| // Step 1 | |||
| if (ConstructCostGraphNodes(all_nodes, root) == SUCCESS) { | |||
| MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are " << entire_costgraph->GetOperators().size() | |||
| << " operators."; | |||
| if (CostModelContext::GetInstance()->is_multi_subgraphs()) { | |||
| if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) { | |||
| MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are " | |||
| << entire_costgraph->GetOperators().size() << " operators."; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed."; | |||
| } | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed."; | |||
| if (ConstructCostGraphNodesByUniqueId(all_nodes, root) == SUCCESS) { | |||
| MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are " | |||
| << entire_costgraph->GetOperators().size() << " operators."; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed."; | |||
| } | |||
| } | |||
| // Step 2 | |||
| @@ -916,7 +972,7 @@ std::vector<std::vector<std::string>> RecInputTensorNames(const std::map<std::st | |||
| } | |||
| Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) { | |||
| if (ConstructCostGraphNodes(all_nodes, root) == SUCCESS) { | |||
| if (ConstructCostGraphNodesByUniqueId(all_nodes, root) == SUCCESS) { | |||
| MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are " << entire_costgraph->GetOperators().size() | |||
| << " operators."; | |||
| } else { | |||
| @@ -43,7 +43,9 @@ std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node); | |||
| std::vector<TypePtr> ExtractOutputTypeByNode(const CNodePtr &node); | |||
| Status ConstructCostGraphNodes(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root); | |||
| Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root); | |||
| Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root); | |||
| void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes); | |||
| @@ -24,6 +24,7 @@ | |||
| #include <functional> | |||
| #include "ir/func_graph_cloner.h" | |||
| #include "parallel/costmodel_context.h" | |||
| #include "pipeline/pass.h" | |||
| #include "pipeline/parse/parse_base.h" | |||
| #include "pipeline/parse/data_converter.h" | |||
| @@ -341,7 +342,10 @@ static std::vector<ActionItem> CommonPipeline() { | |||
| // Resolve the python func | |||
| actions.emplace_back(std::make_pair("symbol_resolve", SymbolResolveAction)); | |||
| actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs)); | |||
| auto multi_graphs = parallel::CostModelContext::GetInstance()->is_multi_subgraphs(); | |||
| if (!multi_graphs) { | |||
| actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs)); | |||
| } | |||
| actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction)); | |||
| // Evaluate type and shape, and specialize | |||
| actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction)); | |||
| @@ -222,6 +222,8 @@ PYBIND11_MODULE(_c_expression, m) { | |||
| "Set the parameter cost_model_communi_bias of the DP algorithm.") | |||
| .def("get_costmodel_communi_bias", &CostModelContext::costmodel_communi_bias, | |||
| "Get the parameter cost_model_communi_bias of the DP algorithm.") | |||
| .def("set_multi_subgraphs", &CostModelContext::set_multi_subgraphs, "Set the parameter is_multi_subgraphs.") | |||
| .def("get_multi_subgraphs", &CostModelContext::is_multi_subgraphs, "Get the parameter is_multi_subgraphs.") | |||
| .def("set_costmodel_allreduce_fusion_algorithm", &CostModelContext::set_costmodel_allreduce_fusion_algorithm, | |||
| "Set the parameter gradient AllReduce fusion algorithm.") | |||
| .def("get_costmodel_allreduce_fusion_algorithm", &CostModelContext::costmodel_allreduce_fusion_algorithm, | |||
| @@ -214,6 +214,31 @@ class _CostModelContext: | |||
| raise ValueError("Context handle is none in context!!!") | |||
| return self._context_handle.get_costmodel_communi_bias() | |||
| def set_multi_subgraphs(self, multi_subgraph): | |||
| """ | |||
| Set the flag of ANF graph containing multiple subgraphs. | |||
| Args: | |||
| multi_subgraph (bool): A parameter used in marking the multi-subgraphs flag. | |||
| Raises: | |||
| ValueError: If context handle is none. | |||
| """ | |||
| if self._context_handle is None: | |||
| raise ValueError("Context handle is none in context!!!") | |||
| self._context_handle.set_multi_subgraphs(multi_subgraph) | |||
| def get_multi_subgraphs(self): | |||
| """ | |||
| Get the flag of ANF graph containing multiple subgraphs. | |||
| Raises: | |||
| ValueError: If context handle is none. | |||
| """ | |||
| if self._context_handle is None: | |||
| raise ValueError("Context handle is none in context!!!") | |||
| return self._context_handle.get_multi_subgraphs() | |||
| def set_costmodel_allreduce_fusion_algorithm(self, algorithm): | |||
| """ | |||
| Set costmodel allreduce fusion algorithm. | |||
| @@ -427,6 +452,7 @@ set_cost_model_context_func_map = { | |||
| "costmodel_communi_threshold": cost_model_context().set_costmodel_communi_threshold, | |||
| "costmodel_communi_const": cost_model_context().set_costmodel_communi_const, | |||
| "costmodel_communi_bias": cost_model_context().set_costmodel_communi_bias, | |||
| "multi_subgraphs": cost_model_context().set_multi_subgraphs, | |||
| "costmodel_allreduce_fusion_algorithm": cost_model_context().set_costmodel_allreduce_fusion_algorithm, | |||
| "costmodel_allreduce_fusion_times": cost_model_context().set_costmodel_allreduce_fusion_times, | |||
| "costmodel_allreduce_fusion_tail_percent": cost_model_context().set_costmodel_allreduce_fusion_tail_percent, | |||
| @@ -447,6 +473,7 @@ get_cost_model_context_func_map = { | |||
| "costmodel_communi_threshold": cost_model_context().get_costmodel_communi_threshold, | |||
| "costmodel_communi_const": cost_model_context().get_costmodel_communi_const, | |||
| "costmodel_communi_bias": cost_model_context().get_costmodel_communi_bias, | |||
| "multi_subgraphs": cost_model_context().get_multi_subgraphs(), | |||
| "costmodel_allreduce_fusion_algorithm": cost_model_context().get_costmodel_allreduce_fusion_algorithm, | |||
| "costmodel_allreduce_fusion_times": cost_model_context().get_costmodel_allreduce_fusion_times, | |||
| "costmodel_allreduce_fusion_tail_percent": cost_model_context().get_costmodel_allreduce_fusion_tail_percent, | |||
| @@ -461,6 +488,7 @@ get_cost_model_context_func_map = { | |||
| @args_type_check(device_memory_capacity=float, costmodel_alpha=float, costmodel_beta=float, costmodel_gamma=float, | |||
| costmodel_communi_threshold=float, costmodel_communi_const=float, costmodel_communi_bias=float, | |||
| multi_subgraphs=bool, | |||
| costmodel_allreduce_fusion_algorithm=int, costmodel_allreduce_fusion_times=int, | |||
| costmodel_allreduce_fusion_tail_percent=float, costmodel_allreduce_fusion_tail_time=float, | |||
| costmodel_allreduce_fusion_allreduce_inherent_time=float, | |||
| @@ -481,6 +509,7 @@ def set_cost_model_context(**kwargs): | |||
| costmodel_communi_threshold (float): A parameter used in adjusting communication calculation for practice. | |||
| costmodel_communi_const (float): A parameter used in adjusting communication calculation for practice. | |||
| costmodel_communi_bias (float): A parameter used in adjusting communication calculation for practice. | |||
| multi_subgraphs (bool): A parameter used in marking the flag of ANF graph containing multiple subgraphs. | |||
| costmodel_allreduce_fusion_algorithm (int): The allreduce fusion algorithm. | |||
| 0: bypass allreduce fusion; | |||
| 1: only use backward computation time to group allreduce; | |||
| @@ -0,0 +1,101 @@ | |||
| import numpy as np | |||
| from mindspore import context | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| from mindspore.nn.optim import Adam, FTRL | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import functional as F | |||
| from mindspore import Tensor, Parameter, ParameterTuple | |||
| from mindspore.ops import composite as C | |||
| from mindspore.parallel import _cost_model_context as cost_model_context | |||
| from mindspore.common.api import _executor | |||
| from mindspore.parallel import set_algo_parameters, get_algo_parameters, reset_algo_parameters | |||
| from mindspore.parallel._utils import _reset_op_id as reset_op_id | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.mul = P.Mul() | |||
| self.relu = P.ReLU() | |||
| self.wd = Parameter(Tensor(np.ones([8, 8, 8, 8]).astype(np.float32)), name="wide") | |||
| self.wt = Parameter(Tensor(np.ones([8, 8, 8, 8]).astype(np.float32)), name="l") | |||
| def construct(self, x): | |||
| out = self.mul(x, self.wd) | |||
| out = self.mul(out, self.wt) | |||
| out = self.relu(out) | |||
| return out | |||
| class NetWithLoss(nn.Cell): | |||
| def __init__(self, network): | |||
| super(NetWithLoss, self).__init__() | |||
| self.sum = P.ReduceSum() | |||
| self.mean = P.ReduceMean() | |||
| self.net = network | |||
| def construct(self, x): | |||
| predict = self.net(x) | |||
| loss1 = self.sum(predict, -1) | |||
| loss2 = self.mean(predict, -1) | |||
| return loss1, loss2 | |||
| class IthOutputCell(nn.Cell): | |||
| def __init__(self, network, output_index): | |||
| super(IthOutputCell, self).__init__() | |||
| self.network = network | |||
| self.output_index = output_index | |||
| def construct(self, x): | |||
| predict = self.network(x)[self.output_index] | |||
| return predict | |||
| class TrainStepWarp(nn.Cell): | |||
| def __init__(self, network, sens=1000.0): | |||
| super(TrainStepWarp, self).__init__() | |||
| self.network = network | |||
| self.network.set_train() | |||
| self.trainable_params = network.trainable_params() | |||
| weights_w = [] | |||
| weights_d = [] | |||
| for params in self.trainable_params: | |||
| weights_w.append(params) | |||
| weights_d.append(params) | |||
| self.weights_w = ParameterTuple(weights_w) | |||
| self.weights_d = ParameterTuple(weights_d) | |||
| self.optimizer_w = FTRL(learning_rate=1e-2, params=self.weights_w, l1=1e-8, | |||
| l2=1e-8, initial_accum=1.0) | |||
| self.optimizer_d = Adam(self.weights_d, learning_rate=3.5e-4, eps=1e-8, | |||
| loss_scale=sens) | |||
| self.hyper_map = C.HyperMap() | |||
| self.grad_w = C.GradOperation('grad_w', get_by_list=True, sens_param=True) | |||
| self.grad_d = C.GradOperation('grad_d', get_by_list=True, sens_param=True) | |||
| self.sens = sens | |||
| self.loss_net_w = IthOutputCell(network, output_index=0) | |||
| self.loss_net_d = IthOutputCell(network, output_index=1) | |||
| def construct(self, x): | |||
| weights_w = self.weights_w | |||
| weights_d = self.weights_d | |||
| loss_w, loss_d = self.network(x) | |||
| sens_w = P.Fill()(P.DType()(loss_w), P.Shape()(loss_w), self.sens) | |||
| sens_d = P.Fill()(P.DType()(loss_d), P.Shape()(loss_d), self.sens) | |||
| grads_w = self.grad_w(self.loss_net_w, weights_w)(x, sens_w) | |||
| grads_d = self.grad_d(self.loss_net_d, weights_d)(x, sens_d) | |||
| return F.depend(loss_w, self.optimizer_w(grads_w)), F.depend(loss_d, self.optimizer_d(grads_d)) | |||
| def test_double_subgraphs(): | |||
| cost_model_context.set_cost_model_context(multi_subgraphs=True) | |||
| context.set_context(save_graphs=True) | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0) | |||
| net = TrainStepWarp(NetWithLoss(Net())) | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||
| x = Tensor(np.ones([8, 8, 8, 8]), dtype=ms.float32) | |||
| reset_op_id() | |||
| _executor.compile(net, x, phase='train') | |||
| strategies = _executor._get_strategy(net) | |||
| expected_strategies = {'Default/network-NetWithLoss/ReduceMean-op0': [[8, 1, 1, 1]], | |||
| 'Default/network-NetWithLoss/net-Net/ReLU-op1': [[8, 1, 1, 1]], | |||
| 'Default/network-NetWithLoss/net-Net/Mul-op2': [[8, 1, 1, 1], [8, 1, 1, 1]], | |||
| 'Default/network-NetWithLoss/net-Net/Mul-op3': [[8, 1, 1, 1], [8, 1, 1, 1]], | |||
| 'Default/network-NetWithLoss/ReduceSum-op4': [[8, 1, 1, 1]]} | |||
| assert strategies == expected_strategies | |||
| @@ -0,0 +1,70 @@ | |||
| import numpy as np | |||
| from mindspore import context | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| from mindspore.ops import operations as P | |||
| from mindspore import Tensor | |||
| from mindspore.common.api import _executor | |||
| from tests.ut.python.ops.test_math_ops import VirtualLoss | |||
| from mindspore.parallel import set_algo_parameters | |||
| from mindspore.parallel._utils import _reset_op_id as reset_op_id | |||
| import re | |||
| class NetWithLoss(nn.Cell): | |||
| def __init__(self, network): | |||
| super(NetWithLoss, self).__init__() | |||
| self.loss = VirtualLoss() | |||
| self.network = network | |||
| def construct(self, x): | |||
| predict = self.network(x) | |||
| return self.loss(predict) | |||
| class Blockcell(nn.Cell): | |||
| def __init__(self): | |||
| super(Blockcell, self).__init__() | |||
| self.bn = nn.BatchNorm2d(64, momentum=0.9) | |||
| def construct(self, x): | |||
| out = self.bn(x) | |||
| return out | |||
| def getBlock(): | |||
| return Blockcell() | |||
| def test_two_bn(): | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.block1 = getBlock() | |||
| self.block2 = getBlock() | |||
| self.relu = P.ReLU() | |||
| self.add = P.TensorAdd() | |||
| self.bias = Tensor(np.ones([64, 64]), dtype=ms.float32) | |||
| def construct(self, x): | |||
| out = self.block1(x) | |||
| out = self.relu(out) | |||
| out = self.add(out, self.bias) | |||
| out = self.block2(out) | |||
| return out | |||
| net = NetWithLoss(Net()) | |||
| x = Tensor(np.ones([64, 64]), dtype=ms.float32) | |||
| context.set_context(save_graphs=True) | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0) | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||
| set_algo_parameters(elementwise_op_strategy_follow=True) | |||
| reset_op_id() | |||
| _executor.compile(net, x, phase='train') | |||
| strategies = _executor._get_strategy(net) | |||
| assert len(strategies) == 4 | |||
| for (k, v) in strategies.items(): | |||
| if re.search('BatchNorm-op', k) is not None: | |||
| assert v == [[8, 1], [1], [1], [1], [1]] | |||
| elif re.search('TensorAdd-op', k) is not None: | |||
| assert v == [[8, 1], [8, 1]] | |||
| elif re.search('ReLU-op', k) is not None: | |||
| assert v == [[8, 1]] | |||