|
|
|
@@ -19,13 +19,16 @@ |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace kernel { |
|
|
|
const size_t kUseBucketUniqueSize = 100000; |
|
|
|
constexpr size_t kBucketSortThreshold = 100000; |
|
|
|
void UniqueCPUKernel::InitKernel(const CNodePtr &kernel_node) { |
|
|
|
node_wpt_ = kernel_node; |
|
|
|
CheckParam(kernel_node); |
|
|
|
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); |
|
|
|
input_size_ = input_shape[0]; |
|
|
|
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); |
|
|
|
if (AnfAlgo::HasNodeAttr(SORTED, kernel_node)) { |
|
|
|
sorted_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, SORTED); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void UniqueCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { |
|
|
|
@@ -41,9 +44,11 @@ bool UniqueCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, |
|
|
|
if (dtype_ == kNumberTypeInt32) { |
|
|
|
LaunchKernel<int, int>(inputs, workspace, outputs); |
|
|
|
} else if (dtype_ == kNumberTypeInt64) { |
|
|
|
LaunchKernel<int64_t, int>(inputs, workspace, outputs); |
|
|
|
} else if (dtype_ == kNumberTypeFloat32) { |
|
|
|
LaunchKernel<int64_t, int64_t>(inputs, workspace, outputs); |
|
|
|
} else if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat16) { |
|
|
|
LaunchKernel<float, int>(inputs, workspace, outputs); |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "Not support type: " << dtype_; |
|
|
|
} |
|
|
|
if (!node_wpt_.expired()) { |
|
|
|
auto node_ = node_wpt_.lock(); |
|
|
|
@@ -86,12 +91,18 @@ void UniqueCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const |
|
|
|
params->inverse_idx_ = reinterpret_cast<IndexType *>(outputs[1]->addr); |
|
|
|
params->input_size_ = input_size_; |
|
|
|
params->output_size_ = 0; |
|
|
|
params->need_sort_ = true; |
|
|
|
|
|
|
|
params->thread_num_ = common::ThreadPool::GetInstance().GetSyncRunThreadNum(); |
|
|
|
if (input_size_ < kUseBucketUniqueSize) { |
|
|
|
Unique(params); |
|
|
|
if (sorted_) { |
|
|
|
params->need_sort_ = true; |
|
|
|
if (input_size_ < kBucketSortThreshold) { |
|
|
|
Unique(params); |
|
|
|
} else { |
|
|
|
BucketUnique(params); |
|
|
|
} |
|
|
|
} else { |
|
|
|
BucketUnique(params); |
|
|
|
params->need_sort_ = false; |
|
|
|
Unique(params); |
|
|
|
} |
|
|
|
output_size_ = params->output_size_; |
|
|
|
} |
|
|
|
|