|
|
|
@@ -268,23 +268,12 @@ AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, |
|
|
|
return make_tuple; |
|
|
|
} |
|
|
|
|
|
|
|
bool NeedInsertSwitch() { |
|
|
|
auto context_ptr = MsContext::GetInstance(); |
|
|
|
MS_EXCEPTION_IF_NULL(context_ptr); |
|
|
|
return (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && |
|
|
|
ConfigManager::GetInstance().iter_num() > 1); |
|
|
|
} |
|
|
|
|
|
|
|
size_t LoadCtrlInputTensor(const std::shared_ptr<Context> &context, std::vector<tensor::TensorPtr> *inputs) { |
|
|
|
MS_EXCEPTION_IF_NULL(context); |
|
|
|
if (!NeedInsertSwitch()) { |
|
|
|
(void)context->results_.erase(kInputCtrlTensors); |
|
|
|
size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vector<tensor::TensorPtr> *inputs) { |
|
|
|
MS_LOG(INFO) << "Load kInputCtrlTensors"; |
|
|
|
auto inputs_params = graph->input_ctrl_tensors(); |
|
|
|
if (inputs_params == nullptr) { |
|
|
|
return 0; |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "Load kInputCtrlTensors"; |
|
|
|
auto inputs_params = |
|
|
|
context->GetResult(kInputCtrlTensors).cast<const std::shared_ptr<std::vector<tensor::TensorPtr>>>(); |
|
|
|
MS_EXCEPTION_IF_NULL(inputs_params); |
|
|
|
if (inputs_params->empty()) { |
|
|
|
MS_LOG(EXCEPTION) << "Illegal empty inputs_params"; |
|
|
|
} |
|
|
|
@@ -689,11 +678,10 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap |
|
|
|
const std::vector<tensor::TensorPtr> &inputs_const) const { |
|
|
|
std::vector<tensor::TensorPtr> inputs(inputs_const); |
|
|
|
size_t input_ctrl_size = 1; |
|
|
|
MS_EXCEPTION_IF_NULL(context_); |
|
|
|
if (context_->HasResult(kInputCtrlTensors)) { |
|
|
|
input_ctrl_size = LoadCtrlInputTensor(context_, &inputs); |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph); |
|
|
|
if (kernel_graph->input_ctrl_tensors()) { |
|
|
|
input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs); |
|
|
|
} |
|
|
|
auto input_nodes = kernel_graph->inputs(); |
|
|
|
if ((inputs.size() + input_ctrl_size) - 1 != input_nodes.size()) { |
|
|
|
MS_LOG(EXCEPTION) << "tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size() |
|
|
|
|