|
|
|
@@ -191,16 +191,28 @@ void CPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &grap |
|
|
|
run_op_graphs_[graph_info] = kernel_graph; |
|
|
|
} |
|
|
|
|
|
|
|
void CPUSession::SetOutputFlags(const VectorRef &base_ref, std::vector<tensor::TensorPtr> *outputs_tensors) { |
|
|
|
void CPUSession::SetOutputFlags(const VectorRef &base_ref) { |
|
|
|
for (size_t i = 0; i < base_ref.size(); ++i) { |
|
|
|
if (utils::isa<VectorRef>(base_ref[i])) { |
|
|
|
auto ref_iter = utils::cast<VectorRef>(base_ref[i]); |
|
|
|
SetOutputFlags(ref_iter, outputs_tensors); |
|
|
|
SetOutputFlags(ref_iter); |
|
|
|
} else if (utils::isa<tensor::TensorPtr>(base_ref[i])) { |
|
|
|
auto tensor_ptr = utils::cast<std::shared_ptr<tensor::Tensor>>(base_ref[i]); |
|
|
|
tensor_ptr->SetNeedWait(false); |
|
|
|
tensor_ptr->data_sync(false); |
|
|
|
outputs_tensors->push_back(tensor_ptr); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void CPUSession::UpdateDynamicOutputShape(const std::map<tensor::TensorPtr, KernelWithIndex> &tensor_to_node) { |
|
|
|
for (const auto &tensor_node : tensor_to_node) { |
|
|
|
if (AnfAlgo::IsDynamicShape(tensor_node.second.first)) { |
|
|
|
const auto &kernel = tensor_node.second.first; |
|
|
|
const auto &output_index = tensor_node.second.second; |
|
|
|
const auto &shape = AnfAlgo::GetOutputInferShape(kernel, output_index); |
|
|
|
std::vector<int64_t> refresh_shape; |
|
|
|
(void)std::copy(shape.begin(), shape.end(), std::back_inserter(refresh_shape)); |
|
|
|
tensor_node.first->set_shape(refresh_shape); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -236,9 +248,12 @@ void CPUSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, |
|
|
|
if (!ret) { |
|
|
|
MS_LOG(EXCEPTION) << "Run Op failed"; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<tensor::TensorPtr> output_tensors; |
|
|
|
SetOutputFlags(*outputs, &output_tensors); |
|
|
|
UpdateDynamicOutputShape(tensor_to_node); |
|
|
|
// update output abstract of dynamic op to op_run_info |
|
|
|
if (op_run_info->is_dynamic_shape) { |
|
|
|
UpdateOutputAbstract(kernel_graph, op_run_info); |
|
|
|
} |
|
|
|
SetOutputFlags(*outputs); |
|
|
|
runtime_.RunOpClearMemory(kernel_graph.get()); |
|
|
|
MS_LOG(INFO) << "Run Op end"; |
|
|
|
} |
|
|
|
|