Browse Source

!304 [Auto parallel] Change 'NOT_FULLY_USE_DEVICES' to 'FULLY_USE_DEVICES' and make ALL-1 user-specified-strategy valid in auto-parallel

Merge pull request !304 from Xiaoda/modify-not-fully-use-devices-and-strategy-checking
tags/v0.2.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
ba55a8ed0b
12 changed files with 55 additions and 65 deletions
  1. +3
    -3
      mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc
  2. +7
    -7
      mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc
  3. +2
    -2
      mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h
  4. +2
    -2
      mindspore/ccsrc/parallel/costmodel_context.cc
  5. +5
    -5
      mindspore/ccsrc/parallel/costmodel_context.h
  6. +1
    -1
      mindspore/ccsrc/parallel/ops_info/matmul_info.cc
  7. +1
    -1
      mindspore/ccsrc/parallel/ops_info/operator_info.cc
  8. +14
    -24
      mindspore/ccsrc/parallel/step_auto_parallel.cc
  9. +4
    -4
      mindspore/ccsrc/pipeline/init.cc
  10. +8
    -8
      mindspore/parallel/algo_parameter_config.py
  11. +5
    -5
      tests/ut/python/parallel/test_auto_parallel_two_matmul.py
  12. +3
    -3
      tests/ut/python/parallel/test_reshape.py

+ 3
- 3
mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc View File

