|
|
|
@@ -52,6 +52,7 @@ PyObject *GetParamDefaultInputTensor(const AnfNodePtr &node) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(parameter->default_param()); |
|
|
|
MS_EXCEPTION_IF_NULL(param_value); |
|
|
|
auto py_param = param_value->value(); |
|
|
|
return py_param.ptr(); |
|
|
|
} |
|
|
|
@@ -69,7 +70,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne |
|
|
|
} |
|
|
|
if (node->isa<Parameter>()) { |
|
|
|
for (size_t input_idx = 0; input_idx < graph.inputs().size(); input_idx++) { |
|
|
|
if (input_idx > input_tensors.size()) { |
|
|
|
if (input_idx >= input_tensors.size()) { |
|
|
|
MS_LOG(EXCEPTION) << "input idx:" << input_idx << "out of range:" << input_tensors.size(); |
|
|
|
} |
|
|
|
if (graph.inputs()[input_idx] == node) { |
|
|
|
@@ -149,6 +150,8 @@ BaseRef CreatTupleForOutput(const AnfNodePtr &anf, const KernelGraph &graph, |
|
|
|
} |
|
|
|
|
|
|
|
ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(anf); |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
auto value_node = anf->cast<ValueNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(value_node); |
|
|
|
auto value = value_node->value(); |
|
|
|
@@ -229,6 +232,7 @@ ValueNodePtr ConstructRunOpValueNode(const std::shared_ptr<KernelGraph> &graph, |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(input_tensor); |
|
|
|
auto value_node = std::make_shared<ValueNode>(input_tensor); |
|
|
|
MS_EXCEPTION_IF_NULL(value_node); |
|
|
|
// construct abstract of value node |
|
|
|
auto type_of_tensor = input_tensor->Dtype(); |
|
|
|
auto shape_of_tensor = input_tensor->shape(); |
|
|
|
@@ -242,6 +246,7 @@ ValueNodePtr ConstructRunOpValueNode(const std::shared_ptr<KernelGraph> &graph, |
|
|
|
|
|
|
|
ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph, const tensor::TensorPtr &input_tensor, |
|
|
|
int tensor_mask) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
auto param = graph->NewParameter(); |
|
|
|
MS_EXCEPTION_IF_NULL(param); |
|
|
|
if (tensor_mask == kParameterWeightTensorMask) { |
|
|
|
@@ -295,6 +300,7 @@ void DumpGraphOutput(const Any &any, size_t recurse_level = 0) { |
|
|
|
} |
|
|
|
|
|
|
|
bool ExistSummaryNode(const KernelGraph *graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
auto ret = graph->get_return(); |
|
|
|
MS_EXCEPTION_IF_NULL(ret); |
|
|
|
auto all_nodes = DeepLinkedGraphSearch(ret); |
|
|
|
@@ -315,7 +321,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf |
|
|
|
if (!anf->isa<Parameter>()) { |
|
|
|
MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter"; |
|
|
|
} |
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
auto m_tensor = GetParamDefaultInputTensor(anf); |
|
|
|
auto valid_inputs = graph->MutableValidInputs(); |
|
|
|
MS_EXCEPTION_IF_NULL(valid_inputs); |
|
|
|
@@ -344,6 +350,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf |
|
|
|
|
|
|
|
AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(anf); |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]"; |
|
|
|
auto parameters = CreateParameterFromTuple(anf, valid_input, graph); |
|
|
|
if (parameters.empty()) { |
|
|
|
@@ -482,6 +489,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) |
|
|
|
|
|
|
|
ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(anf); |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
auto value_node = anf->cast<ValueNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(value_node); |
|
|
|
auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(anf); |
|
|
|
@@ -509,6 +517,7 @@ ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, Ker |
|
|
|
|
|
|
|
ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(anf); |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
if (!anf->isa<Parameter>()) { |
|
|
|
MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter"; |
|
|
|
} |
|
|
|
@@ -536,6 +545,7 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph |
|
|
|
KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { |
|
|
|
std::unordered_map<AnfNodePtr, AnfNodePtr> other_graph_cnode; |
|
|
|
auto graph = NewKernelGraph(); |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_LOG(INFO) << "Create graph: " << graph->graph_id(); |
|
|
|
size_t from_other_graph_depend_num = 0; |
|
|
|
for (const auto &node : lst) { |
|
|
|
@@ -585,6 +595,7 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP |
|
|
|
MS_EXCEPTION_IF_NULL(all_out_graph); |
|
|
|
auto node_list = TopoSort(func_graph->get_return()); |
|
|
|
auto graph = NewKernelGraph(); |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
front_backend_graph_map_[func_graph] = graph; |
|
|
|
MS_LOG(INFO) << "Create graph: " << graph->graph_id(); |
|
|
|
|
|
|
|
@@ -724,8 +735,8 @@ void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_grap |
|
|
|
} |
|
|
|
auto anf_outputs = kernel_graph->outputs(); |
|
|
|
for (auto &item : anf_outputs) { |
|
|
|
MS_LOG(INFO) << "update output[" << item->DebugString() << "]"; |
|
|
|
MS_EXCEPTION_IF_NULL(item); |
|
|
|
MS_LOG(INFO) << "update output[" << item->DebugString() << "]"; |
|
|
|
if (AnfAlgo::IsTupleOutput(item) && AnfAlgo::IsRealKernel(item)) { |
|
|
|
outputs->emplace_back(CreatTupleForOutput(item, *kernel_graph, input_tensors)); |
|
|
|
continue; |
|
|
|
@@ -761,6 +772,7 @@ void SessionBasic::GetSummaryNodes(KernelGraph *graph) { |
|
|
|
auto node = cnode->input(kSummaryGetItem); |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true); |
|
|
|
MS_EXCEPTION_IF_NULL(item_with_index.first); |
|
|
|
if (!AnfAlgo::IsRealKernel(item_with_index.first)) { |
|
|
|
MS_LOG(EXCEPTION) << "Unexpected node:" << item_with_index.first->DebugString(); |
|
|
|
} |
|
|
|
@@ -812,6 +824,7 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std: |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
std::vector<AnfNodePtr> output_args; |
|
|
|
for (const auto &output : outputs) { |
|
|
|
MS_EXCEPTION_IF_NULL(output); |
|
|
|
MS_LOG(INFO) << "output:" << output->DebugString(); |
|
|
|
} |
|
|
|
auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr { |
|
|
|
@@ -883,7 +896,9 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInf |
|
|
|
} |
|
|
|
auto parameter = ConstructRunOpParameter(graph, input_tensors[i], tensors_mask[i]); |
|
|
|
inputs.push_back(parameter); |
|
|
|
graph->MutableInputs()->push_back(parameter); |
|
|
|
auto mutable_inputs = graph->MutableInputs(); |
|
|
|
MS_EXCEPTION_IF_NULL(mutable_inputs); |
|
|
|
mutable_inputs->push_back(parameter); |
|
|
|
} |
|
|
|
// set execution order |
|
|
|
auto cnode = graph->NewCNode(inputs); |
|
|
|
|