Merge pull request !8016 from Xiaoda/21-approximate-algo-in-searching-strategytags/v1.1.0
| @@ -28,6 +28,10 @@ namespace mindspore { | |||
| namespace parallel { | |||
| Status Edge::InitEdgeCost() { | |||
| bool has_available_cost = false; | |||
| pre_op_output_.clear(); | |||
| next_op_input_.clear(); | |||
| cost_map_.clear(); | |||
| for (auto &swc : prev_op_->GetStrategyCost()) { | |||
| MS_EXCEPTION_IF_NULL(swc); | |||
| pre_op_output_.emplace_back(std::make_pair(swc->strategy_ptr, swc->outputs_ptr)); | |||
| @@ -332,5 +336,8 @@ void Edge::SetCostMapAndInputOutput(std::map<CostPtrKey, CostPtrList> &cost_map) | |||
| next_op_input_.emplace_back(std::pair<StrategyPtr, std::vector<TensorInfo>>(key_pair.second, {})); | |||
| } | |||
| } | |||
| // Return true if there are available strategies in this edge. | |||
| bool Edge::CheckStrategyCostPossibility() { return !cost_map_.empty(); } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -140,6 +140,8 @@ class Edge { | |||
| // In the inference phase, | |||
| Status CalculateMemoryCostForInference(); | |||
| void mark_output_critical() { is_output_critical_ = 1; } | |||
| // Whether there exists any available strategy in 'cost_map_' | |||
| bool CheckStrategyCostPossibility(); | |||
| private: | |||
| std::string edge_name_; | |||
| @@ -41,6 +41,8 @@ bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; | |||
| bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS; | |||
| int32_t RUN_PHASE = DEFAULT_RUN_PHASE; | |||
| bool TRIANGLE_STAR_STRATEGY_OVERWRITE = DEFAULT_TRIANGLE_STAR_STRATEGY_OVERWRITE; | |||
| bool DP_ALGO_ENABLE_APPROX = DEFAULT_DP_ALGO_ENABLE_APPROX; | |||
| double DP_ALGO_APPROX_EPSILON = DEFAULT_DP_ALGO_APPROX_EPSILON; | |||
| void CostGraph::SetDeviceMemoryAndCostParameter() { | |||
| MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance()); | |||
| @@ -170,6 +172,21 @@ void CostGraph::SetDeviceMemoryAndCostParameter() { | |||
| } | |||
| RUN_PHASE = phase; | |||
| MS_LOG(INFO) << "run_phase: " << RUN_PHASE << "."; | |||
| auto enable_approx = CostModelContext::GetInstance()->dp_algo_enable_approxi(); | |||
| DP_ALGO_ENABLE_APPROX = enable_approx; | |||
| if (enable_approx) { | |||
| MS_LOG(INFO) << "dp_algo_enable_approx: true."; | |||
| } else { | |||
| MS_LOG(INFO) << "dp_algo_enable_approx: false."; | |||
| } | |||
| auto epsilon = CostModelContext::GetInstance()->dp_algo_approxi_epsilon(); | |||
| if (epsilon <= 0 || epsilon > 1) { | |||
| MS_LOG(EXCEPTION) << "'epsilon' must be in (0, 1]"; | |||
| } | |||
| DP_ALGO_APPROX_EPSILON = epsilon; | |||
| MS_LOG(INFO) << "epsilon: " << epsilon << "."; | |||
| } | |||
| void CostGraph::RemoveOperator(const OperatorInfoPtr &op) { | |||
| @@ -1901,5 +1918,31 @@ Status CostGraph::CalculateMemoryCost() { | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| void CostGraph::CheckApproximateCostGraphEdges() { | |||
| auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi(); | |||
| if (!approximation) { | |||
| return; | |||
| } | |||
| for (auto &s_edge : edges_) { | |||
| auto &edges_vector = s_edge.second; | |||
| for (auto &edge_ptr : edges_vector) { | |||
| MS_EXCEPTION_IF_NULL(edge_ptr); | |||
| if (edge_ptr->CheckStrategyCostPossibility()) { | |||
| continue; | |||
| } | |||
| MS_LOG(INFO) << "Checking StrategyCost for edge: " << edge_ptr->edge_name() | |||
| << " impossible, re-initing the operators and edges"; | |||
| auto prev_op = edge_ptr->prev_operator(); | |||
| MS_EXCEPTION_IF_NULL(prev_op); | |||
| auto next_op = edge_ptr->next_operator(); | |||
| MS_EXCEPTION_IF_NULL(next_op); | |||
| // Check the 'prev_op' | |||
| prev_op->ExactStrategiesAndRelatedEdges(); | |||
| // Check the 'next_op' | |||
| next_op->ExactStrategiesAndRelatedEdges(); | |||
| } | |||
| } | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -45,6 +45,8 @@ extern size_t TENSOR_SLICE_ALIGNMENT_SIZE; | |||
| extern bool FULLY_USE_DEVICES; | |||
| extern bool ELEMENTWISE_OP_STRA_FOLLOW; | |||
| extern bool MULTI_SUBGRAPHS; | |||
| extern bool DP_ALGO_ENABLE_APPROX; | |||
| extern double DP_ALGO_APPROX_EPSILON; | |||
| extern int32_t RUN_PHASE; | |||
| extern bool TRIANGLE_STAR_STRATEGY_OVERWRITE; | |||
| @@ -193,6 +195,9 @@ class CostGraph { | |||
| // When TmpIdentity is used by mulitple operators, the corresponding parameter's memory cost should be calculated only | |||
| // once (instead of multiple times), this method is used to correct this. | |||
| Status CorrectOpsMemoryCost(); | |||
| // When APPROXIMATION is enabled in the DP algorithm, some edges may have no valid strategies. | |||
| // This method is to re-init those edge involved operators. | |||
| void CheckApproximateCostGraphEdges(); | |||
| // Needed by rec_parser | |||
| void add_inputs_tensor_name(const std::vector<std::string> &inputs_tensor_name) { | |||
| inputs_tensor_name_list_.push_back(inputs_tensor_name); | |||
| @@ -65,6 +65,8 @@ void CostModelContext::ResetAlgoParameters() { | |||
| fully_use_device_ = DEFAULT_FULLY_USE_DEVICES; | |||
| elementwise_stra_follow_ = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; | |||
| triangle_star_strategy_overwrite_ = DEFAULT_TRIANGLE_STAR_STRATEGY_OVERWRITE; | |||
| dp_algo_enable_approxi_ = DEFAULT_DP_ALGO_ENABLE_APPROX; | |||
| dp_algo_approxi_epsilon_ = DEFAULT_DP_ALGO_APPROX_EPSILON; | |||
| } | |||
| void CostModelContext::set_costmodel_context_for_device(const std::string &device_target) { | |||
| @@ -73,6 +75,10 @@ void CostModelContext::set_costmodel_context_for_device(const std::string &devic | |||
| } | |||
| } | |||
| void CostModelContext::set_dp_algo_approxi_epsilon(double epsilon) { dp_algo_approxi_epsilon_ = epsilon; } | |||
| void CostModelContext::set_dp_algo_enable_approxi(bool approxi) { dp_algo_enable_approxi_ = approxi; } | |||
| void CostModelContext::set_device_memory_capacity(double dm_capacity) { device_memory_capacity_ = dm_capacity; } | |||
| void CostModelContext::set_costmodel_alpha(double cm_alpha) { costmodel_alpha_ = cm_alpha; } | |||
| @@ -45,6 +45,8 @@ namespace parallel { | |||
| #define TRAINING_PHASE 0 | |||
| #define INFERENCE_PHASE 1 | |||
| #define DEFAULT_TRIANGLE_STAR_STRATEGY_OVERWRITE true; | |||
| #define DEFAULT_DP_ALGO_ENABLE_APPROX false | |||
| #define DEFAULT_DP_ALGO_APPROX_EPSILON 0.1 | |||
| class CostModelContext { | |||
| public: | |||
| @@ -141,6 +143,12 @@ class CostModelContext { | |||
| void set_run_phase(int32_t); | |||
| int32_t run_phase() const { return run_phase_; } | |||
| void set_dp_algo_approxi_epsilon(double); | |||
| double dp_algo_approxi_epsilon() const { return dp_algo_approxi_epsilon_; } | |||
| void set_dp_algo_enable_approxi(bool); | |||
| bool dp_algo_enable_approxi() const { return dp_algo_enable_approxi_; } | |||
| private: | |||
| CostModelContext(); | |||
| static std::shared_ptr<CostModelContext> cm_context_inst_; | |||
| @@ -176,6 +184,12 @@ class CostModelContext { | |||
| // whether overwrite the right-node strategy | |||
| bool triangle_star_strategy_overwrite_; | |||
| // Whether to enable APPROXIMATION in the DP algorithm. | |||
| bool dp_algo_enable_approxi_; | |||
| // When APPROXIMATION is enabled in the DP algorithm, the 'epsilon' value used in the APPROXIMATION. | |||
| double dp_algo_approxi_epsilon_; | |||
| int32_t run_phase_; // 0: 'training', 1: 'inference' | |||
| int32_t costmodel_allreduce_fusion_algorithm_; | |||
| @@ -1130,6 +1130,67 @@ Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr &strategy) { | |||
| return SUCCESS; | |||
| } | |||
| // Keep at most (1.0 / epsilon) number of available strategies for each operator. | |||
| void OperatorInfo::ApproximateStrategies() { | |||
| auto enable_approxi = CostModelContext::GetInstance()->dp_algo_enable_approxi(); | |||
| if (!enable_approxi) { | |||
| return; | |||
| } | |||
| MS_LOG(INFO) << "Approximating strategy-cost for: " << name_; | |||
| auto epsilon = CostModelContext::GetInstance()->dp_algo_approxi_epsilon(); | |||
| auto target_num = static_cast<size_t>(std::ceil(1.0 / epsilon)); | |||
| if (strategy_cost_.size() <= target_num) { | |||
| MS_LOG(INFO) << name_ << "'s strategy number is: " << strategy_cost_.size() | |||
| << ", no greater than target-num: " << target_num; | |||
| return; | |||
| } | |||
| std::vector<std::shared_ptr<StrategyWithCost>> ret; | |||
| auto &origin_stra_cost = strategy_cost_; | |||
| auto alpha = CostModelContext::GetInstance()->costmodel_alpha(); | |||
| auto beta = CostModelContext::GetInstance()->costmodel_beta(); | |||
| // sort | |||
| std::sort( | |||
| origin_stra_cost.begin(), origin_stra_cost.end(), | |||
| [&alpha, &beta](const std::shared_ptr<StrategyWithCost> &s1, const std::shared_ptr<StrategyWithCost> &s2) { | |||
| if (alpha * s1->cost_list[0]->computation_cost_ + beta * s1->cost_list[0]->communication_with_partial_para_ < | |||
| alpha * s2->cost_list[0]->computation_cost_ + beta * s2->cost_list[0]->communication_with_partial_para_) { | |||
| return true; | |||
| } | |||
| return false; | |||
| }); | |||
| size_t step_length = origin_stra_cost.size() / target_num; | |||
| for (size_t i = 0; ret.size() < target_num && static_cast<size_t>(i * step_length) < origin_stra_cost.size(); ++i) { | |||
| ret.push_back(origin_stra_cost[static_cast<size_t>(i * step_length)]); | |||
| } | |||
| strategy_cost_ = ret; | |||
| is_strategy_cost_exact_ = false; | |||
| } | |||
| void OperatorInfo::ExactStrategiesAndRelatedEdges() { | |||
| if (is_strategy_cost_exact()) { | |||
| return; | |||
| } | |||
| ClearStrategyCost(); | |||
| if (GenerateStrategies(0) != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << "Strategy search for Operator " << name() << " failed."; | |||
| return; | |||
| } | |||
| SetIsStrategyCostExactTrue(); | |||
| // re-init the previous edges | |||
| for (auto &prev_edge : prev_edges()) { | |||
| if (prev_edge->InitEdgeCost() != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << "Edge: " << prev_edge->edge_name() << " cost init failed."; | |||
| } | |||
| } | |||
| // re-init the successive edges | |||
| for (auto &next_edge : succ_edges()) { | |||
| if (next_edge->InitEdgeCost() != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << "Edge: " << next_edge->edge_name() << " cost init failed."; | |||
| } | |||
| } | |||
| } | |||
| int OperatorInfo::ComputeOpAndPrevEdgeParameterInvolved() { | |||
| if (is_output_parameter_involve_ != -1) { | |||
| return is_output_parameter_involve_; | |||
| @@ -138,6 +138,13 @@ class OperatorInfo { | |||
| } | |||
| StrategyPtr selected_strategy() const { return selected_strategy_; } | |||
| CostPtr selected_cost() const { return selected_cost_; } | |||
| // Approximate the list of available strategies | |||
| void ApproximateStrategies(); | |||
| // Make the list of available strategies exact and re-init the related edges incident to this operator | |||
| void ExactStrategiesAndRelatedEdges(); | |||
| bool is_strategy_cost_exact() { return is_strategy_cost_exact_; } | |||
| void SetIsStrategyCostExactTrue() { is_strategy_cost_exact_ = true; } | |||
| void ClearStrategyCost() { strategy_cost_.clear(); } | |||
| void CheckSelectedStrategy(const StrategyPtr &); | |||
| Status InitSelectedStrategy(const StrategyPtr &s_strategy) { return Init(s_strategy); } | |||
| void set_input_value(const std::vector<ValuePtr> &input_value) { input_value_ = input_value; } | |||
| @@ -263,6 +270,8 @@ class OperatorInfo { | |||
| int32_t used_devices_ = -1; | |||
| // the repeated_calc_num_ will be inserted to the last dimension of dev matrix in default | |||
| bool repeated_num_in_dev_matrix_right_ = true; | |||
| // Whether the list of available strategies is exact or approximate | |||
| bool is_strategy_cost_exact_ = true; | |||
| private: | |||
| OperatorCostPtr operator_cost_; | |||
| @@ -408,6 +408,12 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & | |||
| MS_LOG(ERROR) << "Strategy search for Operator " << operator_info->name() << " failed."; | |||
| return nullptr; | |||
| } | |||
| // If 'approximation' is enabled, the 'strategy_cost' of each operator is approximated | |||
| auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi(); | |||
| if (approximation) { | |||
| operator_info->ApproximateStrategies(); | |||
| MS_LOG(INFO) << "Approximated StrategyCost for: " << operator_info->name(); | |||
| } | |||
| } else { | |||
| // In this case, the configured strategy should be extracted to help setting cost | |||
| StrategyPtr strategyPtr; | |||
| @@ -695,6 +701,11 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) { | |||
| } | |||
| MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << node_op_info->name(); | |||
| } | |||
| // If 'approximation' is enabled, the edges need to be checked have effective costs. | |||
| auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi(); | |||
| if (approximation) { | |||
| entire_costgraph->CheckApproximateCostGraphEdges(); | |||
| } | |||
| MS_LOG(INFO) << "Constructing edges for cost graph ends."; | |||
| } | |||
| @@ -800,6 +811,11 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) { | |||
| } | |||
| std::shared_ptr<Edge> edge_ptr = | |||
| std::make_shared<Edge>(edge_name, tmp_identity_ptr, target_op_info, 0, input_index - 1, false, true); | |||
| // If 'approximation' is enabled, the edges need to be checked have effective costs. | |||
| auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi(); | |||
| if (approximation) { | |||
| target_op_info->ExactStrategiesAndRelatedEdges(); | |||
| } | |||
| if (edge_ptr->InitEdgeCost() != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << "Edge cost initialization failed"; | |||
| @@ -246,6 +246,14 @@ PYBIND11_MODULE(_c_expression, m) { | |||
| "Set the parameter elementwise_op_strategy_follow in the DP algorithm.") | |||
| .def("get_elementwise_op_strategy_follow", &CostModelContext::elementwise_stra_follow, | |||
| "Get the parameter elementwise_op_strategy_follow in the DP algorithm.") | |||
| .def("set_dp_algo_enable_approxi", &CostModelContext::set_dp_algo_enable_approxi, | |||
| "Set the flag whether enabling approximation in the DP algorithm.") | |||
| .def("get_dp_algo_enable_approxi", &CostModelContext::dp_algo_enable_approxi, | |||
| "Get the flag whether enabling approximation in the DP algorithm.") | |||
| .def("set_dp_algo_approxi_epsilon", &CostModelContext::set_dp_algo_approxi_epsilon, | |||
| "Set the epsilon which is used in the approximation of DP algorithm.") | |||
| .def("get_dp_algo_approxi_epsilon", &CostModelContext::dp_algo_approxi_epsilon, | |||
| "Get the epsilon which is used in the approximation of DP algorithm.") | |||
| .def("reset_cost_model", &CostModelContext::ResetCostModel, "Reset the CostModelContext.") | |||
| .def("reset_algo_parameters", &CostModelContext::ResetAlgoParameters, "Reset the AlgoParameters."); | |||
| @@ -88,6 +88,22 @@ class _AlgoParameterConfig(): | |||
| self.check_config_handle() | |||
| return self._config_handle.get_tensor_slice_align_size() | |||
| def set_dp_algo_enable_approxi(self, enable_flag): | |||
| self.check_config_handle() | |||
| self._config_handle.set_dp_algo_enable_approxi(enable_flag) | |||
| def get_dp_algo_enable_approxi(self): | |||
| self.check_config_handle() | |||
| return self._config_handle.get_dp_algo_enable_approxi() | |||
| def set_dp_algo_approxi_epsilon(self, epsilon): | |||
| self.check_config_handle() | |||
| self._config_handle.set_dp_algo_approxi_epsilon(epsilon) | |||
| def get_dp_algo_approxi_epsilon(self): | |||
| self.check_config_handle() | |||
| return self._config_handle.get_dp_algo_approxi_epsilon() | |||
| def reset_algo_parameters(self): | |||
| self.check_config_handle() | |||
| self._config_handle.reset_algo_parameters() | |||
| @@ -113,18 +129,23 @@ set_algo_parameters_config_func_map = { | |||
| "fully_use_devices": _algo_parameter_config().set_fully_use_devices, | |||
| "elementwise_op_strategy_follow": _algo_parameter_config().set_elementwise_op_strategy_follow, | |||
| "tensor_slice_align_enable": _algo_parameter_config().set_tensor_slice_align_enable, | |||
| "tensor_slice_align_size": _algo_parameter_config().set_tensor_slice_align_size} | |||
| "tensor_slice_align_size": _algo_parameter_config().set_tensor_slice_align_size, | |||
| "enable_algo_approxi": _algo_parameter_config().set_dp_algo_enable_approxi, | |||
| "algo_approxi_epsilon": _algo_parameter_config().set_dp_algo_approxi_epsilon} | |||
| get_algo_parameters_config_func_map = { | |||
| "fully_use_devices": _algo_parameter_config().get_fully_use_devices, | |||
| "elementwise_op_strategy_follow": _algo_parameter_config().get_elementwise_op_strategy_follow, | |||
| "tensor_slice_align_enable": _algo_parameter_config().get_tensor_slice_align_enable, | |||
| "tensor_slice_align_size": _algo_parameter_config().get_tensor_slice_align_size} | |||
| "tensor_slice_align_size": _algo_parameter_config().get_tensor_slice_align_size, | |||
| "enable_algo_approxi": _algo_parameter_config().get_dp_algo_enable_approxi, | |||
| "algo_approxi_epsilon": _algo_parameter_config().get_dp_algo_approxi_epsilon} | |||
| @args_type_check(tensor_slice_align_enable=bool, tensor_slice_align_size=int, | |||
| fully_use_devices=bool, elementwise_op_strategy_follow=bool) | |||
| fully_use_devices=bool, elementwise_op_strategy_follow=bool, | |||
| enable_algo_approxi=bool, algo_approxi_epsilon=float) | |||
| def set_algo_parameters(**kwargs): | |||
| """ | |||
| Set algo parameter config. | |||
| @@ -139,6 +160,8 @@ def set_algo_parameters(**kwargs): | |||
| fully_use_devices (bool): Whether ONLY generating strategies that fully use all available devices. Default: True | |||
| elementwise_op_strategy_follow (bool): Whether the elementwise operator has the same strategies as its | |||
| subsequent operators. Default: False | |||
| enable_algo_approxi (bool): Whether to enable the approximation in the DP algorithms. | |||
| algo_approxi_epsilon (float): The epsilon value used int the approximation DP algorithm. | |||
| Raises: | |||
| ValueError: If context keyword is not recognized. | |||
| @@ -686,6 +686,33 @@ def test_train_8k_8p_gpu(batch_size=32, num_classes=8192): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num) | |||
| set_algo_parameters(elementwise_op_strategy_follow=True) | |||
| #set_algo_parameters(enable_algo_approxi=True) | |||
| resset_op_id() | |||
| np.random.seed(6) | |||
| input_np = np.ones([batch_size, 3, 224, 224]).astype(np.float32) | |||
| label_np = np.zeros([batch_size]).astype(np.int32) | |||
| for i in range(0, batch_size): | |||
| label_np[i] = i % num_classes | |||
| dataset = DatasetLenet(Tensor(input_np), Tensor(label_np), 1) | |||
| net = resnet50(num_classes) | |||
| loss = SoftmaxCrossEntropyExpand(sparse=True) | |||
| opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9) | |||
| model = Model(net, loss_fn=loss, optimizer=opt) | |||
| model.train(5, dataset, dataset_sink_mode=False) | |||
| strategies = _executor._get_shard_strategy(model._train_network) | |||
| for (k, v) in strategies.items(): | |||
| if re.search('Conv2D-op', k) is not None: | |||
| assert v[0][0] == dev_num | |||
| elif re.search('MatMul-op', k) is not None: | |||
| assert v == [[1, 1], [dev_num, 1]] | |||
| elif re.search('ReduceSum-op', k) is not None: | |||
| assert v == [[1, dev_num]] | |||
| def test_train_8k_8p_gpu_approxi(batch_size=32, num_classes=8192): | |||
| dev_num = 8 | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num) | |||
| set_algo_parameters(enable_algo_approxi=True) | |||
| resset_op_id() | |||
| np.random.seed(6) | |||
| input_np = np.ones([batch_size, 3, 224, 224]).astype(np.float32) | |||
| @@ -105,7 +105,8 @@ def test_two_matmul(): | |||
| assert costmodel_communi_bias == 1024.0 | |||
| set_algo_parameters(tensor_slice_align_enable=False, tensor_slice_align_size=32, | |||
| fully_use_devices=False, elementwise_op_strategy_follow=False) | |||
| fully_use_devices=False, elementwise_op_strategy_follow=False, | |||
| enable_algo_approxi=True, algo_approxi_epsilon=0.001) | |||
| para_slice_align_enable = get_algo_parameters("tensor_slice_align_enable") | |||
| assert not para_slice_align_enable | |||
| para_slice_align_size = get_algo_parameters("tensor_slice_align_size") | |||
| @@ -114,6 +115,10 @@ def test_two_matmul(): | |||
| assert not fully_use_devices | |||
| elementwise_op_strategy_follow = get_algo_parameters("elementwise_op_strategy_follow") | |||
| assert not elementwise_op_strategy_follow | |||
| enable_approxi = get_algo_parameters("enable_algo_approxi") | |||
| assert enable_approxi | |||
| algo_epsilon = get_algo_parameters("algo_approxi_epsilon") | |||
| assert algo_epsilon == 0.001 | |||
| reset_algo_parameters() | |||
| para_slice_align_enable = get_algo_parameters("tensor_slice_align_enable") | |||
| @@ -124,6 +129,10 @@ def test_two_matmul(): | |||
| assert fully_use_devices | |||
| elementwise_op_strategy_follow = get_algo_parameters("elementwise_op_strategy_follow") | |||
| assert not elementwise_op_strategy_follow | |||
| enable_approxi = get_algo_parameters("enable_algo_approxi") | |||
| assert not enable_approxi | |||
| algo_epsilon = get_algo_parameters("algo_approxi_epsilon") | |||
| assert algo_epsilon == 0.1 | |||
| x = Tensor(np.ones([128, 32]), dtype=ms.float32) | |||
| y = Tensor(np.ones([32, 64]), dtype=ms.float32) | |||