| @@ -130,28 +130,30 @@ static bool KernelBuildParallelCompile(const mindspore::session::KernelGraph *ke | |||||
| return tbe_ret && akg_ret; | return tbe_ret && akg_ret; | ||||
| } | } | ||||
| static std::vector<int> CalCleanZerosSize(const CNodePtr &pre_node) { | |||||
| static std::vector<size_t> CalCleanZerosSize(const CNodePtr &pre_node) { | |||||
| MS_EXCEPTION_IF_NULL(pre_node); | MS_EXCEPTION_IF_NULL(pre_node); | ||||
| std::vector<int> clean_size_list; | |||||
| auto kernel_mod = AnfAlgo::GetKernelMod(pre_node); | |||||
| MS_EXCEPTION_IF_NULL(kernel_mod); | |||||
| std::vector<size_t> clean_size_list; | |||||
| // clean output | // clean output | ||||
| if (AnfAlgo::HasNodeAttr(kAttrAutomicOutputIndexs, pre_node)) { | |||||
| auto clean_output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAutomicOutputIndexs); | |||||
| for (auto index : clean_output_indexs) { | |||||
| TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(pre_node, index); | |||||
| size_t type_size = GetTypeByte(TypeIdToType(output_type_id)); | |||||
| std::vector<size_t> shape = AnfAlgo::GetOutputDeviceShape(pre_node, index); | |||||
| auto size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>()); | |||||
| clean_size_list.push_back((size + kMemAlignSize + 31) / kMemAlignSize * kMemAlignSize); | |||||
| if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) { | |||||
| auto output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicOutputIndexs); | |||||
| auto output_men_size = kernel_mod->GetOutputSizeList(); | |||||
| for (auto index : output_indexs) { | |||||
| auto clean_item = (output_men_size.at(index) + kMemAlignSize + 31) / kMemAlignSize * kMemAlignSize; | |||||
| clean_size_list.emplace_back(clean_item); | |||||
| } | } | ||||
| } | } | ||||
| // clean workspace | // clean workspace | ||||
| auto workspaces_size = 0; | |||||
| if (AnfAlgo::HasNodeAttr(kAttrAutomicWorkspaceSize, pre_node)) { | |||||
| workspaces_size = AnfAlgo::GetNodeAttr<int>(pre_node, kAttrAutomicWorkspaceSize); | |||||
| clean_size_list.push_back(workspaces_size); | |||||
| if (AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) { | |||||
| auto workspace_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicWorkspaceIndexs); | |||||
| auto workspace_men_sizes = kernel_mod->GetWorkspaceSizeList(); | |||||
| for (const auto &index : workspace_indexs) { | |||||
| auto clean_item = (workspace_men_sizes.at(index) + kMemAlignSize + 31) / kMemAlignSize * kMemAlignSize; | |||||
| clean_size_list.emplace_back(clean_item); | |||||
| } | |||||
| } | } | ||||
| MS_LOG(INFO) << "clear output size:" << clean_size_list.size() << ", workspace size:" << workspaces_size | |||||
| << ",pre_node:" << pre_node->fullname_with_scope(); | |||||
| MS_LOG(INFO) << "clear output size:" << clean_size_list.size() << ",pre_node:" << pre_node->fullname_with_scope(); | |||||
| return clean_size_list; | return clean_size_list; | ||||
| } | } | ||||
| @@ -175,12 +177,12 @@ static void AddTbeClearZeroNode(mindspore::session::KernelGraph *const kernel_gr | |||||
| builder->SetKernelType(KernelType::TBE_KERNEL); | builder->SetKernelType(KernelType::TBE_KERNEL); | ||||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), clear_zero.get()); | AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), clear_zero.get()); | ||||
| auto clean_size = CalCleanZerosSize(pre_node); | auto clean_size = CalCleanZerosSize(pre_node); | ||||
| AnfAlgo::SetNodeAttr(kAttrAutomicAddMemSize, MakeValue(clean_size), clear_zero); | |||||
| AnfAlgo::SetNodeAttr(kAttrAtomicAddMemSize, MakeValue(clean_size), clear_zero); | |||||
| AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(pre_node.get()), clear_zero.get()); | AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(pre_node.get()), clear_zero.get()); | ||||
| new_nodes->push_back(clear_zero); | new_nodes->push_back(clear_zero); | ||||
| } | } | ||||
| bool IsAtomicNode(const CNodePtr &kernel_node) { | |||||
| static bool IsAtomicNode(const CNodePtr &kernel_node) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| auto kernel_mod = AnfAlgo::GetKernelMod(kernel_node); | auto kernel_mod = AnfAlgo::GetKernelMod(kernel_node); | ||||
| MS_EXCEPTION_IF_NULL(kernel_mod); | MS_EXCEPTION_IF_NULL(kernel_mod); | ||||
| @@ -188,40 +190,44 @@ bool IsAtomicNode(const CNodePtr &kernel_node) { | |||||
| if (parameters_indexs.empty()) { | if (parameters_indexs.empty()) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| auto atomic_flag = false; | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | ||||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | ||||
| auto workspace_size_list = kernel_mod->GetWorkspaceSizeList(); | |||||
| size_t workspace_num = kernel_mod->GetWorkspaceSizeList().size(); | size_t workspace_num = kernel_mod->GetWorkspaceSizeList().size(); | ||||
| if (input_num + workspace_num + output_num > parameters_indexs.size()) { | |||||
| size_t lossNum = (input_num + workspace_num + output_num) - parameters_indexs.size(); | |||||
| for (size_t i = 0; i < lossNum; i++) { | |||||
| parameters_indexs.push_back(0); | |||||
| } | |||||
| size_t param_num = parameters_indexs.size(); | |||||
| size_t total_num = input_num + workspace_num + output_num; | |||||
| MS_LOG(INFO) << "parameters size: " << param_num << ", input & workspace & output num: " << total_num; | |||||
| size_t pad_index = param_num; | |||||
| for (; pad_index < total_num; ++pad_index) { | |||||
| parameters_indexs.emplace_back(0); | |||||
| } | } | ||||
| std::vector<size_t> clean_output_indexs; | |||||
| // in parameters data sort as input->workspace->output | |||||
| size_t index = 0; | |||||
| while (index < output_num) { | |||||
| if (parameters_indexs[input_num + workspace_num + index] == 1) { | |||||
| atomic_flag = true; | |||||
| clean_output_indexs.push_back(index); | |||||
| // process input | |||||
| for (size_t j = 0; j < input_num; ++j) { | |||||
| if (parameters_indexs.at(j) == 1) { | |||||
| MS_LOG(EXCEPTION) << "Atomic addr clean does't support clean input address, input index: " << j; | |||||
| } | } | ||||
| index++; | |||||
| } | } | ||||
| if (atomic_flag) { | |||||
| AnfAlgo::SetNodeAttr(kAttrAutomicOutputIndexs, MakeValue(clean_output_indexs), kernel_node); | |||||
| // process output | |||||
| std::vector<size_t> output_indexs; | |||||
| for (size_t i = 0; i < output_num; ++i) { | |||||
| auto param_output = parameters_indexs.at(input_num + workspace_num + i); | |||||
| if (param_output == 1) { | |||||
| output_indexs.emplace_back(i); | |||||
| MS_LOG(INFO) << "Atomic clear output index: " << i; | |||||
| } | |||||
| } | } | ||||
| for (size_t i = 0; i < workspace_num; ++i) { | |||||
| if (parameters_indexs[input_num + i] == 1) { | |||||
| atomic_flag = true; | |||||
| AnfAlgo::SetNodeAttr(kAttrAutomicWorkspaceSize, | |||||
| MakeValue(std::accumulate(workspace_size_list.begin(), workspace_size_list.end(), 0)), | |||||
| kernel_node); | |||||
| break; | |||||
| AnfAlgo::SetNodeAttr(kAttrAtomicOutputIndexs, MakeValue(output_indexs), kernel_node); | |||||
| // process workspace | |||||
| std::vector<size_t> workspace_indexs; | |||||
| for (size_t k = 0; k < workspace_num; ++k) { | |||||
| auto param_workspace = parameters_indexs.at(input_num + k); | |||||
| if (param_workspace == 1) { | |||||
| workspace_indexs.emplace_back(k); | |||||
| MS_LOG(INFO) << "Atomic clear workspace index: " << k; | |||||
| } | } | ||||
| } | } | ||||
| return atomic_flag; | |||||
| AnfAlgo::SetNodeAttr(kAttrAtomicWorkspaceIndexs, MakeValue(workspace_indexs), kernel_node); | |||||
| return !(workspace_indexs.empty() && output_indexs.empty()); | |||||
| } | } | ||||
| bool KernelPreBuild(const mindspore::session::KernelGraph *kernel_graph_ptr) { | bool KernelPreBuild(const mindspore::session::KernelGraph *kernel_graph_ptr) { | ||||
| @@ -45,8 +45,8 @@ void TaskGenerator::LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressP | |||||
| if (anf_node_ptr->inputs().size() != 2) { | if (anf_node_ptr->inputs().size() != 2) { | ||||
| // akg process | // akg process | ||||
| // set atomic clean addr | // set atomic clean addr | ||||
| if (AnfAlgo::HasNodeAttr(kAttrAutomicOutputIndexs, anf_node_ptr)) { | |||||
| auto clean_output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(anf_node_ptr, kAttrAutomicOutputIndexs); | |||||
| if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, anf_node_ptr)) { | |||||
| auto clean_output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(anf_node_ptr, kAttrAtomicOutputIndexs); | |||||
| auto graph = anf_node_ptr->func_graph(); | auto graph = anf_node_ptr->func_graph(); | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| auto manager = graph->manager(); | auto manager = graph->manager(); | ||||
| @@ -78,8 +78,8 @@ void TaskGenerator::LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressP | |||||
| MS_EXCEPTION_IF_NULL(anf_node_ptr->inputs()[1]); | MS_EXCEPTION_IF_NULL(anf_node_ptr->inputs()[1]); | ||||
| auto pre_node = (anf_node_ptr->inputs()[1])->cast<CNodePtr>(); | auto pre_node = (anf_node_ptr->inputs()[1])->cast<CNodePtr>(); | ||||
| // set clean output addr | // set clean output addr | ||||
| if (AnfAlgo::HasNodeAttr(kAttrAutomicOutputIndexs, pre_node)) { | |||||
| auto clean_output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAutomicOutputIndexs); | |||||
| if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) { | |||||
| auto clean_output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicOutputIndexs); | |||||
| for (auto index : clean_output_indexs) { | for (auto index : clean_output_indexs) { | ||||
| auto device_address = AnfAlgo::GetOutputAddr(pre_node, index); | auto device_address = AnfAlgo::GetOutputAddr(pre_node, index); | ||||
| kernel::AddressPtr input = std::make_shared<kernel::Address>(); | kernel::AddressPtr input = std::make_shared<kernel::Address>(); | ||||
| @@ -92,10 +92,10 @@ void TaskGenerator::LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressP | |||||
| MS_LOG(DEBUG) << "AtomicAddClean clean output size:" << clean_output_indexs.size(); | MS_LOG(DEBUG) << "AtomicAddClean clean output size:" << clean_output_indexs.size(); | ||||
| } | } | ||||
| // set clean workspace address | // set clean workspace address | ||||
| if (AnfAlgo::HasNodeAttr(kAttrAutomicWorkspaceSize, pre_node)) { | |||||
| auto clean_workspaces = AnfAlgo::GetNodeAttr<int>(pre_node, kAttrAutomicWorkspaceSize); | |||||
| if (clean_workspaces != 0) { | |||||
| auto device_address = AnfAlgo::GetWorkspaceAddr(pre_node, 0); | |||||
| if (AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) { | |||||
| auto clean_workspace_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicWorkspaceIndexs); | |||||
| for (const auto &index : clean_workspace_indexs) { | |||||
| auto device_address = AnfAlgo::GetWorkspaceAddr(pre_node, index); | |||||
| kernel::AddressPtr workspace = std::make_shared<kernel::Address>(); | kernel::AddressPtr workspace = std::make_shared<kernel::Address>(); | ||||
| MS_EXCEPTION_IF_NULL(workspace); | MS_EXCEPTION_IF_NULL(workspace); | ||||
| workspace->addr = device_address->ptr_; | workspace->addr = device_address->ptr_; | ||||
| @@ -103,9 +103,8 @@ void TaskGenerator::LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressP | |||||
| workspace->size = device_address->size_; | workspace->size = device_address->size_; | ||||
| kernel_inputs->push_back(workspace); | kernel_inputs->push_back(workspace); | ||||
| } | } | ||||
| MS_LOG(INFO) << "AtomicAddClean clean workspace size" << clean_workspaces; | |||||
| } | } | ||||
| auto clear_mems = AnfAlgo::GetNodeAttr<std::vector<int>>(anf_node_ptr, kAttrAutomicAddMemSize); | |||||
| auto clear_mems = AnfAlgo::GetNodeAttr<std::vector<size_t>>(anf_node_ptr, kAttrAtomicAddMemSize); | |||||
| if (kernel_inputs->size() != clear_mems.size()) { | if (kernel_inputs->size() != clear_mems.size()) { | ||||
| MS_LOG(EXCEPTION) << "AtomicAddClean kernel inputs size not equal clear memory size,kerenl_inputs size:" | MS_LOG(EXCEPTION) << "AtomicAddClean kernel inputs size not equal clear memory size,kerenl_inputs size:" | ||||
| << kernel_inputs->size() << ",clean mem size" << clear_mems.size(); | << kernel_inputs->size() << ",clean mem size" << clear_mems.size(); | ||||
| @@ -676,8 +676,8 @@ void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList | |||||
| MS_EXCEPTION_IF_NULL(cnode->inputs()[1]); | MS_EXCEPTION_IF_NULL(cnode->inputs()[1]); | ||||
| auto pre_node = (cnode->inputs()[1])->cast<CNodePtr>(); | auto pre_node = (cnode->inputs()[1])->cast<CNodePtr>(); | ||||
| // set clean output address | // set clean output address | ||||
| if (AnfAlgo::HasNodeAttr(kAttrAutomicOutputIndexs, pre_node)) { | |||||
| auto clean_output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAutomicOutputIndexs); | |||||
| if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) { | |||||
| auto clean_output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicOutputIndexs); | |||||
| for (auto index : clean_output_indexs) { | for (auto index : clean_output_indexs) { | ||||
| auto device_address = AnfAlgo::GetOutputAddr(pre_node, index); | auto device_address = AnfAlgo::GetOutputAddr(pre_node, index); | ||||
| kernel::AddressPtr input = std::make_shared<kernel::Address>(); | kernel::AddressPtr input = std::make_shared<kernel::Address>(); | ||||
| @@ -690,10 +690,10 @@ void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList | |||||
| MS_LOG(INFO) << "AtomicAddClean clean output size:" << clean_output_indexs.size(); | MS_LOG(INFO) << "AtomicAddClean clean output size:" << clean_output_indexs.size(); | ||||
| } | } | ||||
| // set clean workspace address | // set clean workspace address | ||||
| if (AnfAlgo::HasNodeAttr(kAttrAutomicWorkspaceSize, pre_node)) { | |||||
| auto clean_workspaces = AnfAlgo::GetNodeAttr<int>(pre_node, kAttrAutomicWorkspaceSize); | |||||
| if (clean_workspaces != 0) { | |||||
| auto device_address = AnfAlgo::GetWorkspaceAddr(pre_node, 0); | |||||
| if (AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) { | |||||
| auto clean_workspaces_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicWorkspaceIndexs); | |||||
| for (const auto &index : clean_workspaces_indexs) { | |||||
| auto device_address = AnfAlgo::GetWorkspaceAddr(pre_node, index); | |||||
| kernel::AddressPtr workspace = std::make_shared<kernel::Address>(); | kernel::AddressPtr workspace = std::make_shared<kernel::Address>(); | ||||
| MS_EXCEPTION_IF_NULL(workspace); | MS_EXCEPTION_IF_NULL(workspace); | ||||
| workspace->addr = device_address->ptr_; | workspace->addr = device_address->ptr_; | ||||
| @@ -701,7 +701,6 @@ void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList | |||||
| workspace->size = device_address->size_; | workspace->size = device_address->size_; | ||||
| kernel_inputs->emplace_back(workspace); | kernel_inputs->emplace_back(workspace); | ||||
| } | } | ||||
| MS_LOG(INFO) << "AtomicAddClean clean workspace size" << clean_workspaces; | |||||
| } | } | ||||
| } | } | ||||
| @@ -70,50 +70,6 @@ const std::unordered_map<std::string, FusionType> fusion_type_maps = { | |||||
| {"SEGMENT", FusionType::SEGMENT}, {"OPAQUE", FusionType::OPAQUE}, | {"SEGMENT", FusionType::SEGMENT}, {"OPAQUE", FusionType::OPAQUE}, | ||||
| }; | }; | ||||
| bool IsAtomicNode(const CNodePtr &kernel_node) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||||
| auto kernel_mod = AnfAlgo::GetKernelMod(kernel_node); | |||||
| MS_EXCEPTION_IF_NULL(kernel_mod); | |||||
| auto parameters_indexs = kernel_mod->GenParameters(); | |||||
| if (parameters_indexs.empty()) { | |||||
| return false; | |||||
| } | |||||
| auto atomic_flag = false; | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||||
| auto workspace_size_list = kernel_mod->GetWorkspaceSizeList(); | |||||
| size_t workspace_num = kernel_mod->GetWorkspaceSizeList().size(); | |||||
| if (input_num + workspace_num + output_num > parameters_indexs.size()) { | |||||
| size_t lossNum = (input_num + workspace_num + output_num) - parameters_indexs.size(); | |||||
| for (size_t i = 0; i < lossNum; i++) { | |||||
| parameters_indexs.push_back(0); | |||||
| } | |||||
| } | |||||
| std::vector<int> clean_output_indexs; | |||||
| // in parameters data sort as input->workspace->output | |||||
| size_t index = 0; | |||||
| while (index < output_num) { | |||||
| if (parameters_indexs[input_num + workspace_num + index] == 1) { | |||||
| atomic_flag = true; | |||||
| clean_output_indexs.push_back(SizeToInt(index)); | |||||
| } | |||||
| index++; | |||||
| } | |||||
| if (atomic_flag) { | |||||
| AnfAlgo::SetNodeAttr(kAttrAutomicOutputIndexs, MakeValue(clean_output_indexs), kernel_node); | |||||
| } | |||||
| for (size_t i = 0; i < workspace_num; ++i) { | |||||
| if (parameters_indexs[input_num + i] == 1) { | |||||
| atomic_flag = true; | |||||
| AnfAlgo::SetNodeAttr(kAttrAutomicWorkspaceSize, | |||||
| MakeValue(std::accumulate(workspace_size_list.begin(), workspace_size_list.end(), 0)), | |||||
| kernel_node); | |||||
| break; | |||||
| } | |||||
| } | |||||
| return atomic_flag; | |||||
| } | |||||
| void KernelMeta::Initialize() { | void KernelMeta::Initialize() { | ||||
| kernel_meta_path_ = std::string(kGpuKernelMeta) + "_" + std::to_string(getpid()) + "/"; | kernel_meta_path_ = std::string(kGpuKernelMeta) + "_" + std::to_string(getpid()) + "/"; | ||||
| // remove old kernel cache | // remove old kernel cache | ||||
| @@ -65,6 +65,7 @@ constexpr auto kVTypeBool = "bool"; | |||||
| constexpr auto kVTypeFloat = "float"; | constexpr auto kVTypeFloat = "float"; | ||||
| constexpr auto kVTypeListInt = "listInt"; | constexpr auto kVTypeListInt = "listInt"; | ||||
| constexpr auto kVTypeInt32 = "Int32"; | constexpr auto kVTypeInt32 = "Int32"; | ||||
| constexpr auto kVTypeListUInt64 = "listUInt64"; | |||||
| constexpr auto kVTypeListFloat = "listFloat"; | constexpr auto kVTypeListFloat = "listFloat"; | ||||
| constexpr auto kVTypeListListInt = "listListInt"; | constexpr auto kVTypeListListInt = "listListInt"; | ||||
| constexpr auto kJValue = "value"; | constexpr auto kJValue = "value"; | ||||
| @@ -443,6 +444,9 @@ void TbeKernelJsonCreator::ParseAttrValue(const std::string &type, const mindspo | |||||
| attr_value = GetValue<std::vector<float>>(value); | attr_value = GetValue<std::vector<float>>(value); | ||||
| } | } | ||||
| (*attr_obj)[kJValue] = attr_value; | (*attr_obj)[kJValue] = attr_value; | ||||
| } else if (type == kVTypeListUInt64) { | |||||
| auto attr_value = GetValue<std::vector<size_t>>(value); | |||||
| (*attr_obj)[kJValue] = attr_value; | |||||
| } else if (type == kVTypeListListInt) { | } else if (type == kVTypeListListInt) { | ||||
| auto attr_value = GetValue<std::vector<std::vector<int>>>(value); | auto attr_value = GetValue<std::vector<std::vector<int>>>(value); | ||||
| (*attr_obj)[kJValue] = attr_value; | (*attr_obj)[kJValue] = attr_value; | ||||
| @@ -70,8 +70,8 @@ CNodePtr CreateTbeAtomicCleanNode(const std::shared_ptr<session::KernelGraph> &k | |||||
| builder->SetKernelType(KernelType::TBE_KERNEL); | builder->SetKernelType(KernelType::TBE_KERNEL); | ||||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), clean_zero.get()); | AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), clean_zero.get()); | ||||
| auto clean_size = CalCleanSize(pre_node); | auto clean_size = CalCleanSize(pre_node); | ||||
| AnfAlgo::SetNodeAttr(kAttrAutomicAddMemSize, MakeValue(clean_size), clean_zero); | |||||
| AnfAlgo::SetNodeAttr(kAttrAutomicOutputIndexs, MakeValue(g_output_idx), clean_zero); | |||||
| AnfAlgo::SetNodeAttr(kAttrAtomicAddMemSize, MakeValue(clean_size), clean_zero); | |||||
| AnfAlgo::SetNodeAttr(kAttrAtomicOutputIndexs, MakeValue(g_output_idx), clean_zero); | |||||
| AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(pre_node.get()), clean_zero.get()); | AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(pre_node.get()), clean_zero.get()); | ||||
| return clean_zero; | return clean_zero; | ||||
| } | } | ||||
| @@ -180,9 +180,9 @@ constexpr auto kAttrKeepDims = "keep_dims"; | |||||
| constexpr auto kAttrShapeGamma = "shape_gamma"; | constexpr auto kAttrShapeGamma = "shape_gamma"; | ||||
| constexpr auto kAttrPerm = "perm"; | constexpr auto kAttrPerm = "perm"; | ||||
| constexpr auto kAttrTransposeFirst = "transpose_first"; | constexpr auto kAttrTransposeFirst = "transpose_first"; | ||||
| constexpr auto kAttrAutomicAddMemSize = "automic_add_mem_size"; | |||||
| constexpr auto kAttrAutomicOutputIndexs = "atomic_output_clean_indexs"; | |||||
| constexpr auto kAttrAutomicWorkspaceSize = "atomic_workspace_clean_size"; | |||||
| constexpr auto kAttrAtomicAddMemSize = "automic_add_mem_size"; | |||||
| constexpr auto kAttrAtomicOutputIndexs = "atomic_output_clean_indexs"; | |||||
| constexpr auto kAttrAtomicWorkspaceIndexs = "atomic_workspace_clean_indexs"; | |||||
| constexpr auto kAttrSwitchCondition = "switch_condition"; | constexpr auto kAttrSwitchCondition = "switch_condition"; | ||||
| constexpr auto kAttrDataType = "data_type"; | constexpr auto kAttrDataType = "data_type"; | ||||
| constexpr auto kAttrActiveTarget = "active_target"; | constexpr auto kAttrActiveTarget = "active_target"; | ||||
| @@ -23,7 +23,7 @@ atomic_addr_clean_op_info = TBERegOp("AtomicAddrClean") \ | |||||
| .compute_cost(10) \ | .compute_cost(10) \ | ||||
| .kernel_name("atomic_addr_clean") \ | .kernel_name("atomic_addr_clean") \ | ||||
| .partial_flag(True) \ | .partial_flag(True) \ | ||||
| .attr("automic_add_mem_size", "required", "listInt", "all") \ | |||||
| .attr("automic_add_mem_size", "required", "listUInt64", "all") \ | |||||
| .get_op_info() | .get_op_info() | ||||