From: @zhaosida_hw Reviewed-by: @zhoufeng54,@jjfeing Signed-off-by: @jjfeingpull/14451/MERGE
| @@ -567,59 +567,46 @@ void Somas::InitAtomicCleanInputs(bool is_all_nop_node, const CNodePtr &kernel) | |||
| auto stream = node->GetStream(); | |||
| MS_EXCEPTION_IF_NULL(stream); | |||
| MS_EXCEPTION_IF_NULL(kernel->inputs()[1]); | |||
| auto pre_node = (kernel->inputs()[1])->cast<CNodePtr>(); | |||
| auto iter = nodes_map_.find(pre_node.get()); | |||
| if (iter == nodes_map_.end()) { | |||
| MS_LOG(EXCEPTION) << "Kernel[" << kernel->fullname_with_scope() << "]'s input [" << pre_node->fullname_with_scope() | |||
| << "] is not init."; | |||
| } | |||
| auto pre_somas_node = iter->second; | |||
| // set clean output tensors | |||
| if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) { | |||
| auto clean_output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicOutputIndexs); | |||
| for (auto index : clean_output_indexs) { | |||
| if (index > pre_somas_node->output_tensors_.size()) { | |||
| MS_LOG(EXCEPTION) << "Output index " << index << " exceed input node [" << pre_node->fullname_with_scope() | |||
| << "]'s outputs size " << pre_somas_node->output_tensors_.size(); | |||
| } | |||
| auto input_somas_tensor = pre_somas_node->output_tensors_[index]; | |||
| MS_EXCEPTION_IF_NULL(input_somas_tensor); | |||
| node->input_tensors_.push_back(input_somas_tensor); | |||
| input_somas_tensor->destinations_.insert(node); | |||
| input_somas_tensor->destinationStreams_.insert(stream); | |||
| if (input_somas_tensor->lifetime_.start_ > node->GetId()) { | |||
| input_somas_tensor->lifetime_.start_ = node->GetId(); | |||
| } | |||
| node->ancestor_nodes_.insert(pre_somas_node); | |||
| auto input_tensor_stream = input_somas_tensor->GetSourceStream(); | |||
| if (input_tensor_stream != stream) { | |||
| stream->ancestor_streams_.insert(input_tensor_stream); | |||
| input_somas_tensor->between_streams_ = true; | |||
| } | |||
| auto input_tensor_num = AnfAlgo::GetInputTensorNum(kernel); | |||
| for (size_t i = 0; i < input_tensor_num; i++) { | |||
| MS_EXCEPTION_IF_NULL(kernel->inputs()[i + 1]); | |||
| auto pre_node = kernel->input(i + 1)->cast<CNodePtr>(); | |||
| auto iter = nodes_map_.find(pre_node.get()); | |||
| if (iter == nodes_map_.end()) { | |||
| MS_LOG(EXCEPTION) << "Kernel[" << kernel->fullname_with_scope() << "]'s input [" | |||
| << pre_node->fullname_with_scope() << "] is not init."; | |||
| } | |||
| } | |||
| // set clean workspace tensors | |||
| if (AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) { | |||
| auto clean_workspace_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicWorkspaceIndexs); | |||
| for (const auto &index : clean_workspace_indexs) { | |||
| if (index > pre_somas_node->output_tensors_.size()) { | |||
| MS_LOG(EXCEPTION) << "Workspace index " << index << " exceed input node [" << pre_node->fullname_with_scope() | |||
| << "]'s Workspace size " << pre_somas_node->workspace_tensors_.size(); | |||
| } | |||
| auto input_somas_tensor = pre_somas_node->workspace_tensors_[index]; | |||
| MS_EXCEPTION_IF_NULL(input_somas_tensor); | |||
| node->input_tensors_.push_back(input_somas_tensor); | |||
| input_somas_tensor->destinations_.insert(node); | |||
| input_somas_tensor->destinationStreams_.insert(stream); | |||
| if (input_somas_tensor->lifetime_.start_ > node->GetId()) { | |||
| input_somas_tensor->lifetime_.start_ = node->GetId(); | |||
| auto pre_somas_node = iter->second; | |||
| // set clean output tensors | |||
| if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) { | |||
| auto clean_output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicOutputIndexs); | |||
| for (auto index : clean_output_indexs) { | |||
| if (index > pre_somas_node->output_tensors_.size()) { | |||
| MS_LOG(EXCEPTION) << "Output index " << index << " exceed input node [" << pre_node->fullname_with_scope() | |||
| << "]'s outputs size " << pre_somas_node->output_tensors_.size(); | |||
| } | |||
| auto input_somas_tensor = pre_somas_node->output_tensors_[index]; | |||
| MS_EXCEPTION_IF_NULL(input_somas_tensor); | |||
| node->input_tensors_.push_back(input_somas_tensor); | |||
| input_somas_tensor->lifelong_value_ = kLifeLongGraphAll; | |||
| MS_LOG(INFO) << "Set " << node->scope_full_name_ << "'s Input node " << pre_somas_node->scope_full_name_ | |||
| << " 's output" << index << " to lifelong"; | |||
| } | |||
| node->ancestor_nodes_.insert(pre_somas_node); | |||
| auto input_tensor_stream = input_somas_tensor->GetSourceStream(); | |||
| if (input_tensor_stream != stream) { | |||
| stream->ancestor_streams_.insert(input_tensor_stream); | |||
| input_somas_tensor->between_streams_ = true; | |||
| } | |||
| // set clean workspace tensors | |||
| if (AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) { | |||
| auto clean_workspace_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicWorkspaceIndexs); | |||
| for (const auto &index : clean_workspace_indexs) { | |||
| if (index > pre_somas_node->output_tensors_.size()) { | |||
| MS_LOG(EXCEPTION) << "Workspace index " << index << " exceed input node [" << pre_node->fullname_with_scope() | |||
| << "]'s Workspace size " << pre_somas_node->workspace_tensors_.size(); | |||
| } | |||
| auto input_somas_tensor = pre_somas_node->workspace_tensors_[index]; | |||
| MS_EXCEPTION_IF_NULL(input_somas_tensor); | |||
| node->input_tensors_.push_back(input_somas_tensor); | |||
| input_somas_tensor->lifelong_value_ = kLifeLongGraphAll; | |||
| MS_LOG(INFO) << "Set " << node->scope_full_name_ << "'s Input node " << pre_somas_node->scope_full_name_ | |||
| << " 's workspace" << index << " to lifelong"; | |||
| } | |||
| } | |||
| } | |||
| @@ -40,6 +40,8 @@ namespace device { | |||
| namespace ascend { | |||
| using mindspore::kernel::tbe::TbeUtils; | |||
| using std::make_shared; | |||
| constexpr size_t kMaxAttrMemListSize = 192; | |||
| static kernel::KernelModPtr SerialCompileImpl(const AnfNodePtr &anf_node) { | |||
| kernel::KernelModPtr kernel_mod_ptr = nullptr; | |||
| KernelType kernel_type = AnfAlgo::GetKernelType(anf_node); | |||
| @@ -159,6 +161,30 @@ static void AddTbeClearZeroNode(mindspore::session::KernelGraph *const kernel_gr | |||
| new_nodes->push_back(clear_zero); | |||
| } | |||
| static void AddFusionTbeClearZeroNode(mindspore::session::KernelGraph *const kernel_graph, | |||
| const mindspore::CNodePtr &stream_node, | |||
| const std::vector<AnfNodePtr> &fusion_clear_inputs, | |||
| const std::vector<size_t> &clean_size_list, | |||
| std::vector<mindspore::CNodePtr> *new_nodes) { | |||
| auto clear_zero_prim = std::make_shared<Primitive>(kAtomicAddrCleanOpName); | |||
| MS_EXCEPTION_IF_NULL(clear_zero_prim); | |||
| auto new_value_node = NewValueNode(clear_zero_prim); | |||
| MS_EXCEPTION_IF_NULL(new_value_node); | |||
| std::vector<AnfNodePtr> inputs = {new_value_node}; | |||
| inputs.insert(inputs.end(), fusion_clear_inputs.begin(), fusion_clear_inputs.end()); | |||
| CNodePtr clear_zero = kernel_graph->NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(clear_zero); | |||
| AbstractBasePtr abstract = std::make_shared<abstract::AbstractNone>(); | |||
| MS_EXCEPTION_IF_NULL(abstract); | |||
| clear_zero->set_abstract(abstract); | |||
| auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); | |||
| builder->SetKernelType(KernelType::TBE_KERNEL); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), clear_zero.get()); | |||
| AnfAlgo::SetNodeAttr(kAttrAtomicAddMemSize, MakeValue(clean_size_list), clear_zero); | |||
| AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(stream_node.get()), clear_zero.get()); | |||
| new_nodes->insert(new_nodes->begin(), clear_zero); | |||
| } | |||
| static bool IsAtomicNode(const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| auto kernel_mod = AnfAlgo::GetKernelMod(kernel_node); | |||
| @@ -264,23 +290,23 @@ std::map<AnfNodePtr, std::vector<size_t>> GetCommunicationOpInputInfo( | |||
| return comm_input_info_map; | |||
| } | |||
| void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph) { | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| static void TbeClearZeroNodeFusion(mindspore::session::KernelGraph *const kernel_graph) { | |||
| std::vector<CNodePtr> new_nodes; | |||
| std::vector<size_t> clean_size_list; | |||
| std::vector<AnfNodePtr> fusion_clear_inputs; | |||
| std::map<AnfNodePtr, std::vector<size_t>> comm_input_info_map = GetCommunicationOpInputInfo(kernel_graph); | |||
| for (const auto &anf_node : kernel_graph->execution_order()) { | |||
| std::string apply_function_name = AnfAlgo::GetCNodeName(anf_node); | |||
| bool is_comm_input = false; | |||
| // set communication input output index attr | |||
| if (comm_input_info_map.find(anf_node) != comm_input_info_map.end()) { | |||
| auto indexes = comm_input_info_map[anf_node]; | |||
| AnfAlgo::SetNodeAttr(kAttrAtomicOutputIndexs, MakeValue(indexes), anf_node); | |||
| is_comm_input = true; | |||
| } | |||
| if (is_comm_input) { | |||
| AddTbeClearZeroNode(kernel_graph, anf_node, &new_nodes); | |||
| } else if (apply_function_name == prim::kPrimMaxPoolGrad->name() && | |||
| AnfAlgo::GetKernelType(anf_node) == KernelType::AKG_KERNEL) { | |||
| if (apply_function_name == prim::kPrimMaxPoolGrad->name() && | |||
| AnfAlgo::GetKernelType(anf_node) == KernelType::AKG_KERNEL) { | |||
| auto clear_zero_prim = std::make_shared<Primitive>(kClearZeroOpName); | |||
| MS_EXCEPTION_IF_NULL(clear_zero_prim); | |||
| auto new_value_node = NewValueNode(clear_zero_prim); | |||
| @@ -299,15 +325,85 @@ void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph) { | |||
| // set the distinction label of clear same with anf | |||
| AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(anf_node.get()), clear_zero.get()); | |||
| new_nodes.push_back(clear_zero); | |||
| } else if (AnfAlgo::GetKernelType(anf_node) == KernelType::TBE_KERNEL) { | |||
| if (IsAtomicNode(anf_node)) { | |||
| AddTbeClearZeroNode(kernel_graph, anf_node, &new_nodes); | |||
| } else if (is_comm_input || | |||
| (AnfAlgo::GetKernelType(anf_node) == KernelType::TBE_KERNEL && IsAtomicNode(anf_node))) { | |||
| auto clean_sizes = CalCleanZerosSize(anf_node); | |||
| if (!clean_sizes.empty()) { | |||
| auto clean_total_num = clean_size_list.size() + clean_sizes.size(); | |||
| if (clean_total_num >= kMaxAttrMemListSize) { | |||
| // create clean node | |||
| auto stream_node = new_nodes.empty() ? anf_node : new_nodes.front(); | |||
| AddFusionTbeClearZeroNode(kernel_graph, stream_node, fusion_clear_inputs, clean_size_list, &new_nodes); | |||
| clean_size_list.clear(); | |||
| fusion_clear_inputs.clear(); | |||
| } | |||
| clean_size_list.insert(clean_size_list.end(), clean_sizes.begin(), clean_sizes.end()); | |||
| fusion_clear_inputs.emplace_back(anf_node); | |||
| MS_LOG(DEBUG) << "fusion_clear_inputs size: " << fusion_clear_inputs.size() | |||
| << ", clean_size_list: " << clean_size_list.size(); | |||
| } | |||
| } | |||
| new_nodes.push_back(anf_node); | |||
| new_nodes.emplace_back(anf_node); | |||
| } | |||
| if (!fusion_clear_inputs.empty() && !clean_size_list.empty()) { | |||
| // create clean node | |||
| auto stream_node = new_nodes.front(); | |||
| AddFusionTbeClearZeroNode(kernel_graph, stream_node, fusion_clear_inputs, clean_size_list, &new_nodes); | |||
| } | |||
| kernel_graph->set_execution_order(new_nodes); | |||
| } | |||
| void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph) { | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| static const auto enable_fusion_clear = (common::GetEnv("ENV_FUSION_CLEAR") == "1"); | |||
| bool is_dynamic_graph = kernel_graph->is_dynamic_shape(); | |||
| if (!is_dynamic_graph && enable_fusion_clear) { | |||
| TbeClearZeroNodeFusion(kernel_graph); | |||
| } else { | |||
| std::vector<CNodePtr> new_nodes; | |||
| std::map<AnfNodePtr, std::vector<size_t>> comm_input_info_map = GetCommunicationOpInputInfo(kernel_graph); | |||
| for (const auto &anf_node : kernel_graph->execution_order()) { | |||
| std::string apply_function_name = AnfAlgo::GetCNodeName(anf_node); | |||
| bool is_comm_input = false; | |||
| if (comm_input_info_map.find(anf_node) != comm_input_info_map.end()) { | |||
| auto indexes = comm_input_info_map[anf_node]; | |||
| AnfAlgo::SetNodeAttr(kAttrAtomicOutputIndexs, MakeValue(indexes), anf_node); | |||
| is_comm_input = true; | |||
| } | |||
| if (is_comm_input) { | |||
| AddTbeClearZeroNode(kernel_graph, anf_node, &new_nodes); | |||
| } else if (apply_function_name == prim::kPrimMaxPoolGrad->name() && | |||
| AnfAlgo::GetKernelType(anf_node) == KernelType::AKG_KERNEL) { | |||
| auto clear_zero_prim = std::make_shared<Primitive>(kClearZeroOpName); | |||
| MS_EXCEPTION_IF_NULL(clear_zero_prim); | |||
| auto new_value_node = NewValueNode(clear_zero_prim); | |||
| MS_EXCEPTION_IF_NULL(new_value_node); | |||
| std::vector<AnfNodePtr> inputs = {new_value_node}; | |||
| inputs.push_back(anf_node); | |||
| CNodePtr clear_zero = kernel_graph->NewCNode(inputs); | |||
| MS_EXCEPTION_IF_NULL(clear_zero); | |||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| clear_zero->set_kernel_info(kernel_info); | |||
| AbstractBasePtr abstract = std::make_shared<abstract::AbstractNone>(); | |||
| MS_EXCEPTION_IF_NULL(abstract); | |||
| AnfAlgo::SetNodeAttr("input_names", MakeValue(std::vector<std::string>({"x"})), clear_zero); | |||
| SelectKernelInfo(clear_zero); | |||
| // set the distinction label of clear same with anf | |||
| AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(anf_node.get()), clear_zero.get()); | |||
| new_nodes.push_back(clear_zero); | |||
| } else if (AnfAlgo::GetKernelType(anf_node) == KernelType::TBE_KERNEL) { | |||
| if (IsAtomicNode(anf_node)) { | |||
| AddTbeClearZeroNode(kernel_graph, anf_node, &new_nodes); | |||
| } | |||
| } | |||
| new_nodes.push_back(anf_node); | |||
| } | |||
| kernel_graph->set_execution_order(new_nodes); | |||
| } | |||
| } | |||
| } // namespace ascend | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| @@ -92,42 +92,48 @@ void TaskGenerator::LaunchAddrCleanAkgKernel(const CNodePtr &anf_node_ptr, Addre | |||
| void TaskGenerator::LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressPtrList *kernel_inputs) { | |||
| MS_EXCEPTION_IF_NULL(anf_node_ptr); | |||
| MS_EXCEPTION_IF_NULL(kernel_inputs); | |||
| if (anf_node_ptr->inputs().size() != 2) { | |||
| // akg process | |||
| if (AnfAlgo::GetKernelType(anf_node_ptr) == KernelType::AKG_KERNEL) { | |||
| LaunchAddrCleanAkgKernel(anf_node_ptr, kernel_inputs); | |||
| return; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(anf_node_ptr->inputs()[1]); | |||
| auto pre_node = (anf_node_ptr->inputs()[1])->cast<CNodePtr>(); | |||
| // set clean output addr | |||
| if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) { | |||
| auto clean_output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicOutputIndexs); | |||
| for (auto index : clean_output_indexs) { | |||
| auto device_address = AnfAlgo::GetOutputAddr(pre_node, index); | |||
| kernel::AddressPtr input = std::make_shared<kernel::Address>(); | |||
| MS_EXCEPTION_IF_NULL(input); | |||
| input->addr = device_address->ptr_; | |||
| MS_EXCEPTION_IF_NULL(input->addr); | |||
| input->size = device_address->size_; | |||
| kernel_inputs->push_back(input); | |||
| // tbe process | |||
| auto input_tensor_num = AnfAlgo::GetInputTensorNum(anf_node_ptr); | |||
| for (size_t i = 0; i < input_tensor_num; i++) { | |||
| // set clean output addr | |||
| MS_EXCEPTION_IF_NULL(anf_node_ptr->inputs()[i + 1]); | |||
| auto pre_node = anf_node_ptr->input(i + 1)->cast<CNodePtr>(); | |||
| if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) { | |||
| auto clean_output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicOutputIndexs); | |||
| for (auto index : clean_output_indexs) { | |||
| auto device_address = AnfAlgo::GetOutputAddr(pre_node, index); | |||
| kernel::AddressPtr input = std::make_shared<kernel::Address>(); | |||
| MS_EXCEPTION_IF_NULL(input); | |||
| input->addr = device_address->ptr_; | |||
| MS_EXCEPTION_IF_NULL(input->addr); | |||
| input->size = device_address->size_; | |||
| kernel_inputs->push_back(input); | |||
| } | |||
| MS_LOG(DEBUG) << "AtomicAddClean clean output size:" << clean_output_indexs.size(); | |||
| } | |||
| MS_LOG(DEBUG) << "AtomicAddClean clean output size:" << clean_output_indexs.size(); | |||
| } | |||
| // set clean workspace address | |||
| if (AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) { | |||
| auto clean_workspace_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicWorkspaceIndexs); | |||
| for (const auto &index : clean_workspace_indexs) { | |||
| auto device_address = AnfAlgo::GetWorkspaceAddr(pre_node, index); | |||
| kernel::AddressPtr workspace = std::make_shared<kernel::Address>(); | |||
| MS_EXCEPTION_IF_NULL(workspace); | |||
| workspace->addr = device_address->ptr_; | |||
| MS_EXCEPTION_IF_NULL(workspace->addr); | |||
| workspace->size = device_address->size_; | |||
| kernel_inputs->push_back(workspace); | |||
| // set clean workspace address | |||
| if (AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) { | |||
| auto clean_workspace_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicWorkspaceIndexs); | |||
| for (const auto &index : clean_workspace_indexs) { | |||
| auto device_address = AnfAlgo::GetWorkspaceAddr(pre_node, index); | |||
| kernel::AddressPtr workspace = std::make_shared<kernel::Address>(); | |||
| MS_EXCEPTION_IF_NULL(workspace); | |||
| workspace->addr = device_address->ptr_; | |||
| MS_EXCEPTION_IF_NULL(workspace->addr); | |||
| workspace->size = device_address->size_; | |||
| kernel_inputs->push_back(workspace); | |||
| } | |||
| MS_LOG(DEBUG) << "AtomicAddClean clean workspace size:" << clean_workspace_indexs.size(); | |||
| } | |||
| } | |||
| auto clear_mems = AnfAlgo::GetNodeAttr<std::vector<size_t>>(anf_node_ptr, kAttrAtomicAddMemSize); | |||
| if (kernel_inputs->size() != clear_mems.size()) { | |||
| MS_LOG(EXCEPTION) << "AtomicAddClean kernel inputs size not equal clear memory size,kerenl_inputs size:" | |||
| MS_LOG(EXCEPTION) << "AtomicAddClean kernel inputs size not equal clear memory size, kernel inputs size:" | |||
| << kernel_inputs->size() << ",clean mem size" << clear_mems.size(); | |||
| } | |||
| } | |||