diff --git a/requirements.txt b/requirements.txt index 3dfd400..73f5ab3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,4 +21,12 @@ metaphor_python==0.1.23 langchainhub==0.1.20 langchain-experimental==0.0.62 arxiv==2.1.3 -spark_ai_python~=0.4.4 \ No newline at end of file +spark_ai_python~=0.4.4 +colorama~=0.4.6 +tqdm~=4.66.4 +numpy~=1.26.4 +chardet~=5.2.0 +elasticsearch~=8.15.0 +tenacity~=8.5.0 +pymilvus~=2.4.5 +python-dateutil~=2.9.0post0 \ No newline at end of file diff --git a/src/mindpilot/app/api/api_server.py b/src/mindpilot/app/api/api_server.py index 560d2ea..ab092e6 100644 --- a/src/mindpilot/app/api/api_server.py +++ b/src/mindpilot/app/api/api_server.py @@ -7,6 +7,7 @@ from .tool_routes import tool_router from .agent_routes import agent_router from .config_routes import config_router from .conversation_routes import conversation_router +from .kb_routes import kb_router def create_app(run_mode: str = None): @@ -28,5 +29,6 @@ def create_app(run_mode: str = None): app.include_router(agent_router) app.include_router(config_router) app.include_router(conversation_router) + app.include_router(kb_router) return app diff --git a/src/mindpilot/app/api/kb_routes.py b/src/mindpilot/app/api/kb_routes.py new file mode 100644 index 0000000..999b0ec --- /dev/null +++ b/src/mindpilot/app/api/kb_routes.py @@ -0,0 +1,92 @@ +from __future__ import annotations +from typing import List +from fastapi import APIRouter, Request + +# from chatchat.server.chat.file_chat import upload_temp_docs +from ..knowledge_base.kb_api import create_kb, delete_kb, list_kbs +from ..knowledge_base.kb_doc_api import ( + delete_docs, + download_doc, + list_files, + recreate_vector_store, + search_docs, + update_docs, + update_info, + upload_docs, +) +# from chatchat.server.knowledge_base.kb_summary_api import ( +# recreate_summary_vector_store, +# summary_doc_ids_to_vector_store, +# summary_file_to_vector_store, +# ) +from ..utils.system_utils import BaseResponse, ListResponse + +kb_router = APIRouter(prefix="/knowledge_base", tags=["知识库管理"]) + + +kb_router.get( + "/list_knowledge_bases", response_model=ListResponse, summary="获取知识库列表" +)(list_kbs) + +kb_router.post( + "/create_knowledge_base", response_model=BaseResponse, summary="创建知识库" +)(create_kb) + +kb_router.post( + "/delete_knowledge_base", response_model=BaseResponse, summary="删除知识库" +)(delete_kb) + +kb_router.get( + "/list_files", response_model=ListResponse, summary="获取知识库内的文件列表" +)(list_files) + +kb_router.post( + "/search_docs", response_model=List[dict], summary="搜索知识库" +)(search_docs) + +kb_router.post( + "/upload_docs", + response_model=BaseResponse, + summary="上传文件到知识库,并/或进行向量化", +)(upload_docs) + +kb_router.post( + "/delete_docs", response_model=BaseResponse, summary="删除知识库内指定文件" +)(delete_docs) + +kb_router.post( + "/update_info", response_model=BaseResponse, summary="更新知识库介绍" +)(update_info) + +kb_router.post( + "/update_docs", response_model=BaseResponse, summary="更新现有文件到知识库" +)(update_docs) + +kb_router.get( + "/download_doc", summary="下载对应的知识文件" +)(download_doc) + +kb_router.post( + "/recreate_vector_store", summary="根据content中文档重建向量库,流式输出处理进度。" +)(recreate_vector_store) + +# kb_router.post( +# "/upload_temp_docs", summary="上传文件到临时目录,用于文件对话。")(upload_temp_docs) +# +# +# summary_router = APIRouter(prefix="/kb_summary_api") +# summary_router.post( +# "/summary_file_to_vector_store", summary="单个知识库根据文件名称摘要" +# )(summary_file_to_vector_store) +# +# summary_router.post( +# "/summary_doc_ids_to_vector_store", +# summary="单个知识库根据doc_ids摘要", +# response_model=BaseResponse, +# )(summary_doc_ids_to_vector_store) +# +# summary_router.post( +# "/recreate_summary_vector_store", summary="重建单个知识库文件摘要" +# )(recreate_summary_vector_store) +# +# kb_router.include_router(summary_router) diff --git a/src/mindpilot/app/configs/__init__.py b/src/mindpilot/app/configs/__init__.py index f03d84c..a3328df 100644 --- a/src/mindpilot/app/configs/__init__.py +++ b/src/mindpilot/app/configs/__init__.py @@ -2,6 +2,7 @@ from .system_config import HOST, PORT from .model_config import MODEL_CONFIG from .prompt_config import OPENAI_PROMPT, PROMPT_TEMPLATES from .tool_config import TOOL_CONFIG +from .kb_config import * __all__ = [ "HOST", @@ -9,5 +10,27 @@ __all__ = [ "MODEL_CONFIG", "OPENAI_PROMPT", "TOOL_CONFIG", - "PROMPT_TEMPLATES" -] + "PROMPT_TEMPLATES", + "DEFAULT_KNOWLEDGE_BASE", + "DEFAULT_VS_TYPE", + "CACHED_VS_NUM", + "CACHED_MEMO_VS_NUM", + "CHUNK_SIZE", + "OVERLAP_SIZE", + "VECTOR_SEARCH_TOP_K", + "SCORE_THRESHOLD", + "DEFAULT_SEARCH_ENGINE", + "SEARCH_ENGINE_TOP_K", + "ZH_TITLE_ENHANCE", + # "PDF_OCR_THRESHOLD", + "KB_INFO", + "CHATCHAT_ROOT", + "KB_ROOT_PATH", + "DB_ROOT_PATH", + "SQLALCHEMY_DATABASE_URI", + "kbs_config", + "text_splitter_dict", + "TEXT_SPLITTER_NAME", + "EMBEDDING_KEYWORD_FILE", + "DEFAULT_EMBEDDING_MODEL", +] \ No newline at end of file diff --git a/src/mindpilot/app/configs/kb_config.py b/src/mindpilot/app/configs/kb_config.py new file mode 100644 index 0000000..ee74dc3 --- /dev/null +++ b/src/mindpilot/app/configs/kb_config.py @@ -0,0 +1,113 @@ +import os +from pathlib import Path + +DEFAULT_KNOWLEDGE_BASE = "samples" + +# 默认向量库/全文检索引擎类型。 +DEFAULT_VS_TYPE = "faiss" + +# 缓存向量库数量(针对FAISS) +CACHED_VS_NUM = 1 + +# 缓存临时向量库数量(针对FAISS),用于文件对话 +CACHED_MEMO_VS_NUM = 10 + +# 知识库中单段文本长度(不适用MarkdownHeaderTextSplitter) +CHUNK_SIZE = 250 + +# 知识库中相邻文本重合长度(不适用MarkdownHeaderTextSplitter) +OVERLAP_SIZE = 50 + +# 知识库匹配向量数量 +VECTOR_SEARCH_TOP_K = 3 + +# 知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右 +SCORE_THRESHOLD = 1 + +# 默认搜索引擎。可选:bing, duckduckgo, metaphor +DEFAULT_SEARCH_ENGINE = "bing" + +# 搜索引擎匹配结题数量 +SEARCH_ENGINE_TOP_K = 3 + +# 是否开启中文标题加强,以及标题增强的相关配置 +# 通过增加标题判断,判断哪些文本为标题,并在metadata中进行标记; +# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。 +ZH_TITLE_ENHANCE = False + +# PDF OCR 控制:只对宽高超过页面一定比例(图片宽/页面宽,图片高/页面高)的图片进行 OCR。 +# 这样可以避免 PDF 中一些小图片的干扰,提高非扫描版 PDF 处理速度 +PDF_OCR_THRESHOLD = (0.6, 0.6) + +# 每个知识库的初始化介绍,用于在初始化知识库时显示和Agent调用,没写则没有介绍,不会被Agent调用。 +KB_INFO = { + "samples": "关于本项目issue的解答", +} + +CHATCHAT_ROOT = str(Path(__file__).absolute().parent.parent.parent) + +KB_ROOT_PATH = os.path.join(CHATCHAT_ROOT, "knowledge_base") + +# 数据库默认存储路径。 + +DB_ROOT_PATH = os.path.join(CHATCHAT_ROOT, "mindpilot.db") +SQLALCHEMY_DATABASE_URI = f"sqlite:///{DB_ROOT_PATH}" + +# 可选向量库类型及对应配置 +kbs_config = { + "faiss": {}, + "milvus": { + "host": "127.0.0.1", + "port": "19530", + "user": "", + "password": "", + "secure": False, + }, + "es": { + "host": "127.0.0.1", + "port": "9200", + "index_name": "test_index", + "user": "", + "password": "", + }, + "milvus_kwargs": { + "search_params": {"metric_type": "L2"}, # 在此处增加search_params + "index_params": { + "metric_type": "L2", + "index_type": "HNSW", + "params": {"M": 8, "efConstruction": 64}, + }, # 在此处增加index_params + }, +} + +# TextSplitter配置项,如果你不明白其中的含义,就不要修改。 +text_splitter_dict = { + "ChineseRecursiveTextSplitter": { + "source": "", # 选择tiktoken则使用openai的方法 "huggingface" + "tokenizer_name_or_path": "", + }, + "SpacyTextSplitter": { + "source": "huggingface", + "tokenizer_name_or_path": "gpt2", + }, + "RecursiveCharacterTextSplitter": { + "source": "tiktoken", + "tokenizer_name_or_path": "cl100k_base", + }, + "MarkdownHeaderTextSplitter": { + "headers_to_split_on": [ + ("#", "head1"), + ("##", "head2"), + ("###", "head3"), + ("####", "head4"), + ] + }, +} + +# TEXT_SPLITTER 名称 +TEXT_SPLITTER_NAME = "ChineseRecursiveTextSplitter" + +# Embedding模型定制词语的词表文件 +EMBEDDING_KEYWORD_FILE = "embedding_keywords.txt" + +DEFAULT_EMBEDDING_MODEL = "bce-embedding-base_v1" diff --git a/src/mindpilot/app/knowledge_base/__init__.py b/src/mindpilot/app/knowledge_base/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/mindpilot/app/knowledge_base/db/__init__.py b/src/mindpilot/app/knowledge_base/db/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/mindpilot/app/knowledge_base/db/base.py b/src/mindpilot/app/knowledge_base/db/base.py new file mode 100644 index 0000000..8ff1b67 --- /dev/null +++ b/src/mindpilot/app/knowledge_base/db/base.py @@ -0,0 +1,17 @@ +import json + +from sqlalchemy import create_engine +from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base +from sqlalchemy.orm import sessionmaker + +from ...configs.kb_config import SQLALCHEMY_DATABASE_URI + + +engine = create_engine( + SQLALCHEMY_DATABASE_URI, + json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False), +) + +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +Base: DeclarativeMeta = declarative_base() diff --git a/src/mindpilot/app/knowledge_base/db/models/__init__.py b/src/mindpilot/app/knowledge_base/db/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/mindpilot/app/knowledge_base/db/models/base.py b/src/mindpilot/app/knowledge_base/db/models/base.py new file mode 100644 index 0000000..17e96a4 --- /dev/null +++ b/src/mindpilot/app/knowledge_base/db/models/base.py @@ -0,0 +1,17 @@ +from datetime import datetime + +from sqlalchemy import Column, DateTime, Integer, String + + +class BaseModel: + """ + 基础模型 + """ + + id = Column(Integer, primary_key=True, index=True, comment="主键ID") + create_time = Column(DateTime, default=datetime.utcnow, comment="创建时间") + update_time = Column( + DateTime, default=None, onupdate=datetime.utcnow, comment="更新时间" + ) + create_by = Column(String, default=None, comment="创建者") + update_by = Column(String, default=None, comment="更新者") diff --git a/src/mindpilot/app/knowledge_base/db/models/knowledge_base_model.py b/src/mindpilot/app/knowledge_base/db/models/knowledge_base_model.py new file mode 100644 index 0000000..fd7f790 --- /dev/null +++ b/src/mindpilot/app/knowledge_base/db/models/knowledge_base_model.py @@ -0,0 +1,39 @@ +from datetime import datetime +from typing import Optional + +from pydantic import BaseModel +from sqlalchemy import Column, DateTime, Integer, String, func + +from ..base import Base + + +class KnowledgeBaseModel(Base): + """ + 知识库模型 + """ + + __tablename__ = "knowledge_base" + id = Column(Integer, primary_key=True, autoincrement=True, comment="知识库ID") + kb_name = Column(String(50), comment="知识库名称") + kb_info = Column(String(200), comment="知识库简介(用于Agent)") + vs_type = Column(String(50), comment="向量库类型") + embed_model = Column(String(50), comment="嵌入模型名称") + file_count = Column(Integer, default=0, comment="文件数量") + create_time = Column(DateTime, default=func.now(), comment="创建时间") + + def __repr__(self): + return f"" + + +# 创建一个对应的 Pydantic 模型 +class KnowledgeBaseSchema(BaseModel): + id: int + kb_name: str + kb_info: Optional[str] + vs_type: Optional[str] + embed_model: Optional[str] + file_count: Optional[int] + create_time: Optional[datetime] + + class Config: + from_attributes = True # 确保可以从 ORM 实例进行验证 diff --git a/src/mindpilot/app/knowledge_base/db/models/knowledge_file_model.py b/src/mindpilot/app/knowledge_base/db/models/knowledge_file_model.py new file mode 100644 index 0000000..2b4d3ec --- /dev/null +++ b/src/mindpilot/app/knowledge_base/db/models/knowledge_file_model.py @@ -0,0 +1,42 @@ +from sqlalchemy import JSON, Boolean, Column, DateTime, Float, Integer, String, func + +from ..base import Base + + +class KnowledgeFileModel(Base): + """ + 知识文件模型 + """ + + __tablename__ = "knowledge_file" + id = Column(Integer, primary_key=True, autoincrement=True, comment="知识文件ID") + file_name = Column(String(255), comment="文件名") + file_ext = Column(String(10), comment="文件扩展名") + kb_name = Column(String(50), comment="所属知识库名称") + document_loader_name = Column(String(50), comment="文档加载器名称") + text_splitter_name = Column(String(50), comment="文本分割器名称") + file_version = Column(Integer, default=1, comment="文件版本") + file_mtime = Column(Float, default=0.0, comment="文件修改时间") + file_size = Column(Integer, default=0, comment="文件大小") + custom_docs = Column(Boolean, default=False, comment="是否自定义docs") + docs_count = Column(Integer, default=0, comment="切分文档数量") + create_time = Column(DateTime, default=func.now(), comment="创建时间") + + def __repr__(self): + return f"" + + +class FileDocModel(Base): + """ + 文件-向量库文档模型 + """ + + __tablename__ = "file_doc" + id = Column(Integer, primary_key=True, autoincrement=True, comment="ID") + kb_name = Column(String(50), comment="知识库名称") + file_name = Column(String(255), comment="文件名称") + doc_id = Column(String(50), comment="向量库文档ID") + meta_data = Column(JSON, default={}) + + def __repr__(self): + return f"" diff --git a/src/mindpilot/app/knowledge_base/db/models/knowledge_metadata_model.py b/src/mindpilot/app/knowledge_base/db/models/knowledge_metadata_model.py new file mode 100644 index 0000000..5372e3d --- /dev/null +++ b/src/mindpilot/app/knowledge_base/db/models/knowledge_metadata_model.py @@ -0,0 +1,31 @@ +from sqlalchemy import JSON, Boolean, Column, DateTime, Float, Integer, String, func + +from ..base import Base + + +class SummaryChunkModel(Base): + """ + chunk summary模型,用于存储file_doc中每个doc_id的chunk 片段, + 数据来源: + 用户输入: 用户上传文件,可填写文件的描述,生成的file_doc中的doc_id,存入summary_chunk中 + 程序自动切分 对file_doc表meta_data字段信息中存储的页码信息,按每页的页码切分,自定义prompt生成总结文本,将对应页码关联的doc_id存入summary_chunk中 + 后续任务: + 矢量库构建: 对数据库表summary_chunk中summary_context创建索引,构建矢量库,meta_data为矢量库的元数据(doc_ids) + 语义关联: 通过用户输入的描述,自动切分的总结文本,计算 + 语义相似度 + + """ + + __tablename__ = "summary_chunk" + id = Column(Integer, primary_key=True, autoincrement=True, comment="ID") + kb_name = Column(String(50), comment="知识库名称") + summary_context = Column(String(255), comment="总结文本") + summary_id = Column(String(255), comment="总结矢量id") + doc_ids = Column(String(1024), comment="向量库id关联列表") + meta_data = Column(JSON, default={}) + + def __repr__(self): + return ( + f"" + ) diff --git a/src/mindpilot/app/knowledge_base/db/repository/__init__.py b/src/mindpilot/app/knowledge_base/db/repository/__init__.py new file mode 100644 index 0000000..e188d9e --- /dev/null +++ b/src/mindpilot/app/knowledge_base/db/repository/__init__.py @@ -0,0 +1,3 @@ +from .knowledge_base_repository import * +from .knowledge_file_repository import * + diff --git a/src/mindpilot/app/knowledge_base/db/repository/knowledge_base_repository.py b/src/mindpilot/app/knowledge_base/db/repository/knowledge_base_repository.py new file mode 100644 index 0000000..376c140 --- /dev/null +++ b/src/mindpilot/app/knowledge_base/db/repository/knowledge_base_repository.py @@ -0,0 +1,93 @@ +from ...db.models.knowledge_base_model import ( + KnowledgeBaseModel, + KnowledgeBaseSchema, +) +from ...db.session import with_session + + +@with_session +def add_kb_to_db(session, kb_name, kb_info, vs_type, embed_model): + # 创建知识库实例 + kb = ( + session.query(KnowledgeBaseModel) + .filter(KnowledgeBaseModel.kb_name.ilike(kb_name)) + .first() + ) + if not kb: + kb = KnowledgeBaseModel( + kb_name=kb_name, kb_info=kb_info, vs_type=vs_type, embed_model=embed_model + ) + session.add(kb) + else: # update kb with new vs_type and embed_model + kb.kb_info = kb_info + kb.vs_type = vs_type + kb.embed_model = embed_model + return True + + +@with_session +def list_kbs_from_db(session, min_file_count: int = -1): + kbs = ( + session.query(KnowledgeBaseModel) + .filter(KnowledgeBaseModel.file_count > min_file_count) + .all() + ) + kbs = [KnowledgeBaseSchema.model_validate(kb) for kb in kbs] + return kbs + + +@with_session +def kb_exists(session, kb_name): + kb = ( + session.query(KnowledgeBaseModel) + .filter(KnowledgeBaseModel.kb_name.ilike(kb_name)) + .first() + ) + status = True if kb else False + return status + + +@with_session +def load_kb_from_db(session, kb_name): + kb = ( + session.query(KnowledgeBaseModel) + .filter(KnowledgeBaseModel.kb_name.ilike(kb_name)) + .first() + ) + if kb: + kb_name, vs_type, embed_model = kb.kb_name, kb.vs_type, kb.embed_model + else: + kb_name, vs_type, embed_model = None, None, None + return kb_name, vs_type, embed_model + + +@with_session +def delete_kb_from_db(session, kb_name): + kb = ( + session.query(KnowledgeBaseModel) + .filter(KnowledgeBaseModel.kb_name.ilike(kb_name)) + .first() + ) + if kb: + session.delete(kb) + return True + + +@with_session +def get_kb_detail(session, kb_name: str) -> dict: + kb: KnowledgeBaseModel = ( + session.query(KnowledgeBaseModel) + .filter(KnowledgeBaseModel.kb_name.ilike(kb_name)) + .first() + ) + if kb: + return { + "kb_name": kb.kb_name, + "kb_info": kb.kb_info, + "vs_type": kb.vs_type, + "embed_model": kb.embed_model, + "file_count": kb.file_count, + "create_time": kb.create_time, + } + else: + return {} diff --git a/src/mindpilot/app/knowledge_base/db/repository/knowledge_file_repository.py b/src/mindpilot/app/knowledge_base/db/repository/knowledge_file_repository.py new file mode 100644 index 0000000..198b620 --- /dev/null +++ b/src/mindpilot/app/knowledge_base/db/repository/knowledge_file_repository.py @@ -0,0 +1,245 @@ +from typing import Dict, List + +from ...db.models.knowledge_base_model import KnowledgeBaseModel +from ...db.models.knowledge_file_model import ( + FileDocModel, + KnowledgeFileModel, +) +from ...db.session import with_session +from ...utils import KnowledgeFile + + +@with_session +def list_file_num_docs_id_by_kb_name_and_file_name( + session, + kb_name: str, + file_name: str, +) -> List[int]: + """ + 列出某知识库某文件对应的所有Document的id。 + 返回形式:[str, ...] + """ + doc_ids = ( + session.query(FileDocModel.doc_id) + .filter_by(kb_name=kb_name, file_name=file_name) + .all() + ) + return [int(_id[0]) for _id in doc_ids] + + +@with_session +def list_docs_from_db( + session, + kb_name: str, + file_name: str = None, + metadata: Dict = {}, +) -> List[Dict]: + """ + 列出某知识库某文件对应的所有Document。 + 返回形式:[{"id": str, "metadata": dict}, ...] + """ + docs = session.query(FileDocModel).filter(FileDocModel.kb_name.ilike(kb_name)) + if file_name: + docs = docs.filter(FileDocModel.file_name.ilike(file_name)) + for k, v in metadata.items(): + docs = docs.filter(FileDocModel.meta_data[k].as_string() == str(v)) + + return [{"id": x.doc_id, "metadata": x.metadata} for x in docs.all()] + + +@with_session +def delete_docs_from_db( + session, + kb_name: str, + file_name: str = None, +) -> List[Dict]: + """ + 删除某知识库某文件对应的所有Document,并返回被删除的Document。 + 返回形式:[{"id": str, "metadata": dict}, ...] + """ + docs = list_docs_from_db(kb_name=kb_name, file_name=file_name) + query = session.query(FileDocModel).filter(FileDocModel.kb_name.ilike(kb_name)) + if file_name: + query = query.filter(FileDocModel.file_name.ilike(file_name)) + query.delete(synchronize_session=False) + session.commit() + return docs + + +@with_session +def add_docs_to_db(session, kb_name: str, file_name: str, doc_infos: List[Dict]): + """ + 将某知识库某文件对应的所有Document信息添加到数据库。 + doc_infos形式:[{"id": str, "metadata": dict}, ...] + """ + # ! 这里会出现doc_infos为None的情况,需要进一步排查 + if doc_infos is None: + print( + "输入的server.db.repository.knowledge_file_repository.add_docs_to_db的doc_infos参数为None" + ) + return False + for d in doc_infos: + obj = FileDocModel( + kb_name=kb_name, + file_name=file_name, + doc_id=d["id"], + meta_data=d["metadata"], + ) + session.add(obj) + return True + + +@with_session +def count_files_from_db(session, kb_name: str) -> int: + return ( + session.query(KnowledgeFileModel) + .filter(KnowledgeFileModel.kb_name.ilike(kb_name)) + .count() + ) + + +@with_session +def list_files_from_db(session, kb_name): + files = ( + session.query(KnowledgeFileModel) + .filter(KnowledgeFileModel.kb_name.ilike(kb_name)) + .all() + ) + docs = [f.file_name for f in files] + return docs + + +@with_session +def add_file_to_db( + session, + kb_file: KnowledgeFile, + docs_count: int = 0, + custom_docs: bool = False, + doc_infos: List[Dict] = [], # 形式:[{"id": str, "metadata": dict}, ...] +): + kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first() + if kb: + # 如果已经存在该文件,则更新文件信息与版本号 + existing_file: KnowledgeFileModel = ( + session.query(KnowledgeFileModel) + .filter( + KnowledgeFileModel.kb_name.ilike(kb_file.kb_name), + KnowledgeFileModel.file_name.ilike(kb_file.filename), + ) + .first() + ) + mtime = kb_file.get_mtime() + size = kb_file.get_size() + + if existing_file: + existing_file.file_mtime = mtime + existing_file.file_size = size + existing_file.docs_count = docs_count + existing_file.custom_docs = custom_docs + existing_file.file_version += 1 + # 否则,添加新文件 + else: + new_file = KnowledgeFileModel( + file_name=kb_file.filename, + file_ext=kb_file.ext, + kb_name=kb_file.kb_name, + document_loader_name=kb_file.document_loader_name, + text_splitter_name=kb_file.text_splitter_name or "SpacyTextSplitter", + file_mtime=mtime, + file_size=size, + docs_count=docs_count, + custom_docs=custom_docs, + ) + kb.file_count += 1 + session.add(new_file) + add_docs_to_db( + kb_name=kb_file.kb_name, file_name=kb_file.filename, doc_infos=doc_infos + ) + return True + + +@with_session +def delete_file_from_db(session, kb_file: KnowledgeFile): + existing_file = ( + session.query(KnowledgeFileModel) + .filter( + KnowledgeFileModel.file_name.ilike(kb_file.filename), + KnowledgeFileModel.kb_name.ilike(kb_file.kb_name), + ) + .first() + ) + if existing_file: + session.delete(existing_file) + delete_docs_from_db(kb_name=kb_file.kb_name, file_name=kb_file.filename) + session.commit() + + kb = ( + session.query(KnowledgeBaseModel) + .filter(KnowledgeBaseModel.kb_name.ilike(kb_file.kb_name)) + .first() + ) + if kb: + kb.file_count -= 1 + session.commit() + return True + + +@with_session +def delete_files_from_db(session, knowledge_base_name: str): + session.query(KnowledgeFileModel).filter( + KnowledgeFileModel.kb_name.ilike(knowledge_base_name) + ).delete(synchronize_session=False) + session.query(FileDocModel).filter( + FileDocModel.kb_name.ilike(knowledge_base_name) + ).delete(synchronize_session=False) + kb = ( + session.query(KnowledgeBaseModel) + .filter(KnowledgeBaseModel.kb_name.ilike(knowledge_base_name)) + .first() + ) + if kb: + kb.file_count = 0 + + session.commit() + return True + + +@with_session +def file_exists_in_db(session, kb_file: KnowledgeFile): + existing_file = ( + session.query(KnowledgeFileModel) + .filter( + KnowledgeFileModel.file_name.ilike(kb_file.filename), + KnowledgeFileModel.kb_name.ilike(kb_file.kb_name), + ) + .first() + ) + return True if existing_file else False + + +@with_session +def get_file_detail(session, kb_name: str, filename: str) -> dict: + file: KnowledgeFileModel = ( + session.query(KnowledgeFileModel) + .filter( + KnowledgeFileModel.file_name.ilike(filename), + KnowledgeFileModel.kb_name.ilike(kb_name), + ) + .first() + ) + if file: + return { + "kb_name": file.kb_name, + "file_name": file.file_name, + "file_ext": file.file_ext, + "file_version": file.file_version, + "document_loader": file.document_loader_name, + "text_splitter": file.text_splitter_name, + "create_time": file.create_time, + "file_mtime": file.file_mtime, + "file_size": file.file_size, + "custom_docs": file.custom_docs, + "docs_count": file.docs_count, + } + else: + return {} diff --git a/src/mindpilot/app/knowledge_base/db/repository/knowledge_metadata_repository.py b/src/mindpilot/app/knowledge_base/db/repository/knowledge_metadata_repository.py new file mode 100644 index 0000000..35731b9 --- /dev/null +++ b/src/mindpilot/app/knowledge_base/db/repository/knowledge_metadata_repository.py @@ -0,0 +1,77 @@ +from typing import Dict, List + +from ...db.models.knowledge_metadata_model import SummaryChunkModel +from ...db.session import with_session + + +@with_session +def list_summary_from_db( + session, + kb_name: str, + metadata: Dict = {}, +) -> List[Dict]: + """ + 列出某知识库chunk summary。 + 返回形式:[{"id": str, "summary_context": str, "doc_ids": str}, ...] + """ + docs = session.query(SummaryChunkModel).filter( + SummaryChunkModel.kb_name.ilike(kb_name) + ) + + for k, v in metadata.items(): + docs = docs.filter(SummaryChunkModel.meta_data[k].as_string() == str(v)) + + return [ + { + "id": x.id, + "summary_context": x.summary_context, + "summary_id": x.summary_id, + "doc_ids": x.doc_ids, + "metadata": x.metadata, + } + for x in docs.all() + ] + + +@with_session +def delete_summary_from_db(session, kb_name: str) -> List[Dict]: + """ + 删除知识库chunk summary,并返回被删除的Dchunk summary。 + 返回形式:[{"id": str, "summary_context": str, "doc_ids": str}, ...] + """ + docs = list_summary_from_db(kb_name=kb_name) + query = session.query(SummaryChunkModel).filter( + SummaryChunkModel.kb_name.ilike(kb_name) + ) + query.delete(synchronize_session=False) + session.commit() + return docs + + +@with_session +def add_summary_to_db(session, kb_name: str, summary_infos: List[Dict]): + """ + 将总结信息添加到数据库。 + summary_infos形式:[{"summary_context": str, "doc_ids": str}, ...] + """ + for summary in summary_infos: + obj = SummaryChunkModel( + kb_name=kb_name, + summary_context=summary["summary_context"], + summary_id=summary["summary_id"], + doc_ids=summary["doc_ids"], + meta_data=summary["metadata"], + ) + session.add(obj) + + session.commit() + return True + + +@with_session +def count_summary_from_db(session, kb_name: str) -> int: + return ( + session.query(SummaryChunkModel) + .filter(SummaryChunkModel.kb_name.ilike(kb_name)) + .count() + ) diff --git a/src/mindpilot/app/knowledge_base/db/session.py b/src/mindpilot/app/knowledge_base/db/session.py new file mode 100644 index 0000000..2c97765 --- /dev/null +++ b/src/mindpilot/app/knowledge_base/db/session.py @@ -0,0 +1,48 @@ +from contextlib import contextmanager +from functools import wraps + +from sqlalchemy.orm import Session + +from .base import SessionLocal + + +@contextmanager +def session_scope() -> Session: + """上下文管理器用于自动获取 Session, 避免错误""" + session = SessionLocal() + try: + yield session + session.commit() + except: + session.rollback() + raise + finally: + session.close() + + +def with_session(f): + @wraps(f) + def wrapper(*args, **kwargs): + with session_scope() as session: + try: + result = f(session, *args, **kwargs) + session.commit() + return result + except: + session.rollback() + raise + + return wrapper + + +def get_db() -> SessionLocal: + db = SessionLocal() + try: + yield db + finally: + db.close() + + +def get_db0() -> SessionLocal: + db = SessionLocal() + return db diff --git a/src/mindpilot/app/knowledge_base/embedding/localai_embeddings.py b/src/mindpilot/app/knowledge_base/embedding/localai_embeddings.py new file mode 100644 index 0000000..360c10e --- /dev/null +++ b/src/mindpilot/app/knowledge_base/embedding/localai_embeddings.py @@ -0,0 +1,380 @@ +from __future__ import annotations + +import logging +import warnings +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Optional, + Sequence, + Set, + Tuple, + Union, +) + +from langchain_community.utils.openai import is_openai_v1 +from langchain_core.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel, Field, root_validator +from langchain_core.utils import get_from_dict_or_env, get_pydantic_field_names +from tenacity import ( + AsyncRetrying, + before_sleep_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +from ...utils.system_utils import run_in_thread_pool + +logger = logging.getLogger(__name__) + + +def _create_retry_decorator(embeddings: LocalAIEmbeddings) -> Callable[[Any], Any]: + import openai + + min_seconds = 4 + max_seconds = 10 + # Wait 2^x * 1 second between each retry starting with + # 4 seconds, then up to 10 seconds, then 10 seconds afterwards + return retry( + reraise=True, + stop=stop_after_attempt(embeddings.max_retries), + wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), + retry=( + retry_if_exception_type(openai.Timeout) + | retry_if_exception_type(openai.APIError) + | retry_if_exception_type(openai.APIConnectionError) + | retry_if_exception_type(openai.RateLimitError) + | retry_if_exception_type(openai.InternalServerError) + ), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + + +def _async_retry_decorator(embeddings: LocalAIEmbeddings) -> Any: + import openai + + min_seconds = 4 + max_seconds = 10 + # Wait 2^x * 1 second between each retry starting with + # 4 seconds, then up to 10 seconds, then 10 seconds afterwards + async_retrying = AsyncRetrying( + reraise=True, + stop=stop_after_attempt(embeddings.max_retries), + wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), + retry=( + retry_if_exception_type(openai.Timeout) + | retry_if_exception_type(openai.APIError) + | retry_if_exception_type(openai.APIConnectionError) + | retry_if_exception_type(openai.RateLimitError) + | retry_if_exception_type(openai.InternalServerError) + ), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + + def wrap(func: Callable) -> Callable: + async def wrapped_f(*args: Any, **kwargs: Any) -> Callable: + async for _ in async_retrying: + return await func(*args, **kwargs) + raise AssertionError("this is unreachable") + + return wrapped_f + + return wrap + + +# https://stackoverflow.com/questions/76469415/getting-embeddings-of-length-1-from-langchain-openaiembeddings +def _check_response(response: dict) -> dict: + if any([len(d.embedding) == 1 for d in response.data]): + import openai + + raise openai.APIError("LocalAI API returned an empty embedding") + return response + + +def embed_with_retry(embeddings: LocalAIEmbeddings, **kwargs: Any) -> Any: + """Use tenacity to retry the embedding call.""" + retry_decorator = _create_retry_decorator(embeddings) + + @retry_decorator + def _embed_with_retry(**kwargs: Any) -> Any: + response = embeddings.client.create(**kwargs) + return _check_response(response) + + return _embed_with_retry(**kwargs) + + +async def async_embed_with_retry(embeddings: LocalAIEmbeddings, **kwargs: Any) -> Any: + """Use tenacity to retry the embedding call.""" + + @_async_retry_decorator(embeddings) + async def _async_embed_with_retry(**kwargs: Any) -> Any: + response = await embeddings.async_client.create(**kwargs) + return _check_response(response) + + return await _async_embed_with_retry(**kwargs) + + +class LocalAIEmbeddings(BaseModel, Embeddings): + """LocalAI embedding models. + + Since LocalAI and OpenAI have 1:1 compatibility between APIs, this class + uses the ``openai`` Python package's ``openai.Embedding`` as its client. + Thus, you should have the ``openai`` python package installed, and defeat + the environment variable ``OPENAI_API_KEY`` by setting to a random string. + You also need to specify ``OPENAI_API_BASE`` to point to your LocalAI + service endpoint. + + Example: + .. code-block:: python + + from langchain_community.embeddings import LocalAIEmbeddings + openai = LocalAIEmbeddings( + openai_api_key="random-string", + openai_api_base="http://localhost:8080" + ) + + """ + + client: Any = Field(default=None, exclude=True) #: :meta private: + async_client: Any = Field(default=None, exclude=True) #: :meta private: + model: str = "text-embedding-ada-002" + deployment: str = model + openai_api_version: Optional[str] = None + openai_api_base: Optional[str] = Field(default=None, alias="base_url") + # to support explicit proxy for LocalAI + openai_proxy: Optional[str] = None + embedding_ctx_length: int = 8191 + """The maximum number of tokens to embed at once.""" + openai_api_key: Optional[str] = Field(default=None, alias="api_key") + openai_organization: Optional[str] = Field(default=None, alias="organization") + allowed_special: Union[Literal["all"], Set[str]] = set() + disallowed_special: Union[Literal["all"], Set[str], Sequence[str]] = "all" + chunk_size: int = 1000 + """Maximum number of texts to embed in each batch""" + max_retries: int = 6 + """Maximum number of retries to make when generating.""" + request_timeout: Union[float, Tuple[float, float], Any, None] = Field( + default=None, alias="timeout" + ) + """Timeout in seconds for the LocalAI request.""" + headers: Any = None + show_progress_bar: bool = False + """Whether to show a progress bar when embedding.""" + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Holds any model parameters valid for `create` call not explicitly specified.""" + + class Config: + """Configuration for this pydantic object.""" + + allow_population_by_field_name = True + + @root_validator(pre=True) + def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Build extra kwargs from additional params that were passed in.""" + all_required_field_names = get_pydantic_field_names(cls) + extra = values.get("model_kwargs", {}) + for field_name in list(values): + if field_name in extra: + raise ValueError(f"Found {field_name} supplied twice.") + if field_name not in all_required_field_names: + warnings.warn( + f"""WARNING! {field_name} is not default parameter. + {field_name} was transferred to model_kwargs. + Please confirm that {field_name} is what you intended.""" + ) + extra[field_name] = values.pop(field_name) + + invalid_model_kwargs = all_required_field_names.intersection(extra.keys()) + if invalid_model_kwargs: + raise ValueError( + f"Parameters {invalid_model_kwargs} should be specified explicitly. " + f"Instead they were passed in as part of `model_kwargs` parameter." + ) + + values["model_kwargs"] = extra + return values + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + values["openai_api_key"] = get_from_dict_or_env( + values, "openai_api_key", "OPENAI_API_KEY" + ) + values["openai_api_base"] = get_from_dict_or_env( + values, + "openai_api_base", + "OPENAI_API_BASE", + default="", + ) + values["openai_proxy"] = get_from_dict_or_env( + values, + "openai_proxy", + "OPENAI_PROXY", + default="", + ) + + default_api_version = "" + values["openai_api_version"] = get_from_dict_or_env( + values, + "openai_api_version", + "OPENAI_API_VERSION", + default=default_api_version, + ) + values["openai_organization"] = get_from_dict_or_env( + values, + "openai_organization", + "OPENAI_ORGANIZATION", + default="", + ) + try: + import openai + + if is_openai_v1(): + client_params = { + "api_key": values["openai_api_key"], + "organization": values["openai_organization"], + "base_url": values["openai_api_base"], + "timeout": values["request_timeout"], + "max_retries": values["max_retries"], + } + + if not values.get("client"): + values["client"] = openai.OpenAI(**client_params).embeddings + if not values.get("async_client"): + values["async_client"] = openai.AsyncOpenAI( + **client_params + ).embeddings + elif not values.get("client"): + values["client"] = openai.Embedding + else: + pass + except ImportError: + raise ImportError( + "Could not import openai python package. " + "Please install it with `pip install openai`." + ) + return values + + @property + def _invocation_params(self) -> Dict: + openai_args = { + "model": self.model, + "timeout": self.request_timeout, + "extra_headers": self.headers, + **self.model_kwargs, + } + if self.openai_proxy: + import openai + + openai.proxy = { + "http": self.openai_proxy, + "https": self.openai_proxy, + } # type: ignore[assignment] # noqa: E501 + return openai_args + + def _embedding_func(self, text: str, *, engine: str) -> List[float]: + """Call out to LocalAI's embedding endpoint.""" + # handle large input text + if self.model.endswith("001"): + # See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500 + # replace newlines, which can negatively affect performance. + text = text.replace("\n", " ") + return ( + embed_with_retry( + self, + input=[text], + **self._invocation_params, + ) + .data[0] + .embedding + ) + + async def _aembedding_func(self, text: str, *, engine: str) -> List[float]: + """Call out to LocalAI's embedding endpoint.""" + # handle large input text + if self.model.endswith("001"): + # See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500 + # replace newlines, which can negatively affect performance. + text = text.replace("\n", " ") + return ( + ( + await async_embed_with_retry( + self, + input=[text], + **self._invocation_params, + ) + ) + .data[0] + .embedding + ) + + def embed_documents( + self, texts: List[str], chunk_size: Optional[int] = 0 + ) -> List[List[float]]: + """Call out to LocalAI's embedding endpoint for embedding search docs. + + Args: + texts: The list of texts to embed. + chunk_size: The chunk size of embeddings. If None, will use the chunk size + specified by the class. + + Returns: + List of embeddings, one for each text. + """ + + # call _embedding_func for each text with multithreads + def task(seq, text): + return (seq, self._embedding_func(text, engine=self.deployment)) + + params = [{"seq": i, "text": text} for i, text in enumerate(texts)] + result = list(run_in_thread_pool(func=task, params=params)) + result = sorted(result, key=lambda x: x[0]) + return [x[1] for x in result] + + async def aembed_documents( + self, texts: List[str], chunk_size: Optional[int] = 0 + ) -> List[List[float]]: + """Call out to LocalAI's embedding endpoint async for embedding search docs. + + Args: + texts: The list of texts to embed. + chunk_size: The chunk size of embeddings. If None, will use the chunk size + specified by the class. + + Returns: + List of embeddings, one for each text. + """ + embeddings = [] + for text in texts: + response = await self._aembedding_func(text, engine=self.deployment) + embeddings.append(response) + return embeddings + + def embed_query(self, text: str) -> List[float]: + """Call out to LocalAI's embedding endpoint for embedding query text. + + Args: + text: The text to embed. + + Returns: + Embedding for the text. + """ + embedding = self._embedding_func(text, engine=self.deployment) + return embedding + + async def aembed_query(self, text: str) -> List[float]: + """Call out to LocalAI's embedding endpoint async for embedding query text. + + Args: + text: The text to embed. + + Returns: + Embedding for the text. + """ + embedding = await self._aembedding_func(text, engine=self.deployment) + return embedding diff --git a/src/mindpilot/app/knowledge_base/file_rag/__init__.py b/src/mindpilot/app/knowledge_base/file_rag/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/mindpilot/app/knowledge_base/file_rag/document_loaders/FilteredCSVloader.py b/src/mindpilot/app/knowledge_base/file_rag/document_loaders/FilteredCSVloader.py new file mode 100644 index 0000000..c1cc6b6 --- /dev/null +++ b/src/mindpilot/app/knowledge_base/file_rag/document_loaders/FilteredCSVloader.py @@ -0,0 +1,87 @@ +## 指定制定列的csv文件加载器 + +import csv +from io import TextIOWrapper +from typing import Dict, List, Optional + +from langchain.docstore.document import Document +from langchain_community.document_loaders import CSVLoader +from langchain_community.document_loaders.helpers import detect_file_encodings + + +class FilteredCSVLoader(CSVLoader): + def __init__( + self, + file_path: str, + columns_to_read: List[str], + source_column: Optional[str] = None, + metadata_columns: List[str] = [], + csv_args: Optional[Dict] = None, + encoding: Optional[str] = None, + autodetect_encoding: bool = False, + ): + super().__init__( + file_path=file_path, + source_column=source_column, + metadata_columns=metadata_columns, + csv_args=csv_args, + encoding=encoding, + autodetect_encoding=autodetect_encoding, + ) + self.columns_to_read = columns_to_read + + def load(self) -> List[Document]: + """Load data into document objects.""" + + docs = [] + try: + with open(self.file_path, newline="", encoding=self.encoding) as csvfile: + docs = self.__read_file(csvfile) + except UnicodeDecodeError as e: + if self.autodetect_encoding: + detected_encodings = detect_file_encodings(self.file_path) + for encoding in detected_encodings: + try: + with open( + self.file_path, newline="", encoding=encoding.encoding + ) as csvfile: + docs = self.__read_file(csvfile) + break + except UnicodeDecodeError: + continue + else: + raise RuntimeError(f"Error loading {self.file_path}") from e + except Exception as e: + raise RuntimeError(f"Error loading {self.file_path}") from e + + return docs + + def __read_file(self, csvfile: TextIOWrapper) -> List[Document]: + docs = [] + csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore + for i, row in enumerate(csv_reader): + content = [] + for col in self.columns_to_read: + if col in row: + content.append(f"{col}:{str(row[col])}") + else: + raise ValueError( + f"Column '{self.columns_to_read[0]}' not found in CSV file." + ) + content = "\n".join(content) + # Extract the source if available + source = ( + row.get(self.source_column, None) + if self.source_column is not None + else self.file_path + ) + metadata = {"source": source, "row": i} + + for col in self.metadata_columns: + if col in row: + metadata[col] = row[col] + + doc = Document(page_content=content, metadata=metadata) + docs.append(doc) + + return docs diff --git a/src/mindpilot/app/knowledge_base/file_rag/document_loaders/__init__.py b/src/mindpilot/app/knowledge_base/file_rag/document_loaders/__init__.py new file mode 100644 index 0000000..eb99c82 --- /dev/null +++ b/src/mindpilot/app/knowledge_base/file_rag/document_loaders/__init__.py @@ -0,0 +1,4 @@ +from .mydocloader import RapidOCRDocLoader +from .myimgloader import RapidOCRLoader +from .mypdfloader import RapidOCRPDFLoader +from .mypptloader import RapidOCRPPTLoader diff --git a/src/mindpilot/app/knowledge_base/file_rag/document_loaders/mydocloader.py b/src/mindpilot/app/knowledge_base/file_rag/document_loaders/mydocloader.py new file mode 100644 index 0000000..82d71ed --- /dev/null +++ b/src/mindpilot/app/knowledge_base/file_rag/document_loaders/mydocloader.py @@ -0,0 +1,79 @@ +from typing import List + +import tqdm +from langchain_community.document_loaders.unstructured import UnstructuredFileLoader + + +class RapidOCRDocLoader(UnstructuredFileLoader): + def _get_elements(self) -> List: + def doc2text(filepath): + from io import BytesIO + + import numpy as np + from docx import Document, ImagePart + from docx.oxml.table import CT_Tbl + from docx.oxml.text.paragraph import CT_P + from docx.table import Table, _Cell + from docx.text.paragraph import Paragraph + from PIL import Image + from rapidocr_onnxruntime import RapidOCR + + ocr = RapidOCR() + doc = Document(filepath) + resp = "" + + def iter_block_items(parent): + from docx.document import Document + + if isinstance(parent, Document): + parent_elm = parent.element.body + elif isinstance(parent, _Cell): + parent_elm = parent._tc + else: + raise ValueError("RapidOCRDocLoader parse fail") + + for child in parent_elm.iterchildren(): + if isinstance(child, CT_P): + yield Paragraph(child, parent) + elif isinstance(child, CT_Tbl): + yield Table(child, parent) + + b_unit = tqdm.tqdm( + total=len(doc.paragraphs) + len(doc.tables), + desc="RapidOCRDocLoader block index: 0", + ) + for i, block in enumerate(iter_block_items(doc)): + b_unit.set_description("RapidOCRDocLoader block index: {}".format(i)) + b_unit.refresh() + if isinstance(block, Paragraph): + resp += block.text.strip() + "\n" + images = block._element.xpath(".//pic:pic") # 获取所有图片 + for image in images: + for img_id in image.xpath(".//a:blip/@r:embed"): # 获取图片id + part = doc.part.related_parts[ + img_id + ] # 根据图片id获取对应的图片 + if isinstance(part, ImagePart): + image = Image.open(BytesIO(part._blob)) + result, _ = ocr(np.array(image)) + if result: + ocr_result = [line[1] for line in result] + resp += "\n".join(ocr_result) + elif isinstance(block, Table): + for row in block.rows: + for cell in row.cells: + for paragraph in cell.paragraphs: + resp += paragraph.text.strip() + "\n" + b_unit.update(1) + return resp + + text = doc2text(self.file_path) + from unstructured.partition.text import partition_text + + return partition_text(text=text, **self.unstructured_kwargs) + + +if __name__ == "__main__": + loader = RapidOCRDocLoader(file_path="../tests/samples/ocr_test.docx") + docs = loader.load() + print(docs) diff --git a/src/mindpilot/app/knowledge_base/file_rag/document_loaders/myimgloader.py b/src/mindpilot/app/knowledge_base/file_rag/document_loaders/myimgloader.py new file mode 100644 index 0000000..f11b6c5 --- /dev/null +++ b/src/mindpilot/app/knowledge_base/file_rag/document_loaders/myimgloader.py @@ -0,0 +1,28 @@ +from typing import List + +from langchain_community.document_loaders.unstructured import UnstructuredFileLoader + +from chatchat.server.file_rag.document_loaders.ocr import get_ocr + + +class RapidOCRLoader(UnstructuredFileLoader): + def _get_elements(self) -> List: + def img2text(filepath): + resp = "" + ocr = get_ocr() + result, _ = ocr(filepath) + if result: + ocr_result = [line[1] for line in result] + resp += "\n".join(ocr_result) + return resp + + text = img2text(self.file_path) + from unstructured.partition.text import partition_text + + return partition_text(text=text, **self.unstructured_kwargs) + + +if __name__ == "__main__": + loader = RapidOCRLoader(file_path="../tests/samples/ocr_test.jpg") + docs = loader.load() + print(docs) diff --git a/src/mindpilot/app/knowledge_base/file_rag/document_loaders/mypdfloader.py b/src/mindpilot/app/knowledge_base/file_rag/document_loaders/mypdfloader.py new file mode 100644 index 0000000..aa981a2 --- /dev/null +++ b/src/mindpilot/app/knowledge_base/file_rag/document_loaders/mypdfloader.py @@ -0,0 +1,102 @@ +from typing import List + +import cv2 +import numpy as np +import tqdm +from langchain_community.document_loaders.unstructured import UnstructuredFileLoader +from PIL import Image + +from chatchat.configs import PDF_OCR_THRESHOLD +from chatchat.server.file_rag.document_loaders.ocr import get_ocr + + +class RapidOCRPDFLoader(UnstructuredFileLoader): + def _get_elements(self) -> List: + def rotate_img(img, angle): + """ + img --image + angle --rotation angle + return--rotated img + """ + + h, w = img.shape[:2] + rotate_center = (w / 2, h / 2) + # 获取旋转矩阵 + # 参数1为旋转中心点; + # 参数2为旋转角度,正值-逆时针旋转;负值-顺时针旋转 + # 参数3为各向同性的比例因子,1.0原图,2.0变成原来的2倍,0.5变成原来的0.5倍 + M = cv2.getRotationMatrix2D(rotate_center, angle, 1.0) + # 计算图像新边界 + new_w = int(h * np.abs(M[0, 1]) + w * np.abs(M[0, 0])) + new_h = int(h * np.abs(M[0, 0]) + w * np.abs(M[0, 1])) + # 调整旋转矩阵以考虑平移 + M[0, 2] += (new_w - w) / 2 + M[1, 2] += (new_h - h) / 2 + + rotated_img = cv2.warpAffine(img, M, (new_w, new_h)) + return rotated_img + + def pdf2text(filepath): + import fitz # pyMuPDF里面的fitz包,不要与pip install fitz混淆 + import numpy as np + + ocr = get_ocr() + doc = fitz.open(filepath) + resp = "" + + b_unit = tqdm.tqdm( + total=doc.page_count, desc="RapidOCRPDFLoader context page index: 0" + ) + for i, page in enumerate(doc): + b_unit.set_description( + "RapidOCRPDFLoader context page index: {}".format(i) + ) + b_unit.refresh() + text = page.get_text("") + resp += text + "\n" + + img_list = page.get_image_info(xrefs=True) + for img in img_list: + if xref := img.get("xref"): + bbox = img["bbox"] + # 检查图片尺寸是否超过设定的阈值 + if (bbox[2] - bbox[0]) / (page.rect.width) < PDF_OCR_THRESHOLD[ + 0 + ] or (bbox[3] - bbox[1]) / ( + page.rect.height + ) < PDF_OCR_THRESHOLD[1]: + continue + pix = fitz.Pixmap(doc, xref) + samples = pix.samples + if int(page.rotation) != 0: # 如果Page有旋转角度,则旋转图片 + img_array = np.frombuffer( + pix.samples, dtype=np.uint8 + ).reshape(pix.height, pix.width, -1) + tmp_img = Image.fromarray(img_array) + ori_img = cv2.cvtColor(np.array(tmp_img), cv2.COLOR_RGB2BGR) + rot_img = rotate_img(img=ori_img, angle=360 - page.rotation) + img_array = cv2.cvtColor(rot_img, cv2.COLOR_RGB2BGR) + else: + img_array = np.frombuffer( + pix.samples, dtype=np.uint8 + ).reshape(pix.height, pix.width, -1) + + result, _ = ocr(img_array) + if result: + ocr_result = [line[1] for line in result] + resp += "\n".join(ocr_result) + + # 更新进度 + b_unit.update(1) + return resp + + text = pdf2text(self.file_path) + from unstructured.partition.text import partition_text + + return partition_text(text=text, **self.unstructured_kwargs) + + +if __name__ == "__main__": + loader = RapidOCRPDFLoader(file_path="/Users/tonysong/Desktop/test.pdf") + docs = loader.load() + print(docs) diff --git a/src/mindpilot/app/knowledge_base/file_rag/document_loaders/mypptloader.py b/src/mindpilot/app/knowledge_base/file_rag/document_loaders/mypptloader.py new file mode 100644 index 0000000..7b00df0 --- /dev/null +++ b/src/mindpilot/app/knowledge_base/file_rag/document_loaders/mypptloader.py @@ -0,0 +1,66 @@ +from typing import List + +import tqdm +from langchain_community.document_loaders.unstructured import UnstructuredFileLoader + + +class RapidOCRPPTLoader(UnstructuredFileLoader): + def _get_elements(self) -> List: + def ppt2text(filepath): + from io import BytesIO + + import numpy as np + from PIL import Image + from pptx import Presentation + from rapidocr_onnxruntime import RapidOCR + + ocr = RapidOCR() + prs = Presentation(filepath) + resp = "" + + def extract_text(shape): + nonlocal resp + if shape.has_text_frame: + resp += shape.text.strip() + "\n" + if shape.has_table: + for row in shape.table.rows: + for cell in row.cells: + for paragraph in cell.text_frame.paragraphs: + resp += paragraph.text.strip() + "\n" + if shape.shape_type == 13: # 13 表示图片 + image = Image.open(BytesIO(shape.image.blob)) + result, _ = ocr(np.array(image)) + if result: + ocr_result = [line[1] for line in result] + resp += "\n".join(ocr_result) + elif shape.shape_type == 6: # 6 表示组合 + for child_shape in shape.shapes: + extract_text(child_shape) + + b_unit = tqdm.tqdm( + total=len(prs.slides), desc="RapidOCRPPTLoader slide index: 1" + ) + # 遍历所有幻灯片 + for slide_number, slide in enumerate(prs.slides, start=1): + b_unit.set_description( + "RapidOCRPPTLoader slide index: {}".format(slide_number) + ) + b_unit.refresh() + sorted_shapes = sorted( + slide.shapes, key=lambda x: (x.top, x.left) + ) # 从上到下、从左到右遍历 + for shape in sorted_shapes: + extract_text(shape) + b_unit.update(1) + return resp + + text = ppt2text(self.file_path) + from unstructured.partition.text import partition_text + + return partition_text(text=text, **self.unstructured_kwargs) + + +if __name__ == "__main__": + loader = RapidOCRPPTLoader(file_path="../tests/samples/ocr_test.pptx") + docs = loader.load() + print(docs) diff --git a/src/mindpilot/app/knowledge_base/file_rag/document_loaders/ocr.py b/src/mindpilot/app/knowledge_base/file_rag/document_loaders/ocr.py new file mode 100644 index 0000000..4916028 --- /dev/null +++ b/src/mindpilot/app/knowledge_base/file_rag/document_loaders/ocr.py @@ -0,0 +1,21 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + try: + from rapidocr_paddle import RapidOCR + except ImportError: + from rapidocr_onnxruntime import RapidOCR + + +def get_ocr(use_cuda: bool = True) -> "RapidOCR": + try: + from rapidocr_paddle import RapidOCR + + ocr = RapidOCR( + det_use_cuda=use_cuda, cls_use_cuda=use_cuda, rec_use_cuda=use_cuda + ) + except ImportError: + from rapidocr_onnxruntime import RapidOCR + + ocr = RapidOCR() + return ocr diff --git a/src/mindpilot/app/knowledge_base/file_rag/retrievers/__init__.py b/src/mindpilot/app/knowledge_base/file_rag/retrievers/__init__.py new file mode 100644 index 0000000..e23ff11 --- /dev/null +++ b/src/mindpilot/app/knowledge_base/file_rag/retrievers/__init__.py @@ -0,0 +1,3 @@ +from ..retrievers.base import BaseRetrieverService +from ..retrievers.ensemble import EnsembleRetrieverService +from ..retrievers.vectorstore import VectorstoreRetrieverService diff --git a/src/mindpilot/app/knowledge_base/file_rag/retrievers/base.py b/src/mindpilot/app/knowledge_base/file_rag/retrievers/base.py new file mode 100644 index 0000000..6cda595 --- /dev/null +++ b/src/mindpilot/app/knowledge_base/file_rag/retrievers/base.py @@ -0,0 +1,24 @@ +from abc import ABCMeta, abstractmethod + +from langchain.vectorstores import VectorStore + + +class BaseRetrieverService(metaclass=ABCMeta): + def __init__(self, **kwargs): + self.do_init(**kwargs) + + @abstractmethod + def do_init(self, **kwargs): + pass + + @abstractmethod + def from_vectorstore( + vectorstore: VectorStore, + top_k: int, + score_threshold: int or float, + ): + pass + + @abstractmethod + def get_relevant_documents(self, query: str): + pass diff --git a/src/mindpilot/app/knowledge_base/file_rag/retrievers/ensemble.py b/src/mindpilot/app/knowledge_base/file_rag/retrievers/ensemble.py new file mode 100644 index 0000000..00d2ae4 --- /dev/null +++ b/src/mindpilot/app/knowledge_base/file_rag/retrievers/ensemble.py @@ -0,0 +1,46 @@ +from langchain.retrievers import EnsembleRetriever +from langchain.vectorstores import VectorStore +from langchain_community.retrievers import BM25Retriever +from langchain_core.retrievers import BaseRetriever + +from .base import BaseRetrieverService + + +class EnsembleRetrieverService(BaseRetrieverService): + def do_init( + self, + retriever: BaseRetriever = None, + top_k: int = 5, + ): + self.vs = None + self.top_k = top_k + self.retriever = retriever + + @staticmethod + def from_vectorstore( + vectorstore: VectorStore, + top_k: int, + score_threshold: int or float, + ): + faiss_retriever = vectorstore.as_retriever( + search_type="similarity_score_threshold", + search_kwargs={"score_threshold": score_threshold, "k": top_k}, + ) + # TODO: 换个不用torch的实现方式 + # from cutword.cutword import Cutter + import jieba + + # cutter = Cutter() + docs = list(vectorstore.docstore._dict.values()) + bm25_retriever = BM25Retriever.from_documents( + docs, + preprocess_func=jieba.lcut_for_search, + ) + bm25_retriever.k = top_k + ensemble_retriever = EnsembleRetriever( + retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5] + ) + return EnsembleRetrieverService(retriever=ensemble_retriever) + + def get_relevant_documents(self, query: str): + return self.retriever.get_relevant_documents(query)[: self.top_k] diff --git a/src/mindpilot/app/knowledge_base/file_rag/retrievers/vectorstore.py b/src/mindpilot/app/knowledge_base/file_rag/retrievers/vectorstore.py new file mode 100644 index 0000000..c971ee3 --- /dev/null +++ b/src/mindpilot/app/knowledge_base/file_rag/retrievers/vectorstore.py @@ -0,0 +1,30 @@ +from langchain.vectorstores import VectorStore +from langchain_core.retrievers import BaseRetriever + +from .base import BaseRetrieverService + + +class VectorstoreRetrieverService(BaseRetrieverService): + def do_init( + self, + retriever: BaseRetriever = None, + top_k: int = 5, + ): + self.vs = None + self.top_k = top_k + self.retriever = retriever + + @staticmethod + def from_vectorstore( + vectorstore: VectorStore, + top_k: int, + score_threshold: int or float, + ): + retriever = vectorstore.as_retriever( + search_type="similarity_score_threshold", + search_kwargs={"score_threshold": score_threshold, "k": top_k}, + ) + return VectorstoreRetrieverService(retriever=retriever) + + def get_relevant_documents(self, query: str): + return self.retriever.get_relevant_documents(query)[: self.top_k] diff --git a/src/mindpilot/app/knowledge_base/file_rag/text_splitter/__init__.py b/src/mindpilot/app/knowledge_base/file_rag/text_splitter/__init__.py new file mode 100644 index 0000000..c0e418a --- /dev/null +++ b/src/mindpilot/app/knowledge_base/file_rag/text_splitter/__init__.py @@ -0,0 +1,4 @@ +from .ali_text_splitter import AliTextSplitter +from .chinese_recursive_text_splitter import ChineseRecursiveTextSplitter +from .chinese_text_splitter import ChineseTextSplitter +from .zh_title_enhance import zh_title_enhance diff --git a/src/mindpilot/app/knowledge_base/file_rag/text_splitter/ali_text_splitter.py b/src/mindpilot/app/knowledge_base/file_rag/text_splitter/ali_text_splitter.py new file mode 100644 index 0000000..9def31a --- /dev/null +++ b/src/mindpilot/app/knowledge_base/file_rag/text_splitter/ali_text_splitter.py @@ -0,0 +1,35 @@ +import re +from typing import List + +from langchain.text_splitter import CharacterTextSplitter + + +class AliTextSplitter(CharacterTextSplitter): + def __init__(self, pdf: bool = False, **kwargs): + super().__init__(**kwargs) + self.pdf = pdf + + def split_text(self, text: str) -> List[str]: + # use_document_segmentation参数指定是否用语义切分文档,此处采取的文档语义分割模型为达摩院开源的nlp_bert_document-segmentation_chinese-base,论文见https://arxiv.org/abs/2107.09278 + # 如果使用模型进行文档语义切分,那么需要安装modelscope[nlp]:pip install "modelscope[nlp]" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html + # 考虑到使用了三个模型,可能对于低配置gpu不太友好,因此这里将模型load进cpu计算,有需要的话可以替换device为自己的显卡id + if self.pdf: + text = re.sub(r"\n{3,}", r"\n", text) + text = re.sub("\s", " ", text) + text = re.sub("\n\n", "", text) + try: + from modelscope.pipelines import pipeline + except ImportError: + raise ImportError( + "Could not import modelscope python package. " + "Please install modelscope with `pip install modelscope`. " + ) + + p = pipeline( + task="document-segmentation", + model="damo/nlp_bert_document-segmentation_chinese-base", + device="cpu", + ) + result = p(documents=text) + sent_list = [i for i in result["text"].split("\n\t") if i] + return sent_list diff --git a/src/mindpilot/app/knowledge_base/file_rag/text_splitter/chinese_recursive_text_splitter.py b/src/mindpilot/app/knowledge_base/file_rag/text_splitter/chinese_recursive_text_splitter.py new file mode 100644 index 0000000..cedbb3b --- /dev/null +++ b/src/mindpilot/app/knowledge_base/file_rag/text_splitter/chinese_recursive_text_splitter.py @@ -0,0 +1,106 @@ +import logging +import re +from typing import Any, List, Optional + +from langchain.text_splitter import RecursiveCharacterTextSplitter + +logger = logging.getLogger(__name__) + + +def _split_text_with_regex_from_end( + text: str, separator: str, keep_separator: bool +) -> List[str]: + # Now that we have the separator, split the text + if separator: + if keep_separator: + # The parentheses in the pattern keep the delimiters in the result. + _splits = re.split(f"({separator})", text) + splits = ["".join(i) for i in zip(_splits[0::2], _splits[1::2])] + if len(_splits) % 2 == 1: + splits += _splits[-1:] + # splits = [_splits[0]] + splits + else: + splits = re.split(separator, text) + else: + splits = list(text) + return [s for s in splits if s != ""] + + +class ChineseRecursiveTextSplitter(RecursiveCharacterTextSplitter): + def __init__( + self, + separators: Optional[List[str]] = None, + keep_separator: bool = True, + is_separator_regex: bool = True, + **kwargs: Any, + ) -> None: + """Create a new TextSplitter.""" + super().__init__(keep_separator=keep_separator, **kwargs) + self._separators = separators or [ + "\n\n", + "\n", + "。|!|?", + "\.\s|\!\s|\?\s", + ";|;\s", + ",|,\s", + ] + self._is_separator_regex = is_separator_regex + + def _split_text(self, text: str, separators: List[str]) -> List[str]: + """Split incoming text and return chunks.""" + final_chunks = [] + # Get appropriate separator to use + separator = separators[-1] + new_separators = [] + for i, _s in enumerate(separators): + _separator = _s if self._is_separator_regex else re.escape(_s) + if _s == "": + separator = _s + break + if re.search(_separator, text): + separator = _s + new_separators = separators[i + 1 :] + break + + _separator = separator if self._is_separator_regex else re.escape(separator) + splits = _split_text_with_regex_from_end(text, _separator, self._keep_separator) + + # Now go merging things, recursively splitting longer texts. + _good_splits = [] + _separator = "" if self._keep_separator else separator + for s in splits: + if self._length_function(s) < self._chunk_size: + _good_splits.append(s) + else: + if _good_splits: + merged_text = self._merge_splits(_good_splits, _separator) + final_chunks.extend(merged_text) + _good_splits = [] + if not new_separators: + final_chunks.append(s) + else: + other_info = self._split_text(s, new_separators) + final_chunks.extend(other_info) + if _good_splits: + merged_text = self._merge_splits(_good_splits, _separator) + final_chunks.extend(merged_text) + return [ + re.sub(r"\n{2,}", "\n", chunk.strip()) + for chunk in final_chunks + if chunk.strip() != "" + ] + + +if __name__ == "__main__": + text_splitter = ChineseRecursiveTextSplitter( + keep_separator=True, is_separator_regex=True, chunk_size=50, chunk_overlap=0 + ) + ls = [ + """中国对外贸易形势报告(75页)。前 10 个月,一般贸易进出口 19.5 万亿元,增长 25.1%, 比整体进出口增速高出 2.9 个百分点,占进出口总额的 61.7%,较去年同期提升 1.6 个百分点。其中,一般贸易出口 10.6 万亿元,增长 25.3%,占出口总额的 60.9%,提升 1.5 个百分点;进口8.9万亿元,增长24.9%,占进口总额的62.7%, 提升 1.8 个百分点。加工贸易进出口 6.8 万亿元,增长 11.8%, 占进出口总额的 21.5%,减少 2.0 个百分点。其中,出口增 长 10.4%,占出口总额的 24.3%,减少 2.6 个百分点;进口增 长 14.2%,占进口总额的 18.0%,减少 1.2 个百分点。此外, 以保税物流方式进出口 3.96 万亿元,增长 27.9%。其中,出 口 1.47 万亿元,增长 38.9%;进口 2.49 万亿元,增长 22.2%。前三季度,中国服务贸易继续保持快速增长态势。服务 进出口总额 37834.3 亿元,增长 11.6%;其中服务出口 17820.9 亿元,增长 27.3%;进口 20013.4 亿元,增长 0.5%,进口增 速实现了疫情以来的首次转正。服务出口增幅大于进口 26.8 个百分点,带动服务贸易逆差下降 62.9%至 2192.5 亿元。服 务贸易结构持续优化,知识密集型服务进出口 16917.7 亿元, 增长 13.3%,占服务进出口总额的比重达到 44.7%,提升 0.7 个百分点。 二、中国对外贸易发展环境分析和展望 全球疫情起伏反复,经济复苏分化加剧,大宗商品价格 上涨、能源紧缺、运力紧张及发达经济体政策调整外溢等风 险交织叠加。同时也要看到,我国经济长期向好的趋势没有 改变,外贸企业韧性和活力不断增强,新业态新模式加快发 展,创新转型步伐提速。产业链供应链面临挑战。美欧等加快出台制造业回迁计 划,加速产业链供应链本土布局,跨国公司调整产业链供应 链,全球双链面临新一轮重构,区域化、近岸化、本土化、 短链化趋势凸显。疫苗供应不足,制造业“缺芯”、物流受限、 运价高企,全球产业链供应链面临压力。 全球通胀持续高位运行。能源价格上涨加大主要经济体 的通胀压力,增加全球经济复苏的不确定性。世界银行今年 10 月发布《大宗商品市场展望》指出,能源价格在 2021 年 大涨逾 80%,并且仍将在 2022 年小幅上涨。IMF 指出,全 球通胀上行风险加剧,通胀前景存在巨大不确定性。""", + ] + # text = """""" + for inum, text in enumerate(ls): + print(inum) + chunks = text_splitter.split_text(text) + for chunk in chunks: + print(chunk) diff --git a/src/mindpilot/app/knowledge_base/file_rag/text_splitter/chinese_text_splitter.py b/src/mindpilot/app/knowledge_base/file_rag/text_splitter/chinese_text_splitter.py new file mode 100644 index 0000000..9d4e2e2 --- /dev/null +++ b/src/mindpilot/app/knowledge_base/file_rag/text_splitter/chinese_text_splitter.py @@ -0,0 +1,77 @@ +import re +from typing import List + +from langchain.text_splitter import CharacterTextSplitter + + +class ChineseTextSplitter(CharacterTextSplitter): + def __init__(self, pdf: bool = False, sentence_size: int = 250, **kwargs): + super().__init__(**kwargs) + self.pdf = pdf + self.sentence_size = sentence_size + + def split_text1(self, text: str) -> List[str]: + if self.pdf: + text = re.sub(r"\n{3,}", "\n", text) + text = re.sub("\s", " ", text) + text = text.replace("\n\n", "") + sent_sep_pattern = re.compile( + '([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))' + ) # del :; + sent_list = [] + for ele in sent_sep_pattern.split(text): + if sent_sep_pattern.match(ele) and sent_list: + sent_list[-1] += ele + elif ele: + sent_list.append(ele) + return sent_list + + def split_text(self, text: str) -> List[str]: ##此处需要进一步优化逻辑 + if self.pdf: + text = re.sub(r"\n{3,}", r"\n", text) + text = re.sub("\s", " ", text) + text = re.sub("\n\n", "", text) + + text = re.sub(r"([;;.!?。!?\?])([^”’])", r"\1\n\2", text) # 单字符断句符 + text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text) # 英文省略号 + text = re.sub(r'(\…{2})([^"’”」』])', r"\1\n\2", text) # 中文省略号 + text = re.sub( + r'([;;!?。!?\?]["’”」』]{0,2})([^;;!?,。!?\?])', r"\1\n\2", text + ) + # 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后,注意前面的几句都小心保留了双引号 + text = text.rstrip() # 段尾如果有多余的\n就去掉它 + # 很多规则中会考虑分号;,但是这里我把它忽略不计,破折号、英文双引号等同样忽略,需要的再做些简单调整即可。 + ls = [i for i in text.split("\n") if i] + for ele in ls: + if len(ele) > self.sentence_size: + ele1 = re.sub(r'([,,.]["’”」』]{0,2})([^,,.])', r"\1\n\2", ele) + ele1_ls = ele1.split("\n") + for ele_ele1 in ele1_ls: + if len(ele_ele1) > self.sentence_size: + ele_ele2 = re.sub( + r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', + r"\1\n\2", + ele_ele1, + ) + ele2_ls = ele_ele2.split("\n") + for ele_ele2 in ele2_ls: + if len(ele_ele2) > self.sentence_size: + ele_ele3 = re.sub( + '( ["’”」』]{0,2})([^ ])', r"\1\n\2", ele_ele2 + ) + ele2_id = ele2_ls.index(ele_ele2) + ele2_ls = ( + ele2_ls[:ele2_id] + + [i for i in ele_ele3.split("\n") if i] + + ele2_ls[ele2_id + 1 :] + ) + ele_id = ele1_ls.index(ele_ele1) + ele1_ls = ( + ele1_ls[:ele_id] + + [i for i in ele2_ls if i] + + ele1_ls[ele_id + 1 :] + ) + + id = ls.index(ele) + ls = ls[:id] + [i for i in ele1_ls if i] + ls[id + 1 :] + return ls diff --git a/src/mindpilot/app/knowledge_base/file_rag/text_splitter/zh_title_enhance.py b/src/mindpilot/app/knowledge_base/file_rag/text_splitter/zh_title_enhance.py new file mode 100644 index 0000000..793e0ba --- /dev/null +++ b/src/mindpilot/app/knowledge_base/file_rag/text_splitter/zh_title_enhance.py @@ -0,0 +1,100 @@ +import re + +from langchain.docstore.document import Document + + +def under_non_alpha_ratio(text: str, threshold: float = 0.5): + """Checks if the proportion of non-alpha characters in the text snippet exceeds a given + threshold. This helps prevent text like "-----------BREAK---------" from being tagged + as a title or narrative text. The ratio does not count spaces. + + Parameters + ---------- + text + The input string to test + threshold + If the proportion of non-alpha characters exceeds this threshold, the function + returns False + """ + if len(text) == 0: + return False + + alpha_count = len([char for char in text if char.strip() and char.isalpha()]) + total_count = len([char for char in text if char.strip()]) + try: + ratio = alpha_count / total_count + return ratio < threshold + except: + return False + + +def is_possible_title( + text: str, + title_max_word_length: int = 20, + non_alpha_threshold: float = 0.5, +) -> bool: + """Checks to see if the text passes all of the checks for a valid title. + + Parameters + ---------- + text + The input text to check + title_max_word_length + The maximum number of words a title can contain + non_alpha_threshold + The minimum number of alpha characters the text needs to be considered a title + """ + + # 文本长度为0的话,肯定不是title + if len(text) == 0: + print("Not a title. Text is empty.") + return False + + # 文本中有标点符号,就不是title + ENDS_IN_PUNCT_PATTERN = r"[^\w\s]\Z" + ENDS_IN_PUNCT_RE = re.compile(ENDS_IN_PUNCT_PATTERN) + if ENDS_IN_PUNCT_RE.search(text) is not None: + return False + + # 文本长度不能超过设定值,默认20 + # NOTE(robinson) - splitting on spaces here instead of word tokenizing because it + # is less expensive and actual tokenization doesn't add much value for the length check + if len(text) > title_max_word_length: + return False + + # 文本中数字的占比不能太高,否则不是title + if under_non_alpha_ratio(text, threshold=non_alpha_threshold): + return False + + # NOTE(robinson) - Prevent flagging salutations like "To My Dearest Friends," as titles + if text.endswith((",", ".", ",", "。")): + return False + + if text.isnumeric(): + print(f"Not a title. Text is all numeric:\n\n{text}") # type: ignore + return False + + # 开头的字符内应该有数字,默认5个字符内 + if len(text) < 5: + text_5 = text + else: + text_5 = text[:5] + alpha_in_text_5 = sum(list(map(lambda x: x.isnumeric(), list(text_5)))) + if not alpha_in_text_5: + return False + + return True + + +def zh_title_enhance(docs: Document) -> Document: + title = None + if len(docs) > 0: + for doc in docs: + if is_possible_title(doc.page_content): + doc.metadata["category"] = "cn_Title" + title = doc.page_content + elif title: + doc.page_content = f"下文与({title})有关。{doc.page_content}" + return docs + else: + print("文件不存在") diff --git a/src/mindpilot/app/knowledge_base/file_rag/utils.py b/src/mindpilot/app/knowledge_base/file_rag/utils.py new file mode 100644 index 0000000..9079f22 --- /dev/null +++ b/src/mindpilot/app/knowledge_base/file_rag/utils.py @@ -0,0 +1,14 @@ +from .retrievers import ( + BaseRetrieverService, + EnsembleRetrieverService, + VectorstoreRetrieverService, +) + +Retrivals = { + "vectorstore": VectorstoreRetrieverService, + "ensemble": EnsembleRetrieverService, +} + + +def get_Retriever(type: str = "vectorstore") -> BaseRetrieverService: + return Retrivals[type] diff --git a/src/mindpilot/app/knowledge_base/kb_api.py b/src/mindpilot/app/knowledge_base/kb_api.py new file mode 100644 index 0000000..5b9ad38 --- /dev/null +++ b/src/mindpilot/app/knowledge_base/kb_api.py @@ -0,0 +1,72 @@ +import urllib + +from fastapi import Body + +from .db.repository.knowledge_base_repository import list_kbs_from_db +from .kb_service.base import KBServiceFactory +from .utils import validate_kb_name +from ..utils.system_utils import BaseResponse, ListResponse, logger + + +def list_kbs(): + # Get List of Knowledge Base + return ListResponse(data=list_kbs_from_db()) + + +def create_kb( + knowledge_base_name: str = Body(..., examples=["samples"]), + vector_store_type: str = Body("faiss"), + kb_info: str = Body("", description="知识库内容简介,用于Agent选择知识库。"), + embed_model: str = Body(..., examples=["bce-embedding-base_v1", "bge-large-zh-v1.5"]), +) -> BaseResponse: + # Create selected knowledge base + if not validate_kb_name(knowledge_base_name): + return BaseResponse(code=403, msg="Don't attack me") + if knowledge_base_name is None or knowledge_base_name.strip() == "": + return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称") + + kb = KBServiceFactory.get_service_by_name(knowledge_base_name) + if kb is not None: + return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}") + + kb = KBServiceFactory.get_service( + knowledge_base_name, vector_store_type, embed_model, kb_info=kb_info + ) + try: + kb.create_kb() + except Exception as e: + msg = f"创建知识库出错: {e}" + logger.error( + f"{e.__class__.__name__}: {msg}", exc_info=e + ) + return BaseResponse(code=500, msg=msg) + + return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}") + + +def delete_kb( + knowledge_base_name: str = Body(..., examples=["samples"]), +) -> BaseResponse: + # Delete selected knowledge base + if not validate_kb_name(knowledge_base_name): + return BaseResponse(code=403, msg="Don't attack me") + knowledge_base_name = urllib.parse.unquote(knowledge_base_name) + + kb = KBServiceFactory.get_service_by_name(knowledge_base_name) + + if kb is None: + return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") + + try: + status = kb.clear_vs() + status = kb.drop_kb() + if status: + return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}") + except Exception as e: + msg = f"删除知识库时出现意外: {e}" + logger.error( + f"{e.__class__.__name__}: {msg}", exc_info=e + ) + return BaseResponse(code=500, msg=msg) + + return BaseResponse(code=500, msg=f"删除知识库失败 {knowledge_base_name}") diff --git a/src/mindpilot/app/knowledge_base/kb_cache/base.py b/src/mindpilot/app/knowledge_base/kb_cache/base.py new file mode 100644 index 0000000..5cd54b6 --- /dev/null +++ b/src/mindpilot/app/knowledge_base/kb_cache/base.py @@ -0,0 +1,94 @@ +import threading +from collections import OrderedDict +from contextlib import contextmanager +from typing import Any, List, Tuple, Union + +from langchain_community.vectorstores import FAISS + + +class ThreadSafeObject: + def __init__( + self, key: Union[str, Tuple], obj: Any = None, pool: "CachePool" = None + ): + self._obj = obj + self._key = key + self._pool = pool + self._lock = threading.RLock() + self._loaded = threading.Event() + + def __repr__(self) -> str: + cls = type(self).__name__ + return f"<{cls}: key: {self.key}, obj: {self._obj}>" + + @property + def key(self): + return self._key + + @contextmanager + def acquire(self, owner: str = "", msg: str = "") -> FAISS: + owner = owner or f"thread {threading.get_native_id()}" + try: + self._lock.acquire() + if self._pool is not None: + self._pool._cache.move_to_end(self.key) + yield self._obj + finally: + self._lock.release() + + def start_loading(self): + self._loaded.clear() + + def finish_loading(self): + self._loaded.set() + + def wait_for_loading(self): + self._loaded.wait() + + @property + def obj(self): + return self._obj + + @obj.setter + def obj(self, val: Any): + self._obj = val + + +class CachePool: + def __init__(self, cache_num: int = -1): + self._cache_num = cache_num + self._cache = OrderedDict() + self.atomic = threading.RLock() + + def keys(self) -> List[str]: + return list(self._cache.keys()) + + def _check_count(self): + if isinstance(self._cache_num, int) and self._cache_num > 0: + while len(self._cache) > self._cache_num: + self._cache.popitem(last=False) + + def get(self, key: str) -> ThreadSafeObject: + if cache := self._cache.get(key): + cache.wait_for_loading() + return cache + + def set(self, key: str, obj: ThreadSafeObject) -> ThreadSafeObject: + self._cache[key] = obj + self._check_count() + return obj + + def pop(self, key: str = None) -> ThreadSafeObject: + if key is None: + return self._cache.popitem(last=False) + else: + return self._cache.pop(key, None) + + def acquire(self, key: Union[str, Tuple], owner: str = "", msg: str = ""): + cache = self.get(key) + if cache is None: + raise RuntimeError(f"请求的资源 {key} 不存在") + elif isinstance(cache, ThreadSafeObject): + self._cache.move_to_end(key) + return cache.acquire(owner=owner, msg=msg) + else: + return cache diff --git a/src/mindpilot/app/knowledge_base/kb_cache/faiss_cache.py b/src/mindpilot/app/knowledge_base/kb_cache/faiss_cache.py new file mode 100644 index 0000000..75ae859 --- /dev/null +++ b/src/mindpilot/app/knowledge_base/kb_cache/faiss_cache.py @@ -0,0 +1,211 @@ +import logging +import os + +from langchain_community.docstore.in_memory import InMemoryDocstore +from langchain.schema import Document + +from ...configs import CACHED_MEMO_VS_NUM, CACHED_VS_NUM, DEFAULT_EMBEDDING_MODEL +from .base import * +from ..utils import get_vs_path +from ...utils.system_utils import get_Embeddings + +logger = logging.getLogger() + +# patch FAISS to include doc id in Document.metadata +def _new_ds_search(self, search: str) -> Union[str, Document]: + if search not in self._dict: + return f"ID {search} not found." + else: + doc = self._dict[search] + if isinstance(doc, Document): + doc.metadata["id"] = search + return doc + + +InMemoryDocstore.search = _new_ds_search + + +class ThreadSafeFaiss(ThreadSafeObject): + def __repr__(self) -> str: + cls = type(self).__name__ + return f"<{cls}: key: {self.key}, obj: {self._obj}, docs_count: {self.docs_count()}>" + + def docs_count(self) -> int: + return len(self._obj.docstore._dict) + + def save(self, path: str, create_path: bool = True): + with self.acquire(): + if not os.path.isdir(path) and create_path: + os.makedirs(path) + ret = self._obj.save_local(path) + logger.info(f"已将向量库 {self.key} 保存到磁盘") + return ret + + def clear(self): + ret = [] + with self.acquire(): + ids = list(self._obj.docstore._dict.keys()) + if ids: + ret = self._obj.delete(ids) + assert len(self._obj.docstore._dict) == 0 + logger.info(f"已将向量库 {self.key} 清空") + return ret + + +class _FaissPool(CachePool): + def new_vector_store( + self, + kb_name: str, + embed_model: str = DEFAULT_EMBEDDING_MODEL, + ) -> FAISS: + # create an empty vector store + embeddings = get_Embeddings(embed_model=embed_model) + doc = Document(page_content="init", metadata={}) + vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True) + ids = list(vector_store.docstore._dict.keys()) + vector_store.delete(ids) + return vector_store + + def new_temp_vector_store( + self, + embed_model: str = DEFAULT_EMBEDDING_MODEL, + ) -> FAISS: + # create an empty vector store + embeddings = get_Embeddings(embed_model=embed_model) + doc = Document(page_content="init", metadata={}) + vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True) + ids = list(vector_store.docstore._dict.keys()) + vector_store.delete(ids) + return vector_store + + def save_vector_store(self, kb_name: str, path: str = None): + if cache := self.get(kb_name): + return cache.save(path) + + def unload_vector_store(self, kb_name: str): + if cache := self.get(kb_name): + self.pop(kb_name) + logger.info(f"成功释放向量库:{kb_name}") + + +class KBFaissPool(_FaissPool): + def load_vector_store( + self, + kb_name: str, + vector_name: str = None, + create: bool = True, + embed_model: str = DEFAULT_EMBEDDING_MODEL, + ) -> ThreadSafeFaiss: + self.atomic.acquire() + locked = True + vector_name = vector_name or embed_model + cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些 + try: + if cache is None: + item = ThreadSafeFaiss((kb_name, vector_name), pool=self) + self.set((kb_name, vector_name), item) + with item.acquire(msg="初始化"): + self.atomic.release() + locked = False + logger.info( + f"loading vector store in '{kb_name}/vector_store/{vector_name}' from disk." + ) + vs_path = get_vs_path(kb_name, vector_name) + + if os.path.isfile(os.path.join(vs_path, "index.faiss")): + embeddings = get_Embeddings(embed_model=embed_model) + vector_store = FAISS.load_local( + vs_path, + embeddings, + normalize_L2=True, + allow_dangerous_deserialization=True, + ) + elif create: + # create an empty vector store + if not os.path.exists(vs_path): + os.makedirs(vs_path) + vector_store = self.new_vector_store( + kb_name=kb_name, embed_model=embed_model + ) + vector_store.save_local(vs_path) + else: + raise RuntimeError(f"knowledge base {kb_name} not exist.") + item.obj = vector_store + item.finish_loading() + else: + self.atomic.release() + locked = False + except Exception as e: + if locked: # we don't know exception raised before or after atomic.release + self.atomic.release() + logger.error(e, exc_info=True) + raise RuntimeError(f"向量库 {kb_name} 加载失败。") + return self.get((kb_name, vector_name)) + + +class MemoFaissPool(_FaissPool): + r""" + 临时向量库的缓存池 + """ + + def load_vector_store( + self, + kb_name: str, + embed_model: str = DEFAULT_EMBEDDING_MODEL, + ) -> ThreadSafeFaiss: + self.atomic.acquire() + cache = self.get(kb_name) + if cache is None: + item = ThreadSafeFaiss(kb_name, pool=self) + self.set(kb_name, item) + with item.acquire(msg="初始化"): + self.atomic.release() + logger.info(f"loading vector store in '{kb_name}' to memory.") + # create an empty vector store + vector_store = self.new_temp_vector_store(embed_model=embed_model) + item.obj = vector_store + item.finish_loading() + else: + self.atomic.release() + return self.get(kb_name) + + +kb_faiss_pool = KBFaissPool(cache_num=CACHED_VS_NUM) +memo_faiss_pool = MemoFaissPool(cache_num=CACHED_MEMO_VS_NUM) +# +# +# if __name__ == "__main__": +# import time, random +# from pprint import pprint +# +# kb_names = ["vs1", "vs2", "vs3"] +# # for name in kb_names: +# # memo_faiss_pool.load_vector_store(name) +# +# def worker(vs_name: str, name: str): +# vs_name = "samples" +# time.sleep(random.randint(1, 5)) +# embeddings = load_local_embeddings() +# r = random.randint(1, 3) +# +# with kb_faiss_pool.load_vector_store(vs_name).acquire(name) as vs: +# if r == 1: # add docs +# ids = vs.add_texts([f"text added by {name}"], embeddings=embeddings) +# pprint(ids) +# elif r == 2: # search docs +# docs = vs.similarity_search_with_score(f"{name}", k=3, score_threshold=1.0) +# pprint(docs) +# if r == 3: # delete docs +# logger.warning(f"清除 {vs_name} by {name}") +# kb_faiss_pool.get(vs_name).clear() +# +# threads = [] +# for n in range(1, 30): +# t = threading.Thread(target=worker, +# kwargs={"vs_name": random.choice(kb_names), "name": f"worker {n}"}, +# daemon=True) +# t.start() +# threads.append(t) +# +# for t in threads: +# t.join() diff --git a/src/mindpilot/app/knowledge_base/kb_doc_api.py b/src/mindpilot/app/knowledge_base/kb_doc_api.py new file mode 100644 index 0000000..26338fd --- /dev/null +++ b/src/mindpilot/app/knowledge_base/kb_doc_api.py @@ -0,0 +1,451 @@ +import json +import os +import urllib +from typing import Dict, List + +from fastapi import Body, File, Form, Query, UploadFile +from fastapi.responses import FileResponse +from langchain.docstore.document import Document +from sse_starlette import EventSourceResponse + +from ..configs import ( + CHUNK_SIZE, + DEFAULT_VS_TYPE, + OVERLAP_SIZE, + SCORE_THRESHOLD, + VECTOR_SEARCH_TOP_K, + ZH_TITLE_ENHANCE, +) +from .db.repository.knowledge_file_repository import get_file_detail +from .kb_service.base import ( + KBServiceFactory, + get_kb_file_details, +) +from .model.kb_document_model import DocumentWithVSId +from .utils import ( + KnowledgeFile, + files2docs_in_thread, + get_file_path, + list_files_from_folder, + validate_kb_name, +) +from ..utils.system_utils import ( + BaseResponse, + ListResponse, + run_in_thread_pool, +) + + +def search_docs( + query: str = Body("", description="用户输入", examples=["你好"]), + knowledge_base_name: str = Body( + ..., description="知识库名称", examples=["samples"] + ), + top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), + score_threshold: float = Body( + SCORE_THRESHOLD, + description="知识库匹配相关度阈值,取值范围在0-1之间," + "SCORE越小,相关度越高," + "取到1相当于不筛选,建议设置在0.5左右", + ge=0.0, + le=1.0, + ), + file_name: str = Body("", description="文件名称,支持 sql 通配符"), + metadata: dict = Body({}, description="根据 metadata 进行过滤,仅支持一级键"), +) -> List[Dict]: + kb = KBServiceFactory.get_service_by_name(knowledge_base_name) + data = [] + if kb is not None: + if query: + docs = kb.search_docs(query, top_k, score_threshold) + # data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs] + data = [DocumentWithVSId(**x.dict(), id=x.metadata.get("id")) for x in docs] + elif file_name or metadata: + data = kb.list_docs(file_name=file_name, metadata=metadata) + for d in data: + if "vector" in d.metadata: + del d.metadata["vector"] + return [x.dict() for x in data] + + +def list_files(knowledge_base_name: str) -> ListResponse: + if not validate_kb_name(knowledge_base_name): + return ListResponse(code=403, msg="Don't attack me", data=[]) + + knowledge_base_name = urllib.parse.unquote(knowledge_base_name) + kb = KBServiceFactory.get_service_by_name(knowledge_base_name) + if kb is None: + return ListResponse( + code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[] + ) + else: + all_docs = get_kb_file_details(knowledge_base_name) + return ListResponse(data=all_docs) + + +def _save_files_in_thread( + files: List[UploadFile], knowledge_base_name: str, override: bool +): + """ + 通过多线程将上传的文件保存到对应知识库目录内。 + 生成器返回保存结果:{"code":200, "msg": "xxx", "data": {"knowledge_base_name":"xxx", "file_name": "xxx"}} + """ + + def save_file(file: UploadFile, knowledge_base_name: str, override: bool) -> dict: + """ + 保存单个文件。 + """ + try: + filename = file.filename + file_path = get_file_path( + knowledge_base_name=knowledge_base_name, doc_name=filename + ) + data = {"knowledge_base_name": knowledge_base_name, "file_name": filename} + + file_content = file.file.read() # 读取上传文件的内容 + if ( + os.path.isfile(file_path) + and not override + and os.path.getsize(file_path) == len(file_content) + ): + file_status = f"文件 {filename} 已存在。" + return dict(code=404, msg=file_status, data=data) + + if not os.path.isdir(os.path.dirname(file_path)): + os.makedirs(os.path.dirname(file_path)) + with open(file_path, "wb") as f: + f.write(file_content) + return dict(code=200, msg=f"成功上传文件 {filename}", data=data) + except Exception as e: + msg = f"{filename} 文件上传失败,报错信息为: {e}" + return dict(code=500, msg=msg, data=data) + + params = [ + {"file": file, "knowledge_base_name": knowledge_base_name, "override": override} + for file in files + ] + for result in run_in_thread_pool(save_file, params=params): + yield result + + +# def files2docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"), +# knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]), +# override: bool = Form(False, description="覆盖已有文件"), +# save: bool = Form(True, description="是否将文件保存到知识库目录")): +# def save_files(files, knowledge_base_name, override): +# for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override): +# yield json.dumps(result, ensure_ascii=False) + +# def files_to_docs(files): +# for result in files2docs_in_thread(files): +# yield json.dumps(result, ensure_ascii=False) + + +def upload_docs( + files: List[UploadFile] = File(..., description="上传文件,支持多文件"), + knowledge_base_name: str = Form( + ..., description="知识库名称", examples=["samples"] + ), + override: bool = Form(False, description="覆盖已有文件"), + to_vector_store: bool = Form(True, description="上传文件后是否进行向量化"), + chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"), + chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"), + zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"), + docs: str = Form("", description="自定义的docs,需要转为json字符串"), + not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"), +) -> BaseResponse: + """ + API接口:上传文件,并/或向量化 + """ + if not validate_kb_name(knowledge_base_name): + return BaseResponse(code=403, msg="Don't attack me") + + kb = KBServiceFactory.get_service_by_name(knowledge_base_name) + if kb is None: + return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") + + docs = json.loads(docs) if docs else {} + failed_files = {} + file_names = list(docs.keys()) + + # 先将上传的文件保存到磁盘 + for result in _save_files_in_thread( + files, knowledge_base_name=knowledge_base_name, override=override + ): + filename = result["data"]["file_name"] + if result["code"] != 200: + failed_files[filename] = result["msg"] + + if filename not in file_names: + file_names.append(filename) + + # 对保存的文件进行向量化 + if to_vector_store: + result = update_docs( + knowledge_base_name=knowledge_base_name, + file_names=file_names, + override_custom_docs=True, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + zh_title_enhance=zh_title_enhance, + docs=docs, + not_refresh_vs_cache=True, + ) + failed_files.update(result.data["failed_files"]) + if not not_refresh_vs_cache: + kb.save_vector_store() + + return BaseResponse( + code=200, msg="文件上传与向量化完成", data={"failed_files": failed_files} + ) + + +def delete_docs( + knowledge_base_name: str = Body(..., examples=["samples"]), + file_names: List[str] = Body(..., examples=[["file_name.md", "test.txt"]]), + delete_content: bool = Body(False), + not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"), +) -> BaseResponse: + if not validate_kb_name(knowledge_base_name): + return BaseResponse(code=403, msg="Don't attack me") + + knowledge_base_name = urllib.parse.unquote(knowledge_base_name) + kb = KBServiceFactory.get_service_by_name(knowledge_base_name) + if kb is None: + return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") + + failed_files = {} + for file_name in file_names: + if not kb.exist_doc(file_name): + failed_files[file_name] = f"未找到文件 {file_name}" + + try: + kb_file = KnowledgeFile( + filename=file_name, knowledge_base_name=knowledge_base_name + ) + kb.delete_doc(kb_file, delete_content, not_refresh_vs_cache=True) + except Exception as e: + msg = f"{file_name} 文件删除失败,错误信息:{e}" + failed_files[file_name] = msg + + if not not_refresh_vs_cache: + kb.save_vector_store() + + return BaseResponse( + code=200, msg=f"文件删除完成", data={"failed_files": failed_files} + ) + + +def update_info( + knowledge_base_name: str = Body( + ..., description="知识库名称", examples=["samples"] + ), + kb_info: str = Body(..., description="知识库介绍", examples=["这是一个知识库"]), +): + if not validate_kb_name(knowledge_base_name): + return BaseResponse(code=403, msg="Don't attack me") + + kb = KBServiceFactory.get_service_by_name(knowledge_base_name) + if kb is None: + return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") + kb.update_info(kb_info) + + return BaseResponse(code=200, msg=f"知识库介绍修改完成", data={"kb_info": kb_info}) + + +def update_docs( + knowledge_base_name: str = Body( + ..., description="知识库名称", examples=["samples"] + ), + file_names: List[str] = Body( + ..., description="文件名称,支持多文件", examples=[["file_name1", "text.txt"]] + ), + chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"), + chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"), + zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"), + override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"), + docs: str = Body("", description="自定义的docs,需要转为json字符串"), + not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"), +) -> BaseResponse: + """ + 更新知识库文档 + """ + if not validate_kb_name(knowledge_base_name): + return BaseResponse(code=403, msg="Don't attack me") + + kb = KBServiceFactory.get_service_by_name(knowledge_base_name) + if kb is None: + return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") + + failed_files = {} + kb_files = [] + docs = json.loads(docs) if docs else {} + + # 生成需要加载docs的文件列表 + for file_name in file_names: + file_detail = get_file_detail(kb_name=knowledge_base_name, filename=file_name) + # 如果该文件之前使用了自定义docs,则根据参数决定略过或覆盖 + if file_detail.get("custom_docs") and not override_custom_docs: + continue + if file_name not in docs: + try: + kb_files.append( + KnowledgeFile( + filename=file_name, knowledge_base_name=knowledge_base_name + ) + ) + except Exception as e: + msg = f"加载文档 {file_name} 时出错:{e}" + failed_files[file_name] = msg + + # 从文件生成docs,并进行向量化。 + # 这里利用了KnowledgeFile的缓存功能,在多线程中加载Document,然后传给KnowledgeFile + for status, result in files2docs_in_thread( + kb_files, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + zh_title_enhance=zh_title_enhance, + ): + if status: + kb_name, file_name, new_docs = result + kb_file = KnowledgeFile( + filename=file_name, knowledge_base_name=knowledge_base_name + ) + kb_file.splited_docs = new_docs + kb.update_doc(kb_file, not_refresh_vs_cache=True) + else: + kb_name, file_name, error = result + failed_files[file_name] = error + + # 将自定义的docs进行向量化 + for file_name, v in docs.items(): + try: + v = [x if isinstance(x, Document) else Document(**x) for x in v] + kb_file = KnowledgeFile( + filename=file_name, knowledge_base_name=knowledge_base_name + ) + kb.update_doc(kb_file, docs=v, not_refresh_vs_cache=True) + except Exception as e: + msg = f"为 {file_name} 添加自定义docs时出错:{e}" + failed_files[file_name] = msg + + if not not_refresh_vs_cache: + kb.save_vector_store() + + return BaseResponse( + code=200, msg=f"更新文档完成", data={"failed_files": failed_files} + ) + + +def download_doc( + knowledge_base_name: str = Query( + ..., description="知识库名称", examples=["samples"] + ), + file_name: str = Query(..., description="文件名称", examples=["test.txt"]), + preview: bool = Query(False, description="是:浏览器内预览;否:下载"), +): + """ + 下载知识库文档 + """ + if not validate_kb_name(knowledge_base_name): + return BaseResponse(code=403, msg="Don't attack me") + + kb = KBServiceFactory.get_service_by_name(knowledge_base_name) + if kb is None: + return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") + + if preview: + content_disposition_type = "inline" + else: + content_disposition_type = None + + try: + kb_file = KnowledgeFile( + filename=file_name, knowledge_base_name=knowledge_base_name + ) + + if os.path.exists(kb_file.filepath): + return FileResponse( + path=kb_file.filepath, + filename=kb_file.filename, + media_type="multipart/form-data", + content_disposition_type=content_disposition_type, + ) + except Exception as e: + msg = f"{kb_file.filename} 读取文件失败,错误信息是:{e}" + return BaseResponse(code=500, msg=msg) + + return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败") + + +def recreate_vector_store( + knowledge_base_name: str = Body(..., examples=["samples"]), + allow_empty_kb: bool = Body(True), + vs_type: str = Body(DEFAULT_VS_TYPE), + embed_model: str = Body(...), + chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"), + chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"), + zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"), + not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"), +): + """ + recreate vector store from the content. + this is usefull when user can copy files to content folder directly instead of upload through network. + by default, get_service_by_name only return knowledge base in the info.db and having document files in it. + set allow_empty_kb to True make it applied on empty knowledge base which it not in the info.db or having no documents. + """ + + def output(): + kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model) + if not kb.exists() and not allow_empty_kb: + yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"} + else: + error_msg = ( + f"could not recreate vector store because failed to access embed model." + ) + if not kb.check_embed_model(error_msg): + yield {"code": 404, "msg": error_msg} + else: + if kb.exists(): + kb.clear_vs() + kb.create_kb() + files = list_files_from_folder(knowledge_base_name) + kb_files = [(file, knowledge_base_name) for file in files] + i = 0 + for status, result in files2docs_in_thread( + kb_files, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + zh_title_enhance=zh_title_enhance, + ): + if status: + kb_name, file_name, docs = result + kb_file = KnowledgeFile( + filename=file_name, knowledge_base_name=kb_name + ) + kb_file.splited_docs = docs + yield json.dumps( + { + "code": 200, + "msg": f"({i + 1} / {len(files)}): {file_name}", + "total": len(files), + "finished": i + 1, + "doc": file_name, + }, + ensure_ascii=False, + ) + kb.add_doc(kb_file, not_refresh_vs_cache=True) + else: + kb_name, file_name, error = result + msg = f"添加文件‘{file_name}’到知识库‘{knowledge_base_name}’时出错:{error}。已跳过。" + yield json.dumps( + { + "code": 500, + "msg": msg, + } + ) + i += 1 + if not not_refresh_vs_cache: + kb.save_vector_store() + + return EventSourceResponse(output()) \ No newline at end of file diff --git a/src/mindpilot/app/knowledge_base/kb_service/__init__.py b/src/mindpilot/app/knowledge_base/kb_service/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/mindpilot/app/knowledge_base/kb_service/base.py b/src/mindpilot/app/knowledge_base/kb_service/base.py new file mode 100644 index 0000000..bb98687 --- /dev/null +++ b/src/mindpilot/app/knowledge_base/kb_service/base.py @@ -0,0 +1,500 @@ +import operator +import os +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +from langchain.docstore.document import Document + +from ...configs import ( + KB_INFO, + SCORE_THRESHOLD, + VECTOR_SEARCH_TOP_K, + kbs_config, + DEFAULT_EMBEDDING_MODEL +) +from ..db.models.knowledge_base_model import KnowledgeBaseSchema +from ..db.repository.knowledge_base_repository import ( + add_kb_to_db, + delete_kb_from_db, + kb_exists, + list_kbs_from_db, + load_kb_from_db, +) +from ..db.repository.knowledge_file_repository import ( + add_file_to_db, + count_files_from_db, + delete_file_from_db, + delete_files_from_db, + file_exists_in_db, + get_file_detail, + list_docs_from_db, + list_files_from_db, +) +from ..model.kb_document_model import DocumentWithVSId +from ..utils import ( + KnowledgeFile, + get_doc_path, + get_kb_path, + list_files_from_folder, + list_kbs_from_folder, +) +from ...utils.system_utils import check_embed_model as _check_embed_model + +class SupportedVSType: + FAISS = "faiss" + MILVUS = "milvus" + DEFAULT = "default" + ES = "es" + + +class KBService(ABC): + def __init__( + self, + knowledge_base_name: str, + kb_info: str = None, + embed_model: str = DEFAULT_EMBEDDING_MODEL, + ): + self.kb_name = knowledge_base_name + self.kb_info = kb_info or KB_INFO.get( + knowledge_base_name, f"关于{knowledge_base_name}的知识库" + ) + self.embed_model = embed_model + self.kb_path = get_kb_path(self.kb_name) + self.doc_path = get_doc_path(self.kb_name) + self.do_init() + + def __repr__(self) -> str: + return f"{self.kb_name} @ {self.embed_model}" + + def save_vector_store(self): + """ + 保存向量库:FAISS保存到磁盘,milvus保存到数据库。PGVector暂未支持 + """ + pass + + def check_embed_model(self, error_msg: str) -> bool: + if not _check_embed_model(self.embed_model): + return False + else: + return True + + def create_kb(self): + """ + 创建知识库 + """ + if not os.path.exists(self.doc_path): + os.makedirs(self.doc_path) + + status = add_kb_to_db( + self.kb_name, self.kb_info, self.vs_type(), self.embed_model + ) + + if status: + self.do_create_kb() + return status + + def clear_vs(self): + """ + 删除向量库中所有内容 + """ + self.do_clear_vs() + status = delete_files_from_db(self.kb_name) + return status + + def drop_kb(self): + """ + 删除知识库 + """ + self.do_drop_kb() + status = delete_kb_from_db(self.kb_name) + return status + + def add_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs): + """ + 向知识库添加文件 + 如果指定了docs,则不再将文本向量化,并将数据库对应条目标为custom_docs=True + """ + if not self.check_embed_model( + f"could not add docs because failed to access embed model." + ): + return False + + if docs: + custom_docs = True + else: + docs = kb_file.file2text() + custom_docs = False + + if docs: + # 将 metadata["source"] 改为相对路径 + for doc in docs: + try: + doc.metadata.setdefault("source", kb_file.filename) + source = doc.metadata.get("source", "") + if os.path.isabs(source): + rel_path = Path(source).relative_to(self.doc_path) + doc.metadata["source"] = str(rel_path.as_posix().strip("/")) + except Exception as e: + print( + f"cannot convert absolute path ({source}) to relative path. error is : {e}" + ) + self.delete_doc(kb_file) + doc_infos = self.do_add_doc(docs, **kwargs) + status = add_file_to_db( + kb_file, + custom_docs=custom_docs, + docs_count=len(docs), + doc_infos=doc_infos, + ) + else: + status = False + return status + + def delete_doc( + self, kb_file: KnowledgeFile, delete_content: bool = False, **kwargs + ): + """ + 从知识库删除文件 + """ + self.do_delete_doc(kb_file, **kwargs) + status = delete_file_from_db(kb_file) + if delete_content and os.path.exists(kb_file.filepath): + os.remove(kb_file.filepath) + return status + + def update_info(self, kb_info: str): + """ + 更新知识库介绍 + """ + self.kb_info = kb_info + status = add_kb_to_db( + self.kb_name, self.kb_info, self.vs_type(), self.embed_model + ) + return status + + def update_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs): + """ + 使用content中的文件更新向量库 + 如果指定了docs,则使用自定义docs,并将数据库对应条目标为custom_docs=True + """ + if not self.check_embed_model( + f"could not update docs because failed to access embed model." + ): + return False + + if os.path.exists(kb_file.filepath): + self.delete_doc(kb_file, **kwargs) + return self.add_doc(kb_file, docs=docs, **kwargs) + + def exist_doc(self, file_name: str): + return file_exists_in_db( + KnowledgeFile(knowledge_base_name=self.kb_name, filename=file_name) + ) + + def list_files(self): + return list_files_from_db(self.kb_name) + + def count_files(self): + return count_files_from_db(self.kb_name) + + def search_docs( + self, + query: str, + top_k: int = VECTOR_SEARCH_TOP_K, + score_threshold: float = SCORE_THRESHOLD, + ) -> List[Document]: + if not self.check_embed_model( + f"could not search docs because failed to access embed model." + ): + return [] + docs = self.do_search(query, top_k, score_threshold) + return docs + + def get_doc_by_ids(self, ids: List[str]) -> List[Document]: + return [] + + def del_doc_by_ids(self, ids: List[str]) -> bool: + raise NotImplementedError + + def update_doc_by_ids(self, docs: Dict[str, Document]) -> bool: + """ + 传入参数为: {doc_id: Document, ...} + 如果对应 doc_id 的值为 None,或其 page_content 为空,则删除该文档 + """ + if not self.check_embed_model( + f"could not update docs because failed to access embed model." + ): + return False + + self.del_doc_by_ids(list(docs.keys())) + pending_docs = [] + ids = [] + for _id, doc in docs.items(): + if not doc or not doc.page_content.strip(): + continue + ids.append(_id) + pending_docs.append(doc) + self.do_add_doc(docs=pending_docs, ids=ids) + return True + + def list_docs( + self, file_name: str = None, metadata: Dict = {} + ) -> List[DocumentWithVSId]: + """ + 通过file_name或metadata检索Document + """ + doc_infos = list_docs_from_db( + kb_name=self.kb_name, file_name=file_name, metadata=metadata + ) + docs = [] + for x in doc_infos: + doc_info = self.get_doc_by_ids([x["id"]])[0] + if doc_info is not None: + # 处理非空的情况 + doc_with_id = DocumentWithVSId(**doc_info.dict(), id=x["id"]) + docs.append(doc_with_id) + else: + # 处理空的情况 + # 可以选择跳过当前循环迭代或执行其他操作 + pass + return docs + + def get_relative_source_path(self, filepath: str): + """ + 将文件路径转化为相对路径,保证查询时一致 + """ + relative_path = filepath + if os.path.isabs(relative_path): + try: + relative_path = Path(filepath).relative_to(self.doc_path) + except Exception as e: + print( + f"cannot convert absolute path ({relative_path}) to relative path. error is : {e}" + ) + + relative_path = str(relative_path.as_posix().strip("/")) + return relative_path + + @abstractmethod + def do_create_kb(self): + """ + 创建知识库子类实自己逻辑 + """ + pass + + @staticmethod + def list_kbs_type(): + return list(kbs_config.keys()) + + @classmethod + def list_kbs(cls): + return list_kbs_from_db() + + def exists(self, kb_name: str = None): + kb_name = kb_name or self.kb_name + return kb_exists(kb_name) + + @abstractmethod + def vs_type(self) -> str: + pass + + @abstractmethod + def do_init(self): + pass + + @abstractmethod + def do_drop_kb(self): + """ + 删除知识库子类实自己逻辑 + """ + pass + + @abstractmethod + def do_search( + self, + query: str, + top_k: int, + score_threshold: float, + ) -> List[Tuple[Document, float]]: + """ + 搜索知识库子类实自己逻辑 + """ + pass + + @abstractmethod + def do_add_doc( + self, + docs: List[Document], + **kwargs, + ) -> List[Dict]: + """ + 向知识库添加文档子类实自己逻辑 + """ + pass + + @abstractmethod + def do_delete_doc(self, kb_file: KnowledgeFile): + """ + 从知识库删除文档子类实自己逻辑 + """ + pass + + @abstractmethod + def do_clear_vs(self): + """ + 从知识库删除全部向量子类实自己逻辑 + """ + pass + + +class KBServiceFactory: + @staticmethod + def get_service( + kb_name: str, + vector_store_type: Union[str, SupportedVSType], + embed_model: str = DEFAULT_EMBEDDING_MODEL, + kb_info: str = None, + ) -> KBService: + if isinstance(vector_store_type, str): + vector_store_type = getattr(SupportedVSType, vector_store_type.upper()) + params = { + "knowledge_base_name": kb_name, + "embed_model": embed_model, + "kb_info": kb_info, + } + if SupportedVSType.FAISS == vector_store_type: + from ..kb_service.faiss_kb_service import ( + FaissKBService, + ) + + return FaissKBService(**params) + + elif SupportedVSType.MILVUS == vector_store_type: + from ..kb_service.milvus_kb_service import ( + MilvusKBService, + ) + + return MilvusKBService(**params) + + elif SupportedVSType.DEFAULT == vector_store_type: + from ..kb_service.milvus_kb_service import ( + MilvusKBService, + ) + + return MilvusKBService( + **params + ) # other milvus parameters are set in model_config.kbs_config + elif SupportedVSType.ES == vector_store_type: + from ..kb_service.es_kb_service import ( + ESKBService, + ) + + return ESKBService(**params) + + elif ( + SupportedVSType.DEFAULT == vector_store_type + ): + from ..kb_service.default_kb_service import ( + DefaultKBService, + ) + + return DefaultKBService(kb_name) + + @staticmethod + def get_service_by_name(kb_name: str) -> KBService: + _, vs_type, embed_model = load_kb_from_db(kb_name) + if _ is None: # kb not in db, just return None + return None + return KBServiceFactory.get_service(kb_name, vs_type, embed_model) + + @staticmethod + def get_default(): + return KBServiceFactory.get_service("default", SupportedVSType.DEFAULT) + + +def get_kb_details() -> List[Dict]: + kbs_in_folder = list_kbs_from_folder() + kbs_in_db: List[KnowledgeBaseSchema] = KBService.list_kbs() + result = {} + + for kb in kbs_in_folder: + result[kb] = { + "kb_name": kb, + "vs_type": "", + "kb_info": "", + "embed_model": "", + "file_count": 0, + "create_time": None, + "in_folder": True, + "in_db": False, + } + + for kb_detail in kbs_in_db: + kb_detail = kb_detail.model_dump() + kb_name = kb_detail["kb_name"] + kb_detail["in_db"] = True + if kb_name in result: + result[kb_name].update(kb_detail) + else: + kb_detail["in_folder"] = False + result[kb_name] = kb_detail + + data = [] + for i, v in enumerate(result.values()): + v["No"] = i + 1 + data.append(v) + + return data + + +def get_kb_file_details(kb_name: str) -> List[Dict]: + kb = KBServiceFactory.get_service_by_name(kb_name) + if kb is None: + return [] + + files_in_folder = list_files_from_folder(kb_name) + files_in_db = kb.list_files() + result = {} + + for doc in files_in_folder: + result[doc] = { + "kb_name": kb_name, + "file_name": doc, + "file_ext": os.path.splitext(doc)[-1], + "file_version": 0, + "document_loader": "", + "docs_count": 0, + "text_splitter": "", + "create_time": None, + "in_folder": True, + "in_db": False, + } + lower_names = {x.lower(): x for x in result} + for doc in files_in_db: + doc_detail = get_file_detail(kb_name, doc) + if doc_detail: + doc_detail["in_db"] = True + if doc.lower() in lower_names: + result[lower_names[doc.lower()]].update(doc_detail) + else: + doc_detail["in_folder"] = False + result[doc] = doc_detail + + data = [] + for i, v in enumerate(result.values()): + v["No"] = i + 1 + data.append(v) + + return data + + +def score_threshold_process(score_threshold, k, docs): + if score_threshold is not None: + cmp = operator.le + docs = [ + (doc, similarity) + for doc, similarity in docs + if cmp(similarity, score_threshold) + ] + return docs[:k] diff --git a/src/mindpilot/app/knowledge_base/kb_service/default_kb_service.py b/src/mindpilot/app/knowledge_base/kb_service/default_kb_service.py new file mode 100644 index 0000000..a060a42 --- /dev/null +++ b/src/mindpilot/app/knowledge_base/kb_service/default_kb_service.py @@ -0,0 +1,38 @@ +from typing import List + +from langchain.embeddings.base import Embeddings +from langchain.schema import Document + +from ..kb_service.base import KBService + + +class DefaultKBService(KBService): + def do_create_kb(self): + pass + + def do_drop_kb(self): + pass + + def do_add_doc(self, docs: List[Document]): + pass + + def do_clear_vs(self): + pass + + def vs_type(self) -> str: + return "default" + + def do_init(self): + pass + + def do_search(self): + pass + + def do_insert_multi_knowledge(self): + pass + + def do_insert_one_knowledge(self): + pass + + def do_delete_doc(self): + pass diff --git a/src/mindpilot/app/knowledge_base/kb_service/es_kb_service.py b/src/mindpilot/app/knowledge_base/kb_service/es_kb_service.py new file mode 100644 index 0000000..7e7e301 --- /dev/null +++ b/src/mindpilot/app/knowledge_base/kb_service/es_kb_service.py @@ -0,0 +1,226 @@ +import logging +import os +import shutil +from typing import List + +from elasticsearch import BadRequestError, Elasticsearch +from langchain.schema import Document +from langchain_community.vectorstores.elasticsearch import ( + ApproxRetrievalStrategy, + ElasticsearchStore, +) + +from src.mindpilot.app.utils.system_utils import get_Embeddings +from ...configs import KB_ROOT_PATH, kbs_config +from ..file_rag.utils import get_Retriever +from ..kb_service.base import KBService, SupportedVSType +from ..utils import KnowledgeFile + + +logger = logging.getLogger() + + +class ESKBService(KBService): + def do_init(self): + self.kb_path = self.get_kb_path(self.kb_name) + self.index_name = os.path.split(self.kb_path)[-1] + self.IP = kbs_config[self.vs_type()]["host"] + self.PORT = kbs_config[self.vs_type()]["port"] + self.user = kbs_config[self.vs_type()].get("user", "") + self.password = kbs_config[self.vs_type()].get("password", "") + self.dims_length = kbs_config[self.vs_type()].get("dims_length", None) + self.embeddings_model = get_Embeddings(self.embed_model) + try: + # ES python客户端连接(仅连接) + if self.user != "" and self.password != "": + self.es_client_python = Elasticsearch( + f"http://{self.IP}:{self.PORT}", + basic_auth=(self.user, self.password), + ) + else: + logger.warning("ES未配置用户名和密码") + self.es_client_python = Elasticsearch(f"http://{self.IP}:{self.PORT}") + except ConnectionError: + logger.error("连接到 Elasticsearch 失败!") + raise ConnectionError + except Exception as e: + logger.error(f"Error 发生 : {e}") + raise e + try: + # 首先尝试通过es_client_python创建 + mappings = { + "properties": { + "dense_vector": { + "type": "dense_vector", + "dims": self.dims_length, + "index": True, + } + } + } + self.es_client_python.indices.create( + index=self.index_name, mappings=mappings + ) + except BadRequestError as e: + logger.error("创建索引失败,重新") + logger.error(e) + + try: + # langchain ES 连接、创建索引 + params = dict( + es_url=f"http://{self.IP}:{self.PORT}", + index_name=self.index_name, + query_field="context", + vector_query_field="dense_vector", + embedding=self.embeddings_model, + strategy=ApproxRetrievalStrategy(), + es_params={ + "timeout": 60, + }, + ) + if self.user != "" and self.password != "": + params.update(es_user=self.user, es_password=self.password) + self.db = ElasticsearchStore(**params) + except ConnectionError: + logger.error("### 初始化 Elasticsearch 失败!") + raise ConnectionError + except Exception as e: + logger.error(f"Error 发生 : {e}") + raise e + try: + # 尝试通过db_init创建索引 + self.db._create_index_if_not_exists( + index_name=self.index_name, dims_length=self.dims_length + ) + except Exception as e: + logger.error("创建索引失败...") + logger.error(e) + # raise e + + @staticmethod + def get_kb_path(knowledge_base_name: str): + return os.path.join(KB_ROOT_PATH, knowledge_base_name) + + @staticmethod + def get_vs_path(knowledge_base_name: str): + return os.path.join( + ESKBService.get_kb_path(knowledge_base_name), "vector_store" + ) + + def do_create_kb(self): ... + + def vs_type(self) -> str: + return SupportedVSType.ES + + def do_search(self, query: str, top_k: int, score_threshold: float): + # 文本相似性检索 + retriever = get_Retriever("vectorstore").from_vectorstore( + self.db, + top_k=top_k, + score_threshold=score_threshold, + ) + docs = retriever.get_relevant_documents(query) + return docs + + def get_doc_by_ids(self, ids: List[str]) -> List[Document]: + results = [] + for doc_id in ids: + try: + response = self.es_client_python.get(index=self.index_name, id=doc_id) + source = response["_source"] + # Assuming your document has "text" and "metadata" fields + text = source.get("context", "") + metadata = source.get("metadata", {}) + results.append(Document(page_content=text, metadata=metadata)) + except Exception as e: + logger.error(f"Error retrieving document from Elasticsearch! {e}") + return results + + def del_doc_by_ids(self, ids: List[str]) -> bool: + for doc_id in ids: + try: + self.es_client_python.delete( + index=self.index_name, id=doc_id, refresh=True + ) + except Exception as e: + logger.error(f"ES Docs Delete Error! {e}") + + def do_delete_doc(self, kb_file, **kwargs): + if self.es_client_python.indices.exists(index=self.index_name): + # 从向量数据库中删除索引(文档名称是Keyword) + query = { + "query": { + "term": { + "metadata.source.keyword": self.get_relative_source_path( + kb_file.filepath + ) + } + }, + "track_total_hits": True, + } + # 注意设置size,默认返回10个,es检索设置track_total_hits为True返回数据库中真实的size。 + size = self.es_client_python.search(body=query)["hits"]["total"]["value"] + search_results = self.es_client_python.search(body=query, size=size) + delete_list = [hit["_id"] for hit in search_results["hits"]["hits"]] + if len(delete_list) == 0: + return None + else: + for doc_id in delete_list: + try: + self.es_client_python.delete( + index=self.index_name, id=doc_id, refresh=True + ) + except Exception as e: + logger.error(f"ES Docs Delete Error! {e}") + + # self.db.delete(ids=delete_list) + # self.es_client_python.indices.refresh(index=self.index_name) + + def do_add_doc(self, docs: List[Document], **kwargs): + """向知识库添加文件""" + + print( + f"server.knowledge_base.kb_service.es_kb_service.do_add_doc 输入的docs参数长度为:{len(docs)}" + ) + print("*" * 100) + + self.db.add_documents(documents=docs) + # 获取 id 和 source , 格式:[{"id": str, "metadata": dict}, ...] + print("写入数据成功.") + print("*" * 100) + + if self.es_client_python.indices.exists(index=self.index_name): + file_path = docs[0].metadata.get("source") + query = { + "query": { + "term": {"metadata.source.keyword": file_path}, + "term": {"_index": self.index_name}, + } + } + # 注意设置size,默认返回10个。 + search_results = self.es_client_python.search(body=query, size=50) + if len(search_results["hits"]["hits"]) == 0: + raise ValueError("召回元素个数为0") + info_docs = [ + {"id": hit["_id"], "metadata": hit["_source"]["metadata"]} + for hit in search_results["hits"]["hits"] + ] + return info_docs + + def do_clear_vs(self): + """从知识库删除全部向量""" + if self.es_client_python.indices.exists(index=self.kb_name): + self.es_client_python.indices.delete(index=self.kb_name) + + def do_drop_kb(self): + """删除知识库""" + # self.kb_file: 知识库路径 + if os.path.exists(self.kb_path): + shutil.rmtree(self.kb_path) + + +if __name__ == "__main__": + esKBService = ESKBService("test") + # esKBService.clear_vs() + # esKBService.create_kb() + esKBService.add_doc(KnowledgeFile(filename="README.md", knowledge_base_name="test")) + print(esKBService.search_docs("如何启动api服务")) diff --git a/src/mindpilot/app/knowledge_base/kb_service/faiss_kb_service.py b/src/mindpilot/app/knowledge_base/kb_service/faiss_kb_service.py new file mode 100644 index 0000000..1ec8e63 --- /dev/null +++ b/src/mindpilot/app/knowledge_base/kb_service/faiss_kb_service.py @@ -0,0 +1,136 @@ +import os +import shutil +from typing import Dict, List, Tuple + +from langchain.docstore.document import Document + +from ...configs import SCORE_THRESHOLD +from ..file_rag.utils import get_Retriever +from ..kb_cache.faiss_cache import ( + ThreadSafeFaiss, + kb_faiss_pool, +) +from ..kb_service.base import KBService, SupportedVSType +from ..utils import KnowledgeFile, get_kb_path, get_vs_path + + +class FaissKBService(KBService): + vs_path: str + kb_path: str + vector_name: str = None + + def vs_type(self) -> str: + return SupportedVSType.FAISS + + def get_vs_path(self): + return get_vs_path(self.kb_name, self.vector_name) + + def get_kb_path(self): + return get_kb_path(self.kb_name) + + def load_vector_store(self) -> ThreadSafeFaiss: + return kb_faiss_pool.load_vector_store( + kb_name=self.kb_name, + vector_name=self.vector_name, + embed_model=self.embed_model, + ) + + def save_vector_store(self): + self.load_vector_store().save(self.vs_path) + + def get_doc_by_ids(self, ids: List[str]) -> List[Document]: + with self.load_vector_store().acquire() as vs: + return [vs.docstore._dict.get(id) for id in ids] + + def del_doc_by_ids(self, ids: List[str]) -> bool: + with self.load_vector_store().acquire() as vs: + vs.delete(ids) + + def do_init(self): + self.vector_name = self.vector_name or self.embed_model + self.kb_path = self.get_kb_path() + self.vs_path = self.get_vs_path() + + def do_create_kb(self): + if not os.path.exists(self.vs_path): + os.makedirs(self.vs_path) + self.load_vector_store() + + def do_drop_kb(self): + self.clear_vs() + try: + shutil.rmtree(self.kb_path) + except Exception: + pass + + def do_search( + self, + query: str, + top_k: int, + score_threshold: float = SCORE_THRESHOLD, + ) -> List[Tuple[Document, float]]: + with self.load_vector_store().acquire() as vs: + retriever = get_Retriever("ensemble").from_vectorstore( + vs, + top_k=top_k, + score_threshold=score_threshold, + ) + docs = retriever.get_relevant_documents(query) + return docs + + def do_add_doc( + self, + docs: List[Document], + **kwargs, + ) -> List[Dict]: + texts = [x.page_content for x in docs] + metadatas = [x.metadata for x in docs] + with self.load_vector_store().acquire() as vs: + embeddings = vs.embeddings.embed_documents(texts) + ids = vs.add_embeddings( + text_embeddings=zip(texts, embeddings), metadatas=metadatas + ) + if not kwargs.get("not_refresh_vs_cache"): + vs.save_local(self.vs_path) + doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)] + return doc_infos + + def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs): + with self.load_vector_store().acquire() as vs: + ids = [ + k + for k, v in vs.docstore._dict.items() + if v.metadata.get("source").lower() == kb_file.filename.lower() + ] + if len(ids) > 0: + vs.delete(ids) + if not kwargs.get("not_refresh_vs_cache"): + vs.save_local(self.vs_path) + return ids + + def do_clear_vs(self): + with kb_faiss_pool.atomic: + kb_faiss_pool.pop((self.kb_name, self.vector_name)) + try: + shutil.rmtree(self.vs_path) + except Exception: + ... + os.makedirs(self.vs_path, exist_ok=True) + + def exist_doc(self, file_name: str): + if super().exist_doc(file_name): + return "in_db" + + content_path = os.path.join(self.kb_path, "content") + if os.path.isfile(os.path.join(content_path, file_name)): + return "in_folder" + else: + return False + + +if __name__ == "__main__": + faissService = FaissKBService("test") + faissService.add_doc(KnowledgeFile("README.md", "test")) + faissService.delete_doc(KnowledgeFile("README.md", "test")) + faissService.do_drop_kb() + print(faissService.search_docs("如何启动api服务")) diff --git a/src/mindpilot/app/knowledge_base/kb_service/milvus_kb_service.py b/src/mindpilot/app/knowledge_base/kb_service/milvus_kb_service.py new file mode 100644 index 0000000..6fcf59a --- /dev/null +++ b/src/mindpilot/app/knowledge_base/kb_service/milvus_kb_service.py @@ -0,0 +1,125 @@ +import os +from typing import Dict, List, Optional + +from langchain.schema import Document +from langchain.vectorstores.milvus import Milvus + +from ...configs import kbs_config +from ..db.repository import list_file_num_docs_id_by_kb_name_and_file_name +from ...utils.system_utils import get_Embeddings +from ..file_rag.utils import get_Retriever +from ..kb_service.base import ( + KBService, + SupportedVSType, + score_threshold_process, +) +from ..utils import KnowledgeFile + + +class MilvusKBService(KBService): + milvus: Milvus + + @staticmethod + def get_collection(milvus_name): + from pymilvus import Collection + + return Collection(milvus_name) + + def get_doc_by_ids(self, ids: List[str]) -> List[Document]: + result = [] + if self.milvus.col: + # ids = [int(id) for id in ids] # for milvus if needed #pr 2725 + data_list = self.milvus.col.query( + expr=f"pk in {[int(_id) for _id in ids]}", output_fields=["*"] + ) + for data in data_list: + text = data.pop("text") + result.append(Document(page_content=text, metadata=data)) + return result + + def del_doc_by_ids(self, ids: List[str]) -> bool: + self.milvus.col.delete(expr=f"pk in {ids}") + + @staticmethod + def search(milvus_name, content, limit=3): + search_params = { + "metric_type": "L2", + "params": {"nprobe": 10}, + } + c = MilvusKBService.get_collection(milvus_name) + return c.search( + content, "embeddings", search_params, limit=limit, output_fields=["content"] + ) + + def do_create_kb(self): + pass + + def vs_type(self) -> str: + return SupportedVSType.MILVUS + + def _load_milvus(self): + self.milvus = Milvus( + embedding_function=get_Embeddings(self.embed_model), + collection_name=self.kb_name, + connection_args=kbs_config.get("milvus"), + index_params=kbs_config.get("milvus_kwargs")["index_params"], + search_params=kbs_config.get("milvus_kwargs")["search_params"], + auto_id=True, + ) + + def do_init(self): + self._load_milvus() + + def do_drop_kb(self): + if self.milvus.col: + self.milvus.col.release() + self.milvus.col.drop() + + def do_search(self, query: str, top_k: int, score_threshold: float): + self._load_milvus() + # embed_func = get_Embeddings(self.embed_model) + # embeddings = embed_func.embed_query(query) + # docs = self.milvus.similarity_search_with_score_by_vector(embeddings, top_k) + retriever = get_Retriever("vectorstore").from_vectorstore( + self.milvus, + top_k=top_k, + score_threshold=score_threshold, + ) + docs = retriever.get_relevant_documents(query) + return docs + + def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: + for doc in docs: + for k, v in doc.metadata.items(): + doc.metadata[k] = str(v) + for field in self.milvus.fields: + doc.metadata.setdefault(field, "") + doc.metadata.pop(self.milvus._text_field, None) + doc.metadata.pop(self.milvus._vector_field, None) + + ids = self.milvus.add_documents(docs) + doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)] + return doc_infos + + def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs): + id_list = list_file_num_docs_id_by_kb_name_and_file_name( + kb_file.kb_name, kb_file.filename + ) + if self.milvus.col: + self.milvus.col.delete(expr=f"pk in {id_list}") + + # Issue 2846, for windows + # if self.milvus.col: + # file_path = kb_file.filepath.replace("\\", "\\\\") + # file_name = os.path.basename(file_path) + # id_list = [item.get("pk") for item in + # self.milvus.col.query(expr=f'source == "{file_name}"', output_fields=["pk"])] + # self.milvus.col.delete(expr=f'pk in {id_list}') + + def do_clear_vs(self): + if self.milvus.col: + self.do_drop_kb() + self.do_init() + + + diff --git a/src/mindpilot/app/knowledge_base/migrate.py b/src/mindpilot/app/knowledge_base/migrate.py new file mode 100644 index 0000000..460b50c --- /dev/null +++ b/src/mindpilot/app/knowledge_base/migrate.py @@ -0,0 +1,234 @@ +import os +from datetime import datetime +from typing import List, Literal + +from dateutil.parser import parse + +from ..configs import ( + CHUNK_SIZE, + DEFAULT_EMBEDDING_MODEL, + DEFAULT_VS_TYPE, + OVERLAP_SIZE, + ZH_TITLE_ENHANCE, +) + +from .db.base import Base, engine +from .db.repository.knowledge_file_repository import ( + add_file_to_db, +) + +from .db.session import session_scope +from .kb_service.base import ( + KBServiceFactory, + SupportedVSType, +) +from .utils import ( + KnowledgeFile, + files2docs_in_thread, + get_file_path, + list_files_from_folder, + list_kbs_from_folder, +) + + +def create_tables(): + Base.metadata.create_all(bind=engine) + + +def reset_tables(): + Base.metadata.drop_all(bind=engine) + create_tables() + + +def import_from_db( + sqlite_path: str = None, + # csv_path: str = None, +) -> bool: + """ + 在知识库与向量库无变化的情况下,从备份数据库中导入数据到 info.db。 + 适用于版本升级时,info.db 结构变化,但无需重新向量化的情况。 + 请确保两边数据库表名一致,需要导入的字段名一致 + 当前仅支持 sqlite + """ + import sqlite3 as sql + from pprint import pprint + + models = list(Base.registry.mappers) + + try: + con = sql.connect(sqlite_path) + con.row_factory = sql.Row + cur = con.cursor() + tables = [ + x["name"] + for x in cur.execute( + "select name from sqlite_master where type='table'" + ).fetchall() + ] + for model in models: + table = model.local_table.fullname + if table not in tables: + continue + print(f"processing table: {table}") + with session_scope() as session: + for row in cur.execute(f"select * from {table}").fetchall(): + data = {k: row[k] for k in row.keys() if k in model.columns} + if "create_time" in data: + data["create_time"] = parse(data["create_time"]) + pprint(data) + session.add(model.class_(**data)) + con.close() + return True + except Exception as e: + print(f"无法读取备份数据库:{sqlite_path}。错误信息:{e}") + return False + + +def file_to_kbfile(kb_name: str, files: List[str]) -> List[KnowledgeFile]: + kb_files = [] + for file in files: + try: + kb_file = KnowledgeFile(filename=file, knowledge_base_name=kb_name) + kb_files.append(kb_file) + except Exception as e: + msg = f"{e},已跳过" + return kb_files + + +def folder2db( + kb_names: List[str], + mode: Literal["recreate_vs", "update_in_db", "increment"], + vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE, + embed_model: str = DEFAULT_EMBEDDING_MODEL, + chunk_size: int = CHUNK_SIZE, + chunk_overlap: int = OVERLAP_SIZE, + zh_title_enhance: bool = ZH_TITLE_ENHANCE, +): + """ + use existed files in local folder to populate database and/or vector store. + set parameter `mode` to: + recreate_vs: recreate all vector store and fill info to database using existed files in local folder + fill_info_only(disabled): do not create vector store, fill info to db using existed files only + update_in_db: update vector store and database info using local files that existed in database only + increment: create vector store and database info for local files that not existed in database only + """ + + def files2vs(kb_name: str, kb_files: List[KnowledgeFile]) -> List: + result = [] + for success, res in files2docs_in_thread( + kb_files, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + zh_title_enhance=zh_title_enhance, + ): + if success: + _, filename, docs = res + print( + f"正在将 {kb_name}/{filename} 添加到向量库,共包含{len(docs)}条文档" + ) + kb_file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name) + kb_file.splited_docs = docs + kb.add_doc(kb_file=kb_file, not_refresh_vs_cache=True) + result.append({"kb_name": kb_name, "file": filename, "docs": docs}) + else: + print(res) + return result + + kb_names = kb_names or list_kbs_from_folder() + for kb_name in kb_names: + start = datetime.now() + kb = KBServiceFactory.get_service(kb_name, vs_type, embed_model) + if not kb.exists(): + kb.create_kb() + + # 清除向量库,从本地文件重建 + if mode == "recreate_vs": + kb.clear_vs() + kb.create_kb() + kb_files = file_to_kbfile(kb_name, list_files_from_folder(kb_name)) + result = files2vs(kb_name, kb_files) + kb.save_vector_store() + # # 不做文件内容的向量化,仅将文件元信息存到数据库 + # # 由于现在数据库存了很多与文本切分相关的信息,单纯存储文件信息意义不大,该功能取消。 + # elif mode == "fill_info_only": + # files = list_files_from_folder(kb_name) + # kb_files = file_to_kbfile(kb_name, files) + # for kb_file in kb_files: + # add_file_to_db(kb_file) + # print(f"已将 {kb_name}/{kb_file.filename} 添加到数据库") + # 以数据库中文件列表为基准,利用本地文件更新向量库 + elif mode == "update_in_db": + files = kb.list_files() + kb_files = file_to_kbfile(kb_name, files) + result = files2vs(kb_name, kb_files) + kb.save_vector_store() + # 对比本地目录与数据库中的文件列表,进行增量向量化 + elif mode == "increment": + db_files = kb.list_files() + folder_files = list_files_from_folder(kb_name) + files = list(set(folder_files) - set(db_files)) + kb_files = file_to_kbfile(kb_name, files) + result = files2vs(kb_name, kb_files) + kb.save_vector_store() + else: + print(f"unsupported migrate mode: {mode}") + end = datetime.now() + kb_path = ( + f"知识库路径\t:{kb.kb_path}\n" + if kb.vs_type() == SupportedVSType.FAISS + else "" + ) + file_count = len(kb_files) + success_count = len(result) + docs_count = sum([len(x["docs"]) for x in result]) + print("\n" + "-" * 100) + print( + ( + f"知识库名称\t:{kb_name}\n" + f"知识库类型\t:{kb.vs_type()}\n" + f"向量模型:\t:{kb.embed_model}\n" + ) + + kb_path + + ( + f"文件总数量\t:{file_count}\n" + f"入库文件数\t:{success_count}\n" + f"知识条目数\t:{docs_count}\n" + f"用时\t\t:{end-start}" + ) + ) + print("-" * 100 + "\n") + return result + + +def prune_db_docs(kb_names: List[str]): + """ + delete docs in database that not existed in local folder. + it is used to delete database docs after user deleted some doc files in file browser + """ + for kb_name in kb_names: + kb = KBServiceFactory.get_service_by_name(kb_name) + if kb is not None: + files_in_db = kb.list_files() + files_in_folder = list_files_from_folder(kb_name) + files = list(set(files_in_db) - set(files_in_folder)) + kb_files = file_to_kbfile(kb_name, files) + for kb_file in kb_files: + kb.delete_doc(kb_file, not_refresh_vs_cache=True) + print(f"success to delete docs for file: {kb_name}/{kb_file.filename}") + kb.save_vector_store() + + +def prune_folder_files(kb_names: List[str]): + """ + delete doc files in local folder that not existed in database. + it is used to free local disk space by delete unused doc files. + """ + for kb_name in kb_names: + kb = KBServiceFactory.get_service_by_name(kb_name) + if kb is not None: + files_in_db = kb.list_files() + files_in_folder = list_files_from_folder(kb_name) + files = list(set(files_in_folder) - set(files_in_db)) + for file in files: + os.remove(get_file_path(kb_name, file)) + print(f"success to delete file: {kb_name}/{file}") diff --git a/src/mindpilot/app/knowledge_base/model/kb_document_model.py b/src/mindpilot/app/knowledge_base/model/kb_document_model.py new file mode 100644 index 0000000..78c9567 --- /dev/null +++ b/src/mindpilot/app/knowledge_base/model/kb_document_model.py @@ -0,0 +1,10 @@ +from langchain.docstore.document import Document + + +class DocumentWithVSId(Document): + """ + 矢量化后的文档 + """ + + id: str = None + score: float = 3.0 diff --git a/src/mindpilot/app/knowledge_base/utils.py b/src/mindpilot/app/knowledge_base/utils.py new file mode 100644 index 0000000..2f40240 --- /dev/null +++ b/src/mindpilot/app/knowledge_base/utils.py @@ -0,0 +1,483 @@ +import importlib +import json +import logging +import os +from functools import lru_cache +from pathlib import Path +from typing import Dict, Generator, List, Tuple, Union + +import chardet +import langchain_community.document_loaders +from langchain.docstore.document import Document +from langchain.text_splitter import MarkdownHeaderTextSplitter, TextSplitter +from langchain_community.document_loaders import JSONLoader, TextLoader + +from ..configs import ( + CHUNK_SIZE, + KB_ROOT_PATH, + OVERLAP_SIZE, + TEXT_SPLITTER_NAME, + ZH_TITLE_ENHANCE, + text_splitter_dict, +) +from .file_rag.text_splitter import ( + zh_title_enhance as func_zh_title_enhance, +) +from ..utils.system_utils import run_in_thread_pool + +logger = logging.getLogger() + + +def validate_kb_name(knowledge_base_id: str) -> bool: + # 检查是否包含预期外的字符或路径攻击关键字 + if "../" in knowledge_base_id: + return False + return True + + +def get_kb_path(knowledge_base_name: str): + return os.path.join(KB_ROOT_PATH, knowledge_base_name) + + +def get_doc_path(knowledge_base_name: str): + return os.path.join(get_kb_path(knowledge_base_name), "content") + + +def get_vs_path(knowledge_base_name: str, vector_name: str): + return os.path.join(get_kb_path(knowledge_base_name), "vector_store", vector_name) + + +def get_file_path(knowledge_base_name: str, doc_name: str): + doc_path = Path(get_doc_path(knowledge_base_name)).resolve() + file_path = (doc_path / doc_name).resolve() + if str(file_path).startswith(str(doc_path)): + return str(file_path) + + +def list_kbs_from_folder(): + return [ + f + for f in os.listdir(KB_ROOT_PATH) + if os.path.isdir(os.path.join(KB_ROOT_PATH, f)) + ] + + +def list_files_from_folder(kb_name: str): + doc_path = get_doc_path(kb_name) + result = [] + + def is_skiped_path(path: str): + tail = os.path.basename(path).lower() + for x in ["temp", "tmp", ".", "~$"]: + if tail.startswith(x): + return True + return False + + def process_entry(entry): + if is_skiped_path(entry.path): + return + + if entry.is_symlink(): + target_path = os.path.realpath(entry.path) + with os.scandir(target_path) as target_it: + for target_entry in target_it: + process_entry(target_entry) + elif entry.is_file(): + file_path = Path( + os.path.relpath(entry.path, doc_path) + ).as_posix() # 路径统一为 posix 格式 + result.append(file_path) + elif entry.is_dir(): + with os.scandir(entry.path) as it: + for sub_entry in it: + process_entry(sub_entry) + + with os.scandir(doc_path) as it: + for entry in it: + process_entry(entry) + + return result + + +LOADER_DICT = { + "UnstructuredHTMLLoader": [".html", ".htm"], + "MHTMLLoader": [".mhtml"], + "TextLoader": [".md"], + "UnstructuredMarkdownLoader": [".md"], + "JSONLoader": [".json"], + "JSONLinesLoader": [".jsonl"], + "CSVLoader": [".csv"], + # "FilteredCSVLoader": [".csv"], 如果使用自定义分割csv + "RapidOCRPDFLoader": [".pdf"], + "RapidOCRDocLoader": [".docx", ".doc"], + "RapidOCRPPTLoader": [ + ".ppt", + ".pptx", + ], + "RapidOCRLoader": [".png", ".jpg", ".jpeg", ".bmp"], + "UnstructuredFileLoader": [ + ".eml", + ".msg", + ".rst", + ".rtf", + ".txt", + ".xml", + ".epub", + ".odt", + ".tsv", + ], + "UnstructuredEmailLoader": [".eml", ".msg"], + "UnstructuredEPubLoader": [".epub"], + "UnstructuredExcelLoader": [".xlsx", ".xls", ".xlsd"], + "NotebookLoader": [".ipynb"], + "UnstructuredODTLoader": [".odt"], + "PythonLoader": [".py"], + "UnstructuredRSTLoader": [".rst"], + "UnstructuredRTFLoader": [".rtf"], + "SRTLoader": [".srt"], + "TomlLoader": [".toml"], + "UnstructuredTSVLoader": [".tsv"], + "UnstructuredWordDocumentLoader": [".docx", ".doc"], + "UnstructuredXMLLoader": [".xml"], + "UnstructuredPowerPointLoader": [".ppt", ".pptx"], + "EverNoteLoader": [".enex"], +} +SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist] + + +# patch json.dumps to disable ensure_ascii +def _new_json_dumps(obj, **kwargs): + kwargs["ensure_ascii"] = False + return _origin_json_dumps(obj, **kwargs) + + +if json.dumps is not _new_json_dumps: + _origin_json_dumps = json.dumps + json.dumps = _new_json_dumps + + +class JSONLinesLoader(JSONLoader): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._json_lines = True + + +langchain_community.document_loaders.JSONLinesLoader = JSONLinesLoader + + +def get_LoaderClass(file_extension): + for LoaderClass, extensions in LOADER_DICT.items(): + if file_extension in extensions: + return LoaderClass + + +def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None): + """ + 根据loader_name和文件路径或内容返回文档加载器。 + """ + loader_kwargs = loader_kwargs or {} + try: + if loader_name in [ + "RapidOCRPDFLoader", + "RapidOCRLoader", + "FilteredCSVLoader", + "RapidOCRDocLoader", + "RapidOCRPPTLoader", + ]: + document_loaders_module = importlib.import_module( + "chatchat.server.file_rag.document_loaders" + ) + else: + document_loaders_module = importlib.import_module( + "langchain_community.document_loaders" + ) + DocumentLoader = getattr(document_loaders_module, loader_name) + except Exception as e: + msg = f"为文件{file_path}查找加载器{loader_name}时出错:{e}" + logger.error( + f"{e.__class__.__name__}: {msg}", exc_info=e + ) + document_loaders_module = importlib.import_module( + "langchain_community.document_loaders" + ) + DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader") + + if loader_name == "UnstructuredFileLoader": + loader_kwargs.setdefault("autodetect_encoding", True) + elif loader_name == "CSVLoader": + if not loader_kwargs.get("encoding"): + # 如果未指定 encoding,自动识别文件编码类型,避免langchain loader 加载文件报编码错误 + with open(file_path, "rb") as struct_file: + encode_detect = chardet.detect(struct_file.read()) + if encode_detect is None: + encode_detect = {"encoding": "utf-8"} + loader_kwargs["encoding"] = encode_detect["encoding"] + + elif loader_name == "JSONLoader": + loader_kwargs.setdefault("jq_schema", ".") + loader_kwargs.setdefault("text_content", False) + elif loader_name == "JSONLinesLoader": + loader_kwargs.setdefault("jq_schema", ".") + loader_kwargs.setdefault("text_content", False) + + loader = DocumentLoader(file_path, **loader_kwargs) + return loader + + +@lru_cache() +def make_text_splitter(splitter_name, chunk_size, chunk_overlap): + """ + 根据参数获取特定的分词器 + """ + splitter_name = splitter_name or "SpacyTextSplitter" + try: + if ( + splitter_name == "MarkdownHeaderTextSplitter" + ): # MarkdownHeaderTextSplitter特殊判定 + headers_to_split_on = text_splitter_dict[splitter_name][ + "headers_to_split_on" + ] + text_splitter = MarkdownHeaderTextSplitter( + headers_to_split_on=headers_to_split_on, strip_headers=False + ) + else: + try: # 优先使用用户自定义的text_splitter + text_splitter_module = importlib.import_module("chatchat.server.file_rag.text_splitter") + TextSplitter = getattr(text_splitter_module, splitter_name) + except: # 否则使用langchain的text_splitter + text_splitter_module = importlib.import_module( + "langchain.text_splitter" + ) + TextSplitter = getattr(text_splitter_module, splitter_name) + + if ( + text_splitter_dict[splitter_name]["source"] == "tiktoken" + ): # 从tiktoken加载 + try: + text_splitter = TextSplitter.from_tiktoken_encoder( + encoding_name=text_splitter_dict[splitter_name][ + "tokenizer_name_or_path" + ], + pipeline="zh_core_web_sm", + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + except: + text_splitter = TextSplitter.from_tiktoken_encoder( + encoding_name=text_splitter_dict[splitter_name][ + "tokenizer_name_or_path" + ], + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + elif ( + text_splitter_dict[splitter_name]["source"] == "huggingface" + ): # 从huggingface加载 + if ( + text_splitter_dict[splitter_name]["tokenizer_name_or_path"] + == "gpt2" + ): + from langchain.text_splitter import CharacterTextSplitter + from mindnlp.transformers import GPT2TokenizerFast + + tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") + else: # 字符长度加载 + from mindnlp.transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained( + text_splitter_dict[splitter_name]["tokenizer_name_or_path"], + trust_remote_code=True, + ) + text_splitter = TextSplitter.from_huggingface_tokenizer( + tokenizer=tokenizer, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + else: + try: + text_splitter = TextSplitter( + pipeline="zh_core_web_sm", + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + except: + text_splitter = TextSplitter( + chunk_size=chunk_size, chunk_overlap=chunk_overlap + ) + except Exception as e: + print(e) + text_splitter_module = importlib.import_module("langchain.text_splitter") + TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter") + text_splitter = TextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) + + # If you use SpacyTextSplitter you can use GPU to do split likes Issue #1287 + # text_splitter._tokenizer.max_length = 37016792 + # text_splitter._tokenizer.prefer_gpu() + return text_splitter + + +class KnowledgeFile: + def __init__( + self, + filename: str, + knowledge_base_name: str, + loader_kwargs: Dict = {}, + ): + """ + 对应知识库目录中的文件,必须是磁盘上存在的才能进行向量化等操作。 + """ + self.kb_name = knowledge_base_name + self.filename = str(Path(filename).as_posix()) + self.ext = os.path.splitext(filename)[-1].lower() + if self.ext not in SUPPORTED_EXTS: + raise ValueError(f"暂未支持的文件格式 {self.filename}") + self.loader_kwargs = loader_kwargs + self.filepath = get_file_path(knowledge_base_name, filename) + self.docs = None + self.splited_docs = None + self.document_loader_name = get_LoaderClass(self.ext) + self.text_splitter_name = TEXT_SPLITTER_NAME + + def file2docs(self, refresh: bool = False): + if self.docs is None or refresh: + logger.info(f"{self.document_loader_name} used for {self.filepath}") + loader = get_loader( + loader_name=self.document_loader_name, + file_path=self.filepath, + loader_kwargs=self.loader_kwargs, + ) + if isinstance(loader, TextLoader): + loader.encoding = "utf8" + self.docs = loader.load() + else: + self.docs = loader.load() + return self.docs + + def docs2texts( + self, + docs: List[Document] = None, + zh_title_enhance: bool = ZH_TITLE_ENHANCE, + refresh: bool = False, + chunk_size: int = CHUNK_SIZE, + chunk_overlap: int = OVERLAP_SIZE, + text_splitter: TextSplitter = None, + ): + docs = docs or self.file2docs(refresh=refresh) + if not docs: + return [] + if self.ext not in [".csv"]: + if text_splitter is None: + text_splitter = make_text_splitter( + splitter_name=self.text_splitter_name, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + if self.text_splitter_name == "MarkdownHeaderTextSplitter": + docs = text_splitter.split_text(docs[0].page_content) + else: + docs = text_splitter.split_documents(docs) + + if not docs: + return [] + + print(f"文档切分示例:{docs[0]}") + if zh_title_enhance: + docs = func_zh_title_enhance(docs) + self.splited_docs = docs + return self.splited_docs + + def file2text( + self, + zh_title_enhance: bool = ZH_TITLE_ENHANCE, + refresh: bool = False, + chunk_size: int = CHUNK_SIZE, + chunk_overlap: int = OVERLAP_SIZE, + text_splitter: TextSplitter = None, + ): + if self.splited_docs is None or refresh: + docs = self.file2docs() + self.splited_docs = self.docs2texts( + docs=docs, + zh_title_enhance=zh_title_enhance, + refresh=refresh, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + text_splitter=text_splitter, + ) + return self.splited_docs + + def file_exist(self): + return os.path.isfile(self.filepath) + + def get_mtime(self): + return os.path.getmtime(self.filepath) + + def get_size(self): + return os.path.getsize(self.filepath) + + +def files2docs_in_thread_file2docs( + *, file: KnowledgeFile, **kwargs +) -> Tuple[bool, Tuple[str, str, List[Document]]]: + try: + return True, (file.kb_name, file.filename, file.file2text(**kwargs)) + except Exception as e: + msg = f"从文件 {file.kb_name}/{file.filename} 加载文档时出错:{e}" + logger.error( + f"{e.__class__.__name__}: {msg}", exc_info=e + ) + return False, (file.kb_name, file.filename, msg) + + +def files2docs_in_thread( + files: List[Union[KnowledgeFile, Tuple[str, str], Dict]], + chunk_size: int = CHUNK_SIZE, + chunk_overlap: int = OVERLAP_SIZE, + zh_title_enhance: bool = ZH_TITLE_ENHANCE, +) -> Generator: + """ + 利用多线程批量将磁盘文件转化成langchain Document. + 如果传入参数是Tuple,形式为(filename, kb_name) + 生成器返回值为 status, (kb_name, file_name, docs | error) + """ + + kwargs_list = [] + for i, file in enumerate(files): + kwargs = {} + try: + if isinstance(file, tuple) and len(file) >= 2: + filename = file[0] + kb_name = file[1] + file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name) + elif isinstance(file, dict): + filename = file.pop("filename") + kb_name = file.pop("kb_name") + kwargs.update(file) + file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name) + kwargs["file"] = file + kwargs["chunk_size"] = chunk_size + kwargs["chunk_overlap"] = chunk_overlap + kwargs["zh_title_enhance"] = zh_title_enhance + kwargs_list.append(kwargs) + except Exception as e: + yield False, (kb_name, filename, str(e)) + + for result in run_in_thread_pool( + func=files2docs_in_thread_file2docs, params=kwargs_list + ): + yield result + + +if __name__ == "__main__": + from pprint import pprint + + kb_file = KnowledgeFile( + filename="E:\\LLM\\Data\\Test.md", knowledge_base_name="samples" + ) + # kb_file.text_splitter_name = "RecursiveCharacterTextSplitter" + kb_file.text_splitter_name = "MarkdownHeaderTextSplitter" + docs = kb_file.file2docs() + # pprint(docs[-1]) + texts = kb_file.docs2texts(docs) + for text in texts: + print(text) diff --git a/src/mindpilot/app/utils/system_utils.py b/src/mindpilot/app/utils/system_utils.py index 9f6469e..ff922e4 100644 --- a/src/mindpilot/app/utils/system_utils.py +++ b/src/mindpilot/app/utils/system_utils.py @@ -1,12 +1,7 @@ import asyncio import logging -import multiprocessing as mp -import os -import socket import sqlite3 -import sys -from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed -from pathlib import Path +from concurrent.futures import ThreadPoolExecutor, as_completed from typing import ( Any, Awaitable, @@ -14,29 +9,14 @@ from typing import ( Dict, Generator, List, - Literal, - Optional, - Tuple, Union, ) -import httpx -import openai -from fastapi import FastAPI from langchain.tools import BaseTool -from langchain_core.embeddings import Embeddings from langchain_openai.chat_models import ChatOpenAI -from langchain_openai.llms import OpenAI - -# from chatchat.configs import ( -# DEFAULT_EMBEDDING_MODEL, -# DEFAULT_LLM_MODEL, -# HTTPX_DEFAULT_TIMEOUT, -# MODEL_PLATFORMS, -# TEMPERATURE, -# log_verbose, -# ) from .pydantic_v2 import BaseModel, Field +from ..configs import DEFAULT_EMBEDDING_MODEL +from langchain_core.embeddings import Embeddings logger = logging.getLogger() @@ -209,7 +189,60 @@ class ListResponse(BaseResponse): } } + def get_mindpilot_db_connection(): conn = sqlite3.connect('mindpilot.db') conn.row_factory = sqlite3.Row return conn + + +def run_in_thread_pool( + func: Callable, + params: List[Dict] = [], +) -> Generator: + """ + 在线程池中批量运行任务,并将运行结果以生成器的形式返回。 + 请确保任务中的所有操作是线程安全的,任务函数请全部使用关键字参数。 + """ + tasks = [] + with ThreadPoolExecutor() as pool: + for kwargs in params: + tasks.append(pool.submit(func, **kwargs)) + + for obj in as_completed(tasks): + try: + yield obj.result() + except Exception as e: + logger.error(f"error in sub thread: {e}", exc_info=True) + +def get_Embeddings( + embed_model: str = DEFAULT_EMBEDDING_MODEL, +) -> Embeddings: + + from ..knowledge_base.embedding.localai_embeddings import ( + LocalAIEmbeddings, + ) + + params = dict(model=embed_model) + try: + params.update( + openai_api_base=f"http://127.0.0.1:7890/v1", + openai_api_key="EMPTY", + ) + return LocalAIEmbeddings(**params) + except Exception as e: + logger.error( + f"failed to create Embeddings for model: {embed_model}.", exc_info=True + ) + + +def check_embed_model(embed_model: str = DEFAULT_EMBEDDING_MODEL) -> bool: + embeddings = get_Embeddings(embed_model=embed_model) + try: + embeddings.embed_query("this is a test") + return True + except Exception as e: + logger.error( + f"failed to access embed model '{embed_model}': {e}", exc_info=True + ) + return False diff --git a/src/mindpilot/main.py b/src/mindpilot/main.py index 338b226..2336bb7 100644 --- a/src/mindpilot/main.py +++ b/src/mindpilot/main.py @@ -120,6 +120,10 @@ def main(): print亮蓝(f"当前工作目录:{cwd}") print亮蓝(f"OpenAPI 文档地址:http://{HOST}:{PORT}/docs") + from app.knowledge_base.migrate import create_tables + + create_tables() + if sys.version_info < (3, 10): loop = asyncio.get_event_loop() else: