|
|
|
@@ -33,6 +33,29 @@ |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace parallel { |
|
|
|
std::string StrategyToString(const Strategys &strategy) { |
|
|
|
std::string strategy_str = ""; |
|
|
|
strategy_str += "("; |
|
|
|
for (size_t i = 0; i < strategy.size(); ++i) { |
|
|
|
strategy_str += "("; |
|
|
|
for (size_t j = 0; j < strategy[i].size(); ++j) { |
|
|
|
strategy_str += std::to_string(strategy[i][j]); |
|
|
|
if (j != strategy[i].size() - 1) { |
|
|
|
strategy_str += ", "; |
|
|
|
} |
|
|
|
} |
|
|
|
strategy_str += ")"; |
|
|
|
if (i != strategy.size() - 1) { |
|
|
|
strategy_str += ", "; |
|
|
|
} |
|
|
|
} |
|
|
|
if (strategy.size() == 1) { |
|
|
|
strategy_str += ","; |
|
|
|
} |
|
|
|
strategy_str += ")"; |
|
|
|
return strategy_str; |
|
|
|
} |
|
|
|
|
|
|
|
Status OperatorInfo::CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape) { |
|
|
|
if (strategy == nullptr) { |
|
|
|
MS_LOG(ERROR) << name_ << ": The strategy is null."; |
|
|
|
@@ -41,18 +64,18 @@ Status OperatorInfo::CheckStrategyValue(const StrategyPtr &strategy, const Shape |
|
|
|
|
|
|
|
size_t strategy_size = strategy->GetInputNumber(); |
|
|
|
size_t inputs_shape_size = inputs_shape.size(); |
|
|
|
Strategys stra = strategy->GetInputDim(); |
|
|
|
if (strategy_size != inputs_shape_size) { |
|
|
|
if (is_auto_parallel_) { |
|
|
|
MS_LOG(DEBUG) << name_ << ": Strategy size: " << strategy_size |
|
|
|
MS_LOG(DEBUG) << name_ << ": The strategy is " << StrategyToString(stra) << ", strategy size: " << strategy_size |
|
|
|
<< " is not equal to inputs size: " << inputs_shape_size; |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << name_ << ": Strategy size: " << strategy_size |
|
|
|
MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(stra) << ", strategy size: " << strategy_size |
|
|
|
<< " is not equal to inputs size: " << inputs_shape_size; |
|
|
|
} |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
Strategys stra = strategy->GetInputDim(); |
|
|
|
for (size_t i = 0; i < strategy_size; ++i) { |
|
|
|
Shape sub_strategy = stra.at(i); |
|
|
|
Shape sub_input_shape = inputs_shape.at(i); |
|
|
|
@@ -60,11 +83,11 @@ Status OperatorInfo::CheckStrategyValue(const StrategyPtr &strategy, const Shape |
|
|
|
size_t inputs_len = sub_input_shape.size(); |
|
|
|
if (strategy_len != inputs_len) { |
|
|
|
if (is_auto_parallel_) { |
|
|
|
MS_LOG(DEBUG) << name_ << ": Strategy len: " << strategy_len << " is not equal to inputs len: " << inputs_len |
|
|
|
<< ", index: " << i; |
|
|
|
MS_LOG(DEBUG) << name_ << ": The strategy is " << StrategyToString(stra) << ", strategy len: " << strategy_len |
|
|
|
<< " is not equal to inputs len: " << inputs_len << ", index: " << i; |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << name_ << ": Strategy len: " << strategy_len << " is not equal to inputs len: " << inputs_len |
|
|
|
<< ", index: " << i; |
|
|
|
MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(stra) << ", strategy len: " << strategy_len |
|
|
|
<< " is not equal to inputs len: " << inputs_len << ", index: " << i; |
|
|
|
} |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
@@ -73,18 +96,22 @@ Status OperatorInfo::CheckStrategyValue(const StrategyPtr &strategy, const Shape |
|
|
|
int64_t strategy_value = sub_strategy.at(j); |
|
|
|
if (strategy_value < MIN_SLICE_NUM) { |
|
|
|
if (is_auto_parallel_) { |
|
|
|
MS_LOG(DEBUG) << name_ << ": Invalid strategy value: " << strategy_value; |
|
|
|
MS_LOG(DEBUG) << name_ << ": The strategy is " << StrategyToString(stra) |
|
|
|
<< ", the value of strategy must be larger than 0, but get " << strategy_value; |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << name_ << ": Invalid strategy value: " << strategy_value; |
|
|
|
MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(stra) |
|
|
|
<< ", the value of strategy must be larger than 0, but get " << strategy_value; |
|
|
|
} |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
if ((LongToUlong(strategy_value) & LongToUlong(strategy_value - 1)) != 0) { |
|
|
|
if (is_auto_parallel_) { |
|
|
|
MS_LOG(DEBUG) << name_ << ": Invalid Strategy value it is not the power of 2, " << strategy_value; |
|
|
|
MS_LOG(DEBUG) << name_ << ": The strategy is " << StrategyToString(stra) |
|
|
|
<< ", the value of strategy must be the power of 2, but get " << strategy_value; |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << name_ << ": Invalid Strategy value it is not the power of 2, " << strategy_value; |
|
|
|
MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(stra) |
|
|
|
<< ", the value of strategy must be the power of 2, but get " << strategy_value; |
|
|
|
} |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
@@ -92,9 +119,11 @@ Status OperatorInfo::CheckStrategyValue(const StrategyPtr &strategy, const Shape |
|
|
|
int64_t shape_value = sub_input_shape.at(j); |
|
|
|
if ((shape_value % strategy_value) != 0) { |
|
|
|
if (is_auto_parallel_) { |
|
|
|
MS_LOG(DEBUG) << name_ << ": Shape " << shape_value << " cannot be divisible by strategy " << strategy_value; |
|
|
|
MS_LOG(DEBUG) << name_ << ": The strategy is " << StrategyToString(stra) << ", shape " << shape_value |
|
|
|
<< " cannot be divisible by strategy value " << strategy_value; |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << name_ << ": Shape " << shape_value << " cannot be divisible by strategy " << strategy_value; |
|
|
|
MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(stra) << ", shape " << shape_value |
|
|
|
<< " cannot be divisible by strategy value " << strategy_value; |
|
|
|
} |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
@@ -184,7 +213,9 @@ Status OperatorInfo::InferRepeatedCalcInfo() { |
|
|
|
} else if (g_dev_list_size % dev_matrix_size == 0) { |
|
|
|
repeated_calc_num_ = ((int64_t)(g_dev_list_size / dev_matrix_size)); |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << name_ << ": Dev list size " << g_dev_list_size << " can not be divisible by dev matrix size " |
|
|
|
MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(strategy_->GetInputDim()) << ", it requires " |
|
|
|
<< dev_matrix_size << " devices, " |
|
|
|
<< "but the device number of this stage is " << g_dev_list_size << ", it can not be divisible by " |
|
|
|
<< dev_matrix_size; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|