Merge pull request !304 from Xiaoda/modify-not-fully-use-devices-and-strategy-checkingtags/v0.2.0-alpha
| @@ -85,10 +85,10 @@ Status Edge::InitEdgeCost() { | |||||
| } | } | ||||
| } | } | ||||
| if (!has_available_cost) { | if (!has_available_cost) { | ||||
| if (!NOT_FULLY_USE_DEVICES) { | |||||
| if (FULLY_USE_DEVICES) { | |||||
| MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_ | MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_ | ||||
| << " failed, it may be caused by setting 'not_fully_use_devices' false. Try to set " | |||||
| "'not_fully_use_devices' true."; | |||||
| << " failed, it may be caused by setting 'fully_use_devices' true. Try to set " | |||||
| "'fully_use_devices' false."; | |||||
| } else if (ELEMENTWISE_OP_STRA_FOLLOW) { | } else if (ELEMENTWISE_OP_STRA_FOLLOW) { | ||||
| MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_ | MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_ | ||||
| << " failed, it may be caused by setting 'elementwise_op_strategy_follow' true. " | << " failed, it may be caused by setting 'elementwise_op_strategy_follow' true. " | ||||
| @@ -36,7 +36,7 @@ double COST_MODEL_COMMUNI_CONST = DEFAULT_COST_MODEL_COMMUNI_CONST; | |||||
| double COST_MODEL_COMMUNI_BIAS = DEFAULT_COST_MODEL_COMMUNI_BIAS; | double COST_MODEL_COMMUNI_BIAS = DEFAULT_COST_MODEL_COMMUNI_BIAS; | ||||
| bool TENSOR_SLICE_ALIGNMENT_ENABLE = DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE; | bool TENSOR_SLICE_ALIGNMENT_ENABLE = DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE; | ||||
| size_t TENSOR_SLICE_ALIGNMENT_SIZE = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE; | size_t TENSOR_SLICE_ALIGNMENT_SIZE = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE; | ||||
| bool NOT_FULLY_USE_DEVICES = DEFAULT_NOT_FULLY_USE_DEVICES; | |||||
| bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES; | |||||
| bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; | bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; | ||||
| void CostGraph::SetDeviceMemoryAndCostParameter() { | void CostGraph::SetDeviceMemoryAndCostParameter() { | ||||
| @@ -125,13 +125,13 @@ void CostGraph::SetDeviceMemoryAndCostParameter() { | |||||
| TENSOR_SLICE_ALIGNMENT_SIZE = align_size; | TENSOR_SLICE_ALIGNMENT_SIZE = align_size; | ||||
| MS_LOG(INFO) << "tensor_slice_align_size: " << TENSOR_SLICE_ALIGNMENT_SIZE << "."; | MS_LOG(INFO) << "tensor_slice_align_size: " << TENSOR_SLICE_ALIGNMENT_SIZE << "."; | ||||
| // NOT_FULLY_USE_DEVICES | |||||
| auto not_fully_devices = CostModelContext::GetInstance()->not_fully_use_device(); | |||||
| NOT_FULLY_USE_DEVICES = not_fully_devices; | |||||
| if (NOT_FULLY_USE_DEVICES) { | |||||
| MS_LOG(INFO) << "not_fully_use_devices: true."; | |||||
| // FULLY_USE_DEVICES | |||||
| auto fully_devices = CostModelContext::GetInstance()->fully_use_device(); | |||||
| FULLY_USE_DEVICES = fully_devices; | |||||
| if (FULLY_USE_DEVICES) { | |||||
| MS_LOG(INFO) << "fully_use_devices: true."; | |||||
| } else { | } else { | ||||
| MS_LOG(INFO) << "not_fully_use_devices: false."; | |||||
| MS_LOG(INFO) << "fully_use_devices: false."; | |||||
| } | } | ||||
| // ELEMENTWISE_OP_STRA_FOLLOW | // ELEMENTWISE_OP_STRA_FOLLOW | ||||
| @@ -42,7 +42,7 @@ namespace parallel { | |||||
| #define DEFAULT_COST_MODEL_COMMUNI_BIAS 1024.0 | #define DEFAULT_COST_MODEL_COMMUNI_BIAS 1024.0 | ||||
| #define DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE false | #define DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE false | ||||
| #define DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE 16 | #define DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE 16 | ||||
| #define DEFAULT_NOT_FULLY_USE_DEVICES false | |||||
| #define DEFAULT_FULLY_USE_DEVICES true | |||||
| #define DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW false | #define DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW false | ||||
| class CostGraph; | class CostGraph; | ||||
| @@ -57,7 +57,7 @@ extern double COST_MODEL_COMMUNI_CONST; | |||||
| extern double COST_MODEL_COMMUNI_BIAS; | extern double COST_MODEL_COMMUNI_BIAS; | ||||
| extern bool TENSOR_SLICE_ALIGNMENT_ENABLE; | extern bool TENSOR_SLICE_ALIGNMENT_ENABLE; | ||||
| extern size_t TENSOR_SLICE_ALIGNMENT_SIZE; | extern size_t TENSOR_SLICE_ALIGNMENT_SIZE; | ||||
| extern bool NOT_FULLY_USE_DEVICES; | |||||
| extern bool FULLY_USE_DEVICES; | |||||
| extern bool ELEMENTWISE_OP_STRA_FOLLOW; | extern bool ELEMENTWISE_OP_STRA_FOLLOW; | ||||
| class CostGraph { | class CostGraph { | ||||
| @@ -60,7 +60,7 @@ void CostModelContext::ResetAlgoParameters() { | |||||
| costmodel_simplify_cal_ = DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION; | costmodel_simplify_cal_ = DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION; | ||||
| tensor_slice_alignment_enable_ = DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE; | tensor_slice_alignment_enable_ = DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE; | ||||
| tensor_slice_alignment_size_ = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE; | tensor_slice_alignment_size_ = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE; | ||||
| not_fully_use_device_ = DEFAULT_NOT_FULLY_USE_DEVICES; | |||||
| fully_use_device_ = DEFAULT_FULLY_USE_DEVICES; | |||||
| elementwise_stra_follow_ = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; | elementwise_stra_follow_ = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; | ||||
| } | } | ||||
| @@ -118,7 +118,7 @@ void CostModelContext::set_tensor_slice_alignment_size(size_t ts_align_size) { | |||||
| tensor_slice_alignment_size_ = ts_align_size; | tensor_slice_alignment_size_ = ts_align_size; | ||||
| } | } | ||||
| void CostModelContext::set_not_fully_use_device(bool not_fully_use) { not_fully_use_device_ = not_fully_use; } | |||||
| void CostModelContext::set_fully_use_device(bool fully_use) { fully_use_device_ = fully_use; } | |||||
| void CostModelContext::set_elementwise_stra_follow(bool elementwise_follow) { | void CostModelContext::set_elementwise_stra_follow(bool elementwise_follow) { | ||||
| elementwise_stra_follow_ = elementwise_follow; | elementwise_stra_follow_ = elementwise_follow; | ||||
| @@ -102,9 +102,9 @@ class CostModelContext { | |||||
| void set_tensor_slice_alignment_size(size_t); | void set_tensor_slice_alignment_size(size_t); | ||||
| size_t tensor_slice_alignment_size() const { return tensor_slice_alignment_size_; } | size_t tensor_slice_alignment_size() const { return tensor_slice_alignment_size_; } | ||||
| // NOT_FULLY_USE_DEVICES | |||||
| void set_not_fully_use_device(bool); | |||||
| bool not_fully_use_device() const { return not_fully_use_device_; } | |||||
| // FULLY_USE_DEVICES | |||||
| void set_fully_use_device(bool); | |||||
| bool fully_use_device() const { return fully_use_device_; } | |||||
| // ELEMENTWISE_OP_STRA_FOLLOW | // ELEMENTWISE_OP_STRA_FOLLOW | ||||
| void set_elementwise_stra_follow(bool); | void set_elementwise_stra_follow(bool); | ||||
| @@ -158,8 +158,8 @@ class CostModelContext { | |||||
| // TENSOR_SLICE_ALIGNMENT_SIZE | // TENSOR_SLICE_ALIGNMENT_SIZE | ||||
| size_t tensor_slice_alignment_size_; | size_t tensor_slice_alignment_size_; | ||||
| // NOT_FULLY_USE_DEVICES | |||||
| bool not_fully_use_device_; | |||||
| // FULLY_USE_DEVICES | |||||
| bool fully_use_device_; | |||||
| // ELEMENTWISE_OP_STRA_FOLLOW | // ELEMENTWISE_OP_STRA_FOLLOW | ||||
| bool elementwise_stra_follow_; | bool elementwise_stra_follow_; | ||||
| @@ -465,7 +465,7 @@ Status MatMulBase::PrepareStrategy(int32_t stage_id, size_t dev_num, | |||||
| mindspore::parallel::Dimensions combined_partitions, size_t input0_shape_size, | mindspore::parallel::Dimensions combined_partitions, size_t input0_shape_size, | ||||
| size_t input1_shape_size, mindspore::parallel::StrategyPtr* const sp) { | size_t input1_shape_size, mindspore::parallel::StrategyPtr* const sp) { | ||||
| int32_t product = std::accumulate(combined_partitions.begin(), combined_partitions.end(), 1, std::multiplies<int>()); | int32_t product = std::accumulate(combined_partitions.begin(), combined_partitions.end(), 1, std::multiplies<int>()); | ||||
| if (NOT_FULLY_USE_DEVICES) { | |||||
| if (!FULLY_USE_DEVICES) { | |||||
| if (IntToSize(product) > dev_num) { | if (IntToSize(product) > dev_num) { | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -675,7 +675,7 @@ Status PrepareStrategyBase(int32_t stage_id, size_t dev_num, const Shapes& input | |||||
| for (auto& input_partition : inputs_partitions) { | for (auto& input_partition : inputs_partitions) { | ||||
| product *= std::accumulate(input_partition.begin(), input_partition.end(), 1, std::multiplies<int>()); | product *= std::accumulate(input_partition.begin(), input_partition.end(), 1, std::multiplies<int>()); | ||||
| } | } | ||||
| if (NOT_FULLY_USE_DEVICES) { | |||||
| if (!FULLY_USE_DEVICES) { | |||||
| if (IntToSize(product) > dev_num) { | if (IntToSize(product) > dev_num) { | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -110,8 +110,6 @@ std::vector<std::string> splittable_op_ = {MATMUL, | |||||
| std::vector<std::string> elementwise_op_ = {ACTIVATION, GELU, TANH, SOFTMAX, LOG_SOFTMAX, RELU, SQRT, | std::vector<std::string> elementwise_op_ = {ACTIVATION, GELU, TANH, SOFTMAX, LOG_SOFTMAX, RELU, SQRT, | ||||
| CAST, POW, EXP, LOG, COS, ACOS, LOGICALNOT}; | CAST, POW, EXP, LOG, COS, ACOS, LOGICALNOT}; | ||||
| std::vector<std::string> ignore_manual_strategy_op_ = {BATCH_NORM}; | |||||
| bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) { | bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) { | ||||
| MS_EXCEPTION_IF_NULL(root); | MS_EXCEPTION_IF_NULL(root); | ||||
| MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); | MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); | ||||
| @@ -308,16 +306,6 @@ std::vector<TypePtr> ExtractOutputTypeByNode(const CNodePtr &node) { | |||||
| return outputs_type; | return outputs_type; | ||||
| } | } | ||||
| // Be careful the argument is cnode_full_name, not the op_name | |||||
| bool IsIgnoreStrategyOperator(const std::string &cnode_full_name) { | |||||
| for (auto &ignore_op : ignore_manual_strategy_op_) { | |||||
| if (cnode_full_name.find(ignore_op) != std::string::npos) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool IsElementWiseOperator(const std::string &op_name) { | bool IsElementWiseOperator(const std::string &op_name) { | ||||
| auto iter = std::find(elementwise_op_.begin(), elementwise_op_.end(), op_name); | auto iter = std::find(elementwise_op_.begin(), elementwise_op_.end(), op_name); | ||||
| return (iter != elementwise_op_.end()); | return (iter != elementwise_op_.end()); | ||||
| @@ -414,18 +402,20 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & | |||||
| // Set cost for this configured strategy | // Set cost for this configured strategy | ||||
| if (operator_info->SetCostUnderStrategy(strategyPtr) != SUCCESS) { | if (operator_info->SetCostUnderStrategy(strategyPtr) != SUCCESS) { | ||||
| MS_LOG(EXCEPTION) << "Failure: operator " << prim->name() << " SetCostUnderStrategy failed"; | MS_LOG(EXCEPTION) << "Failure: operator " << prim->name() << " SetCostUnderStrategy failed"; | ||||
| } else if (!NOT_FULLY_USE_DEVICES) { | |||||
| if (!IsIgnoreStrategyOperator(cnode->fullname_with_scope())) { | |||||
| // If configured to fully use devices, then checking for the user-specified strategy | |||||
| int32_t used_devices = operator_info->used_devices(); | |||||
| MS_EXCEPTION_IF_NULL(g_device_manager); | |||||
| auto total_device_num = g_device_manager->GetDeviceListByStageId(0).size(); | |||||
| // 'used_devices == -1' means that 'used_devices_' is not set | |||||
| if ((used_devices == -1) || IntToSize(used_devices) != total_device_num) { | |||||
| MS_LOG(EXCEPTION) << "In configuration 'NOT_FULLY_USE_DEVICES' = False, " | |||||
| << "but the specified strategy uses device: " << used_devices | |||||
| << ", total devices: " << total_device_num; | |||||
| } | |||||
| } else if (FULLY_USE_DEVICES) { | |||||
| // If configured to fully use devices, then checking for the user-specified strategy | |||||
| int32_t used_devices = operator_info->used_devices(); | |||||
| MS_EXCEPTION_IF_NULL(g_device_manager); | |||||
| auto total_device_num = g_device_manager->GetDeviceListByStageId(0).size(); | |||||
| // 'used_devices == 1' means that ALL-1 strategy, which is valid in auto-parallel | |||||
| if (used_devices == 1) { | |||||
| return operator_info; | |||||
| } | |||||
| // 'used_devices == -1' means that 'used_devices_' is not set | |||||
| if ((used_devices == -1) || IntToSize(used_devices) != total_device_num) { | |||||
| MS_LOG(EXCEPTION) << "In configuration 'FULLY_USE_DEVICES' = True, " | |||||
| << "but the specified strategy uses device: " << used_devices | |||||
| << ", total devices: " << total_device_num; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -261,10 +261,10 @@ PYBIND11_MODULE(_c_expression, m) { | |||||
| "Set the parameter tensor_slice_size in strategy generation.") | "Set the parameter tensor_slice_size in strategy generation.") | ||||
| .def("get_tensor_slice_align_size", &CostModelContext::tensor_slice_alignment_size, | .def("get_tensor_slice_align_size", &CostModelContext::tensor_slice_alignment_size, | ||||
| "Get the parameter tensor_slice_size in strategy generation.") | "Get the parameter tensor_slice_size in strategy generation.") | ||||
| .def("set_not_fully_use_devices", &CostModelContext::set_not_fully_use_device, | |||||
| "Set the parameter not_fully_use_devices in the DP algorithm.") | |||||
| .def("get_not_fully_use_devices", &CostModelContext::not_fully_use_device, | |||||
| "Get the parameter not_fully_use_devices in the DP algorithm.") | |||||
| .def("set_fully_use_devices", &CostModelContext::set_fully_use_device, | |||||
| "Set the parameter fully_use_devices in the DP algorithm.") | |||||
| .def("get_fully_use_devices", &CostModelContext::fully_use_device, | |||||
| "Get the parameter fully_use_devices in the DP algorithm.") | |||||
| .def("set_elementwise_op_strategy_follow", &CostModelContext::set_elementwise_stra_follow, | .def("set_elementwise_op_strategy_follow", &CostModelContext::set_elementwise_stra_follow, | ||||
| "Set the parameter elementwise_op_strategy_follow in the DP algorithm.") | "Set the parameter elementwise_op_strategy_follow in the DP algorithm.") | ||||
| .def("get_elementwise_op_strategy_follow", &CostModelContext::elementwise_stra_follow, | .def("get_elementwise_op_strategy_follow", &CostModelContext::elementwise_stra_follow, | ||||
| @@ -53,13 +53,13 @@ class _AlgoParameterConfig(): | |||||
| self.check_config_handle() | self.check_config_handle() | ||||
| return self._config_handle.get_simplify_cal() | return self._config_handle.get_simplify_cal() | ||||
| def set_not_fully_use_devices(self, not_fully): | |||||
| def set_fully_use_devices(self, not_fully): | |||||
| self.check_config_handle() | self.check_config_handle() | ||||
| self._config_handle.set_not_fully_use_devices(not_fully) | |||||
| self._config_handle.set_fully_use_devices(not_fully) | |||||
| def get_not_fully_use_devices(self): | |||||
| def get_fully_use_devices(self): | |||||
| self.check_config_handle() | self.check_config_handle() | ||||
| return self._config_handle.get_not_fully_use_devices() | |||||
| return self._config_handle.get_fully_use_devices() | |||||
| def set_elementwise_op_strategy_follow(self, element_strategy_follow): | def set_elementwise_op_strategy_follow(self, element_strategy_follow): | ||||
| self.check_config_handle() | self.check_config_handle() | ||||
| @@ -119,7 +119,7 @@ def _algo_parameter_config(): | |||||
| set_algo_parameters_config_func_map = { | set_algo_parameters_config_func_map = { | ||||
| "simplify_cal": _algo_parameter_config().set_simplify_cal, | "simplify_cal": _algo_parameter_config().set_simplify_cal, | ||||
| "not_fully_use_devices": _algo_parameter_config().set_not_fully_use_devices, | |||||
| "fully_use_devices": _algo_parameter_config().set_fully_use_devices, | |||||
| "elementwise_op_strategy_follow": _algo_parameter_config().set_elementwise_op_strategy_follow, | "elementwise_op_strategy_follow": _algo_parameter_config().set_elementwise_op_strategy_follow, | ||||
| "tensor_slice_align_enable": _algo_parameter_config().set_tensor_slice_align_enable, | "tensor_slice_align_enable": _algo_parameter_config().set_tensor_slice_align_enable, | ||||
| "tensor_slice_align_size": _algo_parameter_config().set_tensor_slice_align_size} | "tensor_slice_align_size": _algo_parameter_config().set_tensor_slice_align_size} | ||||
| @@ -127,14 +127,14 @@ set_algo_parameters_config_func_map = { | |||||
| get_algo_parameters_config_func_map = { | get_algo_parameters_config_func_map = { | ||||
| "simplify_cal": _algo_parameter_config().get_simplify_cal, | "simplify_cal": _algo_parameter_config().get_simplify_cal, | ||||
| "not_fully_use_devices": _algo_parameter_config().get_not_fully_use_devices, | |||||
| "fully_use_devices": _algo_parameter_config().get_fully_use_devices, | |||||
| "elementwise_op_strategy_follow": _algo_parameter_config().get_elementwise_op_strategy_follow, | "elementwise_op_strategy_follow": _algo_parameter_config().get_elementwise_op_strategy_follow, | ||||
| "tensor_slice_align_enable": _algo_parameter_config().get_tensor_slice_align_enable, | "tensor_slice_align_enable": _algo_parameter_config().get_tensor_slice_align_enable, | ||||
| "tensor_slice_align_size": _algo_parameter_config().get_tensor_slice_align_size} | "tensor_slice_align_size": _algo_parameter_config().get_tensor_slice_align_size} | ||||
| @args_type_check(simplify_cal=bool, tensor_slice_align_enable=bool, tensor_slice_align_size=int, | @args_type_check(simplify_cal=bool, tensor_slice_align_enable=bool, tensor_slice_align_size=int, | ||||
| not_fully_use_devices=bool, elementwise_op_strategy_follow=bool) | |||||
| fully_use_devices=bool, elementwise_op_strategy_follow=bool) | |||||
| def set_algo_parameters(**kwargs): | def set_algo_parameters(**kwargs): | ||||
| """ | """ | ||||
| Set algo parameter config. | Set algo parameter config. | ||||
| @@ -146,7 +146,7 @@ def set_algo_parameters(**kwargs): | |||||
| simplify_cal (bool): Whether simplifying calculations in strategy-searching algorithm. Default: True | simplify_cal (bool): Whether simplifying calculations in strategy-searching algorithm. Default: True | ||||
| tensor_slice_align_enable (bool): Whether checking tensor slice shape. Default: False | tensor_slice_align_enable (bool): Whether checking tensor slice shape. Default: False | ||||
| tensor_slice_align_size (int): The minimum tensor slice shape, the value must be in [1, 1024]. Default: 16 | tensor_slice_align_size (int): The minimum tensor slice shape, the value must be in [1, 1024]. Default: 16 | ||||
| not_fully_use_devices (bool): Whether generating strategies that not fully use devices. Default: False | |||||
| fully_use_devices (bool): Whether generating strategies that fully use all available devices. Default: True | |||||
| elementwise_op_strategy_follow (bool): Whether the elementwise operator have the same strategies as its | elementwise_op_strategy_follow (bool): Whether the elementwise operator have the same strategies as its | ||||
| subsequent operators. Default: False | subsequent operators. Default: False | ||||
| @@ -100,7 +100,7 @@ def test_two_matmul(): | |||||
| set_algo_parameters(simplify_cal=True, | set_algo_parameters(simplify_cal=True, | ||||
| tensor_slice_align_enable=False, | tensor_slice_align_enable=False, | ||||
| tensor_slice_align_size=32, | tensor_slice_align_size=32, | ||||
| not_fully_use_devices=True, | |||||
| fully_use_devices=False, | |||||
| elementwise_op_strategy_follow=False) | elementwise_op_strategy_follow=False) | ||||
| para_simplify_cal = get_algo_parameters("simplify_cal") | para_simplify_cal = get_algo_parameters("simplify_cal") | ||||
| assert para_simplify_cal == True | assert para_simplify_cal == True | ||||
| @@ -108,8 +108,8 @@ def test_two_matmul(): | |||||
| assert para_slice_align_enable == False | assert para_slice_align_enable == False | ||||
| para_slice_align_size = get_algo_parameters("tensor_slice_align_size") | para_slice_align_size = get_algo_parameters("tensor_slice_align_size") | ||||
| assert para_slice_align_size == 32 | assert para_slice_align_size == 32 | ||||
| not_fully_use_devices = get_algo_parameters("not_fully_use_devices") | |||||
| assert not_fully_use_devices == True | |||||
| fully_use_devices = get_algo_parameters("fully_use_devices") | |||||
| assert fully_use_devices == False | |||||
| elementwise_op_strategy_follow = get_algo_parameters("elementwise_op_strategy_follow") | elementwise_op_strategy_follow = get_algo_parameters("elementwise_op_strategy_follow") | ||||
| assert elementwise_op_strategy_follow == False | assert elementwise_op_strategy_follow == False | ||||
| @@ -120,8 +120,8 @@ def test_two_matmul(): | |||||
| assert para_slice_align_enable == False | assert para_slice_align_enable == False | ||||
| para_slice_align_size = get_algo_parameters("tensor_slice_align_size") | para_slice_align_size = get_algo_parameters("tensor_slice_align_size") | ||||
| assert para_slice_align_size == 16 | assert para_slice_align_size == 16 | ||||
| not_fully_use_devices = get_algo_parameters("not_fully_use_devices") | |||||
| assert not_fully_use_devices == False | |||||
| fully_use_devices = get_algo_parameters("fully_use_devices") | |||||
| assert fully_use_devices == True | |||||
| elementwise_op_strategy_follow = get_algo_parameters("elementwise_op_strategy_follow") | elementwise_op_strategy_follow = get_algo_parameters("elementwise_op_strategy_follow") | ||||
| assert elementwise_op_strategy_follow == False | assert elementwise_op_strategy_follow == False | ||||
| @@ -576,7 +576,7 @@ def test_flatten_reshape2(parallel_mode="auto_parallel"): | |||||
| epoch_size = 2 | epoch_size = 2 | ||||
| context.reset_auto_parallel_context() | context.reset_auto_parallel_context() | ||||
| context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=8) | context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=8) | ||||
| set_algo_parameters(not_fully_use_devices=True) | |||||
| set_algo_parameters(fully_use_devices=False) | |||||
| net = ParallelReduceMeanNet(conv_in_channel=3, conv_out_channel=64, reducemean_axis=(2, 3), strategy=((4, 1, 1, 1),)) | net = ParallelReduceMeanNet(conv_in_channel=3, conv_out_channel=64, reducemean_axis=(2, 3), strategy=((4, 1, 1, 1),)) | ||||
| loss = CrossEntropyLoss() | loss = CrossEntropyLoss() | ||||
| predict = Tensor(np.ones([batch_size, 3, 32, 32]), dtype=ms.float32) | predict = Tensor(np.ones([batch_size, 3, 32, 32]), dtype=ms.float32) | ||||
| @@ -617,7 +617,7 @@ def test_flatten_reshape3(parallel_mode="auto_parallel"): | |||||
| epoch_size = 2 | epoch_size = 2 | ||||
| context.reset_auto_parallel_context() | context.reset_auto_parallel_context() | ||||
| context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=8) | context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=8) | ||||
| set_algo_parameters(not_fully_use_devices=True) | |||||
| set_algo_parameters(fully_use_devices=False) | |||||
| net = ParallelReshapeNet(dense_in_channel=2048, dense_out_channel=1000, shape=(128, 1000), strategy=((16, 1),)) | net = ParallelReshapeNet(dense_in_channel=2048, dense_out_channel=1000, shape=(128, 1000), strategy=((16, 1),)) | ||||
| loss = CrossEntropyLoss() | loss = CrossEntropyLoss() | ||||
| predict = Tensor(np.ones([batch_size, 1, 2, 1024]), dtype=ms.float32) | predict = Tensor(np.ones([batch_size, 1, 2, 1024]), dtype=ms.float32) | ||||
| @@ -646,7 +646,7 @@ def test_flatten_reshape4(parallel_mode="semi_auto_parallel"): | |||||
| epoch_size = 2 | epoch_size = 2 | ||||
| context.reset_auto_parallel_context() | context.reset_auto_parallel_context() | ||||
| context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=8) | context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=8) | ||||
| set_algo_parameters(not_fully_use_devices=True) | |||||
| set_algo_parameters(fully_use_devices=False) | |||||
| net = ParallelReduceMeanNet(conv_in_channel=3, conv_out_channel=64, reducemean_keep_dims=True, strategy=((4, 1, 1, 1),)) | net = ParallelReduceMeanNet(conv_in_channel=3, conv_out_channel=64, reducemean_keep_dims=True, strategy=((4, 1, 1, 1),)) | ||||
| loss = CrossEntropyLoss2() | loss = CrossEntropyLoss2() | ||||
| predict = Tensor(np.ones([batch_size, 3, 32, 32]), dtype=ms.float32) | predict = Tensor(np.ones([batch_size, 3, 32, 32]), dtype=ms.float32) | ||||