| @@ -1176,9 +1176,12 @@ AnfNodePtr AnfRuntimeAlgorithm::GetInputNode(const CNodePtr &node, size_t index) | |||||
| bool AnfRuntimeAlgorithm::IsFeatureMapOutput(const AnfNodePtr &node) { | bool AnfRuntimeAlgorithm::IsFeatureMapOutput(const AnfNodePtr &node) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (node->isa<ValueNode>() || IsPrimitiveCNode(node, prim::kPrimLoad)) { | |||||
| if (node->isa<ValueNode>()) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| if (IsPrimitiveCNode(node, prim::kPrimLoad)) { | |||||
| return IsFeatureMapOutput(node->cast<CNodePtr>()->input(1)); | |||||
| } | |||||
| auto kernel_info = static_cast<const device::KernelInfo *>(node->kernel_info()); | auto kernel_info = static_cast<const device::KernelInfo *>(node->kernel_info()); | ||||
| MS_EXCEPTION_IF_NULL(kernel_info); | MS_EXCEPTION_IF_NULL(kernel_info); | ||||
| return kernel_info->is_feature_map(); | return kernel_info->is_feature_map(); | ||||
| @@ -1220,13 +1220,9 @@ class AutoMonadConverter { | |||||
| } | } | ||||
| CNodePtr MakeLoad(const CNodePtr &cnode, const AnfNodePtr &ref, const AnfNodePtr &u) { | CNodePtr MakeLoad(const CNodePtr &cnode, const AnfNodePtr &ref, const AnfNodePtr &u) { | ||||
| static const std::string primitive_target = "primitive_target"; | |||||
| // Create Load cnode. | // Create Load cnode. | ||||
| auto load_prim = NewValueNode(prim::kPrimLoad); | auto load_prim = NewValueNode(prim::kPrimLoad); | ||||
| auto load_cnode = func_graph_->NewCNode({load_prim, ref, u}); | auto load_cnode = func_graph_->NewCNode({load_prim, ref, u}); | ||||
| // Set device target for Load CNode. | |||||
| std::string target = GetCNodeTarget(cnode); | |||||
| load_cnode->set_user_data(primitive_target, std::make_shared<std::string>(target)); | |||||
| // Set load_cnode abstract to Tensor according the input Ref[Tensor]. | // Set load_cnode abstract to Tensor according the input Ref[Tensor]. | ||||
| auto ref_abs = dyn_cast<abstract::AbstractRef>(ref->abstract()); | auto ref_abs = dyn_cast<abstract::AbstractRef>(ref->abstract()); | ||||
| MS_EXCEPTION_IF_NULL(ref_abs); | MS_EXCEPTION_IF_NULL(ref_abs); | ||||
| @@ -454,6 +454,8 @@ std::string GetCNodeTarget(const AnfNodePtr &node) { | |||||
| if (inputs.size() == 3 && !IsPrimitiveCNode(inputs[1], prim::kPrimMakeTuple)) { | if (inputs.size() == 3 && !IsPrimitiveCNode(inputs[1], prim::kPrimMakeTuple)) { | ||||
| return GetCNodeTarget(inputs[1]); | return GetCNodeTarget(inputs[1]); | ||||
| } | } | ||||
| } else if (IsPrimitiveCNode(node, prim::kPrimLoad)) { | |||||
| return GetCNodeTarget(cnode->input(1)); | |||||
| } else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { | } else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { | ||||
| return GetMaketupleNodeTarget(cnode); | return GetMaketupleNodeTarget(cnode); | ||||
| } else if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { | } else if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { | ||||