| @@ -38,7 +38,7 @@ void CPUKernelFactory::Register(const std::string &kernel_name, const KernelAttr | |||
| } | |||
| std::shared_ptr<CPUKernel> CPUKernelFactory::Create(const std::string &kernel_name, const CNodePtr &apply_kernel) { | |||
| auto kernel_info = apply_kernel->kernel_info(); | |||
| auto kernel_info = dynamic_cast<device::KernelInfo *>(apply_kernel->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| const KernelBuildInfo *kernel_build_Info = kernel_info->select_kernel_build_info(); | |||
| MS_EXCEPTION_IF_NULL(kernel_build_Info); | |||
| @@ -137,7 +137,7 @@ std::pair<bool, size_t> GpuKernelFactory::GpuKernelAttrCheck(const std::string & | |||
| } | |||
| GpuKernel *GpuKernelFactory::Create(const std::string &kernel_name, const CNodePtr &apply_kernel) { | |||
| auto kernel_info = apply_kernel->kernel_info(); | |||
| auto kernel_info = dynamic_cast<device::KernelInfo *>(apply_kernel->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| const KernelBuildInfo *kernel_build_Info = kernel_info->select_kernel_build_info(); | |||
| MS_EXCEPTION_IF_NULL(kernel_build_Info); | |||
| @@ -63,7 +63,7 @@ const AnfNodePtr ParamTransRoad(const FuncGraphPtr &func_graph, const AnfNodePtr | |||
| kernel::KernelBuildInfoPtr GetKernelBuildInfo(const CNodePtr &cast, const string &format, TypeId input_type, | |||
| TypeId output_type) { | |||
| MS_EXCEPTION_IF_NULL(cast); | |||
| auto kernel_info = cast->kernel_info(); | |||
| auto kernel_info = dynamic_cast<device::KernelInfo *>(cast->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| auto cast_build_info = kernel_info->select_kernel_build_info(); | |||
| MS_EXCEPTION_IF_NULL(cast_build_info); | |||
| @@ -23,8 +23,8 @@ namespace { | |||
| bool CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(main); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto main_kernel_info = main->kernel_info(); | |||
| auto node_kernel_info = node->kernel_info(); | |||
| auto main_kernel_info = dynamic_cast<device::KernelInfo *>(main->kernel_info()); | |||
| auto node_kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); | |||
| if (main_kernel_info == nullptr && node_kernel_info == nullptr) { | |||
| return true; | |||
| } | |||
| @@ -338,7 +338,7 @@ std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t | |||
| if (!AnfAlgo::IsRealKernel(node)) { | |||
| return AnfAlgo::GetPrevNodeOutputFormat(node, output_idx); | |||
| } | |||
| auto kernel_info = node->kernel_info(); | |||
| auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| auto build_info = kernel_info->select_kernel_build_info(); | |||
| MS_EXCEPTION_IF_NULL(build_info); | |||
| @@ -360,7 +360,7 @@ std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t i | |||
| if (!IsRealKernel(node)) { | |||
| GetPrevNodeOutputFormat(node, input_idx); | |||
| } | |||
| auto kernel_info = node->kernel_info(); | |||
| auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| auto build_info = kernel_info->select_kernel_build_info(); | |||
| MS_EXCEPTION_IF_NULL(build_info); | |||
| @@ -467,7 +467,7 @@ std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNode | |||
| if (!IsRealKernel(node)) { | |||
| return GetPrevNodeOutputReshapeType(node, input_idx); | |||
| } | |||
| auto kernel_info = node->kernel_info(); | |||
| auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| auto build_info = kernel_info->select_kernel_build_info(); | |||
| MS_EXCEPTION_IF_NULL(build_info); | |||
| @@ -486,7 +486,7 @@ std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNod | |||
| if (!IsRealKernel(node)) { | |||
| return GetPrevNodeOutputReshapeType(node, output_idx); | |||
| } | |||
| auto kernel_info = node->kernel_info(); | |||
| auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| auto build_info = kernel_info->select_kernel_build_info(); | |||
| MS_EXCEPTION_IF_NULL(build_info); | |||
| @@ -546,7 +546,7 @@ TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size | |||
| if (!IsRealKernel(node)) { | |||
| return GetPrevNodeOutputDeviceDataType(node, output_idx); | |||
| } | |||
| auto kernel_info = node->kernel_info(); | |||
| auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| auto build_info = kernel_info->select_kernel_build_info(); | |||
| MS_EXCEPTION_IF_NULL(build_info); | |||
| @@ -567,7 +567,7 @@ TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_ | |||
| if (!IsRealKernel(node)) { | |||
| return GetPrevNodeOutputDeviceDataType(node, 0); | |||
| } | |||
| auto kernel_info = node->kernel_info(); | |||
| auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| auto build_info = kernel_info->select_kernel_build_info(); | |||
| MS_EXCEPTION_IF_NULL(build_info); | |||
| @@ -597,7 +597,7 @@ const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node, | |||
| MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node"; | |||
| } | |||
| } | |||
| auto kernel_info = node->kernel_info(); | |||
| auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| auto addr = kernel_info->GetOutputAddr(output_idx); | |||
| if (addr == nullptr) { | |||
| @@ -619,7 +619,7 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &nod | |||
| MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node."; | |||
| } | |||
| } | |||
| auto kernel_info = node->kernel_info(); | |||
| auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| auto addr = kernel_info->GetMutableOutputAddr(output_idx); | |||
| if (addr == nullptr) { | |||
| @@ -636,7 +636,7 @@ bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_ | |||
| MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " | |||
| << GetOutputTensorNum(node) << "#node:[ " << node->DebugString() << "]"; | |||
| } | |||
| auto kernel_info = node->kernel_info(); | |||
| auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| return kernel_info->OutputAddrExist(output_idx); | |||
| } | |||
| @@ -656,7 +656,7 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(const AnfNode | |||
| // set output device addr of anf_node | |||
| void AnfRuntimeAlgorithm::SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto kernel_info = node->kernel_info(); | |||
| auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| if (!kernel_info->SetOutputAddr(addr, output_idx)) { | |||
| MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail"; | |||
| @@ -666,7 +666,7 @@ void AnfRuntimeAlgorithm::SetOutputAddr(const DeviceAddressPtr &addr, size_t out | |||
| // set workspace device addr of anf_node | |||
| void AnfRuntimeAlgorithm::SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto kernel_info = node->kernel_info(); | |||
| auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| if (!kernel_info->SetWorkspaceAddr(addr, output_idx)) { | |||
| MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail"; | |||
| @@ -676,7 +676,7 @@ void AnfRuntimeAlgorithm::SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t | |||
| // get workspace device addr of anf_node | |||
| DeviceAddress *AnfRuntimeAlgorithm::GetWorkspaceAddr(const AnfNodePtr &node, size_t output_idx) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto kernel_info = node->kernel_info(); | |||
| auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| auto addr = kernel_info->GetWorkspaceAddr(output_idx); | |||
| if (addr == nullptr) { | |||
| @@ -720,7 +720,7 @@ void AnfRuntimeAlgorithm::CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_ | |||
| kernel::OpPattern AnfRuntimeAlgorithm::GetOpPattern(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto kernel_info = node->kernel_info(); | |||
| auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| // select_kernel_build_info() has checked whether return pointer is null | |||
| auto build_info = kernel_info->select_kernel_build_info(); | |||
| @@ -731,7 +731,7 @@ kernel::OpPattern AnfRuntimeAlgorithm::GetOpPattern(const AnfNodePtr &node) { | |||
| // get KernelBuildType of node, such as ATT,RT,FWK and so on | |||
| KernelType AnfRuntimeAlgorithm::GetKernelType(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto kernel_info = node->kernel_info(); | |||
| auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| // select_kernel_build_info() has checked whether return pointer is null | |||
| auto build_info = kernel_info->select_kernel_build_info(); | |||
| @@ -741,7 +741,7 @@ KernelType AnfRuntimeAlgorithm::GetKernelType(const AnfNodePtr &node) { | |||
| kernel::Processor AnfRuntimeAlgorithm::GetProcessor(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto kernel_info = node->kernel_info(); | |||
| auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| auto build_info = kernel_info->select_kernel_build_info(); | |||
| MS_EXCEPTION_IF_NULL(build_info); | |||
| @@ -750,7 +750,7 @@ kernel::Processor AnfRuntimeAlgorithm::GetProcessor(const AnfNodePtr &node) { | |||
| kernel::FusionType AnfRuntimeAlgorithm::GetFusionType(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto kernel_info = node->kernel_info(); | |||
| auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| auto build_info = kernel_info->select_kernel_build_info(); | |||
| MS_EXCEPTION_IF_NULL(build_info); | |||
| @@ -760,7 +760,7 @@ kernel::FusionType AnfRuntimeAlgorithm::GetFusionType(const AnfNodePtr &node) { | |||
| // set select kernel_build_info | |||
| void AnfRuntimeAlgorithm::SetSelectKernelBuildInfo(const KernelBuildInfoPtr &select_kernel_build_info, AnfNode *node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto kernel_info = node->kernel_info(); | |||
| auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| return kernel_info->set_select_kernel_build_info(select_kernel_build_info); | |||
| } | |||
| @@ -768,7 +768,7 @@ void AnfRuntimeAlgorithm::SetSelectKernelBuildInfo(const KernelBuildInfoPtr &sel | |||
| // get select kernel_build_info | |||
| KernelBuildInfoPtr AnfRuntimeAlgorithm::GetSelectKernelBuildInfo(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto kernel_info = node->kernel_info(); | |||
| auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| return kernel_info->GetMutableSelectKernelBuildInfo(); | |||
| } | |||
| @@ -776,7 +776,7 @@ KernelBuildInfoPtr AnfRuntimeAlgorithm::GetSelectKernelBuildInfo(const AnfNodePt | |||
| // get kernelMode | |||
| KernelMod *AnfRuntimeAlgorithm::GetKernelMod(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto kernel_info = node->kernel_info(); | |||
| auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| return kernel_info->MutableKernelMod(); | |||
| } | |||
| @@ -784,7 +784,7 @@ KernelMod *AnfRuntimeAlgorithm::GetKernelMod(const AnfNodePtr &node) { | |||
| // set kernel mod | |||
| void AnfRuntimeAlgorithm::SetKernelMod(const KernelModPtr &kernel_mod, AnfNode *node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto kernel_info = node->kernel_info(); | |||
| auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| kernel_info->set_kernel_mod(kernel_mod); | |||
| } | |||
| @@ -850,42 +850,42 @@ bool AnfRuntimeAlgorithm::IsParameterWeight(const ParameterPtr &node) { | |||
| void AnfRuntimeAlgorithm::SetStreamId(uint32_t stream_id, AnfNode *node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto kernel_info = node->kernel_info(); | |||
| auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| kernel_info->set_stream_id(stream_id); | |||
| } | |||
| uint32_t AnfRuntimeAlgorithm::GetStreamId(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto kernel_info = node->kernel_info(); | |||
| auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| return kernel_info->stream_id(); | |||
| } | |||
| void AnfRuntimeAlgorithm::SetStreamDistinctionLabel(uint32_t stream_label, AnfNode *node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto kernel_info = node->kernel_info(); | |||
| auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| kernel_info->set_stream_distinction_label(stream_label); | |||
| } | |||
| uint32_t AnfRuntimeAlgorithm::GetStreamDistinctionLabel(const AnfNode *node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto kernel_info = node->kernel_info(); | |||
| auto kernel_info = dynamic_cast<const device::KernelInfo *>(node->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| return kernel_info->stream_distinction_label(); | |||
| } | |||
| void AnfRuntimeAlgorithm::SetGraphId(uint32_t graph_id, AnfNode *node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto kernel_info = node->kernel_info(); | |||
| auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| kernel_info->set_graph_id(graph_id); | |||
| } | |||
| uint32_t AnfRuntimeAlgorithm::GetGraphId(const AnfNode *node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto kernel_info = node->kernel_info(); | |||
| auto kernel_info = dynamic_cast<const device::KernelInfo *>(node->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| return kernel_info->graph_id(); | |||
| } | |||
| @@ -913,7 +913,7 @@ bool AnfRuntimeAlgorithm::IsFeatureMapOutput(const AnfNodePtr &node) { | |||
| if (node->isa<ValueNode>()) { | |||
| return false; | |||
| } | |||
| auto kernel_info = node->kernel_info(); | |||
| auto kernel_info = dynamic_cast<const device::KernelInfo *>(node->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| return kernel_info->is_feature_map(); | |||
| } | |||
| @@ -38,6 +38,8 @@ namespace mindspore { | |||
| namespace session { | |||
| using AnfVisitFuncion = std::function<Any(const AnfNodePtr &node, int index)>; | |||
| using KernelWithIndex = std::pair<AnfNodePtr, size_t>; | |||
| using DeviceAddress = device::DeviceAddress; | |||
| using DeviceAddressPtr = device::DeviceAddressPtr; | |||
| class AnfRuntimeAlgorithm { | |||
| public: | |||
| // get input_anf_node's real kernel by recurse | |||
| @@ -121,7 +121,7 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, | |||
| if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) { | |||
| auto pk_node = input_node->cast<ParameterPtr>(); | |||
| auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); | |||
| auto tensor_address = tensor->device_address(); | |||
| auto tensor_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()); | |||
| bool need_sync = false; | |||
| if (ms_context->enable_pynative_infer()) { | |||
| if (tensor_address == nullptr || tensor_address != device_address) { | |||
| @@ -230,13 +230,14 @@ ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph, | |||
| // set the kernel info of parameter | |||
| auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | |||
| MS_EXCEPTION_IF_NULL(input_tensor); | |||
| if (input_tensor->device_address().get() == nullptr) { | |||
| auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(input_tensor->device_address()); | |||
| if (device_address == nullptr) { | |||
| kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT}); | |||
| TypeId param_init_data_type = AnfAlgo::IsParameterWeight(param) ? kTypeUnknown : input_tensor->data_type(); | |||
| kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{param_init_data_type}); | |||
| } else { | |||
| kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{input_tensor->device_address()->format()}); | |||
| kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{input_tensor->device_address()->type_id()}); | |||
| kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{device_address->format()}); | |||
| kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{device_address->type_id()}); | |||
| } | |||
| AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get()); | |||
| // construct abstract of parameter | |||
| @@ -319,7 +320,7 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const | |||
| if (ref_real_node->isa<CNode>() && node_graph->IsInternalOutput(ref_real_node) && | |||
| node_graph->IsFinalOutputKernel(ref_real_node)) { | |||
| auto kernel_info = ref_real_node->kernel_info(); | |||
| if (kernel_info == nullptr || kernel_info->select_kernel_build_info() == nullptr) { | |||
| if (kernel_info == nullptr || !kernel_info->has_build_info()) { | |||
| MS_LOG(INFO) << "No kernel info"; | |||
| return; | |||
| } | |||
| @@ -330,9 +331,9 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const | |||
| } | |||
| auto format = AnfAlgo::GetOutputFormat(ref_real_node, ref_real_node_index); | |||
| auto type = AnfAlgo::GetOutputDeviceDataType(ref_real_node, ref_real_node_index); | |||
| parameter->set_kernel_info(std::make_shared<device::KernelInfo>()); | |||
| auto d_kernel_info = parameter->kernel_info(); | |||
| auto d_kernel_info = std::make_shared<device::KernelInfo>(); | |||
| MS_EXCEPTION_IF_NULL(d_kernel_info); | |||
| parameter->set_kernel_info(d_kernel_info); | |||
| kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; | |||
| builder.SetOutputsDeviceType({type}); | |||
| builder.SetOutputsFormat({format}); | |||
| @@ -128,7 +128,7 @@ void DumpKernelInfo(const CNodePtr &node, const std::shared_ptr<SubGraphIRInfo> | |||
| return; | |||
| } | |||
| auto kernel_info = node->kernel_info(); | |||
| if (kernel_info == nullptr || kernel_info->select_kernel_build_info() == nullptr) { | |||
| if (kernel_info == nullptr || !kernel_info->has_build_info()) { | |||
| return; | |||
| } | |||
| @@ -179,7 +179,7 @@ void DumpParams(const FuncGraphPtr &graph, std::ostringstream &buffer, OrderedMa | |||
| // print parameters' type and shape | |||
| PrintNodeOutputType(buffer, p); | |||
| auto kernel_info = p->kernel_info(); | |||
| if (kernel_info != nullptr && kernel_info->select_kernel_build_info() != nullptr) { | |||
| if (kernel_info != nullptr && kernel_info->has_build_info()) { | |||
| buffer << " : "; | |||
| auto type = AnfAlgo::GetOutputDeviceDataType(p, 0); | |||
| auto format = AnfAlgo::GetOutputFormat(p, 0); | |||
| @@ -362,8 +362,7 @@ void CheckFormatsAndDtypes(const CNodePtr &kernel_node, const std::vector<AnfNod | |||
| continue; | |||
| } | |||
| for (auto &node_user : iter->second) { | |||
| if (node_user.first->kernel_info() == nullptr || | |||
| node_user.first->kernel_info()->select_kernel_build_info() == nullptr) { | |||
| if (node_user.first->kernel_info() == nullptr || !node_user.first->kernel_info()->has_build_info()) { | |||
| // maybe not a real kernel. | |||
| continue; | |||
| } | |||
| @@ -21,8 +21,7 @@ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "ir/dtype.h" | |||
| using std::string; | |||
| #include "ir/device_sync.h" | |||
| namespace mindspore { | |||
| namespace device { | |||
| @@ -51,15 +50,12 @@ namespace device { | |||
| enum class DeviceAddressStatus { kInDevice, kInHost, kInDeviceToHost, kInHostToDevice }; | |||
| enum class DeviceAddressType { kUnknown, kAscend, kCPU, kGPU }; | |||
| class DeviceAddress { | |||
| class DeviceAddress : public mindspore::DeviceSync { | |||
| public: | |||
| explicit DeviceAddress(void *ptr, size_t size) : ptr_(ptr), size_(size) {} | |||
| explicit DeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id) | |||
| : ptr_(ptr), size_(size), format_(format), type_id_(type_id) {} | |||
| virtual ~DeviceAddress() { ptr_ = nullptr; } | |||
| virtual bool SyncDeviceToHost(const std::vector<int> &shape, size_t size, TypeId type, void *host_ptr) const = 0; | |||
| virtual bool SyncHostToDevice(const std::vector<int> &shape, size_t size, TypeId type, | |||
| const void *host_ptr) const = 0; | |||
| const void *GetPtr() const { return ptr_; } | |||
| size_t GetSize() const { return size_; } | |||
| std::string format() const { return format_; } | |||
| @@ -19,6 +19,7 @@ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "ir/kernel_info_dev.h" | |||
| #include "backend/kernel_compiler/kernel_build_info.h" | |||
| #include "runtime/device/ascend/ascend_device_address.h" | |||
| #include "backend/kernel_compiler/kernel.h" | |||
| @@ -27,7 +28,7 @@ namespace mindspore { | |||
| const uint32_t kInvalidGraphId = UINT32_MAX; | |||
| const uint32_t kInvalidDistincLabel = UINT32_MAX; | |||
| namespace device { | |||
| class KernelInfo { | |||
| class KernelInfo : public KernelInfoDevice { | |||
| public: | |||
| KernelInfo() { | |||
| kernel_mod_ = nullptr; | |||
| @@ -41,6 +42,7 @@ class KernelInfo { | |||
| } | |||
| virtual ~KernelInfo() = default; | |||
| bool has_build_info() const override { return select_kernel_build_info() != nullptr; } | |||
| const kernel::KernelBuildInfo *select_kernel_build_info() const; | |||
| kernel::KernelBuildInfoPtr GetMutableSelectKernelBuildInfo() const; | |||
| void set_select_kernel_build_info(const kernel::KernelBuildInfoPtr &select_kernel_build_info) { | |||
| @@ -214,8 +214,10 @@ void KernelRuntime::RunOpAssignInputMemory(const std::vector<tensor::TensorPtr> | |||
| auto output_size = AnfAlgo::GetOutputTensorNum(item); | |||
| for (size_t index = 0; index < output_size; index++) { | |||
| MS_EXCEPTION_IF_NULL(input_tensors[input_index]); | |||
| if (input_tensors[input_index]->device_address().get() != nullptr) { | |||
| AnfAlgo::SetOutputAddr(input_tensors[input_index]->device_address(), index, item.get()); | |||
| auto output_address = | |||
| std::dynamic_pointer_cast<device::DeviceAddress>(input_tensors[input_index]->device_address()); | |||
| if (output_address != nullptr) { | |||
| AnfAlgo::SetOutputAddr(output_address, index, item.get()); | |||
| continue; | |||
| } | |||
| TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index); | |||
| @@ -27,8 +27,9 @@ | |||
| #include <utility> | |||
| #include "base/base.h" | |||
| #include "debug/info.h" | |||
| #include "ir/kernel_info_dev.h" | |||
| #include "ir/scope.h" | |||
| #include "debug/info.h" | |||
| // A MindSpore ANF IR defined here. | |||
| // with BNF followed: | |||
| @@ -71,12 +72,6 @@ class BaseRef; | |||
| class Var; | |||
| using VarPtr = std::shared_ptr<Var>; | |||
| namespace device { | |||
| class KernelInfo; | |||
| } // namespace device | |||
| using KernelInfoDevice = device::KernelInfo; | |||
| using KernelInfoDevicePtr = std::shared_ptr<KernelInfoDevice>; | |||
| class AnfVisitor; | |||
| class ParamValue; | |||
| @@ -0,0 +1,38 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_IR_DEVICE_SYNC_H_ | |||
| #define MINDSPORE_CCSRC_IR_DEVICE_SYNC_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <string> | |||
| #include "ir/dtype/type.h" | |||
| using std::string; | |||
| namespace mindspore { | |||
| // Interface for data synchornize between device and host. | |||
| class DeviceSync { | |||
| public: | |||
| virtual bool SyncDeviceToHost(const std::vector<int> &shape, size_t size, TypeId type, void *host_ptr) const = 0; | |||
| virtual bool SyncHostToDevice(const std::vector<int> &shape, size_t size, TypeId type, | |||
| const void *host_ptr) const = 0; | |||
| }; | |||
| using DeviceSyncPtr = std::shared_ptr<DeviceSync>; | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_IR_DEVICE_SYNC_H_ | |||
| @@ -0,0 +1,32 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_IR_KERNEL_INFO_DEV_H_ | |||
| #define MINDSPORE_CCSRC_IR_KERNEL_INFO_DEV_H_ | |||
| #include <memory> | |||
| namespace mindspore { | |||
| // Interface for device kernel program information. | |||
| class KernelInfoDevice { | |||
| public: | |||
| // If kernel program was built and build info is set. | |||
| virtual bool has_build_info() const = 0; | |||
| }; | |||
| using KernelInfoDevicePtr = std::shared_ptr<KernelInfoDevice>; | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_IR_KERNEL_INFO_DEV_H_ | |||
| @@ -326,7 +326,7 @@ Tensor::Tensor(const Tensor &tensor) | |||
| data_(tensor.data_), | |||
| dirty_(tensor.dirty_), | |||
| id_(tensor.id_), | |||
| device_address_(tensor.device_address_) {} | |||
| device_sync_(tensor.device_sync_) {} | |||
| Tensor::Tensor(const Tensor &tensor, TypeId data_type) | |||
| : MetaTensor(data_type, tensor.shape_), | |||
| @@ -334,7 +334,7 @@ Tensor::Tensor(const Tensor &tensor, TypeId data_type) | |||
| data_(MakeTensorData(data_type, tensor.shape_, tensor.data_->data(), tensor.data_type_)), | |||
| dirty_(tensor.dirty_), | |||
| id_(tensor.id_), | |||
| device_address_(tensor.device_address_) {} | |||
| device_sync_(tensor.device_sync_) {} | |||
| Tensor::Tensor(TypeId data_type, const std::vector<int> &shape, TensorDataPtr data) | |||
| : MetaTensor(data_type, shape), data_(std::move(data)), id_(MakeId()) {} | |||
| @@ -379,10 +379,10 @@ bool Tensor::ValueEqual(const Tensor &tensor) const { | |||
| Tensor &Tensor::AssignValue(const Tensor &tensor) { | |||
| if (this != &tensor) { | |||
| MetaTensor::operator=(tensor); | |||
| dirty_ = tensor.is_dirty(); | |||
| device_address_ = tensor.device_address(); | |||
| dirty_ = tensor.dirty_; | |||
| device_sync_ = tensor.device_sync_; | |||
| data_ = tensor.data_; | |||
| id_ = tensor.id(); | |||
| id_ = tensor.id_; | |||
| } | |||
| return *this; | |||
| } | |||
| @@ -425,8 +425,8 @@ std::string Tensor::ToStringRepr() const { | |||
| } | |||
| void Tensor::data_sync() const { | |||
| if (device_address_ != nullptr) { | |||
| if (!device_address_->SyncDeviceToHost(shape(), static_cast<size_t>(data().nbytes()), data_type(), data_c())) { | |||
| if (device_sync_ != nullptr) { | |||
| if (!device_sync_->SyncDeviceToHost(shape(), static_cast<size_t>(data().nbytes()), data_type(), data_c())) { | |||
| MS_LOG(EXCEPTION) << "SyncDeviceToHost when asnumpy."; | |||
| } | |||
| } | |||
| @@ -23,15 +23,13 @@ | |||
| #include <numeric> | |||
| #include "Eigen/Core" | |||
| #include "runtime/device/device_address.h" | |||
| #include "ir/device_sync.h" | |||
| #include "ir/meta_tensor.h" | |||
| #include "include/ms_tensor.h" | |||
| #include "utils/log_adapter.h" | |||
| using float16 = Eigen::half; | |||
| using mindspore::device::DeviceAddress; | |||
| using DeviceAddressPtr = std::shared_ptr<mindspore::device::DeviceAddress>; | |||
| // brief mindspore namespace. | |||
| // | |||
| // mindspore namespace is the top level namespace of MindSpore project. | |||
| @@ -222,8 +220,8 @@ class Tensor : public MetaTensor { | |||
| bool is_dirty() const { return dirty_; } | |||
| void set_dirty(const bool dirty) { dirty_ = dirty; } | |||
| DeviceAddressPtr device_address() const { return device_address_; } | |||
| void set_device_address(const DeviceAddressPtr &device_address) { device_address_ = device_address; } | |||
| DeviceSyncPtr device_address() const { return device_sync_; } | |||
| void set_device_address(const DeviceSyncPtr &device_sync) { device_sync_ = device_sync; } | |||
| std::string id() const { return id_; } | |||
| @@ -234,7 +232,7 @@ class Tensor : public MetaTensor { | |||
| TensorDataPtr data_{nullptr}; | |||
| bool dirty_{true}; | |||
| std::string id_{""}; | |||
| DeviceAddressPtr device_address_{nullptr}; | |||
| DeviceSyncPtr device_sync_{nullptr}; | |||
| }; | |||
| using TensorPtr = std::shared_ptr<Tensor>; | |||
| using TensorPtrList = std::vector<std::shared_ptr<Tensor>>; | |||
| @@ -22,7 +22,6 @@ | |||
| #include <sstream> | |||
| #include <string> | |||
| #include "runtime/device/device_address.h" | |||
| #include "pybind_api/api_register.h" | |||
| #include "pybind_api/export_flags.h" | |||
| #include "abstract/abstract_value.h" | |||
| @@ -81,8 +81,6 @@ struct type_caster<float16> : public npy_scalar_caster<float16> { | |||
| } // namespace detail | |||
| } // namespace pybind11 | |||
| using mindspore::device::DeviceAddress; | |||
| using DeviceAddressPtr = std::shared_ptr<mindspore::device::DeviceAddress>; | |||
| // brief mindspore namespace. | |||
| // | |||
| // mindspore namespace is the top level namespace of Mindsporeession project. | |||
| @@ -255,7 +255,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputFormat) { | |||
| AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32, kNumberTypeFloat32}, {shape, shape}, add.get()); | |||
| MS_EXCEPTION_IF_NULL(add); | |||
| add->set_kernel_info(std::make_shared<KernelInfo>()); | |||
| auto d_kernel_info = add->kernel_info(); | |||
| auto d_kernel_info = dynamic_cast<KernelInfo *>(add->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(d_kernel_info); | |||
| KernelBuildInfoBuilder builder; | |||
| builder.SetOutputsDeviceType({kFloat32->type_id(), kFloat16->type_id()}); | |||
| @@ -274,7 +274,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetInputFormat) { | |||
| auto add = kernel_graph->NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(add); | |||
| add->set_kernel_info(std::make_shared<KernelInfo>()); | |||
| auto d_kernel_info = add->kernel_info(); | |||
| auto d_kernel_info = dynamic_cast<KernelInfo *>(add->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(d_kernel_info); | |||
| KernelBuildInfoBuilder builder; | |||
| builder.SetInputsDeviceType({kFloat32->type_id(), kFloat16->type_id()}); | |||
| @@ -293,7 +293,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetPrevNodeOutputFormat) { | |||
| auto pre_add = kernel_graph->NewCNode(pre_node_inputs); | |||
| MS_EXCEPTION_IF_NULL(pre_add); | |||
| pre_add->set_kernel_info(std::make_shared<KernelInfo>()); | |||
| auto d_kernel_info = pre_add->kernel_info(); | |||
| auto d_kernel_info = dynamic_cast<KernelInfo *>(pre_add->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(d_kernel_info); | |||
| KernelBuildInfoBuilder builder; | |||
| builder.SetOutputsDeviceType({kFloat32->type_id()}); | |||
| @@ -373,7 +373,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputDeviceShape) { | |||
| MS_EXCEPTION_IF_NULL(add); | |||
| add->set_abstract(tuple_abstract); | |||
| add->set_kernel_info(std::make_shared<KernelInfo>()); | |||
| auto d_kernel_info = add->kernel_info(); | |||
| auto d_kernel_info = dynamic_cast<KernelInfo *>(add->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(d_kernel_info); | |||
| KernelBuildInfoBuilder builder; | |||
| builder.SetOutputsFormat({kOpFormat_NCHW, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_FRAC_NZ}); | |||
| @@ -404,7 +404,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetInputDeviceShape) { | |||
| auto add = kernel_graph->NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(add); | |||
| add->set_kernel_info(std::make_shared<KernelInfo>()); | |||
| auto d_kernel_info = add->kernel_info(); | |||
| auto d_kernel_info = dynamic_cast<KernelInfo *>(add->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(d_kernel_info); | |||
| KernelBuildInfoBuilder builder; | |||
| builder.SetInputsFormat({kOpFormat_NCHW, kOpFormat_NCHW, kOpFormat_NHWC}); | |||
| @@ -457,7 +457,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputDeviceDataTypeTest) { | |||
| auto add = kernel_graph->NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(add); | |||
| add->set_kernel_info(std::make_shared<KernelInfo>()); | |||
| auto d_kernel_info = add->kernel_info(); | |||
| auto d_kernel_info = dynamic_cast<KernelInfo *>(add->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(d_kernel_info); | |||
| KernelBuildInfoBuilder builder; | |||
| builder.SetOutputsDeviceType({kFloat32->type_id()}); | |||
| @@ -474,7 +474,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetInputDeviceDataTypeTest) { | |||
| auto add = kernel_graph->NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(add); | |||
| add->set_kernel_info(std::make_shared<KernelInfo>()); | |||
| auto d_kernel_info = add->kernel_info(); | |||
| auto d_kernel_info = dynamic_cast<KernelInfo *>(add->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(d_kernel_info); | |||
| KernelBuildInfoBuilder builder; | |||
| builder.SetInputsDeviceType({kFloat32->type_id(), kFloat16->type_id()}); | |||
| @@ -492,7 +492,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetPrevNodeOutputDeviceDataType) { | |||
| auto pre_add = kernel_graph->NewCNode(pre_add_inputs); | |||
| MS_EXCEPTION_IF_NULL(pre_add); | |||
| pre_add->set_kernel_info(std::make_shared<KernelInfo>()); | |||
| auto d_kernel_info = pre_add->kernel_info(); | |||
| auto d_kernel_info = dynamic_cast<KernelInfo *>(pre_add->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(d_kernel_info); | |||
| KernelBuildInfoBuilder builder; | |||
| builder.SetOutputsDeviceType({kFloat32->type_id()}); | |||
| @@ -513,7 +513,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputAddr) { | |||
| auto add = kernel_graph->NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(add); | |||
| add->set_kernel_info(std::make_shared<KernelInfo>()); | |||
| auto d_kernel_info = add->kernel_info(); | |||
| auto d_kernel_info = dynamic_cast<KernelInfo *>(add->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(d_kernel_info); | |||
| int *addr = nullptr; | |||
| auto device_address = std::make_shared<AscendDeviceAddress>(addr, 1); | |||
| @@ -528,7 +528,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetPrevNodeOutputAddr) { | |||
| auto pre_add = kernel_graph->NewCNode(pre_add_inputs); | |||
| MS_EXCEPTION_IF_NULL(pre_add); | |||
| pre_add->set_kernel_info(std::make_shared<KernelInfo>()); | |||
| auto d_kernel_info = pre_add->kernel_info(); | |||
| auto d_kernel_info = dynamic_cast<KernelInfo *>(pre_add->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(d_kernel_info); | |||
| int *addr = nullptr; | |||
| auto device_address = std::make_shared<AscendDeviceAddress>(addr, 1); | |||
| @@ -561,7 +561,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetWorkspaceAddr) { | |||
| auto add = kernel_graph->NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(add); | |||
| add->set_kernel_info(std::make_shared<KernelInfo>()); | |||
| auto d_kernel_info = add->kernel_info(); | |||
| auto d_kernel_info = dynamic_cast<KernelInfo *>(add->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(d_kernel_info); | |||
| int *addr = nullptr; | |||
| auto device_address = std::make_shared<AscendDeviceAddress>(addr, 1); | |||
| @@ -643,7 +643,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetKernelType) { | |||
| auto add = kernel_graph->NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(add); | |||
| add->set_kernel_info(std::make_shared<KernelInfo>()); | |||
| auto d_kernel_info = add->kernel_info(); | |||
| auto d_kernel_info = dynamic_cast<KernelInfo *>(add->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(d_kernel_info); | |||
| KernelBuildInfoBuilder builder; | |||
| builder.SetKernelType(AKG_KERNEL); | |||
| @@ -659,7 +659,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetProcessor) { | |||
| auto add = kernel_graph->NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(add); | |||
| add->set_kernel_info(std::make_shared<KernelInfo>()); | |||
| auto d_kernel_info = add->kernel_info(); | |||
| auto d_kernel_info = dynamic_cast<KernelInfo *>(add->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(d_kernel_info); | |||
| KernelBuildInfoBuilder builder; | |||
| builder.SetProcessor(kernel::AICORE); | |||
| @@ -675,7 +675,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetFusionType) { | |||
| auto add = kernel_graph->NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(add); | |||
| add->set_kernel_info(std::make_shared<KernelInfo>()); | |||
| auto d_kernel_info = add->kernel_info(); | |||
| auto d_kernel_info = dynamic_cast<KernelInfo *>(add->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(d_kernel_info); | |||
| KernelBuildInfoBuilder builder; | |||
| builder.SetFusionType(kernel::CONVLUTION); | |||
| @@ -703,7 +703,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetKernelMod) { | |||
| auto add = kernel_graph->NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(add); | |||
| add->set_kernel_info(std::make_shared<KernelInfo>()); | |||
| auto d_kernel_info = add->kernel_info(); | |||
| auto d_kernel_info = dynamic_cast<KernelInfo *>(add->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(d_kernel_info); | |||
| d_kernel_info->set_kernel_mod(nullptr); | |||
| EXPECT_EQ(AnfAlgo::GetKernelMod(add), nullptr); | |||
| @@ -779,7 +779,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetStreamId) { | |||
| auto add = kernel_graph->NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(add); | |||
| add->set_kernel_info(std::make_shared<KernelInfo>()); | |||
| auto d_kernel_info = add->kernel_info(); | |||
| auto d_kernel_info = dynamic_cast<KernelInfo *>(add->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(d_kernel_info); | |||
| d_kernel_info->set_stream_id(0); | |||
| EXPECT_EQ(AnfAlgo::GetStreamId(add), 0); | |||
| @@ -42,7 +42,7 @@ TEST_F(KernelGraphTest, NewValueNode) { | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shape); | |||
| add_value->set_abstract(x_abstract); | |||
| add_value->set_kernel_info(std::make_shared<KernelInfo>()); | |||
| auto mutable_kernel_info = add_value->kernel_info(); | |||
| auto mutable_kernel_info = dynamic_cast<device::KernelInfo *>(add_value->kernel_info()); | |||
| MS_EXCEPTION_IF_NULL(mutable_kernel_info); | |||
| std::shared_ptr<KernelBuildInfoBuilder> builder = std::make_shared<KernelBuildInfoBuilder>(); | |||
| builder->SetOutputsFormat({kOpFormat_FRAC_Z}); | |||