|
|
|
@@ -456,8 +456,7 @@ Status ReshapeInfo::GenetateStrategyCosts(const std::vector<std::shared_ptr<Stra |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
TensorInfo pre_out_tensor_info = pre_out_tensor_infos[out_index]; |
|
|
|
TensorLayout pre_out_tensor_layout = pre_out_tensor_info.tensor_layout(); |
|
|
|
SetInputLayout(pre_out_tensor_layout); |
|
|
|
SetInputLayout(pre_out_tensor_info.tensor_layout()); |
|
|
|
// infer pre_node output strategy from output_layout. |
|
|
|
Dimensions stra = pre_out_tensor_info.InferStrategy(); |
|
|
|
if (stra.empty()) { |
|
|
|
@@ -481,15 +480,17 @@ Status ReshapeInfo::GenetateStrategyCosts(const std::vector<std::shared_ptr<Stra |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
TensorInfo next_in_tensor_info = next_in_tensor_infos[in_index]; |
|
|
|
TensorLayout next_in_tensor_layout = next_in_tensor_info.tensor_layout(); |
|
|
|
SetOutputLayout(next_in_tensor_layout); |
|
|
|
SetOutputLayout(next_in_tensor_info.tensor_layout()); |
|
|
|
if (Init(nullptr) == FAILED) { |
|
|
|
MS_LOG(ERROR) << "Failure:operator reshape init failed"; |
|
|
|
return FAILED; |
|
|
|
MS_LOG(DEBUG) << "Failure:operator reshape init failed"; |
|
|
|
continue; |
|
|
|
} |
|
|
|
SetCostForReshape(reshape_stra); |
|
|
|
} |
|
|
|
} |
|
|
|
if (strategy_cost_.empty()) { |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
} // namespace parallel |
|
|
|
|