|
|
|
@@ -375,7 +375,7 @@ std::vector<std::string> AnfRuntimeAlgorithm::GetAllOutputFormats(const AnfNodeP |
|
|
|
MS_LOG(EXCEPTION) << "Not real kernel:" |
|
|
|
<< "#node [" << node->DebugString() << "]"; |
|
|
|
} |
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info); |
|
|
|
auto build_info = kernel_info->select_kernel_build_info(); |
|
|
|
MS_EXCEPTION_IF_NULL(build_info); |
|
|
|
@@ -389,7 +389,7 @@ std::vector<std::string> AnfRuntimeAlgorithm::GetAllInputFormats(const AnfNodePt |
|
|
|
MS_LOG(EXCEPTION) << "Not real kernel:" |
|
|
|
<< "#node [" << node->DebugString() << "]"; |
|
|
|
} |
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info); |
|
|
|
auto build_info = kernel_info->select_kernel_build_info(); |
|
|
|
MS_EXCEPTION_IF_NULL(build_info); |
|
|
|
@@ -403,7 +403,7 @@ std::string AnfRuntimeAlgorithm::GetOriginDataFormat(const AnfNodePtr &node) { |
|
|
|
MS_LOG(EXCEPTION) << "Not real kernel:" |
|
|
|
<< "#node [" << node->DebugString() << "]"; |
|
|
|
} |
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info); |
|
|
|
auto build_info = kernel_info->select_kernel_build_info(); |
|
|
|
MS_EXCEPTION_IF_NULL(build_info); |
|
|
|
@@ -421,7 +421,7 @@ std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t |
|
|
|
if (!AnfAlgo::IsRealKernel(node)) { |
|
|
|
return AnfAlgo::GetPrevNodeOutputFormat(node, output_idx); |
|
|
|
} |
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info); |
|
|
|
auto build_info = kernel_info->select_kernel_build_info(); |
|
|
|
MS_EXCEPTION_IF_NULL(build_info); |
|
|
|
@@ -443,7 +443,7 @@ std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t i |
|
|
|
if (!IsRealKernel(node)) { |
|
|
|
return GetPrevNodeOutputFormat(node, input_idx); |
|
|
|
} |
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info); |
|
|
|
auto build_info = kernel_info->select_kernel_build_info(); |
|
|
|
MS_EXCEPTION_IF_NULL(build_info); |
|
|
|
@@ -549,7 +549,7 @@ std::vector<Axis> AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &nod |
|
|
|
if (!IsRealKernel(node)) { |
|
|
|
return GetPrevNodeOutputReshapeType(node, input_idx); |
|
|
|
} |
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info); |
|
|
|
auto build_info = kernel_info->select_kernel_build_info(); |
|
|
|
MS_EXCEPTION_IF_NULL(build_info); |
|
|
|
@@ -568,7 +568,7 @@ std::vector<Axis> AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &no |
|
|
|
if (!IsRealKernel(node)) { |
|
|
|
return GetPrevNodeOutputReshapeType(node, output_idx); |
|
|
|
} |
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info); |
|
|
|
auto build_info = kernel_info->select_kernel_build_info(); |
|
|
|
MS_EXCEPTION_IF_NULL(build_info); |
|
|
|
@@ -624,7 +624,7 @@ TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size |
|
|
|
if (!IsRealKernel(node)) { |
|
|
|
return GetPrevNodeOutputDeviceDataType(node, output_idx); |
|
|
|
} |
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info); |
|
|
|
auto build_info = kernel_info->select_kernel_build_info(); |
|
|
|
MS_EXCEPTION_IF_NULL(build_info); |
|
|
|
@@ -645,7 +645,7 @@ TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_ |
|
|
|
if (!IsRealKernel(node)) { |
|
|
|
return GetPrevNodeOutputDeviceDataType(node, 0); |
|
|
|
} |
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info); |
|
|
|
auto build_info = kernel_info->select_kernel_build_info(); |
|
|
|
MS_EXCEPTION_IF_NULL(build_info); |
|
|
|
@@ -675,7 +675,7 @@ const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node, |
|
|
|
MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node"; |
|
|
|
} |
|
|
|
} |
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info); |
|
|
|
auto addr = kernel_info->GetOutputAddr(output_idx); |
|
|
|
if (addr == nullptr) { |
|
|
|
@@ -697,7 +697,8 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &nod |
|
|
|
MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node."; |
|
|
|
} |
|
|
|
} |
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
// Critical path performance optimization: `KernelInfo` is unique subclass of `KernelInfoDevice` |
|
|
|
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info); |
|
|
|
auto addr = kernel_info->GetMutableOutputAddr(output_idx); |
|
|
|
if (addr == nullptr) { |
|
|
|
@@ -710,11 +711,8 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &nod |
|
|
|
// get output device addr of anf_node |
|
|
|
bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_idx) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
if (output_idx > GetOutputTensorNum(node)) { |
|
|
|
MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " |
|
|
|
<< GetOutputTensorNum(node) << "#node:[ " << node->DebugString() << "]"; |
|
|
|
} |
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
// Critical path performance optimization: `KernelInfo` is unique subclass of `KernelInfoDevice` |
|
|
|
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info); |
|
|
|
return kernel_info->OutputAddrExist(output_idx); |
|
|
|
} |
|
|
|
@@ -734,7 +732,7 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(const AnfNode |
|
|
|
// set output device addr of anf_node |
|
|
|
void AnfRuntimeAlgorithm::SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info); |
|
|
|
if (!kernel_info->SetOutputAddr(addr, output_idx)) { |
|
|
|
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail"; |
|
|
|
@@ -744,7 +742,7 @@ void AnfRuntimeAlgorithm::SetOutputAddr(const DeviceAddressPtr &addr, size_t out |
|
|
|
// set workspace device addr of anf_node |
|
|
|
void AnfRuntimeAlgorithm::SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info); |
|
|
|
if (!kernel_info->SetWorkspaceAddr(addr, output_idx)) { |
|
|
|
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail"; |
|
|
|
@@ -754,7 +752,7 @@ void AnfRuntimeAlgorithm::SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t |
|
|
|
// get workspace device addr of anf_node |
|
|
|
DeviceAddress *AnfRuntimeAlgorithm::GetWorkspaceAddr(const AnfNodePtr &node, size_t output_idx) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info); |
|
|
|
auto addr = kernel_info->GetWorkspaceAddr(output_idx); |
|
|
|
if (addr == nullptr) { |
|
|
|
@@ -767,7 +765,7 @@ DeviceAddress *AnfRuntimeAlgorithm::GetWorkspaceAddr(const AnfNodePtr &node, siz |
|
|
|
// get workspace device mutable addr of anf_node |
|
|
|
DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableWorkspaceAddr(const AnfNodePtr &node, size_t index) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info); |
|
|
|
auto addr = kernel_info->GetMutableWorkspaceAddr(index); |
|
|
|
if (addr == nullptr) { |
|
|
|
@@ -810,7 +808,7 @@ void AnfRuntimeAlgorithm::CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_ |
|
|
|
|
|
|
|
kernel::OpPattern AnfRuntimeAlgorithm::GetOpPattern(const AnfNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info); |
|
|
|
// select_kernel_build_info() has checked whether return pointer is null |
|
|
|
auto build_info = kernel_info->select_kernel_build_info(); |
|
|
|
@@ -821,7 +819,7 @@ kernel::OpPattern AnfRuntimeAlgorithm::GetOpPattern(const AnfNodePtr &node) { |
|
|
|
// get KernelBuildType of node, such as ATT,RT,FWK and so on |
|
|
|
KernelType AnfRuntimeAlgorithm::GetKernelType(const AnfNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info); |
|
|
|
// select_kernel_build_info() has checked whether return pointer is null |
|
|
|
auto build_info = kernel_info->select_kernel_build_info(); |
|
|
|
@@ -831,7 +829,7 @@ KernelType AnfRuntimeAlgorithm::GetKernelType(const AnfNodePtr &node) { |
|
|
|
|
|
|
|
kernel::Processor AnfRuntimeAlgorithm::GetProcessor(const AnfNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info); |
|
|
|
auto build_info = kernel_info->select_kernel_build_info(); |
|
|
|
MS_EXCEPTION_IF_NULL(build_info); |
|
|
|
@@ -840,7 +838,7 @@ kernel::Processor AnfRuntimeAlgorithm::GetProcessor(const AnfNodePtr &node) { |
|
|
|
|
|
|
|
kernel::FusionType AnfRuntimeAlgorithm::GetFusionType(const AnfNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info); |
|
|
|
auto build_info = kernel_info->select_kernel_build_info(); |
|
|
|
MS_EXCEPTION_IF_NULL(build_info); |
|
|
|
@@ -850,7 +848,7 @@ kernel::FusionType AnfRuntimeAlgorithm::GetFusionType(const AnfNodePtr &node) { |
|
|
|
// set select kernel_build_info |
|
|
|
void AnfRuntimeAlgorithm::SetSelectKernelBuildInfo(const KernelBuildInfoPtr &select_kernel_build_info, AnfNode *node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info); |
|
|
|
return kernel_info->set_select_kernel_build_info(select_kernel_build_info); |
|
|
|
} |
|
|
|
@@ -858,7 +856,7 @@ void AnfRuntimeAlgorithm::SetSelectKernelBuildInfo(const KernelBuildInfoPtr &sel |
|
|
|
// get select kernel_build_info |
|
|
|
KernelBuildInfoPtr AnfRuntimeAlgorithm::GetSelectKernelBuildInfo(const AnfNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info); |
|
|
|
return kernel_info->GetMutableSelectKernelBuildInfo(); |
|
|
|
} |
|
|
|
@@ -866,7 +864,7 @@ KernelBuildInfoPtr AnfRuntimeAlgorithm::GetSelectKernelBuildInfo(const AnfNodePt |
|
|
|
// get kernelMode |
|
|
|
KernelMod *AnfRuntimeAlgorithm::GetKernelMod(const AnfNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info); |
|
|
|
return kernel_info->MutableKernelMod(); |
|
|
|
} |
|
|
|
@@ -874,7 +872,7 @@ KernelMod *AnfRuntimeAlgorithm::GetKernelMod(const AnfNodePtr &node) { |
|
|
|
// set kernel mod |
|
|
|
void AnfRuntimeAlgorithm::SetKernelMod(const KernelModPtr &kernel_mod, AnfNode *node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info); |
|
|
|
kernel_info->set_kernel_mod(kernel_mod); |
|
|
|
} |
|
|
|
@@ -940,42 +938,42 @@ bool AnfRuntimeAlgorithm::IsParameterWeight(const ParameterPtr &node) { |
|
|
|
|
|
|
|
void AnfRuntimeAlgorithm::SetStreamId(uint32_t stream_id, AnfNode *node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info); |
|
|
|
kernel_info->set_stream_id(stream_id); |
|
|
|
} |
|
|
|
|
|
|
|
uint32_t AnfRuntimeAlgorithm::GetStreamId(const AnfNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info); |
|
|
|
return kernel_info->stream_id(); |
|
|
|
} |
|
|
|
|
|
|
|
void AnfRuntimeAlgorithm::SetStreamDistinctionLabel(uint32_t stream_label, AnfNode *node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info); |
|
|
|
kernel_info->set_stream_distinction_label(stream_label); |
|
|
|
} |
|
|
|
|
|
|
|
uint32_t AnfRuntimeAlgorithm::GetStreamDistinctionLabel(const AnfNode *node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto kernel_info = dynamic_cast<const device::KernelInfo *>(node->kernel_info()); |
|
|
|
auto kernel_info = static_cast<const device::KernelInfo *>(node->kernel_info()); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info); |
|
|
|
return kernel_info->stream_distinction_label(); |
|
|
|
} |
|
|
|
|
|
|
|
void AnfRuntimeAlgorithm::SetGraphId(uint32_t graph_id, AnfNode *node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info()); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info); |
|
|
|
kernel_info->set_graph_id(graph_id); |
|
|
|
} |
|
|
|
|
|
|
|
uint32_t AnfRuntimeAlgorithm::GetGraphId(const AnfNode *node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto kernel_info = dynamic_cast<const device::KernelInfo *>(node->kernel_info()); |
|
|
|
auto kernel_info = static_cast<const device::KernelInfo *>(node->kernel_info()); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info); |
|
|
|
return kernel_info->graph_id(); |
|
|
|
} |
|
|
|
@@ -1003,7 +1001,7 @@ bool AnfRuntimeAlgorithm::IsFeatureMapOutput(const AnfNodePtr &node) { |
|
|
|
if (node->isa<ValueNode>()) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
auto kernel_info = dynamic_cast<const device::KernelInfo *>(node->kernel_info()); |
|
|
|
auto kernel_info = static_cast<const device::KernelInfo *>(node->kernel_info()); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info); |
|
|
|
return kernel_info->is_feature_map(); |
|
|
|
} |
|
|
|
|