|
|
|
@@ -149,45 +149,10 @@ Status TensorRedistribution::ComputeCost() { |
|
|
|
double prod = |
|
|
|
std::accumulate(slice_shape.begin(), slice_shape.end(), static_cast<double>(1.0), std::multiplies<double>()); |
|
|
|
std::string str = op.first; |
|
|
|
if (str == PERMUTE_BY_AXIS) { |
|
|
|
// Since AlltoAll is a virtual operator, the expanded operators are used here to compute cost. |
|
|
|
// communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape |
|
|
|
forward_comm_cost_ += prod * ALLTOALL_SCALE_FACTOR; |
|
|
|
backward_comm_cost_ += prod * ALLTOALL_SCALE_FACTOR; |
|
|
|
comm_cost_ += 2.0 * prod * ALLTOALL_SCALE_FACTOR; |
|
|
|
int32_t concat_dim = op.second[2]; |
|
|
|
if (concat_dim == 0) { |
|
|
|
// memory cost = all_gather |
|
|
|
computation_cost_ += prod; |
|
|
|
memory_cost_ += prod; |
|
|
|
} else { |
|
|
|
// memory cost = all_gather + split + concat |
|
|
|
int32_t dev_num = op.second[4]; |
|
|
|
computation_cost_ += (prod + prod * dev_num + prod * dev_num); |
|
|
|
memory_cost_ += (prod * dev_num + prod * dev_num + prod); |
|
|
|
} |
|
|
|
} else if (str == CONCAT_BY_AXIS) { |
|
|
|
// communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape |
|
|
|
// computation cost = before_slice_shape |
|
|
|
if (op.second.size() < 3) { |
|
|
|
MS_LOG(ERROR) << "op.second size should not be less than 3!"; |
|
|
|
return Status::FAILED; |
|
|
|
} |
|
|
|
double dev_num = op.second[2]; |
|
|
|
// here, communication cost = all_gather + reduce_scatter |
|
|
|
forward_comm_cost_ += prod * dev_num * ALLGATHER_REDUCESCATTER_SCALE_FACTOR; |
|
|
|
backward_comm_cost_ += prod * ALLGATHER_REDUCESCATTER_SCALE_FACTOR; |
|
|
|
comm_cost_ += prod * (dev_num + 1.0) * ALLGATHER_REDUCESCATTER_SCALE_FACTOR; |
|
|
|
int32_t concat_dim = op.second[0]; |
|
|
|
if (concat_dim == 0) { |
|
|
|
// computation cost = all_gather |
|
|
|
computation_cost_ += prod; |
|
|
|
memory_cost_ += prod * dev_num; |
|
|
|
} else { |
|
|
|
// computation cost = all_gather + split + concat |
|
|
|
computation_cost_ += (prod + prod * dev_num + prod * dev_num); |
|
|
|
memory_cost_ += (prod * dev_num + prod * dev_num + prod); |
|
|
|
} |
|
|
|
if (str == PERMUTE_BY_AXIS && ComputePermuteCost(prod, op.second) != Status::SUCCESS) { |
|
|
|
return Status::FAILED; |
|
|
|
} else if (str == CONCAT_BY_AXIS && ComputeConcatCost(prod, op.second) != Status::SUCCESS) { |
|
|
|
return Status::FAILED; |
|
|
|
} else { |
|
|
|
// There is only computation cost in SplitByAxis. |
|
|
|
// computation cost = before_slice_shape |
|
|
|
@@ -204,5 +169,55 @@ Status TensorRedistribution::ComputeCost() { |
|
|
|
} |
|
|
|
return Status::SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status TensorRedistribution::ComputePermuteCost(double input_size, Shape attrs) { |
|
|
|
// Since AlltoAll is a virtual operator, the expanded operators are used here to compute cost. |
|
|
|
// communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape |
|
|
|
if (attrs.size() < 4) { |
|
|
|
MS_LOG(ERROR) << "attrs size should not be less than 4!"; |
|
|
|
return Status::FAILED; |
|
|
|
} |
|
|
|
forward_comm_cost_ += input_size * ALLTOALL_SCALE_FACTOR; |
|
|
|
backward_comm_cost_ += input_size * ALLTOALL_SCALE_FACTOR; |
|
|
|
comm_cost_ += 2.0 * input_size * ALLTOALL_SCALE_FACTOR; |
|
|
|
int32_t concat_dim = attrs[2]; |
|
|
|
if (concat_dim == 0) { |
|
|
|
// memory cost = all_gather |
|
|
|
computation_cost_ += input_size; |
|
|
|
memory_cost_ += input_size; |
|
|
|
} else { |
|
|
|
// memory cost = all_gather + split + concat |
|
|
|
int32_t dev_num = attrs[4]; |
|
|
|
computation_cost_ += (input_size + input_size * dev_num + input_size * dev_num); |
|
|
|
memory_cost_ += (input_size * dev_num + input_size * dev_num + input_size); |
|
|
|
} |
|
|
|
return Status::SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status TensorRedistribution::ComputeConcatCost(double input_size, Shape attrs) { |
|
|
|
// communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape |
|
|
|
// computation cost = before_slice_shape |
|
|
|
if (attrs.size() < 3) { |
|
|
|
MS_LOG(ERROR) << "op.second size should not be less than 3!"; |
|
|
|
return Status::FAILED; |
|
|
|
} |
|
|
|
double dev_num = attrs[2]; |
|
|
|
// here, communication cost = all_gather + reduce_scatter |
|
|
|
forward_comm_cost_ += input_size * dev_num * ALLGATHER_REDUCESCATTER_SCALE_FACTOR; |
|
|
|
backward_comm_cost_ += input_size * ALLGATHER_REDUCESCATTER_SCALE_FACTOR; |
|
|
|
comm_cost_ += input_size * (dev_num + 1.0) * ALLGATHER_REDUCESCATTER_SCALE_FACTOR; |
|
|
|
int32_t concat_dim = attrs[0]; |
|
|
|
if (concat_dim == 0) { |
|
|
|
// computation cost = all_gather |
|
|
|
computation_cost_ += input_size; |
|
|
|
memory_cost_ += input_size * dev_num; |
|
|
|
} else { |
|
|
|
// computation cost = all_gather + split + concat |
|
|
|
computation_cost_ += (input_size + input_size * dev_num + input_size * dev_num); |
|
|
|
memory_cost_ += (input_size * dev_num + input_size * dev_num + input_size); |
|
|
|
} |
|
|
|
return Status::SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
} // namespace parallel |
|
|
|
} // namespace mindspore |