|
|
|
@@ -29,9 +29,14 @@ |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace parallel { |
|
|
|
// The first dimension of input can not be split. |
|
|
|
// The indices can not be split. |
|
|
|
// The strategy of input and the strategy of updates must be equal. |
|
|
|
// The first dimension of input or updates can not be split. |
|
|
|
// The first n dimensions(n is indices' dimension size) of updates can not be split. |
|
|
|
// The shape of input: [A, B, ..., M], the strategy of input: (1, b, ..., m) |
|
|
|
// The shape of indices: [N, O, ..., Z], the strategy of indices: (1, 1, ..., 1) |
|
|
|
// The shape of updates: [N, O, ..., Z, B, ..., M], the strategy of updates: (1, 1, ..., 1, b, ..., m) |
|
|
|
// The shape of output: [A, B, ..., M], the strategy of output: (1, b, ..., m) |
|
|
|
// The dev matrix: (1, b, ..., m) |
|
|
|
Status ScatterUpdateInfo::CheckStrategy(const StrategyPtr &strategy) { |
|
|
|
MS_EXCEPTION_IF_NULL(strategy); |
|
|
|
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { |
|
|
|
@@ -45,11 +50,6 @@ Status ScatterUpdateInfo::CheckStrategy(const StrategyPtr &strategy) { |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
if (stra[0] != stra[2]) { |
|
|
|
MS_LOG(ERROR) << name_ << ": The strategy[0] and strategy[2] must be equal"; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
if (stra[0].empty()) { |
|
|
|
MS_LOG(ERROR) << name_ << ": The strategy[0] is empty"; |
|
|
|
return FAILED; |
|
|
|
@@ -65,6 +65,16 @@ Status ScatterUpdateInfo::CheckStrategy(const StrategyPtr &strategy) { |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
if (stra[2].empty()) { |
|
|
|
MS_LOG(ERROR) << name_ << ": The strategy[2] is empty"; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
if (std::accumulate(stra[2].begin(), stra[2].begin() + stra[1].size(), 1, std::multiplies<int64_t>()) != 1) { |
|
|
|
MS_LOG(ERROR) << name_ << ": The indices can not be split"; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -81,22 +91,28 @@ Status ScatterUpdateInfo::InferDevMatrixShape() { |
|
|
|
} |
|
|
|
|
|
|
|
Status ScatterUpdateInfo::InferTensorMap() { |
|
|
|
TensorMap input_tensor_map; |
|
|
|
TensorMap indices_tensor_map(inputs_shape_[1].size(), MAP_NONE); |
|
|
|
if (inputs_shape_.size() != 3) { |
|
|
|
MS_LOG(ERROR) << name_ << "The size of inputs shape must be 3"; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
TensorMap input_tensor_map, updates_tensor_map; |
|
|
|
TensorMap indices_tensor_map(inputs_shape_[1].size(), MAP_NONE); |
|
|
|
|
|
|
|
// cannot use dev_matrix_shape_ replace inputs_shape_[0], because it may not be fully split in all devices. |
|
|
|
int64_t size = SizeToLong(inputs_shape_[0].size()); |
|
|
|
for (int64_t i = 0; i < size; ++i) { |
|
|
|
input_tensor_map.push_back(size - i - 1); |
|
|
|
} |
|
|
|
|
|
|
|
// updates_tensor_map = indices_tensor_map + input_tensor_map[1:] |
|
|
|
updates_tensor_map = indices_tensor_map; |
|
|
|
for (size_t i = 1; i < input_tensor_map.size(); ++i) { |
|
|
|
updates_tensor_map.push_back(input_tensor_map[i]); |
|
|
|
} |
|
|
|
inputs_tensor_map_.push_back(input_tensor_map); // input |
|
|
|
inputs_tensor_map_.push_back(indices_tensor_map); // indices |
|
|
|
inputs_tensor_map_.push_back(input_tensor_map); // updates |
|
|
|
inputs_tensor_map_.push_back(updates_tensor_map); // updates |
|
|
|
|
|
|
|
outputs_tensor_map_.push_back(input_tensor_map); |
|
|
|
return SUCCESS; |
|
|
|
@@ -169,9 +185,15 @@ Status ScatterUpdateInfo::GenerateStrategies(int64_t stage_id) { |
|
|
|
Strategys tmp_strategy; |
|
|
|
Dimensions first_input_strategy = sp->GetInputDim()[0]; |
|
|
|
Dimensions indices_strategy(inputs_shape_[1].size(), 1); |
|
|
|
// updates_strategy = indices_strategy + input_strategy[1:] |
|
|
|
Dimensions updates_strategy = indices_strategy; |
|
|
|
for (size_t i = 1; i < first_input_strategy.size(); ++i) { |
|
|
|
updates_strategy.push_back(first_input_strategy[i]); |
|
|
|
} |
|
|
|
|
|
|
|
tmp_strategy.push_back(first_input_strategy); // input |
|
|
|
tmp_strategy.push_back(indices_strategy); // indices |
|
|
|
tmp_strategy.push_back(first_input_strategy); // updates |
|
|
|
tmp_strategy.push_back(updates_strategy); // updates |
|
|
|
|
|
|
|
sp->ResetInputs(tmp_strategy); |
|
|
|
} |
|
|
|
|