|
|
|
@@ -112,6 +112,12 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr |
|
|
|
return VisitKernelWithReturnType(cnode->input(kRealInputNodeIndexInTupleGetItem), IntToSize(item_idx)); |
|
|
|
} else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) { |
|
|
|
return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), 0); |
|
|
|
} else if (opt::IsNopNode(cnode)) { |
|
|
|
if (cnode->inputs().size() == 2) { |
|
|
|
return VisitKernelWithReturnType(cnode->input(1), 0); |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << cnode->DebugString() << "Invalid nop node"; |
|
|
|
} |
|
|
|
} else { |
|
|
|
return std::make_pair(anf_node, index); |
|
|
|
} |
|
|
|
@@ -299,20 +305,23 @@ std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t i |
|
|
|
return build_info->GetInputFormat(input_idx); |
|
|
|
} |
|
|
|
|
|
|
|
std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) { |
|
|
|
KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx) { |
|
|
|
MS_EXCEPTION_IF_NULL(anf_node); |
|
|
|
if (!anf_node->isa<CNode>()) { |
|
|
|
MS_LOG(EXCEPTION) << "anf_node is not CNode."; |
|
|
|
MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode."; |
|
|
|
} |
|
|
|
auto cnode = anf_node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
if (input_idx + 1 >= cnode->inputs().size()) { |
|
|
|
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode) |
|
|
|
<< "."; |
|
|
|
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode); |
|
|
|
} |
|
|
|
auto node = cnode->input(input_idx + 1); |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
KernelWithIndex kernel_with_index = VisitKernel(node, 0); |
|
|
|
return VisitKernel(node, 0); |
|
|
|
} |
|
|
|
|
|
|
|
std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) { |
|
|
|
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx); |
|
|
|
return AnfRuntimeAlgorithm::GetOutputFormat(kernel_with_index.first, kernel_with_index.second); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -346,18 +355,7 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &n |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<size_t> AnfRuntimeAlgorithm::GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
if (!node->isa<CNode>()) { |
|
|
|
MS_LOG(EXCEPTION) << "anf_node is not CNode."; |
|
|
|
} |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
if (input_idx + 1 >= cnode->inputs().size()) { |
|
|
|
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode) |
|
|
|
<< "."; |
|
|
|
} |
|
|
|
auto input_node = cnode->input(input_idx + 1); |
|
|
|
KernelWithIndex kernel_with_index = VisitKernel(input_node, 0); |
|
|
|
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx); |
|
|
|
return AnfRuntimeAlgorithm::GetOutputInferShape(kernel_with_index.first, kernel_with_index.second); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -459,17 +457,7 @@ TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_ |
|
|
|
} |
|
|
|
|
|
|
|
TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
if (!node->isa<CNode>()) { |
|
|
|
MS_LOG(EXCEPTION) << node->DebugString() << "is not a CNode"; |
|
|
|
} |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
if (input_idx + 1 >= cnode->inputs().size()) { |
|
|
|
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode); |
|
|
|
} |
|
|
|
auto input_node = cnode->input(input_idx + 1); |
|
|
|
KernelWithIndex kernel_with_index = VisitKernel(input_node, 0); |
|
|
|
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx); |
|
|
|
return AnfRuntimeAlgorithm::GetOutputInferDataType(kernel_with_index.first, kernel_with_index.second); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -492,17 +480,7 @@ TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_ |
|
|
|
} |
|
|
|
|
|
|
|
TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputDeviceDataType(const AnfNodePtr &anf_node, size_t input_idx) { |
|
|
|
if (!anf_node->isa<CNode>()) { |
|
|
|
MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode."; |
|
|
|
} |
|
|
|
auto cnode = anf_node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
if (input_idx + 1 >= cnode->inputs().size()) { |
|
|
|
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode); |
|
|
|
} |
|
|
|
auto node = cnode->input(input_idx + 1); |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
KernelWithIndex kernel_with_index = VisitKernel(node, 0); |
|
|
|
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx); |
|
|
|
return AnfRuntimeAlgorithm::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -558,32 +536,12 @@ bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_ |
|
|
|
} |
|
|
|
|
|
|
|
const DeviceAddress *AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(const AnfNodePtr &anf_node, size_t input_idx) { |
|
|
|
if (!anf_node->isa<CNode>()) { |
|
|
|
MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf node is not a CNode"; |
|
|
|
} |
|
|
|
auto cnode = anf_node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
if (input_idx + 1 >= cnode->inputs().size()) { |
|
|
|
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode); |
|
|
|
} |
|
|
|
auto node = cnode->input(input_idx + 1); |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
KernelWithIndex kernel_with_index = VisitKernel(node, 0); |
|
|
|
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx); |
|
|
|
return AnfRuntimeAlgorithm::GetOutputAddr(kernel_with_index.first, kernel_with_index.second); |
|
|
|
} |
|
|
|
|
|
|
|
DeviceAddressPtr AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx) { |
|
|
|
if (!anf_node->isa<CNode>()) { |
|
|
|
MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode."; |
|
|
|
} |
|
|
|
auto cnode = anf_node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
if (input_idx + 1 >= cnode->inputs().size()) { |
|
|
|
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode); |
|
|
|
} |
|
|
|
auto node = cnode->input(input_idx + 1); |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
KernelWithIndex kernel_with_index = VisitKernel(node, 0); |
|
|
|
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx); |
|
|
|
return AnfRuntimeAlgorithm::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second); |
|
|
|
} |
|
|
|
|
|
|
|
|