diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.cc index 59be491852..4906034964 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.cc +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.cc @@ -28,6 +28,10 @@ namespace mindspore { namespace parallel { Status Edge::InitEdgeCost() { bool has_available_cost = false; + pre_op_output_.clear(); + next_op_input_.clear(); + cost_map_.clear(); + for (auto &swc : prev_op_->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(swc); pre_op_output_.emplace_back(std::make_pair(swc->strategy_ptr, swc->outputs_ptr)); @@ -332,5 +336,8 @@ void Edge::SetCostMapAndInputOutput(std::map &cost_map) next_op_input_.emplace_back(std::pair>(key_pair.second, {})); } } + +// Return true if there are available strategies in this edge. +bool Edge::CheckStrategyCostPossibility() { return !cost_map_.empty(); } } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.h index b636048723..0a4e2967ec 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.h +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.h @@ -140,6 +140,8 @@ class Edge { // In the inference phase, Status CalculateMemoryCostForInference(); void mark_output_critical() { is_output_critical_ = 1; } + // Whether there exists any available strategy in 'cost_map_' + bool CheckStrategyCostPossibility(); private: std::string edge_name_; diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc index abf4ed8eef..66b50e3a4b 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc @@ -41,6 +41,8 @@ bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS; int32_t RUN_PHASE = DEFAULT_RUN_PHASE; bool TRIANGLE_STAR_STRATEGY_OVERWRITE = DEFAULT_TRIANGLE_STAR_STRATEGY_OVERWRITE; +bool DP_ALGO_ENABLE_APPROX = DEFAULT_DP_ALGO_ENABLE_APPROX; +double DP_ALGO_APPROX_EPSILON = DEFAULT_DP_ALGO_APPROX_EPSILON; void CostGraph::SetDeviceMemoryAndCostParameter() { MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance()); @@ -170,6 +172,21 @@ void CostGraph::SetDeviceMemoryAndCostParameter() { } RUN_PHASE = phase; MS_LOG(INFO) << "run_phase: " << RUN_PHASE << "."; + + auto enable_approx = CostModelContext::GetInstance()->dp_algo_enable_approxi(); + DP_ALGO_ENABLE_APPROX = enable_approx; + if (enable_approx) { + MS_LOG(INFO) << "dp_algo_enable_approx: true."; + } else { + MS_LOG(INFO) << "dp_algo_enable_approx: false."; + } + + auto epsilon = CostModelContext::GetInstance()->dp_algo_approxi_epsilon(); + if (epsilon <= 0 || epsilon > 1) { + MS_LOG(EXCEPTION) << "'epsilon' must be in (0, 1]"; + } + DP_ALGO_APPROX_EPSILON = epsilon; + MS_LOG(INFO) << "epsilon: " << epsilon << "."; } void CostGraph::RemoveOperator(const OperatorInfoPtr &op) { @@ -1901,5 +1918,31 @@ Status CostGraph::CalculateMemoryCost() { } return SUCCESS; } + +void CostGraph::CheckApproximateCostGraphEdges() { + auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi(); + if (!approximation) { + return; + } + for (auto &s_edge : edges_) { + auto &edges_vector = s_edge.second; + for (auto &edge_ptr : edges_vector) { + MS_EXCEPTION_IF_NULL(edge_ptr); + if (edge_ptr->CheckStrategyCostPossibility()) { + continue; + } + MS_LOG(INFO) << "Checking StrategyCost for edge: " << edge_ptr->edge_name() + << " impossible, re-initing the operators and edges"; + auto prev_op = edge_ptr->prev_operator(); + MS_EXCEPTION_IF_NULL(prev_op); + auto next_op = edge_ptr->next_operator(); + MS_EXCEPTION_IF_NULL(next_op); + // Check the 'prev_op' + prev_op->ExactStrategiesAndRelatedEdges(); + // Check the 'next_op' + next_op->ExactStrategiesAndRelatedEdges(); + } + } +} } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h index 34e56361fc..de95a3ae92 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h @@ -45,6 +45,8 @@ extern size_t TENSOR_SLICE_ALIGNMENT_SIZE; extern bool FULLY_USE_DEVICES; extern bool ELEMENTWISE_OP_STRA_FOLLOW; extern bool MULTI_SUBGRAPHS; +extern bool DP_ALGO_ENABLE_APPROX; +extern double DP_ALGO_APPROX_EPSILON; extern int32_t RUN_PHASE; extern bool TRIANGLE_STAR_STRATEGY_OVERWRITE; @@ -193,6 +195,9 @@ class CostGraph { // When TmpIdentity is used by mulitple operators, the corresponding parameter's memory cost should be calculated only // once (instead of multiple times), this method is used to correct this. Status CorrectOpsMemoryCost(); + // When APPROXIMATION is enabled in the DP algorithm, some edges may have no valid strategies. + // This method is to re-init those edge involved operators. + void CheckApproximateCostGraphEdges(); // Needed by rec_parser void add_inputs_tensor_name(const std::vector &inputs_tensor_name) { inputs_tensor_name_list_.push_back(inputs_tensor_name); diff --git a/mindspore/ccsrc/frontend/parallel/costmodel_context.cc b/mindspore/ccsrc/frontend/parallel/costmodel_context.cc index 030d2201f1..3f23ac424a 100644 --- a/mindspore/ccsrc/frontend/parallel/costmodel_context.cc +++ b/mindspore/ccsrc/frontend/parallel/costmodel_context.cc @@ -65,6 +65,8 @@ void CostModelContext::ResetAlgoParameters() { fully_use_device_ = DEFAULT_FULLY_USE_DEVICES; elementwise_stra_follow_ = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; triangle_star_strategy_overwrite_ = DEFAULT_TRIANGLE_STAR_STRATEGY_OVERWRITE; + dp_algo_enable_approxi_ = DEFAULT_DP_ALGO_ENABLE_APPROX; + dp_algo_approxi_epsilon_ = DEFAULT_DP_ALGO_APPROX_EPSILON; } void CostModelContext::set_costmodel_context_for_device(const std::string &device_target) { @@ -73,6 +75,10 @@ void CostModelContext::set_costmodel_context_for_device(const std::string &devic } } +void CostModelContext::set_dp_algo_approxi_epsilon(double epsilon) { dp_algo_approxi_epsilon_ = epsilon; } + +void CostModelContext::set_dp_algo_enable_approxi(bool approxi) { dp_algo_enable_approxi_ = approxi; } + void CostModelContext::set_device_memory_capacity(double dm_capacity) { device_memory_capacity_ = dm_capacity; } void CostModelContext::set_costmodel_alpha(double cm_alpha) { costmodel_alpha_ = cm_alpha; } diff --git a/mindspore/ccsrc/frontend/parallel/costmodel_context.h b/mindspore/ccsrc/frontend/parallel/costmodel_context.h index 6da7a67ce7..04efe224e1 100644 --- a/mindspore/ccsrc/frontend/parallel/costmodel_context.h +++ b/mindspore/ccsrc/frontend/parallel/costmodel_context.h @@ -45,6 +45,8 @@ namespace parallel { #define TRAINING_PHASE 0 #define INFERENCE_PHASE 1 #define DEFAULT_TRIANGLE_STAR_STRATEGY_OVERWRITE true; +#define DEFAULT_DP_ALGO_ENABLE_APPROX false +#define DEFAULT_DP_ALGO_APPROX_EPSILON 0.1 class CostModelContext { public: @@ -141,6 +143,12 @@ class CostModelContext { void set_run_phase(int32_t); int32_t run_phase() const { return run_phase_; } + void set_dp_algo_approxi_epsilon(double); + double dp_algo_approxi_epsilon() const { return dp_algo_approxi_epsilon_; } + + void set_dp_algo_enable_approxi(bool); + bool dp_algo_enable_approxi() const { return dp_algo_enable_approxi_; } + private: CostModelContext(); static std::shared_ptr cm_context_inst_; @@ -176,6 +184,12 @@ class CostModelContext { // whether overwrite the right-node strategy bool triangle_star_strategy_overwrite_; + // Whether to enable APPROXIMATION in the DP algorithm. + bool dp_algo_enable_approxi_; + + // When APPROXIMATION is enabled in the DP algorithm, the 'epsilon' value used in the APPROXIMATION. + double dp_algo_approxi_epsilon_; + int32_t run_phase_; // 0: 'training', 1: 'inference' int32_t costmodel_allreduce_fusion_algorithm_; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc index 62283892c8..b519b41314 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc @@ -1130,6 +1130,67 @@ Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr &strategy) { return SUCCESS; } +// Keep at most (1.0 / epsilon) number of available strategies for each operator. +void OperatorInfo::ApproximateStrategies() { + auto enable_approxi = CostModelContext::GetInstance()->dp_algo_enable_approxi(); + if (!enable_approxi) { + return; + } + MS_LOG(INFO) << "Approximating strategy-cost for: " << name_; + auto epsilon = CostModelContext::GetInstance()->dp_algo_approxi_epsilon(); + auto target_num = static_cast(std::ceil(1.0 / epsilon)); + if (strategy_cost_.size() <= target_num) { + MS_LOG(INFO) << name_ << "'s strategy number is: " << strategy_cost_.size() + << ", no greater than target-num: " << target_num; + return; + } + std::vector> ret; + auto &origin_stra_cost = strategy_cost_; + auto alpha = CostModelContext::GetInstance()->costmodel_alpha(); + auto beta = CostModelContext::GetInstance()->costmodel_beta(); + // sort + std::sort( + origin_stra_cost.begin(), origin_stra_cost.end(), + [&alpha, &beta](const std::shared_ptr &s1, const std::shared_ptr &s2) { + if (alpha * s1->cost_list[0]->computation_cost_ + beta * s1->cost_list[0]->communication_with_partial_para_ < + alpha * s2->cost_list[0]->computation_cost_ + beta * s2->cost_list[0]->communication_with_partial_para_) { + return true; + } + return false; + }); + size_t step_length = origin_stra_cost.size() / target_num; + for (size_t i = 0; ret.size() < target_num && static_cast(i * step_length) < origin_stra_cost.size(); ++i) { + ret.push_back(origin_stra_cost[static_cast(i * step_length)]); + } + + strategy_cost_ = ret; + is_strategy_cost_exact_ = false; +} + +void OperatorInfo::ExactStrategiesAndRelatedEdges() { + if (is_strategy_cost_exact()) { + return; + } + ClearStrategyCost(); + if (GenerateStrategies(0) != SUCCESS) { + MS_LOG(EXCEPTION) << "Strategy search for Operator " << name() << " failed."; + return; + } + SetIsStrategyCostExactTrue(); + // re-init the previous edges + for (auto &prev_edge : prev_edges()) { + if (prev_edge->InitEdgeCost() != SUCCESS) { + MS_LOG(EXCEPTION) << "Edge: " << prev_edge->edge_name() << " cost init failed."; + } + } + // re-init the successive edges + for (auto &next_edge : succ_edges()) { + if (next_edge->InitEdgeCost() != SUCCESS) { + MS_LOG(EXCEPTION) << "Edge: " << next_edge->edge_name() << " cost init failed."; + } + } +} + int OperatorInfo::ComputeOpAndPrevEdgeParameterInvolved() { if (is_output_parameter_involve_ != -1) { return is_output_parameter_involve_; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h index d3fd2a645e..bec113b493 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h @@ -138,6 +138,13 @@ class OperatorInfo { } StrategyPtr selected_strategy() const { return selected_strategy_; } CostPtr selected_cost() const { return selected_cost_; } + // Approximate the list of available strategies + void ApproximateStrategies(); + // Make the list of available strategies exact and re-init the related edges incident to this operator + void ExactStrategiesAndRelatedEdges(); + bool is_strategy_cost_exact() { return is_strategy_cost_exact_; } + void SetIsStrategyCostExactTrue() { is_strategy_cost_exact_ = true; } + void ClearStrategyCost() { strategy_cost_.clear(); } void CheckSelectedStrategy(const StrategyPtr &); Status InitSelectedStrategy(const StrategyPtr &s_strategy) { return Init(s_strategy); } void set_input_value(const std::vector &input_value) { input_value_ = input_value; } @@ -263,6 +270,8 @@ class OperatorInfo { int32_t used_devices_ = -1; // the repeated_calc_num_ will be inserted to the last dimension of dev matrix in default bool repeated_num_in_dev_matrix_right_ = true; + // Whether the list of available strategies is exact or approximate + bool is_strategy_cost_exact_ = true; private: OperatorCostPtr operator_cost_; diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc index 12b4fa1a7c..e4a4ebc586 100644 --- a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc @@ -408,6 +408,12 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & MS_LOG(ERROR) << "Strategy search for Operator " << operator_info->name() << " failed."; return nullptr; } + // If 'approximation' is enabled, the 'strategy_cost' of each operator is approximated + auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi(); + if (approximation) { + operator_info->ApproximateStrategies(); + MS_LOG(INFO) << "Approximated StrategyCost for: " << operator_info->name(); + } } else { // In this case, the configured strategy should be extracted to help setting cost StrategyPtr strategyPtr; @@ -695,6 +701,11 @@ void ConstructCostGraphEdges(const std::vector &all_nodes) { } MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << node_op_info->name(); } + // If 'approximation' is enabled, the edges need to be checked have effective costs. + auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi(); + if (approximation) { + entire_costgraph->CheckApproximateCostGraphEdges(); + } MS_LOG(INFO) << "Constructing edges for cost graph ends."; } @@ -800,6 +811,11 @@ void AugmentCostGraph(const std::vector &all_nodes) { } std::shared_ptr edge_ptr = std::make_shared(edge_name, tmp_identity_ptr, target_op_info, 0, input_index - 1, false, true); + // If 'approximation' is enabled, the edges need to be checked have effective costs. + auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi(); + if (approximation) { + target_op_info->ExactStrategiesAndRelatedEdges(); + } if (edge_ptr->InitEdgeCost() != SUCCESS) { MS_LOG(EXCEPTION) << "Edge cost initialization failed"; diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc index b5cbabe53d..952e114c57 100644 --- a/mindspore/ccsrc/pipeline/jit/init.cc +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -246,6 +246,14 @@ PYBIND11_MODULE(_c_expression, m) { "Set the parameter elementwise_op_strategy_follow in the DP algorithm.") .def("get_elementwise_op_strategy_follow", &CostModelContext::elementwise_stra_follow, "Get the parameter elementwise_op_strategy_follow in the DP algorithm.") + .def("set_dp_algo_enable_approxi", &CostModelContext::set_dp_algo_enable_approxi, + "Set the flag whether enabling approximation in the DP algorithm.") + .def("get_dp_algo_enable_approxi", &CostModelContext::dp_algo_enable_approxi, + "Get the flag whether enabling approximation in the DP algorithm.") + .def("set_dp_algo_approxi_epsilon", &CostModelContext::set_dp_algo_approxi_epsilon, + "Set the epsilon which is used in the approximation of DP algorithm.") + .def("get_dp_algo_approxi_epsilon", &CostModelContext::dp_algo_approxi_epsilon, + "Get the epsilon which is used in the approximation of DP algorithm.") .def("reset_cost_model", &CostModelContext::ResetCostModel, "Reset the CostModelContext.") .def("reset_algo_parameters", &CostModelContext::ResetAlgoParameters, "Reset the AlgoParameters."); diff --git a/mindspore/parallel/algo_parameter_config.py b/mindspore/parallel/algo_parameter_config.py index 462ee0b371..0b6979f83f 100644 --- a/mindspore/parallel/algo_parameter_config.py +++ b/mindspore/parallel/algo_parameter_config.py @@ -88,6 +88,22 @@ class _AlgoParameterConfig(): self.check_config_handle() return self._config_handle.get_tensor_slice_align_size() + def set_dp_algo_enable_approxi(self, enable_flag): + self.check_config_handle() + self._config_handle.set_dp_algo_enable_approxi(enable_flag) + + def get_dp_algo_enable_approxi(self): + self.check_config_handle() + return self._config_handle.get_dp_algo_enable_approxi() + + def set_dp_algo_approxi_epsilon(self, epsilon): + self.check_config_handle() + self._config_handle.set_dp_algo_approxi_epsilon(epsilon) + + def get_dp_algo_approxi_epsilon(self): + self.check_config_handle() + return self._config_handle.get_dp_algo_approxi_epsilon() + def reset_algo_parameters(self): self.check_config_handle() self._config_handle.reset_algo_parameters() @@ -113,18 +129,23 @@ set_algo_parameters_config_func_map = { "fully_use_devices": _algo_parameter_config().set_fully_use_devices, "elementwise_op_strategy_follow": _algo_parameter_config().set_elementwise_op_strategy_follow, "tensor_slice_align_enable": _algo_parameter_config().set_tensor_slice_align_enable, - "tensor_slice_align_size": _algo_parameter_config().set_tensor_slice_align_size} + "tensor_slice_align_size": _algo_parameter_config().set_tensor_slice_align_size, + "enable_algo_approxi": _algo_parameter_config().set_dp_algo_enable_approxi, + "algo_approxi_epsilon": _algo_parameter_config().set_dp_algo_approxi_epsilon} get_algo_parameters_config_func_map = { "fully_use_devices": _algo_parameter_config().get_fully_use_devices, "elementwise_op_strategy_follow": _algo_parameter_config().get_elementwise_op_strategy_follow, "tensor_slice_align_enable": _algo_parameter_config().get_tensor_slice_align_enable, - "tensor_slice_align_size": _algo_parameter_config().get_tensor_slice_align_size} + "tensor_slice_align_size": _algo_parameter_config().get_tensor_slice_align_size, + "enable_algo_approxi": _algo_parameter_config().get_dp_algo_enable_approxi, + "algo_approxi_epsilon": _algo_parameter_config().get_dp_algo_approxi_epsilon} @args_type_check(tensor_slice_align_enable=bool, tensor_slice_align_size=int, - fully_use_devices=bool, elementwise_op_strategy_follow=bool) + fully_use_devices=bool, elementwise_op_strategy_follow=bool, + enable_algo_approxi=bool, algo_approxi_epsilon=float) def set_algo_parameters(**kwargs): """ Set algo parameter config. @@ -139,6 +160,8 @@ def set_algo_parameters(**kwargs): fully_use_devices (bool): Whether ONLY generating strategies that fully use all available devices. Default: True elementwise_op_strategy_follow (bool): Whether the elementwise operator has the same strategies as its subsequent operators. Default: False + enable_algo_approxi (bool): Whether to enable the approximation in the DP algorithms. + algo_approxi_epsilon (float): The epsilon value used int the approximation DP algorithm. Raises: ValueError: If context keyword is not recognized. diff --git a/tests/ut/python/parallel/test_auto_parallel_resnet.py b/tests/ut/python/parallel/test_auto_parallel_resnet.py index 6be053f69e..da3ded3209 100644 --- a/tests/ut/python/parallel/test_auto_parallel_resnet.py +++ b/tests/ut/python/parallel/test_auto_parallel_resnet.py @@ -686,6 +686,33 @@ def test_train_8k_8p_gpu(batch_size=32, num_classes=8192): context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num) set_algo_parameters(elementwise_op_strategy_follow=True) + #set_algo_parameters(enable_algo_approxi=True) + resset_op_id() + np.random.seed(6) + input_np = np.ones([batch_size, 3, 224, 224]).astype(np.float32) + label_np = np.zeros([batch_size]).astype(np.int32) + for i in range(0, batch_size): + label_np[i] = i % num_classes + dataset = DatasetLenet(Tensor(input_np), Tensor(label_np), 1) + net = resnet50(num_classes) + loss = SoftmaxCrossEntropyExpand(sparse=True) + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9) + model = Model(net, loss_fn=loss, optimizer=opt) + model.train(5, dataset, dataset_sink_mode=False) + strategies = _executor._get_shard_strategy(model._train_network) + for (k, v) in strategies.items(): + if re.search('Conv2D-op', k) is not None: + assert v[0][0] == dev_num + elif re.search('MatMul-op', k) is not None: + assert v == [[1, 1], [dev_num, 1]] + elif re.search('ReduceSum-op', k) is not None: + assert v == [[1, dev_num]] + +def test_train_8k_8p_gpu_approxi(batch_size=32, num_classes=8192): + dev_num = 8 + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num) + set_algo_parameters(enable_algo_approxi=True) resset_op_id() np.random.seed(6) input_np = np.ones([batch_size, 3, 224, 224]).astype(np.float32) diff --git a/tests/ut/python/parallel/test_auto_parallel_two_matmul.py b/tests/ut/python/parallel/test_auto_parallel_two_matmul.py index 2f9c91625b..d20b3b39bd 100644 --- a/tests/ut/python/parallel/test_auto_parallel_two_matmul.py +++ b/tests/ut/python/parallel/test_auto_parallel_two_matmul.py @@ -105,7 +105,8 @@ def test_two_matmul(): assert costmodel_communi_bias == 1024.0 set_algo_parameters(tensor_slice_align_enable=False, tensor_slice_align_size=32, - fully_use_devices=False, elementwise_op_strategy_follow=False) + fully_use_devices=False, elementwise_op_strategy_follow=False, + enable_algo_approxi=True, algo_approxi_epsilon=0.001) para_slice_align_enable = get_algo_parameters("tensor_slice_align_enable") assert not para_slice_align_enable para_slice_align_size = get_algo_parameters("tensor_slice_align_size") @@ -114,6 +115,10 @@ def test_two_matmul(): assert not fully_use_devices elementwise_op_strategy_follow = get_algo_parameters("elementwise_op_strategy_follow") assert not elementwise_op_strategy_follow + enable_approxi = get_algo_parameters("enable_algo_approxi") + assert enable_approxi + algo_epsilon = get_algo_parameters("algo_approxi_epsilon") + assert algo_epsilon == 0.001 reset_algo_parameters() para_slice_align_enable = get_algo_parameters("tensor_slice_align_enable") @@ -124,6 +129,10 @@ def test_two_matmul(): assert fully_use_devices elementwise_op_strategy_follow = get_algo_parameters("elementwise_op_strategy_follow") assert not elementwise_op_strategy_follow + enable_approxi = get_algo_parameters("enable_algo_approxi") + assert not enable_approxi + algo_epsilon = get_algo_parameters("algo_approxi_epsilon") + assert algo_epsilon == 0.1 x = Tensor(np.ones([128, 32]), dtype=ms.float32) y = Tensor(np.ones([32, 64]), dtype=ms.float32)