| @@ -23,7 +23,7 @@ namespace mindspore { | |||||
| namespace kernel { | namespace kernel { | ||||
| namespace { | namespace { | ||||
| template <typename T> | template <typename T> | ||||
| void LookUpTableTask(const float *input_addr, const T *indices_addr, float *output_addr, float *output_max_addr, | |||||
| void LookUpTableTask(const float *input_addr, const T *indices_addr, const float *output_max_addr, float *output_addr, | |||||
| size_t indices_lens, size_t outer_dim_size, T offset, size_t first_dim_size) { | size_t indices_lens, size_t outer_dim_size, T offset, size_t first_dim_size) { | ||||
| size_t lens = outer_dim_size * sizeof(float); | size_t lens = outer_dim_size * sizeof(float); | ||||
| for (size_t i = 0; i < indices_lens; ++i) { | for (size_t i = 0; i < indices_lens; ++i) { | ||||
| @@ -82,9 +82,9 @@ void EmbeddingLookUpCPUKernel::LaunchKernel(const std::vector<kernel::AddressPtr | |||||
| break; | break; | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "task_offset: " << task_offset << " task_proc_lenss:" << task_proc_lens; | MS_LOG(DEBUG) << "task_offset: " << task_offset << " task_proc_lenss:" << task_proc_lens; | ||||
| threads[i] = std::thread(LookUpTableTask<T>, input_addr, indices_addr + task_offset, | |||||
| output_addr + task_offset * outer_dim_size_, output_addr + outputs[0]->size, | |||||
| task_proc_lens, outer_dim_size_, offset_, first_dim_size_); | |||||
| threads[i] = std::thread(LookUpTableTask<T>, input_addr, indices_addr + task_offset, output_addr + outputs[0]->size, | |||||
| output_addr + task_offset * outer_dim_size_, task_proc_lens, outer_dim_size_, offset_, | |||||
| first_dim_size_); | |||||
| task_offset += task_proc_lens; | task_offset += task_proc_lens; | ||||
| if (task_offset + task_proc_lens > indices_lens_) { | if (task_offset + task_proc_lens > indices_lens_) { | ||||
| task_proc_lens = indices_lens_ - task_offset; | task_proc_lens = indices_lens_ - task_offset; | ||||
| @@ -30,12 +30,10 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltw | |||||
| const CNodePtr &cnode, const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) { | const CNodePtr &cnode, const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) { | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| MS_EXCEPTION_IF_NULL(candidate_fusion); | MS_EXCEPTION_IF_NULL(candidate_fusion); | ||||
| auto manager = kernel_graph.manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| std::unordered_set<AnfNodePtr> record{cnode}; | std::unordered_set<AnfNodePtr> record{cnode}; | ||||
| auto eltwise_input = cnode->input(1); | auto eltwise_input = cnode->input(1); | ||||
| MS_EXCEPTION_IF_NULL(eltwise_input); | MS_EXCEPTION_IF_NULL(eltwise_input); | ||||
| if (CheckDoubleInEltWiseNode(manager.get(), eltwise_input)) { | |||||
| if (CheckDoubleInEltWiseNode(kernel_graph, eltwise_input)) { | |||||
| (void)record.insert(eltwise_input); | (void)record.insert(eltwise_input); | ||||
| } else { | } else { | ||||
| return; | return; | ||||
| @@ -30,12 +30,10 @@ void ConvDoubleInFusionPass::MatchConvDoubleInEltwise(const CNodePtr &cnode, con | |||||
| FusedNodeRecord *candidate_fusion) { | FusedNodeRecord *candidate_fusion) { | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| MS_EXCEPTION_IF_NULL(candidate_fusion); | MS_EXCEPTION_IF_NULL(candidate_fusion); | ||||
| auto manager = kernel_graph.manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| std::unordered_set<AnfNodePtr> record{cnode}; | std::unordered_set<AnfNodePtr> record{cnode}; | ||||
| auto eltwise_input = cnode->input(1); | auto eltwise_input = cnode->input(1); | ||||
| MS_EXCEPTION_IF_NULL(eltwise_input); | MS_EXCEPTION_IF_NULL(eltwise_input); | ||||
| if (CheckDoubleInEltWiseNode(manager.get(), eltwise_input)) { | |||||
| if (CheckDoubleInEltWiseNode(kernel_graph, eltwise_input)) { | |||||
| (void)record.insert(eltwise_input); | (void)record.insert(eltwise_input); | ||||
| } else { | } else { | ||||
| return; | return; | ||||
| @@ -30,11 +30,9 @@ void ConvSingleInFusionPass::MatchConvSingleInEltwise(const CNodePtr &cnode, con | |||||
| FusedNodeRecord *candidate_fusion) { | FusedNodeRecord *candidate_fusion) { | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| MS_EXCEPTION_IF_NULL(candidate_fusion); | MS_EXCEPTION_IF_NULL(candidate_fusion); | ||||
| auto manager = kernel_graph.manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| std::unordered_set<AnfNodePtr> record{cnode}; | std::unordered_set<AnfNodePtr> record{cnode}; | ||||
| auto eltwise_input = cnode->input(1); | auto eltwise_input = cnode->input(1); | ||||
| while (CheckEltWiseNode(manager.get(), eltwise_input)) { | |||||
| while (CheckEltWiseNode(kernel_graph, eltwise_input)) { | |||||
| (void)record.insert(eltwise_input); | (void)record.insert(eltwise_input); | ||||
| auto input_cnode = eltwise_input->cast<CNodePtr>(); | auto input_cnode = eltwise_input->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(input_cnode); | MS_EXCEPTION_IF_NULL(input_cnode); | ||||
| @@ -30,12 +30,10 @@ void EltwiseFusionPass::MatchEltwise(const CNodePtr &cnode, const session::Kerne | |||||
| FusedNodeRecord *candidate_fusion) { | FusedNodeRecord *candidate_fusion) { | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| MS_EXCEPTION_IF_NULL(candidate_fusion); | MS_EXCEPTION_IF_NULL(candidate_fusion); | ||||
| auto manager = kernel_graph.manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| std::unordered_set<AnfNodePtr> record{cnode}; | std::unordered_set<AnfNodePtr> record{cnode}; | ||||
| auto eltwise_input = cnode->input(1); | auto eltwise_input = cnode->input(1); | ||||
| MS_EXCEPTION_IF_NULL(eltwise_input); | MS_EXCEPTION_IF_NULL(eltwise_input); | ||||
| while (CheckEltWiseNode(manager.get(), eltwise_input)) { | |||||
| while (CheckEltWiseNode(kernel_graph, eltwise_input)) { | |||||
| (void)record.insert(eltwise_input); | (void)record.insert(eltwise_input); | ||||
| if (record.size() == MAX_ELTWISE_SIZE) { | if (record.size() == MAX_ELTWISE_SIZE) { | ||||
| break; | break; | ||||
| @@ -23,7 +23,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| bool FusionBasePass::CheckEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node) { | |||||
| bool FusionBasePass::CheckEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node) { | |||||
| auto manager = kernel_graph.manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | MS_EXCEPTION_IF_NULL(manager); | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (!node->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) { | if (!node->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) { | ||||
| @@ -37,7 +38,8 @@ bool FusionBasePass::CheckEltWiseNode(FuncGraphManager *manager, const AnfNodePt | |||||
| cnode->inputs().size() == ELTWISE_INPUT_SIZE; | cnode->inputs().size() == ELTWISE_INPUT_SIZE; | ||||
| } | } | ||||
| bool FusionBasePass::CheckDoubleInEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node) { | |||||
| bool FusionBasePass::CheckDoubleInEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node) { | |||||
| auto manager = kernel_graph.manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | MS_EXCEPTION_IF_NULL(manager); | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (!node->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) { | if (!node->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) { | ||||
| @@ -51,7 +53,8 @@ bool FusionBasePass::CheckDoubleInEltWiseNode(FuncGraphManager *manager, const A | |||||
| cnode->inputs().size() == ELTWISE_DOUBLE_IN_INPUT_SIZE; | cnode->inputs().size() == ELTWISE_DOUBLE_IN_INPUT_SIZE; | ||||
| } | } | ||||
| bool FusionBasePass::CheckMultiOutputEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node) { | |||||
| bool FusionBasePass::CheckMultiOutputEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node) { | |||||
| auto manager = kernel_graph.manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | MS_EXCEPTION_IF_NULL(manager); | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (!node->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) { | if (!node->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) { | ||||
| @@ -61,9 +61,9 @@ class FusionBasePass : public Pass { | |||||
| virtual void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, | virtual void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, | ||||
| FusedNodeRecord *candidate_fusion) = 0; | FusedNodeRecord *candidate_fusion) = 0; | ||||
| void SetRecordFusionId(const std::unordered_set<AnfNodePtr> &record); | void SetRecordFusionId(const std::unordered_set<AnfNodePtr> &record); | ||||
| bool CheckEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node); | |||||
| bool CheckDoubleInEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node); | |||||
| bool CheckMultiOutputEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node); | |||||
| bool CheckEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node); | |||||
| bool CheckDoubleInEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node); | |||||
| bool CheckMultiOutputEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node); | |||||
| FusionIdAllocatorPtr fusion_id_allocator; | FusionIdAllocatorPtr fusion_id_allocator; | ||||
| }; | }; | ||||
| } // namespace opt | } // namespace opt | ||||
| @@ -35,7 +35,7 @@ void MultiOutputFusionPass::MatchMultiOutputEltwise(const CNodePtr &cnode, const | |||||
| std::unordered_set<AnfNodePtr> record{cnode}; | std::unordered_set<AnfNodePtr> record{cnode}; | ||||
| auto eltwise_input = cnode->input(1); | auto eltwise_input = cnode->input(1); | ||||
| MS_EXCEPTION_IF_NULL(eltwise_input); | MS_EXCEPTION_IF_NULL(eltwise_input); | ||||
| if (CheckMultiOutputEltWiseNode(manager.get(), eltwise_input)) { | |||||
| if (CheckMultiOutputEltWiseNode(kernel_graph, eltwise_input)) { | |||||
| std::vector<int> output_used_num{SizeToInt(manager->node_users()[eltwise_input].size())}; | std::vector<int> output_used_num{SizeToInt(manager->node_users()[eltwise_input].size())}; | ||||
| AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), eltwise_input); | AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), eltwise_input); | ||||
| (void)record.insert(eltwise_input); | (void)record.insert(eltwise_input); | ||||
| @@ -45,7 +45,7 @@ void MultiOutputFusionPass::MatchMultiOutputEltwise(const CNodePtr &cnode, const | |||||
| } else { | } else { | ||||
| return; | return; | ||||
| } | } | ||||
| while (CheckEltWiseNode(manager.get(), eltwise_input)) { | |||||
| while (CheckEltWiseNode(kernel_graph, eltwise_input)) { | |||||
| (void)record.insert(eltwise_input); | (void)record.insert(eltwise_input); | ||||
| if (record.size() == MULTI_ELTWISE_SIZE) { | if (record.size() == MULTI_ELTWISE_SIZE) { | ||||
| break; | break; | ||||
| @@ -31,11 +31,9 @@ void ReduceEltwiseFusionPass::MatchReduceEltwise(const CNodePtr &cnode, const se | |||||
| FusedNodeRecord *candidate_fusion) { | FusedNodeRecord *candidate_fusion) { | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| MS_EXCEPTION_IF_NULL(candidate_fusion); | MS_EXCEPTION_IF_NULL(candidate_fusion); | ||||
| auto manager = kernel_graph.manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| std::unordered_set<AnfNodePtr> record{cnode}; | std::unordered_set<AnfNodePtr> record{cnode}; | ||||
| auto eltwise_input = cnode->input(1); | auto eltwise_input = cnode->input(1); | ||||
| while (CheckEltWiseNode(manager.get(), eltwise_input)) { | |||||
| while (CheckEltWiseNode(kernel_graph, eltwise_input)) { | |||||
| (void)record.insert(eltwise_input); | (void)record.insert(eltwise_input); | ||||
| auto input_cnode = eltwise_input->cast<CNodePtr>(); | auto input_cnode = eltwise_input->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(input_cnode); | MS_EXCEPTION_IF_NULL(input_cnode); | ||||
| @@ -56,7 +54,7 @@ void ReduceEltwiseFusionPass::MatchReduceEltwise(const CNodePtr &cnode, const se | |||||
| MS_EXCEPTION_IF_NULL(previous_input_cnode); | MS_EXCEPTION_IF_NULL(previous_input_cnode); | ||||
| auto previous_eltwise_input = previous_input_cnode->input(1); | auto previous_eltwise_input = previous_input_cnode->input(1); | ||||
| auto previous_size = record.size(); | auto previous_size = record.size(); | ||||
| while (CheckEltWiseNode(manager.get(), previous_eltwise_input)) { | |||||
| while (CheckEltWiseNode(kernel_graph, previous_eltwise_input)) { | |||||
| (void)record.insert(previous_eltwise_input); | (void)record.insert(previous_eltwise_input); | ||||
| auto previous_node = previous_eltwise_input->cast<CNodePtr>(); | auto previous_node = previous_eltwise_input->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(previous_node); | MS_EXCEPTION_IF_NULL(previous_node); | ||||
| @@ -30,11 +30,9 @@ void SegmentEltwiseFusionPass::MatchSegmentEltwise(const CNodePtr &cnode, const | |||||
| FusedNodeRecord *candidate_fusion) { | FusedNodeRecord *candidate_fusion) { | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| MS_EXCEPTION_IF_NULL(candidate_fusion); | MS_EXCEPTION_IF_NULL(candidate_fusion); | ||||
| auto manager = kernel_graph.manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| std::unordered_set<AnfNodePtr> record{cnode}; | std::unordered_set<AnfNodePtr> record{cnode}; | ||||
| auto eltwise_input = cnode->input(1); | auto eltwise_input = cnode->input(1); | ||||
| while (CheckEltWiseNode(manager.get(), eltwise_input)) { | |||||
| while (CheckEltWiseNode(kernel_graph, eltwise_input)) { | |||||
| (void)record.insert(eltwise_input); | (void)record.insert(eltwise_input); | ||||
| auto input_cnode = eltwise_input->cast<CNodePtr>(); | auto input_cnode = eltwise_input->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(input_cnode); | MS_EXCEPTION_IF_NULL(input_cnode); | ||||
| @@ -55,7 +53,7 @@ void SegmentEltwiseFusionPass::MatchSegmentEltwise(const CNodePtr &cnode, const | |||||
| MS_EXCEPTION_IF_NULL(previous_input_cnode); | MS_EXCEPTION_IF_NULL(previous_input_cnode); | ||||
| auto previous_eltwise_input = previous_input_cnode->input(1); | auto previous_eltwise_input = previous_input_cnode->input(1); | ||||
| auto previous_size = record.size(); | auto previous_size = record.size(); | ||||
| while (CheckEltWiseNode(manager.get(), previous_eltwise_input)) { | |||||
| while (CheckEltWiseNode(kernel_graph, previous_eltwise_input)) { | |||||
| (void)record.insert(previous_eltwise_input); | (void)record.insert(previous_eltwise_input); | ||||
| auto previous_node = previous_eltwise_input->cast<CNodePtr>(); | auto previous_node = previous_eltwise_input->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(previous_node); | MS_EXCEPTION_IF_NULL(previous_node); | ||||
| @@ -33,11 +33,9 @@ void StridedReadConvStridedWriteFusionPass::MatchStridedReadConvStridedWrite(con | |||||
| FusedNodeRecord *candidate_fusion) { | FusedNodeRecord *candidate_fusion) { | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| MS_EXCEPTION_IF_NULL(candidate_fusion); | MS_EXCEPTION_IF_NULL(candidate_fusion); | ||||
| auto manager = kernel_graph.manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| std::unordered_set<AnfNodePtr> record{cnode}; | std::unordered_set<AnfNodePtr> record{cnode}; | ||||
| auto write_input = cnode->input(1); | auto write_input = cnode->input(1); | ||||
| if (CheckEltWiseNode(manager.get(), write_input)) { | |||||
| if (CheckEltWiseNode(kernel_graph, write_input)) { | |||||
| (void)record.insert(write_input); | (void)record.insert(write_input); | ||||
| auto input_cnode = write_input->cast<CNodePtr>(); | auto input_cnode = write_input->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(input_cnode); | MS_EXCEPTION_IF_NULL(input_cnode); | ||||