From d6dc3d6da638010adaeb943dc7aac550d5b0fc29 Mon Sep 17 00:00:00 2001 From: zengbin93 Date: Sat, 3 Aug 2019 20:34:35 +0800 Subject: [PATCH] add text cluster --- jiagu/__init__.py | 1 + jiagu/cluster/__init__.py | 2 +- jiagu/cluster/text.py | 57 +++++++++++++++++++++++++++++++++++++++ test/test_cluster.py | 30 +++++++++++++++++++++ 4 files changed, 89 insertions(+), 1 deletion(-) create mode 100644 jiagu/cluster/text.py diff --git a/jiagu/__init__.py b/jiagu/__init__.py index 826a4d2..84e336e 100644 --- a/jiagu/__init__.py +++ b/jiagu/__init__.py @@ -9,6 +9,7 @@ * Description : """ from jiagu import analyze +from jiagu.cluster.text import text_cluster any = analyze.Analyze() diff --git a/jiagu/cluster/__init__.py b/jiagu/cluster/__init__.py index 31147e6..94e7efc 100644 --- a/jiagu/cluster/__init__.py +++ b/jiagu/cluster/__init__.py @@ -2,4 +2,4 @@ from .kmeans import KMeans from .dbscan import DBSCAN - +from .base import count_features diff --git a/jiagu/cluster/text.py b/jiagu/cluster/text.py new file mode 100644 index 0000000..60ea935 --- /dev/null +++ b/jiagu/cluster/text.py @@ -0,0 +1,57 @@ +# coding: utf-8 +from collections import OrderedDict + +from .base import count_features +from .dbscan import DBSCAN +from .kmeans import KMeans + + +def text_cluster(docs, method="k-means", k=None, max_iter=100, eps=None, min_pts=None): + """文本聚类,目前支持 K-Means 和 DBSCAN 两种方法 + + :param docs: list of str + 输入的文本列表,如 ['k-means', 'dbscan'] + :param method: str + 指定使用的方法,默认为 k-means,可选 'k-means', 'dbscan' + :param k: int + k-means 参数,类簇数量 + :param max_iter: int + k-means 参数,最大迭代次数,确保模型不收敛的情况下可以退出循环 + :param eps: float + dbscan 参数,邻域距离 + :param min_pts: + dbscan 参数,核心对象中的最少样本数量 + :return: OrderedDict + 聚类结果 + """ + features, names = count_features(docs) + + # feature to doc + f2d = {k: v.tolist() for k, v in zip(docs, features)} + + if method == 'k-means': + km = KMeans(k=k, max_iter=max_iter) + clusters = km.train(features) + + elif method == 'dbscan': + ds = DBSCAN(eps=eps, min_pts=min_pts) + clusters = ds.train(features) + + else: + raise ValueError("method invalid, please use 'k-means' or 'dbscan'") + + clusters_out = OrderedDict() + + for label, examples in clusters.items(): + c_docs = [] + for example in examples: + doc = [d for d, f in f2d.items() if list(example) == f] + c_docs.extend(doc) + + clusters_out[label] = list(set(c_docs)) + + return clusters_out + + + + diff --git a/test/test_cluster.py b/test/test_cluster.py index 06c3f1c..3e8b5bf 100644 --- a/test/test_cluster.py +++ b/test/test_cluster.py @@ -6,6 +6,7 @@ from pprint import pprint from jiagu.cluster.kmeans import KMeans from jiagu.cluster.dbscan import DBSCAN +from jiagu.cluster.text import text_cluster def load_dataset(): @@ -57,6 +58,20 @@ def show_dataset(): plt.show() +def load_docs(): + docs = [ + "百度深度学习中文情感分析工具Senta试用及在线测试", + "情感分析是自然语言处理里面一个热门话题", + "AI Challenger 2018 文本挖掘类竞赛相关解决方案及代码汇总", + "深度学习实践:从零开始做电影评论文本情感分析", + "BERT相关论文、文章和代码资源汇总", + "将不同长度的句子用BERT预训练模型编码,映射到一个固定长度的向量上", + "自然语言处理工具包spaCy介绍", + "现在可以快速测试一下spaCy的相关功能,我们以英文数据为例,spaCy目前主要支持英文和德文" + ] + return docs + + class TestCluster(unittest.TestCase): def test_a_kmeans(self): print("=" * 68, '\n') @@ -83,6 +98,21 @@ class TestCluster(unittest.TestCase): # self.assertEqual(len(clusters), 6) pprint({k: len(v) for k, v in clusters.items()}) + def test_c_text_cluster_by_kmeans(self): + print("=" * 68, '\n') + print("text_cluster_by_kmeans ... ") + docs = load_docs() + clusters = text_cluster(docs, method='k-means', k=3, max_iter=100) + self.assertTrue(len(clusters) == 3) + + def test_c_text_cluster_by_dbscan(self): + print("=" * 68, '\n') + print("text_cluster_by_dbscan ... ") + docs = load_docs() + clusters = text_cluster(docs, method='dbscan', eps=5, min_pts=1) + self.assertTrue(len(clusters) == 3) + if __name__ == '__main__': unittest.main() +