Browse Source

fix reshape strategy search bug in auto parallel

tags/v1.1.0
yao_yf 5 years ago
parent
commit
4c1d4924cb
2 changed files with 3 additions and 1 deletions
  1. +1
    -0
      mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc
  2. +2
    -1
      mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_redistribution.cc

+ 1
- 0
mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc View File

@@ -423,6 +423,7 @@ void ReshapeInfo::SetCostForReshape(const mindspore::parallel::StrategyPtr &stra
std::make_shared<StrategyWithCost>(strategy, inputs_tensor_info_, outputs_tensor_info_);
swc->cost_list.push_back(result);
strategy_cost_.emplace_back(swc);
ResetQueueMember();
}

Status ReshapeInfo::GenerateStrategies(int32_t stage_id) {


+ 2
- 1
mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_redistribution.cc View File

@@ -223,7 +223,8 @@ Status TensorRedistribution::ComputeCost() {
} else {
prev_shape = from_.tensor_shape().array();
}
double prev_prod = std::accumulate(prev_shape.begin(), prev_shape.end(), 1, std::multiplies<int>());
double prev_prod =
std::accumulate(prev_shape.begin(), prev_shape.end(), static_cast<double>(1.0), std::multiplies<double>());
computation_cost_ += 2.0 * prev_prod;
memory_cost_ += 2.0 * prev_prod;
}


Loading…
Cancel
Save