@@ -85,10 +85,10 @@ Status Edge::InitEdgeCost() {
} }
} }
if (!has_available_cost) { if (!has_available_cost) {
if (!NOT_FULLY_USE_DEVICES) {
if (FULLY_USE_DEVICES) {
MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_ MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_
<< " failed, it may be caused by setting 'not_fully_use_devices' false. Try to set "
"'not_fully_use_devices' true.";
<< " failed, it may be caused by setting 'fully_use_devices' true. Try to set "
"'fully_use_devices' false.";
} else if (ELEMENTWISE_OP_STRA_FOLLOW) { } else if (ELEMENTWISE_OP_STRA_FOLLOW) {
MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_ MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_
<< " failed, it may be caused by setting 'elementwise_op_strategy_follow' true. " << " failed, it may be caused by setting 'elementwise_op_strategy_follow' true. "


+ 7
- 7
mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc View File

@@ -36,7 +36,7 @@ double COST_MODEL_COMMUNI_CONST = DEFAULT_COST_MODEL_COMMUNI_CONST;
double COST_MODEL_COMMUNI_BIAS = DEFAULT_COST_MODEL_COMMUNI_BIAS; double COST_MODEL_COMMUNI_BIAS = DEFAULT_COST_MODEL_COMMUNI_BIAS;
bool TENSOR_SLICE_ALIGNMENT_ENABLE = DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE; bool TENSOR_SLICE_ALIGNMENT_ENABLE = DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE;
size_t TENSOR_SLICE_ALIGNMENT_SIZE = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE; size_t TENSOR_SLICE_ALIGNMENT_SIZE = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE;
bool NOT_FULLY_USE_DEVICES = DEFAULT_NOT_FULLY_USE_DEVICES;
bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES;
bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW;


void CostGraph::SetDeviceMemoryAndCostParameter() { void CostGraph::SetDeviceMemoryAndCostParameter() {
@@ -125,13 +125,13 @@ void CostGraph::SetDeviceMemoryAndCostParameter() {
TENSOR_SLICE_ALIGNMENT_SIZE = align_size; TENSOR_SLICE_ALIGNMENT_SIZE = align_size;
MS_LOG(INFO) << "tensor_slice_align_size: " << TENSOR_SLICE_ALIGNMENT_SIZE << "."; MS_LOG(INFO) << "tensor_slice_align_size: " << TENSOR_SLICE_ALIGNMENT_SIZE << ".";


// NOT_FULLY_USE_DEVICES
auto not_fully_devices = CostModelContext::GetInstance()->not_fully_use_device();
NOT_FULLY_USE_DEVICES = not_fully_devices;
if (NOT_FULLY_USE_DEVICES) {
MS_LOG(INFO) << "not_fully_use_devices: true.";
// FULLY_USE_DEVICES
auto fully_devices = CostModelContext::GetInstance()->fully_use_device();
FULLY_USE_DEVICES = fully_devices;
if (FULLY_USE_DEVICES) {
MS_LOG(INFO) << "fully_use_devices: true.";
} else { } else {
MS_LOG(INFO) << "not_fully_use_devices: false.";
MS_LOG(INFO) << "fully_use_devices: false.";
} }


// ELEMENTWISE_OP_STRA_FOLLOW // ELEMENTWISE_OP_STRA_FOLLOW


+ 2
- 2
mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h View File

@@ -42,7 +42,7 @@ namespace parallel {
#define DEFAULT_COST_MODEL_COMMUNI_BIAS 1024.0 #define DEFAULT_COST_MODEL_COMMUNI_BIAS 1024.0
#define DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE false #define DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE false
#define DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE 16 #define DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE 16
#define DEFAULT_NOT_FULLY_USE_DEVICES false
#define DEFAULT_FULLY_USE_DEVICES true
#define DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW false #define DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW false


class CostGraph; class CostGraph;
@@ -57,7 +57,7 @@ extern double COST_MODEL_COMMUNI_CONST;
extern double COST_MODEL_COMMUNI_BIAS; extern double COST_MODEL_COMMUNI_BIAS;
extern bool TENSOR_SLICE_ALIGNMENT_ENABLE; extern bool TENSOR_SLICE_ALIGNMENT_ENABLE;
extern size_t TENSOR_SLICE_ALIGNMENT_SIZE; extern size_t TENSOR_SLICE_ALIGNMENT_SIZE;
extern bool NOT_FULLY_USE_DEVICES;
extern bool FULLY_USE_DEVICES;
extern bool ELEMENTWISE_OP_STRA_FOLLOW; extern bool ELEMENTWISE_OP_STRA_FOLLOW;


class CostGraph { class CostGraph {


+ 2
- 2
mindspore/ccsrc/parallel/costmodel_context.cc View File

@@ -60,7 +60,7 @@ void CostModelContext::ResetAlgoParameters() {
costmodel_simplify_cal_ = DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION; costmodel_simplify_cal_ = DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION;
tensor_slice_alignment_enable_ = DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE; tensor_slice_alignment_enable_ = DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE;
tensor_slice_alignment_size_ = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE; tensor_slice_alignment_size_ = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE;
not_fully_use_device_ = DEFAULT_NOT_FULLY_USE_DEVICES;
fully_use_device_ = DEFAULT_FULLY_USE_DEVICES;
elementwise_stra_follow_ = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; elementwise_stra_follow_ = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW;
} }


@@ -118,7 +118,7 @@ void CostModelContext::set_tensor_slice_alignment_size(size_t ts_align_size) {
tensor_slice_alignment_size_ = ts_align_size; tensor_slice_alignment_size_ = ts_align_size;
} }


void CostModelContext::set_not_fully_use_device(bool not_fully_use) { not_fully_use_device_ = not_fully_use; }
void CostModelContext::set_fully_use_device(bool fully_use) { fully_use_device_ = fully_use; }


void CostModelContext::set_elementwise_stra_follow(bool elementwise_follow) { void CostModelContext::set_elementwise_stra_follow(bool elementwise_follow) {
elementwise_stra_follow_ = elementwise_follow; elementwise_stra_follow_ = elementwise_follow;


+ 5
- 5
mindspore/ccsrc/parallel/costmodel_context.h View File

@@ -102,9 +102,9 @@ class CostModelContext {
void set_tensor_slice_alignment_size(size_t); void set_tensor_slice_alignment_size(size_t);
size_t tensor_slice_alignment_size() const { return tensor_slice_alignment_size_; } size_t tensor_slice_alignment_size() const { return tensor_slice_alignment_size_; }


// NOT_FULLY_USE_DEVICES
void set_not_fully_use_device(bool);
bool not_fully_use_device() const { return not_fully_use_device_; }
// FULLY_USE_DEVICES
void set_fully_use_device(bool);
bool fully_use_device() const { return fully_use_device_; }


// ELEMENTWISE_OP_STRA_FOLLOW // ELEMENTWISE_OP_STRA_FOLLOW
void set_elementwise_stra_follow(bool); void set_elementwise_stra_follow(bool);
@@ -158,8 +158,8 @@ class CostModelContext {
// TENSOR_SLICE_ALIGNMENT_SIZE // TENSOR_SLICE_ALIGNMENT_SIZE
size_t tensor_slice_alignment_size_; size_t tensor_slice_alignment_size_;


// NOT_FULLY_USE_DEVICES
bool not_fully_use_device_;
// FULLY_USE_DEVICES
bool fully_use_device_;


// ELEMENTWISE_OP_STRA_FOLLOW // ELEMENTWISE_OP_STRA_FOLLOW
bool elementwise_stra_follow_; bool elementwise_stra_follow_;


+ 1
- 1
mindspore/ccsrc/parallel/ops_info/matmul_info.cc View File

@@ -465,7 +465,7 @@ Status MatMulBase::PrepareStrategy(int32_t stage_id, size_t dev_num,
mindspore::parallel::Dimensions combined_partitions, size_t input0_shape_size, mindspore::parallel::Dimensions combined_partitions, size_t input0_shape_size,
size_t input1_shape_size, mindspore::parallel::StrategyPtr* const sp) { size_t input1_shape_size, mindspore::parallel::StrategyPtr* const sp) {
int32_t product = std::accumulate(combined_partitions.begin(), combined_partitions.end(), 1, std::multiplies<int>()); int32_t product = std::accumulate(combined_partitions.begin(), combined_partitions.end(), 1, std::multiplies<int>());
if (NOT_FULLY_USE_DEVICES) {
if (!FULLY_USE_DEVICES) {
if (IntToSize(product) > dev_num) { if (IntToSize(product) > dev_num) {
return FAILED; return FAILED;
} }


+ 1
- 1
mindspore/ccsrc/parallel/ops_info/operator_info.cc View File

@@ -675,7 +675,7 @@ Status PrepareStrategyBase(int32_t stage_id, size_t dev_num, const Shapes& input
for (auto& input_partition : inputs_partitions) { for (auto& input_partition : inputs_partitions) {
product *= std::accumulate(input_partition.begin(), input_partition.end(), 1, std::multiplies<int>()); product *= std::accumulate(input_partition.begin(), input_partition.end(), 1, std::multiplies<int>());
} }
if (NOT_FULLY_USE_DEVICES) {
if (!FULLY_USE_DEVICES) {
if (IntToSize(product) > dev_num) { if (IntToSize(product) > dev_num) {
return FAILED; return FAILED;
} }


+ 14
- 24
mindspore/ccsrc/parallel/step_auto_parallel.cc View File

@@ -110,8 +110,6 @@ std::vector<std::string> splittable_op_ = {MATMUL,
std::vector<std::string> elementwise_op_ = {ACTIVATION, GELU, TANH, SOFTMAX, LOG_SOFTMAX, RELU, SQRT, std::vector<std::string> elementwise_op_ = {ACTIVATION, GELU, TANH, SOFTMAX, LOG_SOFTMAX, RELU, SQRT,
CAST, POW, EXP, LOG, COS, ACOS, LOGICALNOT}; CAST, POW, EXP, LOG, COS, ACOS, LOGICALNOT};


std::vector<std::string> ignore_manual_strategy_op_ = {BATCH_NORM};

bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) { bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
MS_EXCEPTION_IF_NULL(root); MS_EXCEPTION_IF_NULL(root);
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
@@ -308,16 +306,6 @@ std::vector<TypePtr> ExtractOutputTypeByNode(const CNodePtr &node) {
return outputs_type; return outputs_type;
} }


// Be careful the argument is cnode_full_name, not the op_name
bool IsIgnoreStrategyOperator(const std::string &cnode_full_name) {
for (auto &ignore_op : ignore_manual_strategy_op_) {
if (cnode_full_name.find(ignore_op) != std::string::npos) {
return true;
}
}
return false;
}

bool IsElementWiseOperator(const std::string &op_name) { bool IsElementWiseOperator(const std::string &op_name) {
auto iter = std::find(elementwise_op_.begin(), elementwise_op_.end(), op_name); auto iter = std::find(elementwise_op_.begin(), elementwise_op_.end(), op_name);
return (iter != elementwise_op_.end()); return (iter != elementwise_op_.end());
@@ -414,18 +402,20 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
// Set cost for this configured strategy // Set cost for this configured strategy
if (operator_info->SetCostUnderStrategy(strategyPtr) != SUCCESS) { if (operator_info->SetCostUnderStrategy(strategyPtr) != SUCCESS) {
MS_LOG(EXCEPTION) << "Failure: operator " << prim->name() << " SetCostUnderStrategy failed"; MS_LOG(EXCEPTION) << "Failure: operator " << prim->name() << " SetCostUnderStrategy failed";
} else if (!NOT_FULLY_USE_DEVICES) {
if (!IsIgnoreStrategyOperator(cnode->fullname_with_scope())) {
// If configured to fully use devices, then checking for the user-specified strategy
int32_t used_devices = operator_info->used_devices();
MS_EXCEPTION_IF_NULL(g_device_manager);
auto total_device_num = g_device_manager->GetDeviceListByStageId(0).size();
// 'used_devices == -1' means that 'used_devices_' is not set
if ((used_devices == -1) || IntToSize(used_devices) != total_device_num) {
MS_LOG(EXCEPTION) << "In configuration 'NOT_FULLY_USE_DEVICES' = False, "
<< "but the specified strategy uses device: " << used_devices
<< ", total devices: " << total_device_num;
}
} else if (FULLY_USE_DEVICES) {
// If configured to fully use devices, then checking for the user-specified strategy
int32_t used_devices = operator_info->used_devices();
MS_EXCEPTION_IF_NULL(g_device_manager);
auto total_device_num = g_device_manager->GetDeviceListByStageId(0).size();
// 'used_devices == 1' means that ALL-1 strategy, which is valid in auto-parallel
if (used_devices == 1) {
return operator_info;
}
// 'used_devices == -1' means that 'used_devices_' is not set
if ((used_devices == -1) || IntToSize(used_devices) != total_device_num) {
MS_LOG(EXCEPTION) << "In configuration 'FULLY_USE_DEVICES' = True, "
<< "but the specified strategy uses device: " << used_devices
<< ", total devices: " << total_device_num;
} }
} }
} }


+ 4
- 4
mindspore/ccsrc/pipeline/init.cc View File

@@ -261,10 +261,10 @@ PYBIND11_MODULE(_c_expression, m) {
"Set the parameter tensor_slice_size in strategy generation.") "Set the parameter tensor_slice_size in strategy generation.")
.def("get_tensor_slice_align_size", &CostModelContext::tensor_slice_alignment_size, .def("get_tensor_slice_align_size", &CostModelContext::tensor_slice_alignment_size,
"Get the parameter tensor_slice_size in strategy generation.") "Get the parameter tensor_slice_size in strategy generation.")
.def("set_not_fully_use_devices", &CostModelContext::set_not_fully_use_device,
"Set the parameter not_fully_use_devices in the DP algorithm.")
.def("get_not_fully_use_devices", &CostModelContext::not_fully_use_device,
"Get the parameter not_fully_use_devices in the DP algorithm.")
.def("set_fully_use_devices", &CostModelContext::set_fully_use_device,
"Set the parameter fully_use_devices in the DP algorithm.")
.def("get_fully_use_devices", &CostModelContext::fully_use_device,
"Get the parameter fully_use_devices in the DP algorithm.")
.def("set_elementwise_op_strategy_follow", &CostModelContext::set_elementwise_stra_follow, .def("set_elementwise_op_strategy_follow", &CostModelContext::set_elementwise_stra_follow,
"Set the parameter elementwise_op_strategy_follow in the DP algorithm.") "Set the parameter elementwise_op_strategy_follow in the DP algorithm.")
.def("get_elementwise_op_strategy_follow", &CostModelContext::elementwise_stra_follow, .def("get_elementwise_op_strategy_follow", &CostModelContext::elementwise_stra_follow,


+ 8
- 8
mindspore/parallel/algo_parameter_config.py View File

@@ -53,13 +53,13 @@ class _AlgoParameterConfig():
self.check_config_handle() self.check_config_handle()
return self._config_handle.get_simplify_cal() return self._config_handle.get_simplify_cal()


def set_not_fully_use_devices(self, not_fully):
def set_fully_use_devices(self, not_fully):
self.check_config_handle() self.check_config_handle()
self._config_handle.set_not_fully_use_devices(not_fully)
self._config_handle.set_fully_use_devices(not_fully)


def get_not_fully_use_devices(self):
def get_fully_use_devices(self):
self.check_config_handle() self.check_config_handle()
return self._config_handle.get_not_fully_use_devices()
return self._config_handle.get_fully_use_devices()


def set_elementwise_op_strategy_follow(self, element_strategy_follow): def set_elementwise_op_strategy_follow(self, element_strategy_follow):
self.check_config_handle() self.check_config_handle()
@@ -119,7 +119,7 @@ def _algo_parameter_config():


set_algo_parameters_config_func_map = { set_algo_parameters_config_func_map = {
"simplify_cal": _algo_parameter_config().set_simplify_cal, "simplify_cal": _algo_parameter_config().set_simplify_cal,
"not_fully_use_devices": _algo_parameter_config().set_not_fully_use_devices,
"fully_use_devices": _algo_parameter_config().set_fully_use_devices,
"elementwise_op_strategy_follow": _algo_parameter_config().set_elementwise_op_strategy_follow, "elementwise_op_strategy_follow": _algo_parameter_config().set_elementwise_op_strategy_follow,
"tensor_slice_align_enable": _algo_parameter_config().set_tensor_slice_align_enable, "tensor_slice_align_enable": _algo_parameter_config().set_tensor_slice_align_enable,
"tensor_slice_align_size": _algo_parameter_config().set_tensor_slice_align_size} "tensor_slice_align_size": _algo_parameter_config().set_tensor_slice_align_size}
@@ -127,14 +127,14 @@ set_algo_parameters_config_func_map = {


get_algo_parameters_config_func_map = { get_algo_parameters_config_func_map = {
"simplify_cal": _algo_parameter_config().get_simplify_cal, "simplify_cal": _algo_parameter_config().get_simplify_cal,
"not_fully_use_devices": _algo_parameter_config().get_not_fully_use_devices,
"fully_use_devices": _algo_parameter_config().get_fully_use_devices,
"elementwise_op_strategy_follow": _algo_parameter_config().get_elementwise_op_strategy_follow, "elementwise_op_strategy_follow": _algo_parameter_config().get_elementwise_op_strategy_follow,
"tensor_slice_align_enable": _algo_parameter_config().get_tensor_slice_align_enable, "tensor_slice_align_enable": _algo_parameter_config().get_tensor_slice_align_enable,
"tensor_slice_align_size": _algo_parameter_config().get_tensor_slice_align_size} "tensor_slice_align_size": _algo_parameter_config().get_tensor_slice_align_size}




@args_type_check(simplify_cal=bool, tensor_slice_align_enable=bool, tensor_slice_align_size=int, @args_type_check(simplify_cal=bool, tensor_slice_align_enable=bool, tensor_slice_align_size=int,
not_fully_use_devices=bool, elementwise_op_strategy_follow=bool)
fully_use_devices=bool, elementwise_op_strategy_follow=bool)
def set_algo_parameters(**kwargs): def set_algo_parameters(**kwargs):
""" """
Set algo parameter config. Set algo parameter config.
@@ -146,7 +146,7 @@ def set_algo_parameters(**kwargs):
simplify_cal (bool): Whether simplifying calculations in strategy-searching algorithm. Default: True simplify_cal (bool): Whether simplifying calculations in strategy-searching algorithm. Default: True
tensor_slice_align_enable (bool): Whether checking tensor slice shape. Default: False tensor_slice_align_enable (bool): Whether checking tensor slice shape. Default: False
tensor_slice_align_size (int): The minimum tensor slice shape, the value must be in [1, 1024]. Default: 16 tensor_slice_align_size (int): The minimum tensor slice shape, the value must be in [1, 1024]. Default: 16
not_fully_use_devices (bool): Whether generating strategies that not fully use devices. Default: False
fully_use_devices (bool): Whether generating strategies that fully use all available devices. Default: True
elementwise_op_strategy_follow (bool): Whether the elementwise operator have the same strategies as its elementwise_op_strategy_follow (bool): Whether the elementwise operator have the same strategies as its
subsequent operators. Default: False subsequent operators. Default: False




+ 5
- 5
tests/ut/python/parallel/test_auto_parallel_two_matmul.py View File

@@ -100,7 +100,7 @@ def test_two_matmul():
set_algo_parameters(simplify_cal=True, set_algo_parameters(simplify_cal=True,
tensor_slice_align_enable=False, tensor_slice_align_enable=False,
tensor_slice_align_size=32, tensor_slice_align_size=32,
not_fully_use_devices=True,
fully_use_devices=False,
elementwise_op_strategy_follow=False) elementwise_op_strategy_follow=False)
para_simplify_cal = get_algo_parameters("simplify_cal") para_simplify_cal = get_algo_parameters("simplify_cal")
assert para_simplify_cal == True assert para_simplify_cal == True
@@ -108,8 +108,8 @@ def test_two_matmul():
assert para_slice_align_enable == False assert para_slice_align_enable == False
para_slice_align_size = get_algo_parameters("tensor_slice_align_size") para_slice_align_size = get_algo_parameters("tensor_slice_align_size")
assert para_slice_align_size == 32 assert para_slice_align_size == 32
not_fully_use_devices = get_algo_parameters("not_fully_use_devices")
assert not_fully_use_devices == True
fully_use_devices = get_algo_parameters("fully_use_devices")
assert fully_use_devices == False
elementwise_op_strategy_follow = get_algo_parameters("elementwise_op_strategy_follow") elementwise_op_strategy_follow = get_algo_parameters("elementwise_op_strategy_follow")
assert elementwise_op_strategy_follow == False assert elementwise_op_strategy_follow == False


@@ -120,8 +120,8 @@ def test_two_matmul():
assert para_slice_align_enable == False assert para_slice_align_enable == False
para_slice_align_size = get_algo_parameters("tensor_slice_align_size") para_slice_align_size = get_algo_parameters("tensor_slice_align_size")
assert para_slice_align_size == 16 assert para_slice_align_size == 16
not_fully_use_devices = get_algo_parameters("not_fully_use_devices")
assert not_fully_use_devices == False
fully_use_devices = get_algo_parameters("fully_use_devices")
assert fully_use_devices == True
elementwise_op_strategy_follow = get_algo_parameters("elementwise_op_strategy_follow") elementwise_op_strategy_follow = get_algo_parameters("elementwise_op_strategy_follow")
assert elementwise_op_strategy_follow == False assert elementwise_op_strategy_follow == False




+ 3
- 3
tests/ut/python/parallel/test_reshape.py View File

@@ -576,7 +576,7 @@ def test_flatten_reshape2(parallel_mode="auto_parallel"):
epoch_size = 2 epoch_size = 2
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=8) context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=8)
set_algo_parameters(not_fully_use_devices=True)
set_algo_parameters(fully_use_devices=False)
net = ParallelReduceMeanNet(conv_in_channel=3, conv_out_channel=64, reducemean_axis=(2, 3), strategy=((4, 1, 1, 1),)) net = ParallelReduceMeanNet(conv_in_channel=3, conv_out_channel=64, reducemean_axis=(2, 3), strategy=((4, 1, 1, 1),))
loss = CrossEntropyLoss() loss = CrossEntropyLoss()
predict = Tensor(np.ones([batch_size, 3, 32, 32]), dtype=ms.float32) predict = Tensor(np.ones([batch_size, 3, 32, 32]), dtype=ms.float32)
@@ -617,7 +617,7 @@ def test_flatten_reshape3(parallel_mode="auto_parallel"):
epoch_size = 2 epoch_size = 2
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=8) context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=8)
set_algo_parameters(not_fully_use_devices=True)
set_algo_parameters(fully_use_devices=False)
net = ParallelReshapeNet(dense_in_channel=2048, dense_out_channel=1000, shape=(128, 1000), strategy=((16, 1),)) net = ParallelReshapeNet(dense_in_channel=2048, dense_out_channel=1000, shape=(128, 1000), strategy=((16, 1),))
loss = CrossEntropyLoss() loss = CrossEntropyLoss()
predict = Tensor(np.ones([batch_size, 1, 2, 1024]), dtype=ms.float32) predict = Tensor(np.ones([batch_size, 1, 2, 1024]), dtype=ms.float32)
@@ -646,7 +646,7 @@ def test_flatten_reshape4(parallel_mode="semi_auto_parallel"):
epoch_size = 2 epoch_size = 2
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=8) context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=8)
set_algo_parameters(not_fully_use_devices=True)
set_algo_parameters(fully_use_devices=False)
net = ParallelReduceMeanNet(conv_in_channel=3, conv_out_channel=64, reducemean_keep_dims=True, strategy=((4, 1, 1, 1),)) net = ParallelReduceMeanNet(conv_in_channel=3, conv_out_channel=64, reducemean_keep_dims=True, strategy=((4, 1, 1, 1),))
loss = CrossEntropyLoss2() loss = CrossEntropyLoss2()
predict = Tensor(np.ones([batch_size, 3, 32, 32]), dtype=ms.float32) predict = Tensor(np.ones([batch_size, 3, 32, 32]), dtype=ms.float32)


Loading…
Cancel
Save