|
|
|
@@ -31,6 +31,8 @@ namespace session { |
|
|
|
namespace { |
|
|
|
constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput"; |
|
|
|
constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList"; |
|
|
|
const std::set<std::string> kOpAssignKernelNameList = {prim::kPrimAssign->name(), prim::kPrimAssignAdd->name(), |
|
|
|
prim::kPrimAssignSub->name()}; |
|
|
|
void PushNoVisitedNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que, |
|
|
|
std::unordered_set<AnfNodePtr> *visited_nodes) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
@@ -417,21 +419,41 @@ void KernelGraph::CreateKernelInfoFromNewParameter(const CNodePtr &cnode) { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void KernelGraph::ResetAssignInputFeaatureMapFlag(const CNodePtr &cnode) const { |
|
|
|
if (kOpAssignKernelNameList.find(AnfAlgo::GetCNodeName(cnode)) == kOpAssignKernelNameList.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "Only supported to change the node [Assign , AssignSub, AssignAdd] node's input feature map " |
|
|
|
"flag but got the node :" |
|
|
|
<< cnode->DebugString(); |
|
|
|
} |
|
|
|
auto input_node = AnfAlgo::GetInputNode(cnode, 0); |
|
|
|
auto assign_value_node = AnfAlgo::GetInputNode(cnode, 1); |
|
|
|
if (AnfAlgo::IsFeatureMapOutput(input_node)) { |
|
|
|
return; |
|
|
|
} |
|
|
|
if (!AnfAlgo::IsFeatureMapOutput(input_node) && AnfAlgo::IsFeatureMapOutput(assign_value_node)) { |
|
|
|
auto kernel_info = static_cast<device::KernelInfo *>(input_node->kernel_info()); |
|
|
|
kernel_info->set_feature_map_flag(true); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto kernel_info = std::make_shared<device::KernelInfo>(); |
|
|
|
node->set_kernel_info(kernel_info); |
|
|
|
if (node->isa<CNode>()) { |
|
|
|
if (kOpAssignKernelNameList.find(AnfAlgo::GetCNodeName(node)) != kOpAssignKernelNameList.end()) { |
|
|
|
ResetAssignInputFeaatureMapFlag(node->cast<CNodePtr>()); |
|
|
|
} |
|
|
|
std::vector<size_t> feature_map_input_indexs; |
|
|
|
kernel_info->SetFeatureMapFlag(false); |
|
|
|
kernel_info->set_feature_map_flag(false); |
|
|
|
for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(node); ++index) { |
|
|
|
if (AnfAlgo::IsFeatureMapInput(node, index)) { |
|
|
|
kernel_info->SetFeatureMapFlag(true); |
|
|
|
kernel_info->set_feature_map_flag(true); |
|
|
|
feature_map_input_indexs.push_back(index); |
|
|
|
} |
|
|
|
} |
|
|
|
if (AnfAlgo::GetInputTensorNum(node) == 0) { |
|
|
|
kernel_info->SetFeatureMapFlag(true); |
|
|
|
kernel_info->set_feature_map_flag(true); |
|
|
|
} |
|
|
|
if (AnfAlgo::IsRealKernel(node)) { |
|
|
|
// if the node only has the primitive(such as getNext) or the node's input has a feature map input |
|
|
|
@@ -446,7 +468,7 @@ void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const { |
|
|
|
std::vector<TypeId> types; |
|
|
|
std::vector<std::string> formats = {kOpFormat_DEFAULT}; |
|
|
|
if (node->isa<ValueNode>()) { |
|
|
|
kernel_info->SetFeatureMapFlag(false); |
|
|
|
kernel_info->set_feature_map_flag(false); |
|
|
|
types.emplace_back(kTypeUnknown); |
|
|
|
auto value_node = node->cast<ValueNodePtr>(); |
|
|
|
SyncDeviceInfoToValueNode(value_node, &formats, &types); |
|
|
|
@@ -455,7 +477,7 @@ void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const { |
|
|
|
auto parameter = node->cast<ParameterPtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(parameter); |
|
|
|
bool is_weight = AnfAlgo ::IsParameterWeight(parameter); |
|
|
|
kernel_info->SetFeatureMapFlag(!is_weight); |
|
|
|
kernel_info->set_feature_map_flag(!is_weight); |
|
|
|
types.push_back(is_weight ? kTypeUnknown : AnfAlgo::GetOutputInferDataType(parameter, 0)); |
|
|
|
} |
|
|
|
// set parameter initaial device data type |
|
|
|
|