| @@ -21,4 +21,12 @@ metaphor_python==0.1.23 | |||
| langchainhub==0.1.20 | |||
| langchain-experimental==0.0.62 | |||
| arxiv==2.1.3 | |||
| spark_ai_python~=0.4.4 | |||
| spark_ai_python~=0.4.4 | |||
| colorama~=0.4.6 | |||
| tqdm~=4.66.4 | |||
| numpy~=1.26.4 | |||
| chardet~=5.2.0 | |||
| elasticsearch~=8.15.0 | |||
| tenacity~=8.5.0 | |||
| pymilvus~=2.4.5 | |||
| python-dateutil~=2.9.0post0 | |||
| @@ -7,6 +7,7 @@ from .tool_routes import tool_router | |||
| from .agent_routes import agent_router | |||
| from .config_routes import config_router | |||
| from .conversation_routes import conversation_router | |||
| from .kb_routes import kb_router | |||
| def create_app(run_mode: str = None): | |||
| @@ -28,5 +29,6 @@ def create_app(run_mode: str = None): | |||
| app.include_router(agent_router) | |||
| app.include_router(config_router) | |||
| app.include_router(conversation_router) | |||
| app.include_router(kb_router) | |||
| return app | |||
| @@ -0,0 +1,92 @@ | |||
| from __future__ import annotations | |||
| from typing import List | |||
| from fastapi import APIRouter, Request | |||
| # from chatchat.server.chat.file_chat import upload_temp_docs | |||
| from ..knowledge_base.kb_api import create_kb, delete_kb, list_kbs | |||
| from ..knowledge_base.kb_doc_api import ( | |||
| delete_docs, | |||
| download_doc, | |||
| list_files, | |||
| recreate_vector_store, | |||
| search_docs, | |||
| update_docs, | |||
| update_info, | |||
| upload_docs, | |||
| ) | |||
| # from chatchat.server.knowledge_base.kb_summary_api import ( | |||
| # recreate_summary_vector_store, | |||
| # summary_doc_ids_to_vector_store, | |||
| # summary_file_to_vector_store, | |||
| # ) | |||
| from ..utils.system_utils import BaseResponse, ListResponse | |||
| kb_router = APIRouter(prefix="/knowledge_base", tags=["知识库管理"]) | |||
| kb_router.get( | |||
| "/list_knowledge_bases", response_model=ListResponse, summary="获取知识库列表" | |||
| )(list_kbs) | |||
| kb_router.post( | |||
| "/create_knowledge_base", response_model=BaseResponse, summary="创建知识库" | |||
| )(create_kb) | |||
| kb_router.post( | |||
| "/delete_knowledge_base", response_model=BaseResponse, summary="删除知识库" | |||
| )(delete_kb) | |||
| kb_router.get( | |||
| "/list_files", response_model=ListResponse, summary="获取知识库内的文件列表" | |||
| )(list_files) | |||
| kb_router.post( | |||
| "/search_docs", response_model=List[dict], summary="搜索知识库" | |||
| )(search_docs) | |||
| kb_router.post( | |||
| "/upload_docs", | |||
| response_model=BaseResponse, | |||
| summary="上传文件到知识库,并/或进行向量化", | |||
| )(upload_docs) | |||
| kb_router.post( | |||
| "/delete_docs", response_model=BaseResponse, summary="删除知识库内指定文件" | |||
| )(delete_docs) | |||
| kb_router.post( | |||
| "/update_info", response_model=BaseResponse, summary="更新知识库介绍" | |||
| )(update_info) | |||
| kb_router.post( | |||
| "/update_docs", response_model=BaseResponse, summary="更新现有文件到知识库" | |||
| )(update_docs) | |||
| kb_router.get( | |||
| "/download_doc", summary="下载对应的知识文件" | |||
| )(download_doc) | |||
| kb_router.post( | |||
| "/recreate_vector_store", summary="根据content中文档重建向量库,流式输出处理进度。" | |||
| )(recreate_vector_store) | |||
| # kb_router.post( | |||
| # "/upload_temp_docs", summary="上传文件到临时目录,用于文件对话。")(upload_temp_docs) | |||
| # | |||
| # | |||
| # summary_router = APIRouter(prefix="/kb_summary_api") | |||
| # summary_router.post( | |||
| # "/summary_file_to_vector_store", summary="单个知识库根据文件名称摘要" | |||
| # )(summary_file_to_vector_store) | |||
| # | |||
| # summary_router.post( | |||
| # "/summary_doc_ids_to_vector_store", | |||
| # summary="单个知识库根据doc_ids摘要", | |||
| # response_model=BaseResponse, | |||
| # )(summary_doc_ids_to_vector_store) | |||
| # | |||
| # summary_router.post( | |||
| # "/recreate_summary_vector_store", summary="重建单个知识库文件摘要" | |||
| # )(recreate_summary_vector_store) | |||
| # | |||
| # kb_router.include_router(summary_router) | |||
| @@ -2,6 +2,7 @@ from .system_config import HOST, PORT | |||
| from .model_config import MODEL_CONFIG | |||
| from .prompt_config import OPENAI_PROMPT, PROMPT_TEMPLATES | |||
| from .tool_config import TOOL_CONFIG | |||
| from .kb_config import * | |||
| __all__ = [ | |||
| "HOST", | |||
| @@ -9,5 +10,27 @@ __all__ = [ | |||
| "MODEL_CONFIG", | |||
| "OPENAI_PROMPT", | |||
| "TOOL_CONFIG", | |||
| "PROMPT_TEMPLATES" | |||
| ] | |||
| "PROMPT_TEMPLATES", | |||
| "DEFAULT_KNOWLEDGE_BASE", | |||
| "DEFAULT_VS_TYPE", | |||
| "CACHED_VS_NUM", | |||
| "CACHED_MEMO_VS_NUM", | |||
| "CHUNK_SIZE", | |||
| "OVERLAP_SIZE", | |||
| "VECTOR_SEARCH_TOP_K", | |||
| "SCORE_THRESHOLD", | |||
| "DEFAULT_SEARCH_ENGINE", | |||
| "SEARCH_ENGINE_TOP_K", | |||
| "ZH_TITLE_ENHANCE", | |||
| # "PDF_OCR_THRESHOLD", | |||
| "KB_INFO", | |||
| "CHATCHAT_ROOT", | |||
| "KB_ROOT_PATH", | |||
| "DB_ROOT_PATH", | |||
| "SQLALCHEMY_DATABASE_URI", | |||
| "kbs_config", | |||
| "text_splitter_dict", | |||
| "TEXT_SPLITTER_NAME", | |||
| "EMBEDDING_KEYWORD_FILE", | |||
| "DEFAULT_EMBEDDING_MODEL", | |||
| ] | |||
| @@ -0,0 +1,113 @@ | |||
| import os | |||
| from pathlib import Path | |||
| DEFAULT_KNOWLEDGE_BASE = "samples" | |||
| # 默认向量库/全文检索引擎类型。 | |||
| DEFAULT_VS_TYPE = "faiss" | |||
| # 缓存向量库数量(针对FAISS) | |||
| CACHED_VS_NUM = 1 | |||
| # 缓存临时向量库数量(针对FAISS),用于文件对话 | |||
| CACHED_MEMO_VS_NUM = 10 | |||
| # 知识库中单段文本长度(不适用MarkdownHeaderTextSplitter) | |||
| CHUNK_SIZE = 250 | |||
| # 知识库中相邻文本重合长度(不适用MarkdownHeaderTextSplitter) | |||
| OVERLAP_SIZE = 50 | |||
| # 知识库匹配向量数量 | |||
| VECTOR_SEARCH_TOP_K = 3 | |||
| # 知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右 | |||
| SCORE_THRESHOLD = 1 | |||
| # 默认搜索引擎。可选:bing, duckduckgo, metaphor | |||
| DEFAULT_SEARCH_ENGINE = "bing" | |||
| # 搜索引擎匹配结题数量 | |||
| SEARCH_ENGINE_TOP_K = 3 | |||
| # 是否开启中文标题加强,以及标题增强的相关配置 | |||
| # 通过增加标题判断,判断哪些文本为标题,并在metadata中进行标记; | |||
| # 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。 | |||
| ZH_TITLE_ENHANCE = False | |||
| # PDF OCR 控制:只对宽高超过页面一定比例(图片宽/页面宽,图片高/页面高)的图片进行 OCR。 | |||
| # 这样可以避免 PDF 中一些小图片的干扰,提高非扫描版 PDF 处理速度 | |||
| PDF_OCR_THRESHOLD = (0.6, 0.6) | |||
| # 每个知识库的初始化介绍,用于在初始化知识库时显示和Agent调用,没写则没有介绍,不会被Agent调用。 | |||
| KB_INFO = { | |||
| "samples": "关于本项目issue的解答", | |||
| } | |||
| CHATCHAT_ROOT = str(Path(__file__).absolute().parent.parent.parent) | |||
| KB_ROOT_PATH = os.path.join(CHATCHAT_ROOT, "knowledge_base") | |||
| # 数据库默认存储路径。 | |||
| DB_ROOT_PATH = os.path.join(CHATCHAT_ROOT, "mindpilot.db") | |||
| SQLALCHEMY_DATABASE_URI = f"sqlite:///{DB_ROOT_PATH}" | |||
| # 可选向量库类型及对应配置 | |||
| kbs_config = { | |||
| "faiss": {}, | |||
| "milvus": { | |||
| "host": "127.0.0.1", | |||
| "port": "19530", | |||
| "user": "", | |||
| "password": "", | |||
| "secure": False, | |||
| }, | |||
| "es": { | |||
| "host": "127.0.0.1", | |||
| "port": "9200", | |||
| "index_name": "test_index", | |||
| "user": "", | |||
| "password": "", | |||
| }, | |||
| "milvus_kwargs": { | |||
| "search_params": {"metric_type": "L2"}, # 在此处增加search_params | |||
| "index_params": { | |||
| "metric_type": "L2", | |||
| "index_type": "HNSW", | |||
| "params": {"M": 8, "efConstruction": 64}, | |||
| }, # 在此处增加index_params | |||
| }, | |||
| } | |||
| # TextSplitter配置项,如果你不明白其中的含义,就不要修改。 | |||
| text_splitter_dict = { | |||
| "ChineseRecursiveTextSplitter": { | |||
| "source": "", # 选择tiktoken则使用openai的方法 "huggingface" | |||
| "tokenizer_name_or_path": "", | |||
| }, | |||
| "SpacyTextSplitter": { | |||
| "source": "huggingface", | |||
| "tokenizer_name_or_path": "gpt2", | |||
| }, | |||
| "RecursiveCharacterTextSplitter": { | |||
| "source": "tiktoken", | |||
| "tokenizer_name_or_path": "cl100k_base", | |||
| }, | |||
| "MarkdownHeaderTextSplitter": { | |||
| "headers_to_split_on": [ | |||
| ("#", "head1"), | |||
| ("##", "head2"), | |||
| ("###", "head3"), | |||
| ("####", "head4"), | |||
| ] | |||
| }, | |||
| } | |||
| # TEXT_SPLITTER 名称 | |||
| TEXT_SPLITTER_NAME = "ChineseRecursiveTextSplitter" | |||
| # Embedding模型定制词语的词表文件 | |||
| EMBEDDING_KEYWORD_FILE = "embedding_keywords.txt" | |||
| DEFAULT_EMBEDDING_MODEL = "bce-embedding-base_v1" | |||
| @@ -0,0 +1,17 @@ | |||
| import json | |||
| from sqlalchemy import create_engine | |||
| from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base | |||
| from sqlalchemy.orm import sessionmaker | |||
| from ...configs.kb_config import SQLALCHEMY_DATABASE_URI | |||
| engine = create_engine( | |||
| SQLALCHEMY_DATABASE_URI, | |||
| json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False), | |||
| ) | |||
| SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) | |||
| Base: DeclarativeMeta = declarative_base() | |||
| @@ -0,0 +1,17 @@ | |||
| from datetime import datetime | |||
| from sqlalchemy import Column, DateTime, Integer, String | |||
| class BaseModel: | |||
| """ | |||
| 基础模型 | |||
| """ | |||
| id = Column(Integer, primary_key=True, index=True, comment="主键ID") | |||
| create_time = Column(DateTime, default=datetime.utcnow, comment="创建时间") | |||
| update_time = Column( | |||
| DateTime, default=None, onupdate=datetime.utcnow, comment="更新时间" | |||
| ) | |||
| create_by = Column(String, default=None, comment="创建者") | |||
| update_by = Column(String, default=None, comment="更新者") | |||
| @@ -0,0 +1,39 @@ | |||
| from datetime import datetime | |||
| from typing import Optional | |||
| from pydantic import BaseModel | |||
| from sqlalchemy import Column, DateTime, Integer, String, func | |||
| from ..base import Base | |||
| class KnowledgeBaseModel(Base): | |||
| """ | |||
| 知识库模型 | |||
| """ | |||
| __tablename__ = "knowledge_base" | |||
| id = Column(Integer, primary_key=True, autoincrement=True, comment="知识库ID") | |||
| kb_name = Column(String(50), comment="知识库名称") | |||
| kb_info = Column(String(200), comment="知识库简介(用于Agent)") | |||
| vs_type = Column(String(50), comment="向量库类型") | |||
| embed_model = Column(String(50), comment="嵌入模型名称") | |||
| file_count = Column(Integer, default=0, comment="文件数量") | |||
| create_time = Column(DateTime, default=func.now(), comment="创建时间") | |||
| def __repr__(self): | |||
| return f"<KnowledgeBase(id='{self.id}', kb_name='{self.kb_name}',kb_intro='{self.kb_info} vs_type='{self.vs_type}', embed_model='{self.embed_model}', file_count='{self.file_count}', create_time='{self.create_time}')>" | |||
| # 创建一个对应的 Pydantic 模型 | |||
| class KnowledgeBaseSchema(BaseModel): | |||
| id: int | |||
| kb_name: str | |||
| kb_info: Optional[str] | |||
| vs_type: Optional[str] | |||
| embed_model: Optional[str] | |||
| file_count: Optional[int] | |||
| create_time: Optional[datetime] | |||
| class Config: | |||
| from_attributes = True # 确保可以从 ORM 实例进行验证 | |||
| @@ -0,0 +1,42 @@ | |||
| from sqlalchemy import JSON, Boolean, Column, DateTime, Float, Integer, String, func | |||
| from ..base import Base | |||
| class KnowledgeFileModel(Base): | |||
| """ | |||
| 知识文件模型 | |||
| """ | |||
| __tablename__ = "knowledge_file" | |||
| id = Column(Integer, primary_key=True, autoincrement=True, comment="知识文件ID") | |||
| file_name = Column(String(255), comment="文件名") | |||
| file_ext = Column(String(10), comment="文件扩展名") | |||
| kb_name = Column(String(50), comment="所属知识库名称") | |||
| document_loader_name = Column(String(50), comment="文档加载器名称") | |||
| text_splitter_name = Column(String(50), comment="文本分割器名称") | |||
| file_version = Column(Integer, default=1, comment="文件版本") | |||
| file_mtime = Column(Float, default=0.0, comment="文件修改时间") | |||
| file_size = Column(Integer, default=0, comment="文件大小") | |||
| custom_docs = Column(Boolean, default=False, comment="是否自定义docs") | |||
| docs_count = Column(Integer, default=0, comment="切分文档数量") | |||
| create_time = Column(DateTime, default=func.now(), comment="创建时间") | |||
| def __repr__(self): | |||
| return f"<KnowledgeFile(id='{self.id}', file_name='{self.file_name}', file_ext='{self.file_ext}', kb_name='{self.kb_name}', document_loader_name='{self.document_loader_name}', text_splitter_name='{self.text_splitter_name}', file_version='{self.file_version}', create_time='{self.create_time}')>" | |||
| class FileDocModel(Base): | |||
| """ | |||
| 文件-向量库文档模型 | |||
| """ | |||
| __tablename__ = "file_doc" | |||
| id = Column(Integer, primary_key=True, autoincrement=True, comment="ID") | |||
| kb_name = Column(String(50), comment="知识库名称") | |||
| file_name = Column(String(255), comment="文件名称") | |||
| doc_id = Column(String(50), comment="向量库文档ID") | |||
| meta_data = Column(JSON, default={}) | |||
| def __repr__(self): | |||
| return f"<FileDoc(id='{self.id}', kb_name='{self.kb_name}', file_name='{self.file_name}', doc_id='{self.doc_id}', metadata='{self.meta_data}')>" | |||
| @@ -0,0 +1,31 @@ | |||
| from sqlalchemy import JSON, Boolean, Column, DateTime, Float, Integer, String, func | |||
| from ..base import Base | |||
| class SummaryChunkModel(Base): | |||
| """ | |||
| chunk summary模型,用于存储file_doc中每个doc_id的chunk 片段, | |||
| 数据来源: | |||
| 用户输入: 用户上传文件,可填写文件的描述,生成的file_doc中的doc_id,存入summary_chunk中 | |||
| 程序自动切分 对file_doc表meta_data字段信息中存储的页码信息,按每页的页码切分,自定义prompt生成总结文本,将对应页码关联的doc_id存入summary_chunk中 | |||
| 后续任务: | |||
| 矢量库构建: 对数据库表summary_chunk中summary_context创建索引,构建矢量库,meta_data为矢量库的元数据(doc_ids) | |||
| 语义关联: 通过用户输入的描述,自动切分的总结文本,计算 | |||
| 语义相似度 | |||
| """ | |||
| __tablename__ = "summary_chunk" | |||
| id = Column(Integer, primary_key=True, autoincrement=True, comment="ID") | |||
| kb_name = Column(String(50), comment="知识库名称") | |||
| summary_context = Column(String(255), comment="总结文本") | |||
| summary_id = Column(String(255), comment="总结矢量id") | |||
| doc_ids = Column(String(1024), comment="向量库id关联列表") | |||
| meta_data = Column(JSON, default={}) | |||
| def __repr__(self): | |||
| return ( | |||
| f"<SummaryChunk(id='{self.id}', kb_name='{self.kb_name}', summary_context='{self.summary_context}'," | |||
| f" doc_ids='{self.doc_ids}', metadata='{self.metadata}')>" | |||
| ) | |||
| @@ -0,0 +1,3 @@ | |||
| from .knowledge_base_repository import * | |||
| from .knowledge_file_repository import * | |||
| @@ -0,0 +1,93 @@ | |||
| from ...db.models.knowledge_base_model import ( | |||
| KnowledgeBaseModel, | |||
| KnowledgeBaseSchema, | |||
| ) | |||
| from ...db.session import with_session | |||
| @with_session | |||
| def add_kb_to_db(session, kb_name, kb_info, vs_type, embed_model): | |||
| # 创建知识库实例 | |||
| kb = ( | |||
| session.query(KnowledgeBaseModel) | |||
| .filter(KnowledgeBaseModel.kb_name.ilike(kb_name)) | |||
| .first() | |||
| ) | |||
| if not kb: | |||
| kb = KnowledgeBaseModel( | |||
| kb_name=kb_name, kb_info=kb_info, vs_type=vs_type, embed_model=embed_model | |||
| ) | |||
| session.add(kb) | |||
| else: # update kb with new vs_type and embed_model | |||
| kb.kb_info = kb_info | |||
| kb.vs_type = vs_type | |||
| kb.embed_model = embed_model | |||
| return True | |||
| @with_session | |||
| def list_kbs_from_db(session, min_file_count: int = -1): | |||
| kbs = ( | |||
| session.query(KnowledgeBaseModel) | |||
| .filter(KnowledgeBaseModel.file_count > min_file_count) | |||
| .all() | |||
| ) | |||
| kbs = [KnowledgeBaseSchema.model_validate(kb) for kb in kbs] | |||
| return kbs | |||
| @with_session | |||
| def kb_exists(session, kb_name): | |||
| kb = ( | |||
| session.query(KnowledgeBaseModel) | |||
| .filter(KnowledgeBaseModel.kb_name.ilike(kb_name)) | |||
| .first() | |||
| ) | |||
| status = True if kb else False | |||
| return status | |||
| @with_session | |||
| def load_kb_from_db(session, kb_name): | |||
| kb = ( | |||
| session.query(KnowledgeBaseModel) | |||
| .filter(KnowledgeBaseModel.kb_name.ilike(kb_name)) | |||
| .first() | |||
| ) | |||
| if kb: | |||
| kb_name, vs_type, embed_model = kb.kb_name, kb.vs_type, kb.embed_model | |||
| else: | |||
| kb_name, vs_type, embed_model = None, None, None | |||
| return kb_name, vs_type, embed_model | |||
| @with_session | |||
| def delete_kb_from_db(session, kb_name): | |||
| kb = ( | |||
| session.query(KnowledgeBaseModel) | |||
| .filter(KnowledgeBaseModel.kb_name.ilike(kb_name)) | |||
| .first() | |||
| ) | |||
| if kb: | |||
| session.delete(kb) | |||
| return True | |||
| @with_session | |||
| def get_kb_detail(session, kb_name: str) -> dict: | |||
| kb: KnowledgeBaseModel = ( | |||
| session.query(KnowledgeBaseModel) | |||
| .filter(KnowledgeBaseModel.kb_name.ilike(kb_name)) | |||
| .first() | |||
| ) | |||
| if kb: | |||
| return { | |||
| "kb_name": kb.kb_name, | |||
| "kb_info": kb.kb_info, | |||
| "vs_type": kb.vs_type, | |||
| "embed_model": kb.embed_model, | |||
| "file_count": kb.file_count, | |||
| "create_time": kb.create_time, | |||
| } | |||
| else: | |||
| return {} | |||
| @@ -0,0 +1,245 @@ | |||
| from typing import Dict, List | |||
| from ...db.models.knowledge_base_model import KnowledgeBaseModel | |||
| from ...db.models.knowledge_file_model import ( | |||
| FileDocModel, | |||
| KnowledgeFileModel, | |||
| ) | |||
| from ...db.session import with_session | |||
| from ...utils import KnowledgeFile | |||
| @with_session | |||
| def list_file_num_docs_id_by_kb_name_and_file_name( | |||
| session, | |||
| kb_name: str, | |||
| file_name: str, | |||
| ) -> List[int]: | |||
| """ | |||
| 列出某知识库某文件对应的所有Document的id。 | |||
| 返回形式:[str, ...] | |||
| """ | |||
| doc_ids = ( | |||
| session.query(FileDocModel.doc_id) | |||
| .filter_by(kb_name=kb_name, file_name=file_name) | |||
| .all() | |||
| ) | |||
| return [int(_id[0]) for _id in doc_ids] | |||
| @with_session | |||
| def list_docs_from_db( | |||
| session, | |||
| kb_name: str, | |||
| file_name: str = None, | |||
| metadata: Dict = {}, | |||
| ) -> List[Dict]: | |||
| """ | |||
| 列出某知识库某文件对应的所有Document。 | |||
| 返回形式:[{"id": str, "metadata": dict}, ...] | |||
| """ | |||
| docs = session.query(FileDocModel).filter(FileDocModel.kb_name.ilike(kb_name)) | |||
| if file_name: | |||
| docs = docs.filter(FileDocModel.file_name.ilike(file_name)) | |||
| for k, v in metadata.items(): | |||
| docs = docs.filter(FileDocModel.meta_data[k].as_string() == str(v)) | |||
| return [{"id": x.doc_id, "metadata": x.metadata} for x in docs.all()] | |||
| @with_session | |||
| def delete_docs_from_db( | |||
| session, | |||
| kb_name: str, | |||
| file_name: str = None, | |||
| ) -> List[Dict]: | |||
| """ | |||
| 删除某知识库某文件对应的所有Document,并返回被删除的Document。 | |||
| 返回形式:[{"id": str, "metadata": dict}, ...] | |||
| """ | |||
| docs = list_docs_from_db(kb_name=kb_name, file_name=file_name) | |||
| query = session.query(FileDocModel).filter(FileDocModel.kb_name.ilike(kb_name)) | |||
| if file_name: | |||
| query = query.filter(FileDocModel.file_name.ilike(file_name)) | |||
| query.delete(synchronize_session=False) | |||
| session.commit() | |||
| return docs | |||
| @with_session | |||
| def add_docs_to_db(session, kb_name: str, file_name: str, doc_infos: List[Dict]): | |||
| """ | |||
| 将某知识库某文件对应的所有Document信息添加到数据库。 | |||
| doc_infos形式:[{"id": str, "metadata": dict}, ...] | |||
| """ | |||
| # ! 这里会出现doc_infos为None的情况,需要进一步排查 | |||
| if doc_infos is None: | |||
| print( | |||
| "输入的server.db.repository.knowledge_file_repository.add_docs_to_db的doc_infos参数为None" | |||
| ) | |||
| return False | |||
| for d in doc_infos: | |||
| obj = FileDocModel( | |||
| kb_name=kb_name, | |||
| file_name=file_name, | |||
| doc_id=d["id"], | |||
| meta_data=d["metadata"], | |||
| ) | |||
| session.add(obj) | |||
| return True | |||
| @with_session | |||
| def count_files_from_db(session, kb_name: str) -> int: | |||
| return ( | |||
| session.query(KnowledgeFileModel) | |||
| .filter(KnowledgeFileModel.kb_name.ilike(kb_name)) | |||
| .count() | |||
| ) | |||
| @with_session | |||
| def list_files_from_db(session, kb_name): | |||
| files = ( | |||
| session.query(KnowledgeFileModel) | |||
| .filter(KnowledgeFileModel.kb_name.ilike(kb_name)) | |||
| .all() | |||
| ) | |||
| docs = [f.file_name for f in files] | |||
| return docs | |||
| @with_session | |||
| def add_file_to_db( | |||
| session, | |||
| kb_file: KnowledgeFile, | |||
| docs_count: int = 0, | |||
| custom_docs: bool = False, | |||
| doc_infos: List[Dict] = [], # 形式:[{"id": str, "metadata": dict}, ...] | |||
| ): | |||
| kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first() | |||
| if kb: | |||
| # 如果已经存在该文件,则更新文件信息与版本号 | |||
| existing_file: KnowledgeFileModel = ( | |||
| session.query(KnowledgeFileModel) | |||
| .filter( | |||
| KnowledgeFileModel.kb_name.ilike(kb_file.kb_name), | |||
| KnowledgeFileModel.file_name.ilike(kb_file.filename), | |||
| ) | |||
| .first() | |||
| ) | |||
| mtime = kb_file.get_mtime() | |||
| size = kb_file.get_size() | |||
| if existing_file: | |||
| existing_file.file_mtime = mtime | |||
| existing_file.file_size = size | |||
| existing_file.docs_count = docs_count | |||
| existing_file.custom_docs = custom_docs | |||
| existing_file.file_version += 1 | |||
| # 否则,添加新文件 | |||
| else: | |||
| new_file = KnowledgeFileModel( | |||
| file_name=kb_file.filename, | |||
| file_ext=kb_file.ext, | |||
| kb_name=kb_file.kb_name, | |||
| document_loader_name=kb_file.document_loader_name, | |||
| text_splitter_name=kb_file.text_splitter_name or "SpacyTextSplitter", | |||
| file_mtime=mtime, | |||
| file_size=size, | |||
| docs_count=docs_count, | |||
| custom_docs=custom_docs, | |||
| ) | |||
| kb.file_count += 1 | |||
| session.add(new_file) | |||
| add_docs_to_db( | |||
| kb_name=kb_file.kb_name, file_name=kb_file.filename, doc_infos=doc_infos | |||
| ) | |||
| return True | |||
| @with_session | |||
| def delete_file_from_db(session, kb_file: KnowledgeFile): | |||
| existing_file = ( | |||
| session.query(KnowledgeFileModel) | |||
| .filter( | |||
| KnowledgeFileModel.file_name.ilike(kb_file.filename), | |||
| KnowledgeFileModel.kb_name.ilike(kb_file.kb_name), | |||
| ) | |||
| .first() | |||
| ) | |||
| if existing_file: | |||
| session.delete(existing_file) | |||
| delete_docs_from_db(kb_name=kb_file.kb_name, file_name=kb_file.filename) | |||
| session.commit() | |||
| kb = ( | |||
| session.query(KnowledgeBaseModel) | |||
| .filter(KnowledgeBaseModel.kb_name.ilike(kb_file.kb_name)) | |||
| .first() | |||
| ) | |||
| if kb: | |||
| kb.file_count -= 1 | |||
| session.commit() | |||
| return True | |||
| @with_session | |||
| def delete_files_from_db(session, knowledge_base_name: str): | |||
| session.query(KnowledgeFileModel).filter( | |||
| KnowledgeFileModel.kb_name.ilike(knowledge_base_name) | |||
| ).delete(synchronize_session=False) | |||
| session.query(FileDocModel).filter( | |||
| FileDocModel.kb_name.ilike(knowledge_base_name) | |||
| ).delete(synchronize_session=False) | |||
| kb = ( | |||
| session.query(KnowledgeBaseModel) | |||
| .filter(KnowledgeBaseModel.kb_name.ilike(knowledge_base_name)) | |||
| .first() | |||
| ) | |||
| if kb: | |||
| kb.file_count = 0 | |||
| session.commit() | |||
| return True | |||
| @with_session | |||
| def file_exists_in_db(session, kb_file: KnowledgeFile): | |||
| existing_file = ( | |||
| session.query(KnowledgeFileModel) | |||
| .filter( | |||
| KnowledgeFileModel.file_name.ilike(kb_file.filename), | |||
| KnowledgeFileModel.kb_name.ilike(kb_file.kb_name), | |||
| ) | |||
| .first() | |||
| ) | |||
| return True if existing_file else False | |||
| @with_session | |||
| def get_file_detail(session, kb_name: str, filename: str) -> dict: | |||
| file: KnowledgeFileModel = ( | |||
| session.query(KnowledgeFileModel) | |||
| .filter( | |||
| KnowledgeFileModel.file_name.ilike(filename), | |||
| KnowledgeFileModel.kb_name.ilike(kb_name), | |||
| ) | |||
| .first() | |||
| ) | |||
| if file: | |||
| return { | |||
| "kb_name": file.kb_name, | |||
| "file_name": file.file_name, | |||
| "file_ext": file.file_ext, | |||
| "file_version": file.file_version, | |||
| "document_loader": file.document_loader_name, | |||
| "text_splitter": file.text_splitter_name, | |||
| "create_time": file.create_time, | |||
| "file_mtime": file.file_mtime, | |||
| "file_size": file.file_size, | |||
| "custom_docs": file.custom_docs, | |||
| "docs_count": file.docs_count, | |||
| } | |||
| else: | |||
| return {} | |||
| @@ -0,0 +1,77 @@ | |||
| from typing import Dict, List | |||
| from ...db.models.knowledge_metadata_model import SummaryChunkModel | |||
| from ...db.session import with_session | |||
| @with_session | |||
| def list_summary_from_db( | |||
| session, | |||
| kb_name: str, | |||
| metadata: Dict = {}, | |||
| ) -> List[Dict]: | |||
| """ | |||
| 列出某知识库chunk summary。 | |||
| 返回形式:[{"id": str, "summary_context": str, "doc_ids": str}, ...] | |||
| """ | |||
| docs = session.query(SummaryChunkModel).filter( | |||
| SummaryChunkModel.kb_name.ilike(kb_name) | |||
| ) | |||
| for k, v in metadata.items(): | |||
| docs = docs.filter(SummaryChunkModel.meta_data[k].as_string() == str(v)) | |||
| return [ | |||
| { | |||
| "id": x.id, | |||
| "summary_context": x.summary_context, | |||
| "summary_id": x.summary_id, | |||
| "doc_ids": x.doc_ids, | |||
| "metadata": x.metadata, | |||
| } | |||
| for x in docs.all() | |||
| ] | |||
| @with_session | |||
| def delete_summary_from_db(session, kb_name: str) -> List[Dict]: | |||
| """ | |||
| 删除知识库chunk summary,并返回被删除的Dchunk summary。 | |||
| 返回形式:[{"id": str, "summary_context": str, "doc_ids": str}, ...] | |||
| """ | |||
| docs = list_summary_from_db(kb_name=kb_name) | |||
| query = session.query(SummaryChunkModel).filter( | |||
| SummaryChunkModel.kb_name.ilike(kb_name) | |||
| ) | |||
| query.delete(synchronize_session=False) | |||
| session.commit() | |||
| return docs | |||
| @with_session | |||
| def add_summary_to_db(session, kb_name: str, summary_infos: List[Dict]): | |||
| """ | |||
| 将总结信息添加到数据库。 | |||
| summary_infos形式:[{"summary_context": str, "doc_ids": str}, ...] | |||
| """ | |||
| for summary in summary_infos: | |||
| obj = SummaryChunkModel( | |||
| kb_name=kb_name, | |||
| summary_context=summary["summary_context"], | |||
| summary_id=summary["summary_id"], | |||
| doc_ids=summary["doc_ids"], | |||
| meta_data=summary["metadata"], | |||
| ) | |||
| session.add(obj) | |||
| session.commit() | |||
| return True | |||
| @with_session | |||
| def count_summary_from_db(session, kb_name: str) -> int: | |||
| return ( | |||
| session.query(SummaryChunkModel) | |||
| .filter(SummaryChunkModel.kb_name.ilike(kb_name)) | |||
| .count() | |||
| ) | |||
| @@ -0,0 +1,48 @@ | |||
| from contextlib import contextmanager | |||
| from functools import wraps | |||
| from sqlalchemy.orm import Session | |||
| from .base import SessionLocal | |||
| @contextmanager | |||
| def session_scope() -> Session: | |||
| """上下文管理器用于自动获取 Session, 避免错误""" | |||
| session = SessionLocal() | |||
| try: | |||
| yield session | |||
| session.commit() | |||
| except: | |||
| session.rollback() | |||
| raise | |||
| finally: | |||
| session.close() | |||
| def with_session(f): | |||
| @wraps(f) | |||
| def wrapper(*args, **kwargs): | |||
| with session_scope() as session: | |||
| try: | |||
| result = f(session, *args, **kwargs) | |||
| session.commit() | |||
| return result | |||
| except: | |||
| session.rollback() | |||
| raise | |||
| return wrapper | |||
| def get_db() -> SessionLocal: | |||
| db = SessionLocal() | |||
| try: | |||
| yield db | |||
| finally: | |||
| db.close() | |||
| def get_db0() -> SessionLocal: | |||
| db = SessionLocal() | |||
| return db | |||
| @@ -0,0 +1,380 @@ | |||
| from __future__ import annotations | |||
| import logging | |||
| import warnings | |||
| from typing import ( | |||
| Any, | |||
| Callable, | |||
| Dict, | |||
| List, | |||
| Literal, | |||
| Optional, | |||
| Sequence, | |||
| Set, | |||
| Tuple, | |||
| Union, | |||
| ) | |||
| from langchain_community.utils.openai import is_openai_v1 | |||
| from langchain_core.embeddings import Embeddings | |||
| from langchain_core.pydantic_v1 import BaseModel, Field, root_validator | |||
| from langchain_core.utils import get_from_dict_or_env, get_pydantic_field_names | |||
| from tenacity import ( | |||
| AsyncRetrying, | |||
| before_sleep_log, | |||
| retry, | |||
| retry_if_exception_type, | |||
| stop_after_attempt, | |||
| wait_exponential, | |||
| ) | |||
| from ...utils.system_utils import run_in_thread_pool | |||
| logger = logging.getLogger(__name__) | |||
| def _create_retry_decorator(embeddings: LocalAIEmbeddings) -> Callable[[Any], Any]: | |||
| import openai | |||
| min_seconds = 4 | |||
| max_seconds = 10 | |||
| # Wait 2^x * 1 second between each retry starting with | |||
| # 4 seconds, then up to 10 seconds, then 10 seconds afterwards | |||
| return retry( | |||
| reraise=True, | |||
| stop=stop_after_attempt(embeddings.max_retries), | |||
| wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), | |||
| retry=( | |||
| retry_if_exception_type(openai.Timeout) | |||
| | retry_if_exception_type(openai.APIError) | |||
| | retry_if_exception_type(openai.APIConnectionError) | |||
| | retry_if_exception_type(openai.RateLimitError) | |||
| | retry_if_exception_type(openai.InternalServerError) | |||
| ), | |||
| before_sleep=before_sleep_log(logger, logging.WARNING), | |||
| ) | |||
| def _async_retry_decorator(embeddings: LocalAIEmbeddings) -> Any: | |||
| import openai | |||
| min_seconds = 4 | |||
| max_seconds = 10 | |||
| # Wait 2^x * 1 second between each retry starting with | |||
| # 4 seconds, then up to 10 seconds, then 10 seconds afterwards | |||
| async_retrying = AsyncRetrying( | |||
| reraise=True, | |||
| stop=stop_after_attempt(embeddings.max_retries), | |||
| wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), | |||
| retry=( | |||
| retry_if_exception_type(openai.Timeout) | |||
| | retry_if_exception_type(openai.APIError) | |||
| | retry_if_exception_type(openai.APIConnectionError) | |||
| | retry_if_exception_type(openai.RateLimitError) | |||
| | retry_if_exception_type(openai.InternalServerError) | |||
| ), | |||
| before_sleep=before_sleep_log(logger, logging.WARNING), | |||
| ) | |||
| def wrap(func: Callable) -> Callable: | |||
| async def wrapped_f(*args: Any, **kwargs: Any) -> Callable: | |||
| async for _ in async_retrying: | |||
| return await func(*args, **kwargs) | |||
| raise AssertionError("this is unreachable") | |||
| return wrapped_f | |||
| return wrap | |||
| # https://stackoverflow.com/questions/76469415/getting-embeddings-of-length-1-from-langchain-openaiembeddings | |||
| def _check_response(response: dict) -> dict: | |||
| if any([len(d.embedding) == 1 for d in response.data]): | |||
| import openai | |||
| raise openai.APIError("LocalAI API returned an empty embedding") | |||
| return response | |||
| def embed_with_retry(embeddings: LocalAIEmbeddings, **kwargs: Any) -> Any: | |||
| """Use tenacity to retry the embedding call.""" | |||
| retry_decorator = _create_retry_decorator(embeddings) | |||
| @retry_decorator | |||
| def _embed_with_retry(**kwargs: Any) -> Any: | |||
| response = embeddings.client.create(**kwargs) | |||
| return _check_response(response) | |||
| return _embed_with_retry(**kwargs) | |||
| async def async_embed_with_retry(embeddings: LocalAIEmbeddings, **kwargs: Any) -> Any: | |||
| """Use tenacity to retry the embedding call.""" | |||
| @_async_retry_decorator(embeddings) | |||
| async def _async_embed_with_retry(**kwargs: Any) -> Any: | |||
| response = await embeddings.async_client.create(**kwargs) | |||
| return _check_response(response) | |||
| return await _async_embed_with_retry(**kwargs) | |||
| class LocalAIEmbeddings(BaseModel, Embeddings): | |||
| """LocalAI embedding models. | |||
| Since LocalAI and OpenAI have 1:1 compatibility between APIs, this class | |||
| uses the ``openai`` Python package's ``openai.Embedding`` as its client. | |||
| Thus, you should have the ``openai`` python package installed, and defeat | |||
| the environment variable ``OPENAI_API_KEY`` by setting to a random string. | |||
| You also need to specify ``OPENAI_API_BASE`` to point to your LocalAI | |||
| service endpoint. | |||
| Example: | |||
| .. code-block:: python | |||
| from langchain_community.embeddings import LocalAIEmbeddings | |||
| openai = LocalAIEmbeddings( | |||
| openai_api_key="random-string", | |||
| openai_api_base="http://localhost:8080" | |||
| ) | |||
| """ | |||
| client: Any = Field(default=None, exclude=True) #: :meta private: | |||
| async_client: Any = Field(default=None, exclude=True) #: :meta private: | |||
| model: str = "text-embedding-ada-002" | |||
| deployment: str = model | |||
| openai_api_version: Optional[str] = None | |||
| openai_api_base: Optional[str] = Field(default=None, alias="base_url") | |||
| # to support explicit proxy for LocalAI | |||
| openai_proxy: Optional[str] = None | |||
| embedding_ctx_length: int = 8191 | |||
| """The maximum number of tokens to embed at once.""" | |||
| openai_api_key: Optional[str] = Field(default=None, alias="api_key") | |||
| openai_organization: Optional[str] = Field(default=None, alias="organization") | |||
| allowed_special: Union[Literal["all"], Set[str]] = set() | |||
| disallowed_special: Union[Literal["all"], Set[str], Sequence[str]] = "all" | |||
| chunk_size: int = 1000 | |||
| """Maximum number of texts to embed in each batch""" | |||
| max_retries: int = 6 | |||
| """Maximum number of retries to make when generating.""" | |||
| request_timeout: Union[float, Tuple[float, float], Any, None] = Field( | |||
| default=None, alias="timeout" | |||
| ) | |||
| """Timeout in seconds for the LocalAI request.""" | |||
| headers: Any = None | |||
| show_progress_bar: bool = False | |||
| """Whether to show a progress bar when embedding.""" | |||
| model_kwargs: Dict[str, Any] = Field(default_factory=dict) | |||
| """Holds any model parameters valid for `create` call not explicitly specified.""" | |||
| class Config: | |||
| """Configuration for this pydantic object.""" | |||
| allow_population_by_field_name = True | |||
| @root_validator(pre=True) | |||
| def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: | |||
| """Build extra kwargs from additional params that were passed in.""" | |||
| all_required_field_names = get_pydantic_field_names(cls) | |||
| extra = values.get("model_kwargs", {}) | |||
| for field_name in list(values): | |||
| if field_name in extra: | |||
| raise ValueError(f"Found {field_name} supplied twice.") | |||
| if field_name not in all_required_field_names: | |||
| warnings.warn( | |||
| f"""WARNING! {field_name} is not default parameter. | |||
| {field_name} was transferred to model_kwargs. | |||
| Please confirm that {field_name} is what you intended.""" | |||
| ) | |||
| extra[field_name] = values.pop(field_name) | |||
| invalid_model_kwargs = all_required_field_names.intersection(extra.keys()) | |||
| if invalid_model_kwargs: | |||
| raise ValueError( | |||
| f"Parameters {invalid_model_kwargs} should be specified explicitly. " | |||
| f"Instead they were passed in as part of `model_kwargs` parameter." | |||
| ) | |||
| values["model_kwargs"] = extra | |||
| return values | |||
| @root_validator() | |||
| def validate_environment(cls, values: Dict) -> Dict: | |||
| """Validate that api key and python package exists in environment.""" | |||
| values["openai_api_key"] = get_from_dict_or_env( | |||
| values, "openai_api_key", "OPENAI_API_KEY" | |||
| ) | |||
| values["openai_api_base"] = get_from_dict_or_env( | |||
| values, | |||
| "openai_api_base", | |||
| "OPENAI_API_BASE", | |||
| default="", | |||
| ) | |||
| values["openai_proxy"] = get_from_dict_or_env( | |||
| values, | |||
| "openai_proxy", | |||
| "OPENAI_PROXY", | |||
| default="", | |||
| ) | |||
| default_api_version = "" | |||
| values["openai_api_version"] = get_from_dict_or_env( | |||
| values, | |||
| "openai_api_version", | |||
| "OPENAI_API_VERSION", | |||
| default=default_api_version, | |||
| ) | |||
| values["openai_organization"] = get_from_dict_or_env( | |||
| values, | |||
| "openai_organization", | |||
| "OPENAI_ORGANIZATION", | |||
| default="", | |||
| ) | |||
| try: | |||
| import openai | |||
| if is_openai_v1(): | |||
| client_params = { | |||
| "api_key": values["openai_api_key"], | |||
| "organization": values["openai_organization"], | |||
| "base_url": values["openai_api_base"], | |||
| "timeout": values["request_timeout"], | |||
| "max_retries": values["max_retries"], | |||
| } | |||
| if not values.get("client"): | |||
| values["client"] = openai.OpenAI(**client_params).embeddings | |||
| if not values.get("async_client"): | |||
| values["async_client"] = openai.AsyncOpenAI( | |||
| **client_params | |||
| ).embeddings | |||
| elif not values.get("client"): | |||
| values["client"] = openai.Embedding | |||
| else: | |||
| pass | |||
| except ImportError: | |||
| raise ImportError( | |||
| "Could not import openai python package. " | |||
| "Please install it with `pip install openai`." | |||
| ) | |||
| return values | |||
| @property | |||
| def _invocation_params(self) -> Dict: | |||
| openai_args = { | |||
| "model": self.model, | |||
| "timeout": self.request_timeout, | |||
| "extra_headers": self.headers, | |||
| **self.model_kwargs, | |||
| } | |||
| if self.openai_proxy: | |||
| import openai | |||
| openai.proxy = { | |||
| "http": self.openai_proxy, | |||
| "https": self.openai_proxy, | |||
| } # type: ignore[assignment] # noqa: E501 | |||
| return openai_args | |||
| def _embedding_func(self, text: str, *, engine: str) -> List[float]: | |||
| """Call out to LocalAI's embedding endpoint.""" | |||
| # handle large input text | |||
| if self.model.endswith("001"): | |||
| # See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500 | |||
| # replace newlines, which can negatively affect performance. | |||
| text = text.replace("\n", " ") | |||
| return ( | |||
| embed_with_retry( | |||
| self, | |||
| input=[text], | |||
| **self._invocation_params, | |||
| ) | |||
| .data[0] | |||
| .embedding | |||
| ) | |||
| async def _aembedding_func(self, text: str, *, engine: str) -> List[float]: | |||
| """Call out to LocalAI's embedding endpoint.""" | |||
| # handle large input text | |||
| if self.model.endswith("001"): | |||
| # See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500 | |||
| # replace newlines, which can negatively affect performance. | |||
| text = text.replace("\n", " ") | |||
| return ( | |||
| ( | |||
| await async_embed_with_retry( | |||
| self, | |||
| input=[text], | |||
| **self._invocation_params, | |||
| ) | |||
| ) | |||
| .data[0] | |||
| .embedding | |||
| ) | |||
| def embed_documents( | |||
| self, texts: List[str], chunk_size: Optional[int] = 0 | |||
| ) -> List[List[float]]: | |||
| """Call out to LocalAI's embedding endpoint for embedding search docs. | |||
| Args: | |||
| texts: The list of texts to embed. | |||
| chunk_size: The chunk size of embeddings. If None, will use the chunk size | |||
| specified by the class. | |||
| Returns: | |||
| List of embeddings, one for each text. | |||
| """ | |||
| # call _embedding_func for each text with multithreads | |||
| def task(seq, text): | |||
| return (seq, self._embedding_func(text, engine=self.deployment)) | |||
| params = [{"seq": i, "text": text} for i, text in enumerate(texts)] | |||
| result = list(run_in_thread_pool(func=task, params=params)) | |||
| result = sorted(result, key=lambda x: x[0]) | |||
| return [x[1] for x in result] | |||
| async def aembed_documents( | |||
| self, texts: List[str], chunk_size: Optional[int] = 0 | |||
| ) -> List[List[float]]: | |||
| """Call out to LocalAI's embedding endpoint async for embedding search docs. | |||
| Args: | |||
| texts: The list of texts to embed. | |||
| chunk_size: The chunk size of embeddings. If None, will use the chunk size | |||
| specified by the class. | |||
| Returns: | |||
| List of embeddings, one for each text. | |||
| """ | |||
| embeddings = [] | |||
| for text in texts: | |||
| response = await self._aembedding_func(text, engine=self.deployment) | |||
| embeddings.append(response) | |||
| return embeddings | |||
| def embed_query(self, text: str) -> List[float]: | |||
| """Call out to LocalAI's embedding endpoint for embedding query text. | |||
| Args: | |||
| text: The text to embed. | |||
| Returns: | |||
| Embedding for the text. | |||
| """ | |||
| embedding = self._embedding_func(text, engine=self.deployment) | |||
| return embedding | |||
| async def aembed_query(self, text: str) -> List[float]: | |||
| """Call out to LocalAI's embedding endpoint async for embedding query text. | |||
| Args: | |||
| text: The text to embed. | |||
| Returns: | |||
| Embedding for the text. | |||
| """ | |||
| embedding = await self._aembedding_func(text, engine=self.deployment) | |||
| return embedding | |||
| @@ -0,0 +1,87 @@ | |||
| ## 指定制定列的csv文件加载器 | |||
| import csv | |||
| from io import TextIOWrapper | |||
| from typing import Dict, List, Optional | |||
| from langchain.docstore.document import Document | |||
| from langchain_community.document_loaders import CSVLoader | |||
| from langchain_community.document_loaders.helpers import detect_file_encodings | |||
| class FilteredCSVLoader(CSVLoader): | |||
| def __init__( | |||
| self, | |||
| file_path: str, | |||
| columns_to_read: List[str], | |||
| source_column: Optional[str] = None, | |||
| metadata_columns: List[str] = [], | |||
| csv_args: Optional[Dict] = None, | |||
| encoding: Optional[str] = None, | |||
| autodetect_encoding: bool = False, | |||
| ): | |||
| super().__init__( | |||
| file_path=file_path, | |||
| source_column=source_column, | |||
| metadata_columns=metadata_columns, | |||
| csv_args=csv_args, | |||
| encoding=encoding, | |||
| autodetect_encoding=autodetect_encoding, | |||
| ) | |||
| self.columns_to_read = columns_to_read | |||
| def load(self) -> List[Document]: | |||
| """Load data into document objects.""" | |||
| docs = [] | |||
| try: | |||
| with open(self.file_path, newline="", encoding=self.encoding) as csvfile: | |||
| docs = self.__read_file(csvfile) | |||
| except UnicodeDecodeError as e: | |||
| if self.autodetect_encoding: | |||
| detected_encodings = detect_file_encodings(self.file_path) | |||
| for encoding in detected_encodings: | |||
| try: | |||
| with open( | |||
| self.file_path, newline="", encoding=encoding.encoding | |||
| ) as csvfile: | |||
| docs = self.__read_file(csvfile) | |||
| break | |||
| except UnicodeDecodeError: | |||
| continue | |||
| else: | |||
| raise RuntimeError(f"Error loading {self.file_path}") from e | |||
| except Exception as e: | |||
| raise RuntimeError(f"Error loading {self.file_path}") from e | |||
| return docs | |||
| def __read_file(self, csvfile: TextIOWrapper) -> List[Document]: | |||
| docs = [] | |||
| csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore | |||
| for i, row in enumerate(csv_reader): | |||
| content = [] | |||
| for col in self.columns_to_read: | |||
| if col in row: | |||
| content.append(f"{col}:{str(row[col])}") | |||
| else: | |||
| raise ValueError( | |||
| f"Column '{self.columns_to_read[0]}' not found in CSV file." | |||
| ) | |||
| content = "\n".join(content) | |||
| # Extract the source if available | |||
| source = ( | |||
| row.get(self.source_column, None) | |||
| if self.source_column is not None | |||
| else self.file_path | |||
| ) | |||
| metadata = {"source": source, "row": i} | |||
| for col in self.metadata_columns: | |||
| if col in row: | |||
| metadata[col] = row[col] | |||
| doc = Document(page_content=content, metadata=metadata) | |||
| docs.append(doc) | |||
| return docs | |||
| @@ -0,0 +1,4 @@ | |||
| from .mydocloader import RapidOCRDocLoader | |||
| from .myimgloader import RapidOCRLoader | |||
| from .mypdfloader import RapidOCRPDFLoader | |||
| from .mypptloader import RapidOCRPPTLoader | |||
| @@ -0,0 +1,79 @@ | |||
| from typing import List | |||
| import tqdm | |||
| from langchain_community.document_loaders.unstructured import UnstructuredFileLoader | |||
| class RapidOCRDocLoader(UnstructuredFileLoader): | |||
| def _get_elements(self) -> List: | |||
| def doc2text(filepath): | |||
| from io import BytesIO | |||
| import numpy as np | |||
| from docx import Document, ImagePart | |||
| from docx.oxml.table import CT_Tbl | |||
| from docx.oxml.text.paragraph import CT_P | |||
| from docx.table import Table, _Cell | |||
| from docx.text.paragraph import Paragraph | |||
| from PIL import Image | |||
| from rapidocr_onnxruntime import RapidOCR | |||
| ocr = RapidOCR() | |||
| doc = Document(filepath) | |||
| resp = "" | |||
| def iter_block_items(parent): | |||
| from docx.document import Document | |||
| if isinstance(parent, Document): | |||
| parent_elm = parent.element.body | |||
| elif isinstance(parent, _Cell): | |||
| parent_elm = parent._tc | |||
| else: | |||
| raise ValueError("RapidOCRDocLoader parse fail") | |||
| for child in parent_elm.iterchildren(): | |||
| if isinstance(child, CT_P): | |||
| yield Paragraph(child, parent) | |||
| elif isinstance(child, CT_Tbl): | |||
| yield Table(child, parent) | |||
| b_unit = tqdm.tqdm( | |||
| total=len(doc.paragraphs) + len(doc.tables), | |||
| desc="RapidOCRDocLoader block index: 0", | |||
| ) | |||
| for i, block in enumerate(iter_block_items(doc)): | |||
| b_unit.set_description("RapidOCRDocLoader block index: {}".format(i)) | |||
| b_unit.refresh() | |||
| if isinstance(block, Paragraph): | |||
| resp += block.text.strip() + "\n" | |||
| images = block._element.xpath(".//pic:pic") # 获取所有图片 | |||
| for image in images: | |||
| for img_id in image.xpath(".//a:blip/@r:embed"): # 获取图片id | |||
| part = doc.part.related_parts[ | |||
| img_id | |||
| ] # 根据图片id获取对应的图片 | |||
| if isinstance(part, ImagePart): | |||
| image = Image.open(BytesIO(part._blob)) | |||
| result, _ = ocr(np.array(image)) | |||
| if result: | |||
| ocr_result = [line[1] for line in result] | |||
| resp += "\n".join(ocr_result) | |||
| elif isinstance(block, Table): | |||
| for row in block.rows: | |||
| for cell in row.cells: | |||
| for paragraph in cell.paragraphs: | |||
| resp += paragraph.text.strip() + "\n" | |||
| b_unit.update(1) | |||
| return resp | |||
| text = doc2text(self.file_path) | |||
| from unstructured.partition.text import partition_text | |||
| return partition_text(text=text, **self.unstructured_kwargs) | |||
| if __name__ == "__main__": | |||
| loader = RapidOCRDocLoader(file_path="../tests/samples/ocr_test.docx") | |||
| docs = loader.load() | |||
| print(docs) | |||
| @@ -0,0 +1,28 @@ | |||
| from typing import List | |||
| from langchain_community.document_loaders.unstructured import UnstructuredFileLoader | |||
| from chatchat.server.file_rag.document_loaders.ocr import get_ocr | |||
| class RapidOCRLoader(UnstructuredFileLoader): | |||
| def _get_elements(self) -> List: | |||
| def img2text(filepath): | |||
| resp = "" | |||
| ocr = get_ocr() | |||
| result, _ = ocr(filepath) | |||
| if result: | |||
| ocr_result = [line[1] for line in result] | |||
| resp += "\n".join(ocr_result) | |||
| return resp | |||
| text = img2text(self.file_path) | |||
| from unstructured.partition.text import partition_text | |||
| return partition_text(text=text, **self.unstructured_kwargs) | |||
| if __name__ == "__main__": | |||
| loader = RapidOCRLoader(file_path="../tests/samples/ocr_test.jpg") | |||
| docs = loader.load() | |||
| print(docs) | |||
| @@ -0,0 +1,102 @@ | |||
| from typing import List | |||
| import cv2 | |||
| import numpy as np | |||
| import tqdm | |||
| from langchain_community.document_loaders.unstructured import UnstructuredFileLoader | |||
| from PIL import Image | |||
| from chatchat.configs import PDF_OCR_THRESHOLD | |||
| from chatchat.server.file_rag.document_loaders.ocr import get_ocr | |||
| class RapidOCRPDFLoader(UnstructuredFileLoader): | |||
| def _get_elements(self) -> List: | |||
| def rotate_img(img, angle): | |||
| """ | |||
| img --image | |||
| angle --rotation angle | |||
| return--rotated img | |||
| """ | |||
| h, w = img.shape[:2] | |||
| rotate_center = (w / 2, h / 2) | |||
| # 获取旋转矩阵 | |||
| # 参数1为旋转中心点; | |||
| # 参数2为旋转角度,正值-逆时针旋转;负值-顺时针旋转 | |||
| # 参数3为各向同性的比例因子,1.0原图,2.0变成原来的2倍,0.5变成原来的0.5倍 | |||
| M = cv2.getRotationMatrix2D(rotate_center, angle, 1.0) | |||
| # 计算图像新边界 | |||
| new_w = int(h * np.abs(M[0, 1]) + w * np.abs(M[0, 0])) | |||
| new_h = int(h * np.abs(M[0, 0]) + w * np.abs(M[0, 1])) | |||
| # 调整旋转矩阵以考虑平移 | |||
| M[0, 2] += (new_w - w) / 2 | |||
| M[1, 2] += (new_h - h) / 2 | |||
| rotated_img = cv2.warpAffine(img, M, (new_w, new_h)) | |||
| return rotated_img | |||
| def pdf2text(filepath): | |||
| import fitz # pyMuPDF里面的fitz包,不要与pip install fitz混淆 | |||
| import numpy as np | |||
| ocr = get_ocr() | |||
| doc = fitz.open(filepath) | |||
| resp = "" | |||
| b_unit = tqdm.tqdm( | |||
| total=doc.page_count, desc="RapidOCRPDFLoader context page index: 0" | |||
| ) | |||
| for i, page in enumerate(doc): | |||
| b_unit.set_description( | |||
| "RapidOCRPDFLoader context page index: {}".format(i) | |||
| ) | |||
| b_unit.refresh() | |||
| text = page.get_text("") | |||
| resp += text + "\n" | |||
| img_list = page.get_image_info(xrefs=True) | |||
| for img in img_list: | |||
| if xref := img.get("xref"): | |||
| bbox = img["bbox"] | |||
| # 检查图片尺寸是否超过设定的阈值 | |||
| if (bbox[2] - bbox[0]) / (page.rect.width) < PDF_OCR_THRESHOLD[ | |||
| 0 | |||
| ] or (bbox[3] - bbox[1]) / ( | |||
| page.rect.height | |||
| ) < PDF_OCR_THRESHOLD[1]: | |||
| continue | |||
| pix = fitz.Pixmap(doc, xref) | |||
| samples = pix.samples | |||
| if int(page.rotation) != 0: # 如果Page有旋转角度,则旋转图片 | |||
| img_array = np.frombuffer( | |||
| pix.samples, dtype=np.uint8 | |||
| ).reshape(pix.height, pix.width, -1) | |||
| tmp_img = Image.fromarray(img_array) | |||
| ori_img = cv2.cvtColor(np.array(tmp_img), cv2.COLOR_RGB2BGR) | |||
| rot_img = rotate_img(img=ori_img, angle=360 - page.rotation) | |||
| img_array = cv2.cvtColor(rot_img, cv2.COLOR_RGB2BGR) | |||
| else: | |||
| img_array = np.frombuffer( | |||
| pix.samples, dtype=np.uint8 | |||
| ).reshape(pix.height, pix.width, -1) | |||
| result, _ = ocr(img_array) | |||
| if result: | |||
| ocr_result = [line[1] for line in result] | |||
| resp += "\n".join(ocr_result) | |||
| # 更新进度 | |||
| b_unit.update(1) | |||
| return resp | |||
| text = pdf2text(self.file_path) | |||
| from unstructured.partition.text import partition_text | |||
| return partition_text(text=text, **self.unstructured_kwargs) | |||
| if __name__ == "__main__": | |||
| loader = RapidOCRPDFLoader(file_path="/Users/tonysong/Desktop/test.pdf") | |||
| docs = loader.load() | |||
| print(docs) | |||
| @@ -0,0 +1,66 @@ | |||
| from typing import List | |||
| import tqdm | |||
| from langchain_community.document_loaders.unstructured import UnstructuredFileLoader | |||
| class RapidOCRPPTLoader(UnstructuredFileLoader): | |||
| def _get_elements(self) -> List: | |||
| def ppt2text(filepath): | |||
| from io import BytesIO | |||
| import numpy as np | |||
| from PIL import Image | |||
| from pptx import Presentation | |||
| from rapidocr_onnxruntime import RapidOCR | |||
| ocr = RapidOCR() | |||
| prs = Presentation(filepath) | |||
| resp = "" | |||
| def extract_text(shape): | |||
| nonlocal resp | |||
| if shape.has_text_frame: | |||
| resp += shape.text.strip() + "\n" | |||
| if shape.has_table: | |||
| for row in shape.table.rows: | |||
| for cell in row.cells: | |||
| for paragraph in cell.text_frame.paragraphs: | |||
| resp += paragraph.text.strip() + "\n" | |||
| if shape.shape_type == 13: # 13 表示图片 | |||
| image = Image.open(BytesIO(shape.image.blob)) | |||
| result, _ = ocr(np.array(image)) | |||
| if result: | |||
| ocr_result = [line[1] for line in result] | |||
| resp += "\n".join(ocr_result) | |||
| elif shape.shape_type == 6: # 6 表示组合 | |||
| for child_shape in shape.shapes: | |||
| extract_text(child_shape) | |||
| b_unit = tqdm.tqdm( | |||
| total=len(prs.slides), desc="RapidOCRPPTLoader slide index: 1" | |||
| ) | |||
| # 遍历所有幻灯片 | |||
| for slide_number, slide in enumerate(prs.slides, start=1): | |||
| b_unit.set_description( | |||
| "RapidOCRPPTLoader slide index: {}".format(slide_number) | |||
| ) | |||
| b_unit.refresh() | |||
| sorted_shapes = sorted( | |||
| slide.shapes, key=lambda x: (x.top, x.left) | |||
| ) # 从上到下、从左到右遍历 | |||
| for shape in sorted_shapes: | |||
| extract_text(shape) | |||
| b_unit.update(1) | |||
| return resp | |||
| text = ppt2text(self.file_path) | |||
| from unstructured.partition.text import partition_text | |||
| return partition_text(text=text, **self.unstructured_kwargs) | |||
| if __name__ == "__main__": | |||
| loader = RapidOCRPPTLoader(file_path="../tests/samples/ocr_test.pptx") | |||
| docs = loader.load() | |||
| print(docs) | |||
| @@ -0,0 +1,21 @@ | |||
| from typing import TYPE_CHECKING | |||
| if TYPE_CHECKING: | |||
| try: | |||
| from rapidocr_paddle import RapidOCR | |||
| except ImportError: | |||
| from rapidocr_onnxruntime import RapidOCR | |||
| def get_ocr(use_cuda: bool = True) -> "RapidOCR": | |||
| try: | |||
| from rapidocr_paddle import RapidOCR | |||
| ocr = RapidOCR( | |||
| det_use_cuda=use_cuda, cls_use_cuda=use_cuda, rec_use_cuda=use_cuda | |||
| ) | |||
| except ImportError: | |||
| from rapidocr_onnxruntime import RapidOCR | |||
| ocr = RapidOCR() | |||
| return ocr | |||
| @@ -0,0 +1,3 @@ | |||
| from ..retrievers.base import BaseRetrieverService | |||
| from ..retrievers.ensemble import EnsembleRetrieverService | |||
| from ..retrievers.vectorstore import VectorstoreRetrieverService | |||
| @@ -0,0 +1,24 @@ | |||
| from abc import ABCMeta, abstractmethod | |||
| from langchain.vectorstores import VectorStore | |||
| class BaseRetrieverService(metaclass=ABCMeta): | |||
| def __init__(self, **kwargs): | |||
| self.do_init(**kwargs) | |||
| @abstractmethod | |||
| def do_init(self, **kwargs): | |||
| pass | |||
| @abstractmethod | |||
| def from_vectorstore( | |||
| vectorstore: VectorStore, | |||
| top_k: int, | |||
| score_threshold: int or float, | |||
| ): | |||
| pass | |||
| @abstractmethod | |||
| def get_relevant_documents(self, query: str): | |||
| pass | |||
| @@ -0,0 +1,46 @@ | |||
| from langchain.retrievers import EnsembleRetriever | |||
| from langchain.vectorstores import VectorStore | |||
| from langchain_community.retrievers import BM25Retriever | |||
| from langchain_core.retrievers import BaseRetriever | |||
| from .base import BaseRetrieverService | |||
| class EnsembleRetrieverService(BaseRetrieverService): | |||
| def do_init( | |||
| self, | |||
| retriever: BaseRetriever = None, | |||
| top_k: int = 5, | |||
| ): | |||
| self.vs = None | |||
| self.top_k = top_k | |||
| self.retriever = retriever | |||
| @staticmethod | |||
| def from_vectorstore( | |||
| vectorstore: VectorStore, | |||
| top_k: int, | |||
| score_threshold: int or float, | |||
| ): | |||
| faiss_retriever = vectorstore.as_retriever( | |||
| search_type="similarity_score_threshold", | |||
| search_kwargs={"score_threshold": score_threshold, "k": top_k}, | |||
| ) | |||
| # TODO: 换个不用torch的实现方式 | |||
| # from cutword.cutword import Cutter | |||
| import jieba | |||
| # cutter = Cutter() | |||
| docs = list(vectorstore.docstore._dict.values()) | |||
| bm25_retriever = BM25Retriever.from_documents( | |||
| docs, | |||
| preprocess_func=jieba.lcut_for_search, | |||
| ) | |||
| bm25_retriever.k = top_k | |||
| ensemble_retriever = EnsembleRetriever( | |||
| retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5] | |||
| ) | |||
| return EnsembleRetrieverService(retriever=ensemble_retriever) | |||
| def get_relevant_documents(self, query: str): | |||
| return self.retriever.get_relevant_documents(query)[: self.top_k] | |||
| @@ -0,0 +1,30 @@ | |||
| from langchain.vectorstores import VectorStore | |||
| from langchain_core.retrievers import BaseRetriever | |||
| from .base import BaseRetrieverService | |||
| class VectorstoreRetrieverService(BaseRetrieverService): | |||
| def do_init( | |||
| self, | |||
| retriever: BaseRetriever = None, | |||
| top_k: int = 5, | |||
| ): | |||
| self.vs = None | |||
| self.top_k = top_k | |||
| self.retriever = retriever | |||
| @staticmethod | |||
| def from_vectorstore( | |||
| vectorstore: VectorStore, | |||
| top_k: int, | |||
| score_threshold: int or float, | |||
| ): | |||
| retriever = vectorstore.as_retriever( | |||
| search_type="similarity_score_threshold", | |||
| search_kwargs={"score_threshold": score_threshold, "k": top_k}, | |||
| ) | |||
| return VectorstoreRetrieverService(retriever=retriever) | |||
| def get_relevant_documents(self, query: str): | |||
| return self.retriever.get_relevant_documents(query)[: self.top_k] | |||
| @@ -0,0 +1,4 @@ | |||
| from .ali_text_splitter import AliTextSplitter | |||
| from .chinese_recursive_text_splitter import ChineseRecursiveTextSplitter | |||
| from .chinese_text_splitter import ChineseTextSplitter | |||
| from .zh_title_enhance import zh_title_enhance | |||
| @@ -0,0 +1,35 @@ | |||
| import re | |||
| from typing import List | |||
| from langchain.text_splitter import CharacterTextSplitter | |||
| class AliTextSplitter(CharacterTextSplitter): | |||
| def __init__(self, pdf: bool = False, **kwargs): | |||
| super().__init__(**kwargs) | |||
| self.pdf = pdf | |||
| def split_text(self, text: str) -> List[str]: | |||
| # use_document_segmentation参数指定是否用语义切分文档,此处采取的文档语义分割模型为达摩院开源的nlp_bert_document-segmentation_chinese-base,论文见https://arxiv.org/abs/2107.09278 | |||
| # 如果使用模型进行文档语义切分,那么需要安装modelscope[nlp]:pip install "modelscope[nlp]" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html | |||
| # 考虑到使用了三个模型,可能对于低配置gpu不太友好,因此这里将模型load进cpu计算,有需要的话可以替换device为自己的显卡id | |||
| if self.pdf: | |||
| text = re.sub(r"\n{3,}", r"\n", text) | |||
| text = re.sub("\s", " ", text) | |||
| text = re.sub("\n\n", "", text) | |||
| try: | |||
| from modelscope.pipelines import pipeline | |||
| except ImportError: | |||
| raise ImportError( | |||
| "Could not import modelscope python package. " | |||
| "Please install modelscope with `pip install modelscope`. " | |||
| ) | |||
| p = pipeline( | |||
| task="document-segmentation", | |||
| model="damo/nlp_bert_document-segmentation_chinese-base", | |||
| device="cpu", | |||
| ) | |||
| result = p(documents=text) | |||
| sent_list = [i for i in result["text"].split("\n\t") if i] | |||
| return sent_list | |||
| @@ -0,0 +1,106 @@ | |||
| import logging | |||
| import re | |||
| from typing import Any, List, Optional | |||
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |||
| logger = logging.getLogger(__name__) | |||
| def _split_text_with_regex_from_end( | |||
| text: str, separator: str, keep_separator: bool | |||
| ) -> List[str]: | |||
| # Now that we have the separator, split the text | |||
| if separator: | |||
| if keep_separator: | |||
| # The parentheses in the pattern keep the delimiters in the result. | |||
| _splits = re.split(f"({separator})", text) | |||
| splits = ["".join(i) for i in zip(_splits[0::2], _splits[1::2])] | |||
| if len(_splits) % 2 == 1: | |||
| splits += _splits[-1:] | |||
| # splits = [_splits[0]] + splits | |||
| else: | |||
| splits = re.split(separator, text) | |||
| else: | |||
| splits = list(text) | |||
| return [s for s in splits if s != ""] | |||
| class ChineseRecursiveTextSplitter(RecursiveCharacterTextSplitter): | |||
| def __init__( | |||
| self, | |||
| separators: Optional[List[str]] = None, | |||
| keep_separator: bool = True, | |||
| is_separator_regex: bool = True, | |||
| **kwargs: Any, | |||
| ) -> None: | |||
| """Create a new TextSplitter.""" | |||
| super().__init__(keep_separator=keep_separator, **kwargs) | |||
| self._separators = separators or [ | |||
| "\n\n", | |||
| "\n", | |||
| "。|!|?", | |||
| "\.\s|\!\s|\?\s", | |||
| ";|;\s", | |||
| ",|,\s", | |||
| ] | |||
| self._is_separator_regex = is_separator_regex | |||
| def _split_text(self, text: str, separators: List[str]) -> List[str]: | |||
| """Split incoming text and return chunks.""" | |||
| final_chunks = [] | |||
| # Get appropriate separator to use | |||
| separator = separators[-1] | |||
| new_separators = [] | |||
| for i, _s in enumerate(separators): | |||
| _separator = _s if self._is_separator_regex else re.escape(_s) | |||
| if _s == "": | |||
| separator = _s | |||
| break | |||
| if re.search(_separator, text): | |||
| separator = _s | |||
| new_separators = separators[i + 1 :] | |||
| break | |||
| _separator = separator if self._is_separator_regex else re.escape(separator) | |||
| splits = _split_text_with_regex_from_end(text, _separator, self._keep_separator) | |||
| # Now go merging things, recursively splitting longer texts. | |||
| _good_splits = [] | |||
| _separator = "" if self._keep_separator else separator | |||
| for s in splits: | |||
| if self._length_function(s) < self._chunk_size: | |||
| _good_splits.append(s) | |||
| else: | |||
| if _good_splits: | |||
| merged_text = self._merge_splits(_good_splits, _separator) | |||
| final_chunks.extend(merged_text) | |||
| _good_splits = [] | |||
| if not new_separators: | |||
| final_chunks.append(s) | |||
| else: | |||
| other_info = self._split_text(s, new_separators) | |||
| final_chunks.extend(other_info) | |||
| if _good_splits: | |||
| merged_text = self._merge_splits(_good_splits, _separator) | |||
| final_chunks.extend(merged_text) | |||
| return [ | |||
| re.sub(r"\n{2,}", "\n", chunk.strip()) | |||
| for chunk in final_chunks | |||
| if chunk.strip() != "" | |||
| ] | |||
| if __name__ == "__main__": | |||
| text_splitter = ChineseRecursiveTextSplitter( | |||
| keep_separator=True, is_separator_regex=True, chunk_size=50, chunk_overlap=0 | |||
| ) | |||
| ls = [ | |||
| """中国对外贸易形势报告(75页)。前 10 个月,一般贸易进出口 19.5 万亿元,增长 25.1%, 比整体进出口增速高出 2.9 个百分点,占进出口总额的 61.7%,较去年同期提升 1.6 个百分点。其中,一般贸易出口 10.6 万亿元,增长 25.3%,占出口总额的 60.9%,提升 1.5 个百分点;进口8.9万亿元,增长24.9%,占进口总额的62.7%, 提升 1.8 个百分点。加工贸易进出口 6.8 万亿元,增长 11.8%, 占进出口总额的 21.5%,减少 2.0 个百分点。其中,出口增 长 10.4%,占出口总额的 24.3%,减少 2.6 个百分点;进口增 长 14.2%,占进口总额的 18.0%,减少 1.2 个百分点。此外, 以保税物流方式进出口 3.96 万亿元,增长 27.9%。其中,出 口 1.47 万亿元,增长 38.9%;进口 2.49 万亿元,增长 22.2%。前三季度,中国服务贸易继续保持快速增长态势。服务 进出口总额 37834.3 亿元,增长 11.6%;其中服务出口 17820.9 亿元,增长 27.3%;进口 20013.4 亿元,增长 0.5%,进口增 速实现了疫情以来的首次转正。服务出口增幅大于进口 26.8 个百分点,带动服务贸易逆差下降 62.9%至 2192.5 亿元。服 务贸易结构持续优化,知识密集型服务进出口 16917.7 亿元, 增长 13.3%,占服务进出口总额的比重达到 44.7%,提升 0.7 个百分点。 二、中国对外贸易发展环境分析和展望 全球疫情起伏反复,经济复苏分化加剧,大宗商品价格 上涨、能源紧缺、运力紧张及发达经济体政策调整外溢等风 险交织叠加。同时也要看到,我国经济长期向好的趋势没有 改变,外贸企业韧性和活力不断增强,新业态新模式加快发 展,创新转型步伐提速。产业链供应链面临挑战。美欧等加快出台制造业回迁计 划,加速产业链供应链本土布局,跨国公司调整产业链供应 链,全球双链面临新一轮重构,区域化、近岸化、本土化、 短链化趋势凸显。疫苗供应不足,制造业“缺芯”、物流受限、 运价高企,全球产业链供应链面临压力。 全球通胀持续高位运行。能源价格上涨加大主要经济体 的通胀压力,增加全球经济复苏的不确定性。世界银行今年 10 月发布《大宗商品市场展望》指出,能源价格在 2021 年 大涨逾 80%,并且仍将在 2022 年小幅上涨。IMF 指出,全 球通胀上行风险加剧,通胀前景存在巨大不确定性。""", | |||
| ] | |||
| # text = """""" | |||
| for inum, text in enumerate(ls): | |||
| print(inum) | |||
| chunks = text_splitter.split_text(text) | |||
| for chunk in chunks: | |||
| print(chunk) | |||
| @@ -0,0 +1,77 @@ | |||
| import re | |||
| from typing import List | |||
| from langchain.text_splitter import CharacterTextSplitter | |||
| class ChineseTextSplitter(CharacterTextSplitter): | |||
| def __init__(self, pdf: bool = False, sentence_size: int = 250, **kwargs): | |||
| super().__init__(**kwargs) | |||
| self.pdf = pdf | |||
| self.sentence_size = sentence_size | |||
| def split_text1(self, text: str) -> List[str]: | |||
| if self.pdf: | |||
| text = re.sub(r"\n{3,}", "\n", text) | |||
| text = re.sub("\s", " ", text) | |||
| text = text.replace("\n\n", "") | |||
| sent_sep_pattern = re.compile( | |||
| '([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))' | |||
| ) # del :; | |||
| sent_list = [] | |||
| for ele in sent_sep_pattern.split(text): | |||
| if sent_sep_pattern.match(ele) and sent_list: | |||
| sent_list[-1] += ele | |||
| elif ele: | |||
| sent_list.append(ele) | |||
| return sent_list | |||
| def split_text(self, text: str) -> List[str]: ##此处需要进一步优化逻辑 | |||
| if self.pdf: | |||
| text = re.sub(r"\n{3,}", r"\n", text) | |||
| text = re.sub("\s", " ", text) | |||
| text = re.sub("\n\n", "", text) | |||
| text = re.sub(r"([;;.!?。!?\?])([^”’])", r"\1\n\2", text) # 单字符断句符 | |||
| text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text) # 英文省略号 | |||
| text = re.sub(r'(\…{2})([^"’”」』])', r"\1\n\2", text) # 中文省略号 | |||
| text = re.sub( | |||
| r'([;;!?。!?\?]["’”」』]{0,2})([^;;!?,。!?\?])', r"\1\n\2", text | |||
| ) | |||
| # 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后,注意前面的几句都小心保留了双引号 | |||
| text = text.rstrip() # 段尾如果有多余的\n就去掉它 | |||
| # 很多规则中会考虑分号;,但是这里我把它忽略不计,破折号、英文双引号等同样忽略,需要的再做些简单调整即可。 | |||
| ls = [i for i in text.split("\n") if i] | |||
| for ele in ls: | |||
| if len(ele) > self.sentence_size: | |||
| ele1 = re.sub(r'([,,.]["’”」』]{0,2})([^,,.])', r"\1\n\2", ele) | |||
| ele1_ls = ele1.split("\n") | |||
| for ele_ele1 in ele1_ls: | |||
| if len(ele_ele1) > self.sentence_size: | |||
| ele_ele2 = re.sub( | |||
| r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', | |||
| r"\1\n\2", | |||
| ele_ele1, | |||
| ) | |||
| ele2_ls = ele_ele2.split("\n") | |||
| for ele_ele2 in ele2_ls: | |||
| if len(ele_ele2) > self.sentence_size: | |||
| ele_ele3 = re.sub( | |||
| '( ["’”」』]{0,2})([^ ])', r"\1\n\2", ele_ele2 | |||
| ) | |||
| ele2_id = ele2_ls.index(ele_ele2) | |||
| ele2_ls = ( | |||
| ele2_ls[:ele2_id] | |||
| + [i for i in ele_ele3.split("\n") if i] | |||
| + ele2_ls[ele2_id + 1 :] | |||
| ) | |||
| ele_id = ele1_ls.index(ele_ele1) | |||
| ele1_ls = ( | |||
| ele1_ls[:ele_id] | |||
| + [i for i in ele2_ls if i] | |||
| + ele1_ls[ele_id + 1 :] | |||
| ) | |||
| id = ls.index(ele) | |||
| ls = ls[:id] + [i for i in ele1_ls if i] + ls[id + 1 :] | |||
| return ls | |||
| @@ -0,0 +1,100 @@ | |||
| import re | |||
| from langchain.docstore.document import Document | |||
| def under_non_alpha_ratio(text: str, threshold: float = 0.5): | |||
| """Checks if the proportion of non-alpha characters in the text snippet exceeds a given | |||
| threshold. This helps prevent text like "-----------BREAK---------" from being tagged | |||
| as a title or narrative text. The ratio does not count spaces. | |||
| Parameters | |||
| ---------- | |||
| text | |||
| The input string to test | |||
| threshold | |||
| If the proportion of non-alpha characters exceeds this threshold, the function | |||
| returns False | |||
| """ | |||
| if len(text) == 0: | |||
| return False | |||
| alpha_count = len([char for char in text if char.strip() and char.isalpha()]) | |||
| total_count = len([char for char in text if char.strip()]) | |||
| try: | |||
| ratio = alpha_count / total_count | |||
| return ratio < threshold | |||
| except: | |||
| return False | |||
| def is_possible_title( | |||
| text: str, | |||
| title_max_word_length: int = 20, | |||
| non_alpha_threshold: float = 0.5, | |||
| ) -> bool: | |||
| """Checks to see if the text passes all of the checks for a valid title. | |||
| Parameters | |||
| ---------- | |||
| text | |||
| The input text to check | |||
| title_max_word_length | |||
| The maximum number of words a title can contain | |||
| non_alpha_threshold | |||
| The minimum number of alpha characters the text needs to be considered a title | |||
| """ | |||
| # 文本长度为0的话,肯定不是title | |||
| if len(text) == 0: | |||
| print("Not a title. Text is empty.") | |||
| return False | |||
| # 文本中有标点符号,就不是title | |||
| ENDS_IN_PUNCT_PATTERN = r"[^\w\s]\Z" | |||
| ENDS_IN_PUNCT_RE = re.compile(ENDS_IN_PUNCT_PATTERN) | |||
| if ENDS_IN_PUNCT_RE.search(text) is not None: | |||
| return False | |||
| # 文本长度不能超过设定值,默认20 | |||
| # NOTE(robinson) - splitting on spaces here instead of word tokenizing because it | |||
| # is less expensive and actual tokenization doesn't add much value for the length check | |||
| if len(text) > title_max_word_length: | |||
| return False | |||
| # 文本中数字的占比不能太高,否则不是title | |||
| if under_non_alpha_ratio(text, threshold=non_alpha_threshold): | |||
| return False | |||
| # NOTE(robinson) - Prevent flagging salutations like "To My Dearest Friends," as titles | |||
| if text.endswith((",", ".", ",", "。")): | |||
| return False | |||
| if text.isnumeric(): | |||
| print(f"Not a title. Text is all numeric:\n\n{text}") # type: ignore | |||
| return False | |||
| # 开头的字符内应该有数字,默认5个字符内 | |||
| if len(text) < 5: | |||
| text_5 = text | |||
| else: | |||
| text_5 = text[:5] | |||
| alpha_in_text_5 = sum(list(map(lambda x: x.isnumeric(), list(text_5)))) | |||
| if not alpha_in_text_5: | |||
| return False | |||
| return True | |||
| def zh_title_enhance(docs: Document) -> Document: | |||
| title = None | |||
| if len(docs) > 0: | |||
| for doc in docs: | |||
| if is_possible_title(doc.page_content): | |||
| doc.metadata["category"] = "cn_Title" | |||
| title = doc.page_content | |||
| elif title: | |||
| doc.page_content = f"下文与({title})有关。{doc.page_content}" | |||
| return docs | |||
| else: | |||
| print("文件不存在") | |||
| @@ -0,0 +1,14 @@ | |||
| from .retrievers import ( | |||
| BaseRetrieverService, | |||
| EnsembleRetrieverService, | |||
| VectorstoreRetrieverService, | |||
| ) | |||
| Retrivals = { | |||
| "vectorstore": VectorstoreRetrieverService, | |||
| "ensemble": EnsembleRetrieverService, | |||
| } | |||
| def get_Retriever(type: str = "vectorstore") -> BaseRetrieverService: | |||
| return Retrivals[type] | |||
| @@ -0,0 +1,72 @@ | |||
| import urllib | |||
| from fastapi import Body | |||
| from .db.repository.knowledge_base_repository import list_kbs_from_db | |||
| from .kb_service.base import KBServiceFactory | |||
| from .utils import validate_kb_name | |||
| from ..utils.system_utils import BaseResponse, ListResponse, logger | |||
| def list_kbs(): | |||
| # Get List of Knowledge Base | |||
| return ListResponse(data=list_kbs_from_db()) | |||
| def create_kb( | |||
| knowledge_base_name: str = Body(..., examples=["samples"]), | |||
| vector_store_type: str = Body("faiss"), | |||
| kb_info: str = Body("", description="知识库内容简介,用于Agent选择知识库。"), | |||
| embed_model: str = Body(..., examples=["bce-embedding-base_v1", "bge-large-zh-v1.5"]), | |||
| ) -> BaseResponse: | |||
| # Create selected knowledge base | |||
| if not validate_kb_name(knowledge_base_name): | |||
| return BaseResponse(code=403, msg="Don't attack me") | |||
| if knowledge_base_name is None or knowledge_base_name.strip() == "": | |||
| return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称") | |||
| kb = KBServiceFactory.get_service_by_name(knowledge_base_name) | |||
| if kb is not None: | |||
| return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}") | |||
| kb = KBServiceFactory.get_service( | |||
| knowledge_base_name, vector_store_type, embed_model, kb_info=kb_info | |||
| ) | |||
| try: | |||
| kb.create_kb() | |||
| except Exception as e: | |||
| msg = f"创建知识库出错: {e}" | |||
| logger.error( | |||
| f"{e.__class__.__name__}: {msg}", exc_info=e | |||
| ) | |||
| return BaseResponse(code=500, msg=msg) | |||
| return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}") | |||
| def delete_kb( | |||
| knowledge_base_name: str = Body(..., examples=["samples"]), | |||
| ) -> BaseResponse: | |||
| # Delete selected knowledge base | |||
| if not validate_kb_name(knowledge_base_name): | |||
| return BaseResponse(code=403, msg="Don't attack me") | |||
| knowledge_base_name = urllib.parse.unquote(knowledge_base_name) | |||
| kb = KBServiceFactory.get_service_by_name(knowledge_base_name) | |||
| if kb is None: | |||
| return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") | |||
| try: | |||
| status = kb.clear_vs() | |||
| status = kb.drop_kb() | |||
| if status: | |||
| return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}") | |||
| except Exception as e: | |||
| msg = f"删除知识库时出现意外: {e}" | |||
| logger.error( | |||
| f"{e.__class__.__name__}: {msg}", exc_info=e | |||
| ) | |||
| return BaseResponse(code=500, msg=msg) | |||
| return BaseResponse(code=500, msg=f"删除知识库失败 {knowledge_base_name}") | |||
| @@ -0,0 +1,94 @@ | |||
| import threading | |||
| from collections import OrderedDict | |||
| from contextlib import contextmanager | |||
| from typing import Any, List, Tuple, Union | |||
| from langchain_community.vectorstores import FAISS | |||
| class ThreadSafeObject: | |||
| def __init__( | |||
| self, key: Union[str, Tuple], obj: Any = None, pool: "CachePool" = None | |||
| ): | |||
| self._obj = obj | |||
| self._key = key | |||
| self._pool = pool | |||
| self._lock = threading.RLock() | |||
| self._loaded = threading.Event() | |||
| def __repr__(self) -> str: | |||
| cls = type(self).__name__ | |||
| return f"<{cls}: key: {self.key}, obj: {self._obj}>" | |||
| @property | |||
| def key(self): | |||
| return self._key | |||
| @contextmanager | |||
| def acquire(self, owner: str = "", msg: str = "") -> FAISS: | |||
| owner = owner or f"thread {threading.get_native_id()}" | |||
| try: | |||
| self._lock.acquire() | |||
| if self._pool is not None: | |||
| self._pool._cache.move_to_end(self.key) | |||
| yield self._obj | |||
| finally: | |||
| self._lock.release() | |||
| def start_loading(self): | |||
| self._loaded.clear() | |||
| def finish_loading(self): | |||
| self._loaded.set() | |||
| def wait_for_loading(self): | |||
| self._loaded.wait() | |||
| @property | |||
| def obj(self): | |||
| return self._obj | |||
| @obj.setter | |||
| def obj(self, val: Any): | |||
| self._obj = val | |||
| class CachePool: | |||
| def __init__(self, cache_num: int = -1): | |||
| self._cache_num = cache_num | |||
| self._cache = OrderedDict() | |||
| self.atomic = threading.RLock() | |||
| def keys(self) -> List[str]: | |||
| return list(self._cache.keys()) | |||
| def _check_count(self): | |||
| if isinstance(self._cache_num, int) and self._cache_num > 0: | |||
| while len(self._cache) > self._cache_num: | |||
| self._cache.popitem(last=False) | |||
| def get(self, key: str) -> ThreadSafeObject: | |||
| if cache := self._cache.get(key): | |||
| cache.wait_for_loading() | |||
| return cache | |||
| def set(self, key: str, obj: ThreadSafeObject) -> ThreadSafeObject: | |||
| self._cache[key] = obj | |||
| self._check_count() | |||
| return obj | |||
| def pop(self, key: str = None) -> ThreadSafeObject: | |||
| if key is None: | |||
| return self._cache.popitem(last=False) | |||
| else: | |||
| return self._cache.pop(key, None) | |||
| def acquire(self, key: Union[str, Tuple], owner: str = "", msg: str = ""): | |||
| cache = self.get(key) | |||
| if cache is None: | |||
| raise RuntimeError(f"请求的资源 {key} 不存在") | |||
| elif isinstance(cache, ThreadSafeObject): | |||
| self._cache.move_to_end(key) | |||
| return cache.acquire(owner=owner, msg=msg) | |||
| else: | |||
| return cache | |||
| @@ -0,0 +1,211 @@ | |||
| import logging | |||
| import os | |||
| from langchain_community.docstore.in_memory import InMemoryDocstore | |||
| from langchain.schema import Document | |||
| from ...configs import CACHED_MEMO_VS_NUM, CACHED_VS_NUM, DEFAULT_EMBEDDING_MODEL | |||
| from .base import * | |||
| from ..utils import get_vs_path | |||
| from ...utils.system_utils import get_Embeddings | |||
| logger = logging.getLogger() | |||
| # patch FAISS to include doc id in Document.metadata | |||
| def _new_ds_search(self, search: str) -> Union[str, Document]: | |||
| if search not in self._dict: | |||
| return f"ID {search} not found." | |||
| else: | |||
| doc = self._dict[search] | |||
| if isinstance(doc, Document): | |||
| doc.metadata["id"] = search | |||
| return doc | |||
| InMemoryDocstore.search = _new_ds_search | |||
| class ThreadSafeFaiss(ThreadSafeObject): | |||
| def __repr__(self) -> str: | |||
| cls = type(self).__name__ | |||
| return f"<{cls}: key: {self.key}, obj: {self._obj}, docs_count: {self.docs_count()}>" | |||
| def docs_count(self) -> int: | |||
| return len(self._obj.docstore._dict) | |||
| def save(self, path: str, create_path: bool = True): | |||
| with self.acquire(): | |||
| if not os.path.isdir(path) and create_path: | |||
| os.makedirs(path) | |||
| ret = self._obj.save_local(path) | |||
| logger.info(f"已将向量库 {self.key} 保存到磁盘") | |||
| return ret | |||
| def clear(self): | |||
| ret = [] | |||
| with self.acquire(): | |||
| ids = list(self._obj.docstore._dict.keys()) | |||
| if ids: | |||
| ret = self._obj.delete(ids) | |||
| assert len(self._obj.docstore._dict) == 0 | |||
| logger.info(f"已将向量库 {self.key} 清空") | |||
| return ret | |||
| class _FaissPool(CachePool): | |||
| def new_vector_store( | |||
| self, | |||
| kb_name: str, | |||
| embed_model: str = DEFAULT_EMBEDDING_MODEL, | |||
| ) -> FAISS: | |||
| # create an empty vector store | |||
| embeddings = get_Embeddings(embed_model=embed_model) | |||
| doc = Document(page_content="init", metadata={}) | |||
| vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True) | |||
| ids = list(vector_store.docstore._dict.keys()) | |||
| vector_store.delete(ids) | |||
| return vector_store | |||
| def new_temp_vector_store( | |||
| self, | |||
| embed_model: str = DEFAULT_EMBEDDING_MODEL, | |||
| ) -> FAISS: | |||
| # create an empty vector store | |||
| embeddings = get_Embeddings(embed_model=embed_model) | |||
| doc = Document(page_content="init", metadata={}) | |||
| vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True) | |||
| ids = list(vector_store.docstore._dict.keys()) | |||
| vector_store.delete(ids) | |||
| return vector_store | |||
| def save_vector_store(self, kb_name: str, path: str = None): | |||
| if cache := self.get(kb_name): | |||
| return cache.save(path) | |||
| def unload_vector_store(self, kb_name: str): | |||
| if cache := self.get(kb_name): | |||
| self.pop(kb_name) | |||
| logger.info(f"成功释放向量库:{kb_name}") | |||
| class KBFaissPool(_FaissPool): | |||
| def load_vector_store( | |||
| self, | |||
| kb_name: str, | |||
| vector_name: str = None, | |||
| create: bool = True, | |||
| embed_model: str = DEFAULT_EMBEDDING_MODEL, | |||
| ) -> ThreadSafeFaiss: | |||
| self.atomic.acquire() | |||
| locked = True | |||
| vector_name = vector_name or embed_model | |||
| cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些 | |||
| try: | |||
| if cache is None: | |||
| item = ThreadSafeFaiss((kb_name, vector_name), pool=self) | |||
| self.set((kb_name, vector_name), item) | |||
| with item.acquire(msg="初始化"): | |||
| self.atomic.release() | |||
| locked = False | |||
| logger.info( | |||
| f"loading vector store in '{kb_name}/vector_store/{vector_name}' from disk." | |||
| ) | |||
| vs_path = get_vs_path(kb_name, vector_name) | |||
| if os.path.isfile(os.path.join(vs_path, "index.faiss")): | |||
| embeddings = get_Embeddings(embed_model=embed_model) | |||
| vector_store = FAISS.load_local( | |||
| vs_path, | |||
| embeddings, | |||
| normalize_L2=True, | |||
| allow_dangerous_deserialization=True, | |||
| ) | |||
| elif create: | |||
| # create an empty vector store | |||
| if not os.path.exists(vs_path): | |||
| os.makedirs(vs_path) | |||
| vector_store = self.new_vector_store( | |||
| kb_name=kb_name, embed_model=embed_model | |||
| ) | |||
| vector_store.save_local(vs_path) | |||
| else: | |||
| raise RuntimeError(f"knowledge base {kb_name} not exist.") | |||
| item.obj = vector_store | |||
| item.finish_loading() | |||
| else: | |||
| self.atomic.release() | |||
| locked = False | |||
| except Exception as e: | |||
| if locked: # we don't know exception raised before or after atomic.release | |||
| self.atomic.release() | |||
| logger.error(e, exc_info=True) | |||
| raise RuntimeError(f"向量库 {kb_name} 加载失败。") | |||
| return self.get((kb_name, vector_name)) | |||
| class MemoFaissPool(_FaissPool): | |||
| r""" | |||
| 临时向量库的缓存池 | |||
| """ | |||
| def load_vector_store( | |||
| self, | |||
| kb_name: str, | |||
| embed_model: str = DEFAULT_EMBEDDING_MODEL, | |||
| ) -> ThreadSafeFaiss: | |||
| self.atomic.acquire() | |||
| cache = self.get(kb_name) | |||
| if cache is None: | |||
| item = ThreadSafeFaiss(kb_name, pool=self) | |||
| self.set(kb_name, item) | |||
| with item.acquire(msg="初始化"): | |||
| self.atomic.release() | |||
| logger.info(f"loading vector store in '{kb_name}' to memory.") | |||
| # create an empty vector store | |||
| vector_store = self.new_temp_vector_store(embed_model=embed_model) | |||
| item.obj = vector_store | |||
| item.finish_loading() | |||
| else: | |||
| self.atomic.release() | |||
| return self.get(kb_name) | |||
| kb_faiss_pool = KBFaissPool(cache_num=CACHED_VS_NUM) | |||
| memo_faiss_pool = MemoFaissPool(cache_num=CACHED_MEMO_VS_NUM) | |||
| # | |||
| # | |||
| # if __name__ == "__main__": | |||
| # import time, random | |||
| # from pprint import pprint | |||
| # | |||
| # kb_names = ["vs1", "vs2", "vs3"] | |||
| # # for name in kb_names: | |||
| # # memo_faiss_pool.load_vector_store(name) | |||
| # | |||
| # def worker(vs_name: str, name: str): | |||
| # vs_name = "samples" | |||
| # time.sleep(random.randint(1, 5)) | |||
| # embeddings = load_local_embeddings() | |||
| # r = random.randint(1, 3) | |||
| # | |||
| # with kb_faiss_pool.load_vector_store(vs_name).acquire(name) as vs: | |||
| # if r == 1: # add docs | |||
| # ids = vs.add_texts([f"text added by {name}"], embeddings=embeddings) | |||
| # pprint(ids) | |||
| # elif r == 2: # search docs | |||
| # docs = vs.similarity_search_with_score(f"{name}", k=3, score_threshold=1.0) | |||
| # pprint(docs) | |||
| # if r == 3: # delete docs | |||
| # logger.warning(f"清除 {vs_name} by {name}") | |||
| # kb_faiss_pool.get(vs_name).clear() | |||
| # | |||
| # threads = [] | |||
| # for n in range(1, 30): | |||
| # t = threading.Thread(target=worker, | |||
| # kwargs={"vs_name": random.choice(kb_names), "name": f"worker {n}"}, | |||
| # daemon=True) | |||
| # t.start() | |||
| # threads.append(t) | |||
| # | |||
| # for t in threads: | |||
| # t.join() | |||
| @@ -0,0 +1,451 @@ | |||
| import json | |||
| import os | |||
| import urllib | |||
| from typing import Dict, List | |||
| from fastapi import Body, File, Form, Query, UploadFile | |||
| from fastapi.responses import FileResponse | |||
| from langchain.docstore.document import Document | |||
| from sse_starlette import EventSourceResponse | |||
| from ..configs import ( | |||
| CHUNK_SIZE, | |||
| DEFAULT_VS_TYPE, | |||
| OVERLAP_SIZE, | |||
| SCORE_THRESHOLD, | |||
| VECTOR_SEARCH_TOP_K, | |||
| ZH_TITLE_ENHANCE, | |||
| ) | |||
| from .db.repository.knowledge_file_repository import get_file_detail | |||
| from .kb_service.base import ( | |||
| KBServiceFactory, | |||
| get_kb_file_details, | |||
| ) | |||
| from .model.kb_document_model import DocumentWithVSId | |||
| from .utils import ( | |||
| KnowledgeFile, | |||
| files2docs_in_thread, | |||
| get_file_path, | |||
| list_files_from_folder, | |||
| validate_kb_name, | |||
| ) | |||
| from ..utils.system_utils import ( | |||
| BaseResponse, | |||
| ListResponse, | |||
| run_in_thread_pool, | |||
| ) | |||
| def search_docs( | |||
| query: str = Body("", description="用户输入", examples=["你好"]), | |||
| knowledge_base_name: str = Body( | |||
| ..., description="知识库名称", examples=["samples"] | |||
| ), | |||
| top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), | |||
| score_threshold: float = Body( | |||
| SCORE_THRESHOLD, | |||
| description="知识库匹配相关度阈值,取值范围在0-1之间," | |||
| "SCORE越小,相关度越高," | |||
| "取到1相当于不筛选,建议设置在0.5左右", | |||
| ge=0.0, | |||
| le=1.0, | |||
| ), | |||
| file_name: str = Body("", description="文件名称,支持 sql 通配符"), | |||
| metadata: dict = Body({}, description="根据 metadata 进行过滤,仅支持一级键"), | |||
| ) -> List[Dict]: | |||
| kb = KBServiceFactory.get_service_by_name(knowledge_base_name) | |||
| data = [] | |||
| if kb is not None: | |||
| if query: | |||
| docs = kb.search_docs(query, top_k, score_threshold) | |||
| # data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs] | |||
| data = [DocumentWithVSId(**x.dict(), id=x.metadata.get("id")) for x in docs] | |||
| elif file_name or metadata: | |||
| data = kb.list_docs(file_name=file_name, metadata=metadata) | |||
| for d in data: | |||
| if "vector" in d.metadata: | |||
| del d.metadata["vector"] | |||
| return [x.dict() for x in data] | |||
| def list_files(knowledge_base_name: str) -> ListResponse: | |||
| if not validate_kb_name(knowledge_base_name): | |||
| return ListResponse(code=403, msg="Don't attack me", data=[]) | |||
| knowledge_base_name = urllib.parse.unquote(knowledge_base_name) | |||
| kb = KBServiceFactory.get_service_by_name(knowledge_base_name) | |||
| if kb is None: | |||
| return ListResponse( | |||
| code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[] | |||
| ) | |||
| else: | |||
| all_docs = get_kb_file_details(knowledge_base_name) | |||
| return ListResponse(data=all_docs) | |||
| def _save_files_in_thread( | |||
| files: List[UploadFile], knowledge_base_name: str, override: bool | |||
| ): | |||
| """ | |||
| 通过多线程将上传的文件保存到对应知识库目录内。 | |||
| 生成器返回保存结果:{"code":200, "msg": "xxx", "data": {"knowledge_base_name":"xxx", "file_name": "xxx"}} | |||
| """ | |||
| def save_file(file: UploadFile, knowledge_base_name: str, override: bool) -> dict: | |||
| """ | |||
| 保存单个文件。 | |||
| """ | |||
| try: | |||
| filename = file.filename | |||
| file_path = get_file_path( | |||
| knowledge_base_name=knowledge_base_name, doc_name=filename | |||
| ) | |||
| data = {"knowledge_base_name": knowledge_base_name, "file_name": filename} | |||
| file_content = file.file.read() # 读取上传文件的内容 | |||
| if ( | |||
| os.path.isfile(file_path) | |||
| and not override | |||
| and os.path.getsize(file_path) == len(file_content) | |||
| ): | |||
| file_status = f"文件 {filename} 已存在。" | |||
| return dict(code=404, msg=file_status, data=data) | |||
| if not os.path.isdir(os.path.dirname(file_path)): | |||
| os.makedirs(os.path.dirname(file_path)) | |||
| with open(file_path, "wb") as f: | |||
| f.write(file_content) | |||
| return dict(code=200, msg=f"成功上传文件 {filename}", data=data) | |||
| except Exception as e: | |||
| msg = f"{filename} 文件上传失败,报错信息为: {e}" | |||
| return dict(code=500, msg=msg, data=data) | |||
| params = [ | |||
| {"file": file, "knowledge_base_name": knowledge_base_name, "override": override} | |||
| for file in files | |||
| ] | |||
| for result in run_in_thread_pool(save_file, params=params): | |||
| yield result | |||
| # def files2docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"), | |||
| # knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]), | |||
| # override: bool = Form(False, description="覆盖已有文件"), | |||
| # save: bool = Form(True, description="是否将文件保存到知识库目录")): | |||
| # def save_files(files, knowledge_base_name, override): | |||
| # for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override): | |||
| # yield json.dumps(result, ensure_ascii=False) | |||
| # def files_to_docs(files): | |||
| # for result in files2docs_in_thread(files): | |||
| # yield json.dumps(result, ensure_ascii=False) | |||
| def upload_docs( | |||
| files: List[UploadFile] = File(..., description="上传文件,支持多文件"), | |||
| knowledge_base_name: str = Form( | |||
| ..., description="知识库名称", examples=["samples"] | |||
| ), | |||
| override: bool = Form(False, description="覆盖已有文件"), | |||
| to_vector_store: bool = Form(True, description="上传文件后是否进行向量化"), | |||
| chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"), | |||
| chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"), | |||
| zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"), | |||
| docs: str = Form("", description="自定义的docs,需要转为json字符串"), | |||
| not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"), | |||
| ) -> BaseResponse: | |||
| """ | |||
| API接口:上传文件,并/或向量化 | |||
| """ | |||
| if not validate_kb_name(knowledge_base_name): | |||
| return BaseResponse(code=403, msg="Don't attack me") | |||
| kb = KBServiceFactory.get_service_by_name(knowledge_base_name) | |||
| if kb is None: | |||
| return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") | |||
| docs = json.loads(docs) if docs else {} | |||
| failed_files = {} | |||
| file_names = list(docs.keys()) | |||
| # 先将上传的文件保存到磁盘 | |||
| for result in _save_files_in_thread( | |||
| files, knowledge_base_name=knowledge_base_name, override=override | |||
| ): | |||
| filename = result["data"]["file_name"] | |||
| if result["code"] != 200: | |||
| failed_files[filename] = result["msg"] | |||
| if filename not in file_names: | |||
| file_names.append(filename) | |||
| # 对保存的文件进行向量化 | |||
| if to_vector_store: | |||
| result = update_docs( | |||
| knowledge_base_name=knowledge_base_name, | |||
| file_names=file_names, | |||
| override_custom_docs=True, | |||
| chunk_size=chunk_size, | |||
| chunk_overlap=chunk_overlap, | |||
| zh_title_enhance=zh_title_enhance, | |||
| docs=docs, | |||
| not_refresh_vs_cache=True, | |||
| ) | |||
| failed_files.update(result.data["failed_files"]) | |||
| if not not_refresh_vs_cache: | |||
| kb.save_vector_store() | |||
| return BaseResponse( | |||
| code=200, msg="文件上传与向量化完成", data={"failed_files": failed_files} | |||
| ) | |||
| def delete_docs( | |||
| knowledge_base_name: str = Body(..., examples=["samples"]), | |||
| file_names: List[str] = Body(..., examples=[["file_name.md", "test.txt"]]), | |||
| delete_content: bool = Body(False), | |||
| not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"), | |||
| ) -> BaseResponse: | |||
| if not validate_kb_name(knowledge_base_name): | |||
| return BaseResponse(code=403, msg="Don't attack me") | |||
| knowledge_base_name = urllib.parse.unquote(knowledge_base_name) | |||
| kb = KBServiceFactory.get_service_by_name(knowledge_base_name) | |||
| if kb is None: | |||
| return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") | |||
| failed_files = {} | |||
| for file_name in file_names: | |||
| if not kb.exist_doc(file_name): | |||
| failed_files[file_name] = f"未找到文件 {file_name}" | |||
| try: | |||
| kb_file = KnowledgeFile( | |||
| filename=file_name, knowledge_base_name=knowledge_base_name | |||
| ) | |||
| kb.delete_doc(kb_file, delete_content, not_refresh_vs_cache=True) | |||
| except Exception as e: | |||
| msg = f"{file_name} 文件删除失败,错误信息:{e}" | |||
| failed_files[file_name] = msg | |||
| if not not_refresh_vs_cache: | |||
| kb.save_vector_store() | |||
| return BaseResponse( | |||
| code=200, msg=f"文件删除完成", data={"failed_files": failed_files} | |||
| ) | |||
| def update_info( | |||
| knowledge_base_name: str = Body( | |||
| ..., description="知识库名称", examples=["samples"] | |||
| ), | |||
| kb_info: str = Body(..., description="知识库介绍", examples=["这是一个知识库"]), | |||
| ): | |||
| if not validate_kb_name(knowledge_base_name): | |||
| return BaseResponse(code=403, msg="Don't attack me") | |||
| kb = KBServiceFactory.get_service_by_name(knowledge_base_name) | |||
| if kb is None: | |||
| return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") | |||
| kb.update_info(kb_info) | |||
| return BaseResponse(code=200, msg=f"知识库介绍修改完成", data={"kb_info": kb_info}) | |||
| def update_docs( | |||
| knowledge_base_name: str = Body( | |||
| ..., description="知识库名称", examples=["samples"] | |||
| ), | |||
| file_names: List[str] = Body( | |||
| ..., description="文件名称,支持多文件", examples=[["file_name1", "text.txt"]] | |||
| ), | |||
| chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"), | |||
| chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"), | |||
| zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"), | |||
| override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"), | |||
| docs: str = Body("", description="自定义的docs,需要转为json字符串"), | |||
| not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"), | |||
| ) -> BaseResponse: | |||
| """ | |||
| 更新知识库文档 | |||
| """ | |||
| if not validate_kb_name(knowledge_base_name): | |||
| return BaseResponse(code=403, msg="Don't attack me") | |||
| kb = KBServiceFactory.get_service_by_name(knowledge_base_name) | |||
| if kb is None: | |||
| return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") | |||
| failed_files = {} | |||
| kb_files = [] | |||
| docs = json.loads(docs) if docs else {} | |||
| # 生成需要加载docs的文件列表 | |||
| for file_name in file_names: | |||
| file_detail = get_file_detail(kb_name=knowledge_base_name, filename=file_name) | |||
| # 如果该文件之前使用了自定义docs,则根据参数决定略过或覆盖 | |||
| if file_detail.get("custom_docs") and not override_custom_docs: | |||
| continue | |||
| if file_name not in docs: | |||
| try: | |||
| kb_files.append( | |||
| KnowledgeFile( | |||
| filename=file_name, knowledge_base_name=knowledge_base_name | |||
| ) | |||
| ) | |||
| except Exception as e: | |||
| msg = f"加载文档 {file_name} 时出错:{e}" | |||
| failed_files[file_name] = msg | |||
| # 从文件生成docs,并进行向量化。 | |||
| # 这里利用了KnowledgeFile的缓存功能,在多线程中加载Document,然后传给KnowledgeFile | |||
| for status, result in files2docs_in_thread( | |||
| kb_files, | |||
| chunk_size=chunk_size, | |||
| chunk_overlap=chunk_overlap, | |||
| zh_title_enhance=zh_title_enhance, | |||
| ): | |||
| if status: | |||
| kb_name, file_name, new_docs = result | |||
| kb_file = KnowledgeFile( | |||
| filename=file_name, knowledge_base_name=knowledge_base_name | |||
| ) | |||
| kb_file.splited_docs = new_docs | |||
| kb.update_doc(kb_file, not_refresh_vs_cache=True) | |||
| else: | |||
| kb_name, file_name, error = result | |||
| failed_files[file_name] = error | |||
| # 将自定义的docs进行向量化 | |||
| for file_name, v in docs.items(): | |||
| try: | |||
| v = [x if isinstance(x, Document) else Document(**x) for x in v] | |||
| kb_file = KnowledgeFile( | |||
| filename=file_name, knowledge_base_name=knowledge_base_name | |||
| ) | |||
| kb.update_doc(kb_file, docs=v, not_refresh_vs_cache=True) | |||
| except Exception as e: | |||
| msg = f"为 {file_name} 添加自定义docs时出错:{e}" | |||
| failed_files[file_name] = msg | |||
| if not not_refresh_vs_cache: | |||
| kb.save_vector_store() | |||
| return BaseResponse( | |||
| code=200, msg=f"更新文档完成", data={"failed_files": failed_files} | |||
| ) | |||
| def download_doc( | |||
| knowledge_base_name: str = Query( | |||
| ..., description="知识库名称", examples=["samples"] | |||
| ), | |||
| file_name: str = Query(..., description="文件名称", examples=["test.txt"]), | |||
| preview: bool = Query(False, description="是:浏览器内预览;否:下载"), | |||
| ): | |||
| """ | |||
| 下载知识库文档 | |||
| """ | |||
| if not validate_kb_name(knowledge_base_name): | |||
| return BaseResponse(code=403, msg="Don't attack me") | |||
| kb = KBServiceFactory.get_service_by_name(knowledge_base_name) | |||
| if kb is None: | |||
| return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") | |||
| if preview: | |||
| content_disposition_type = "inline" | |||
| else: | |||
| content_disposition_type = None | |||
| try: | |||
| kb_file = KnowledgeFile( | |||
| filename=file_name, knowledge_base_name=knowledge_base_name | |||
| ) | |||
| if os.path.exists(kb_file.filepath): | |||
| return FileResponse( | |||
| path=kb_file.filepath, | |||
| filename=kb_file.filename, | |||
| media_type="multipart/form-data", | |||
| content_disposition_type=content_disposition_type, | |||
| ) | |||
| except Exception as e: | |||
| msg = f"{kb_file.filename} 读取文件失败,错误信息是:{e}" | |||
| return BaseResponse(code=500, msg=msg) | |||
| return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败") | |||
| def recreate_vector_store( | |||
| knowledge_base_name: str = Body(..., examples=["samples"]), | |||
| allow_empty_kb: bool = Body(True), | |||
| vs_type: str = Body(DEFAULT_VS_TYPE), | |||
| embed_model: str = Body(...), | |||
| chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"), | |||
| chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"), | |||
| zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"), | |||
| not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"), | |||
| ): | |||
| """ | |||
| recreate vector store from the content. | |||
| this is usefull when user can copy files to content folder directly instead of upload through network. | |||
| by default, get_service_by_name only return knowledge base in the info.db and having document files in it. | |||
| set allow_empty_kb to True make it applied on empty knowledge base which it not in the info.db or having no documents. | |||
| """ | |||
| def output(): | |||
| kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model) | |||
| if not kb.exists() and not allow_empty_kb: | |||
| yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"} | |||
| else: | |||
| error_msg = ( | |||
| f"could not recreate vector store because failed to access embed model." | |||
| ) | |||
| if not kb.check_embed_model(error_msg): | |||
| yield {"code": 404, "msg": error_msg} | |||
| else: | |||
| if kb.exists(): | |||
| kb.clear_vs() | |||
| kb.create_kb() | |||
| files = list_files_from_folder(knowledge_base_name) | |||
| kb_files = [(file, knowledge_base_name) for file in files] | |||
| i = 0 | |||
| for status, result in files2docs_in_thread( | |||
| kb_files, | |||
| chunk_size=chunk_size, | |||
| chunk_overlap=chunk_overlap, | |||
| zh_title_enhance=zh_title_enhance, | |||
| ): | |||
| if status: | |||
| kb_name, file_name, docs = result | |||
| kb_file = KnowledgeFile( | |||
| filename=file_name, knowledge_base_name=kb_name | |||
| ) | |||
| kb_file.splited_docs = docs | |||
| yield json.dumps( | |||
| { | |||
| "code": 200, | |||
| "msg": f"({i + 1} / {len(files)}): {file_name}", | |||
| "total": len(files), | |||
| "finished": i + 1, | |||
| "doc": file_name, | |||
| }, | |||
| ensure_ascii=False, | |||
| ) | |||
| kb.add_doc(kb_file, not_refresh_vs_cache=True) | |||
| else: | |||
| kb_name, file_name, error = result | |||
| msg = f"添加文件‘{file_name}’到知识库‘{knowledge_base_name}’时出错:{error}。已跳过。" | |||
| yield json.dumps( | |||
| { | |||
| "code": 500, | |||
| "msg": msg, | |||
| } | |||
| ) | |||
| i += 1 | |||
| if not not_refresh_vs_cache: | |||
| kb.save_vector_store() | |||
| return EventSourceResponse(output()) | |||
| @@ -0,0 +1,500 @@ | |||
| import operator | |||
| import os | |||
| from abc import ABC, abstractmethod | |||
| from pathlib import Path | |||
| from typing import Dict, List, Optional, Tuple, Union | |||
| from langchain.docstore.document import Document | |||
| from ...configs import ( | |||
| KB_INFO, | |||
| SCORE_THRESHOLD, | |||
| VECTOR_SEARCH_TOP_K, | |||
| kbs_config, | |||
| DEFAULT_EMBEDDING_MODEL | |||
| ) | |||
| from ..db.models.knowledge_base_model import KnowledgeBaseSchema | |||
| from ..db.repository.knowledge_base_repository import ( | |||
| add_kb_to_db, | |||
| delete_kb_from_db, | |||
| kb_exists, | |||
| list_kbs_from_db, | |||
| load_kb_from_db, | |||
| ) | |||
| from ..db.repository.knowledge_file_repository import ( | |||
| add_file_to_db, | |||
| count_files_from_db, | |||
| delete_file_from_db, | |||
| delete_files_from_db, | |||
| file_exists_in_db, | |||
| get_file_detail, | |||
| list_docs_from_db, | |||
| list_files_from_db, | |||
| ) | |||
| from ..model.kb_document_model import DocumentWithVSId | |||
| from ..utils import ( | |||
| KnowledgeFile, | |||
| get_doc_path, | |||
| get_kb_path, | |||
| list_files_from_folder, | |||
| list_kbs_from_folder, | |||
| ) | |||
| from ...utils.system_utils import check_embed_model as _check_embed_model | |||
| class SupportedVSType: | |||
| FAISS = "faiss" | |||
| MILVUS = "milvus" | |||
| DEFAULT = "default" | |||
| ES = "es" | |||
| class KBService(ABC): | |||
| def __init__( | |||
| self, | |||
| knowledge_base_name: str, | |||
| kb_info: str = None, | |||
| embed_model: str = DEFAULT_EMBEDDING_MODEL, | |||
| ): | |||
| self.kb_name = knowledge_base_name | |||
| self.kb_info = kb_info or KB_INFO.get( | |||
| knowledge_base_name, f"关于{knowledge_base_name}的知识库" | |||
| ) | |||
| self.embed_model = embed_model | |||
| self.kb_path = get_kb_path(self.kb_name) | |||
| self.doc_path = get_doc_path(self.kb_name) | |||
| self.do_init() | |||
| def __repr__(self) -> str: | |||
| return f"{self.kb_name} @ {self.embed_model}" | |||
| def save_vector_store(self): | |||
| """ | |||
| 保存向量库:FAISS保存到磁盘,milvus保存到数据库。PGVector暂未支持 | |||
| """ | |||
| pass | |||
| def check_embed_model(self, error_msg: str) -> bool: | |||
| if not _check_embed_model(self.embed_model): | |||
| return False | |||
| else: | |||
| return True | |||
| def create_kb(self): | |||
| """ | |||
| 创建知识库 | |||
| """ | |||
| if not os.path.exists(self.doc_path): | |||
| os.makedirs(self.doc_path) | |||
| status = add_kb_to_db( | |||
| self.kb_name, self.kb_info, self.vs_type(), self.embed_model | |||
| ) | |||
| if status: | |||
| self.do_create_kb() | |||
| return status | |||
| def clear_vs(self): | |||
| """ | |||
| 删除向量库中所有内容 | |||
| """ | |||
| self.do_clear_vs() | |||
| status = delete_files_from_db(self.kb_name) | |||
| return status | |||
| def drop_kb(self): | |||
| """ | |||
| 删除知识库 | |||
| """ | |||
| self.do_drop_kb() | |||
| status = delete_kb_from_db(self.kb_name) | |||
| return status | |||
| def add_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs): | |||
| """ | |||
| 向知识库添加文件 | |||
| 如果指定了docs,则不再将文本向量化,并将数据库对应条目标为custom_docs=True | |||
| """ | |||
| if not self.check_embed_model( | |||
| f"could not add docs because failed to access embed model." | |||
| ): | |||
| return False | |||
| if docs: | |||
| custom_docs = True | |||
| else: | |||
| docs = kb_file.file2text() | |||
| custom_docs = False | |||
| if docs: | |||
| # 将 metadata["source"] 改为相对路径 | |||
| for doc in docs: | |||
| try: | |||
| doc.metadata.setdefault("source", kb_file.filename) | |||
| source = doc.metadata.get("source", "") | |||
| if os.path.isabs(source): | |||
| rel_path = Path(source).relative_to(self.doc_path) | |||
| doc.metadata["source"] = str(rel_path.as_posix().strip("/")) | |||
| except Exception as e: | |||
| print( | |||
| f"cannot convert absolute path ({source}) to relative path. error is : {e}" | |||
| ) | |||
| self.delete_doc(kb_file) | |||
| doc_infos = self.do_add_doc(docs, **kwargs) | |||
| status = add_file_to_db( | |||
| kb_file, | |||
| custom_docs=custom_docs, | |||
| docs_count=len(docs), | |||
| doc_infos=doc_infos, | |||
| ) | |||
| else: | |||
| status = False | |||
| return status | |||
| def delete_doc( | |||
| self, kb_file: KnowledgeFile, delete_content: bool = False, **kwargs | |||
| ): | |||
| """ | |||
| 从知识库删除文件 | |||
| """ | |||
| self.do_delete_doc(kb_file, **kwargs) | |||
| status = delete_file_from_db(kb_file) | |||
| if delete_content and os.path.exists(kb_file.filepath): | |||
| os.remove(kb_file.filepath) | |||
| return status | |||
| def update_info(self, kb_info: str): | |||
| """ | |||
| 更新知识库介绍 | |||
| """ | |||
| self.kb_info = kb_info | |||
| status = add_kb_to_db( | |||
| self.kb_name, self.kb_info, self.vs_type(), self.embed_model | |||
| ) | |||
| return status | |||
| def update_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs): | |||
| """ | |||
| 使用content中的文件更新向量库 | |||
| 如果指定了docs,则使用自定义docs,并将数据库对应条目标为custom_docs=True | |||
| """ | |||
| if not self.check_embed_model( | |||
| f"could not update docs because failed to access embed model." | |||
| ): | |||
| return False | |||
| if os.path.exists(kb_file.filepath): | |||
| self.delete_doc(kb_file, **kwargs) | |||
| return self.add_doc(kb_file, docs=docs, **kwargs) | |||
| def exist_doc(self, file_name: str): | |||
| return file_exists_in_db( | |||
| KnowledgeFile(knowledge_base_name=self.kb_name, filename=file_name) | |||
| ) | |||
| def list_files(self): | |||
| return list_files_from_db(self.kb_name) | |||
| def count_files(self): | |||
| return count_files_from_db(self.kb_name) | |||
| def search_docs( | |||
| self, | |||
| query: str, | |||
| top_k: int = VECTOR_SEARCH_TOP_K, | |||
| score_threshold: float = SCORE_THRESHOLD, | |||
| ) -> List[Document]: | |||
| if not self.check_embed_model( | |||
| f"could not search docs because failed to access embed model." | |||
| ): | |||
| return [] | |||
| docs = self.do_search(query, top_k, score_threshold) | |||
| return docs | |||
| def get_doc_by_ids(self, ids: List[str]) -> List[Document]: | |||
| return [] | |||
| def del_doc_by_ids(self, ids: List[str]) -> bool: | |||
| raise NotImplementedError | |||
| def update_doc_by_ids(self, docs: Dict[str, Document]) -> bool: | |||
| """ | |||
| 传入参数为: {doc_id: Document, ...} | |||
| 如果对应 doc_id 的值为 None,或其 page_content 为空,则删除该文档 | |||
| """ | |||
| if not self.check_embed_model( | |||
| f"could not update docs because failed to access embed model." | |||
| ): | |||
| return False | |||
| self.del_doc_by_ids(list(docs.keys())) | |||
| pending_docs = [] | |||
| ids = [] | |||
| for _id, doc in docs.items(): | |||
| if not doc or not doc.page_content.strip(): | |||
| continue | |||
| ids.append(_id) | |||
| pending_docs.append(doc) | |||
| self.do_add_doc(docs=pending_docs, ids=ids) | |||
| return True | |||
| def list_docs( | |||
| self, file_name: str = None, metadata: Dict = {} | |||
| ) -> List[DocumentWithVSId]: | |||
| """ | |||
| 通过file_name或metadata检索Document | |||
| """ | |||
| doc_infos = list_docs_from_db( | |||
| kb_name=self.kb_name, file_name=file_name, metadata=metadata | |||
| ) | |||
| docs = [] | |||
| for x in doc_infos: | |||
| doc_info = self.get_doc_by_ids([x["id"]])[0] | |||
| if doc_info is not None: | |||
| # 处理非空的情况 | |||
| doc_with_id = DocumentWithVSId(**doc_info.dict(), id=x["id"]) | |||
| docs.append(doc_with_id) | |||
| else: | |||
| # 处理空的情况 | |||
| # 可以选择跳过当前循环迭代或执行其他操作 | |||
| pass | |||
| return docs | |||
| def get_relative_source_path(self, filepath: str): | |||
| """ | |||
| 将文件路径转化为相对路径,保证查询时一致 | |||
| """ | |||
| relative_path = filepath | |||
| if os.path.isabs(relative_path): | |||
| try: | |||
| relative_path = Path(filepath).relative_to(self.doc_path) | |||
| except Exception as e: | |||
| print( | |||
| f"cannot convert absolute path ({relative_path}) to relative path. error is : {e}" | |||
| ) | |||
| relative_path = str(relative_path.as_posix().strip("/")) | |||
| return relative_path | |||
| @abstractmethod | |||
| def do_create_kb(self): | |||
| """ | |||
| 创建知识库子类实自己逻辑 | |||
| """ | |||
| pass | |||
| @staticmethod | |||
| def list_kbs_type(): | |||
| return list(kbs_config.keys()) | |||
| @classmethod | |||
| def list_kbs(cls): | |||
| return list_kbs_from_db() | |||
| def exists(self, kb_name: str = None): | |||
| kb_name = kb_name or self.kb_name | |||
| return kb_exists(kb_name) | |||
| @abstractmethod | |||
| def vs_type(self) -> str: | |||
| pass | |||
| @abstractmethod | |||
| def do_init(self): | |||
| pass | |||
| @abstractmethod | |||
| def do_drop_kb(self): | |||
| """ | |||
| 删除知识库子类实自己逻辑 | |||
| """ | |||
| pass | |||
| @abstractmethod | |||
| def do_search( | |||
| self, | |||
| query: str, | |||
| top_k: int, | |||
| score_threshold: float, | |||
| ) -> List[Tuple[Document, float]]: | |||
| """ | |||
| 搜索知识库子类实自己逻辑 | |||
| """ | |||
| pass | |||
| @abstractmethod | |||
| def do_add_doc( | |||
| self, | |||
| docs: List[Document], | |||
| **kwargs, | |||
| ) -> List[Dict]: | |||
| """ | |||
| 向知识库添加文档子类实自己逻辑 | |||
| """ | |||
| pass | |||
| @abstractmethod | |||
| def do_delete_doc(self, kb_file: KnowledgeFile): | |||
| """ | |||
| 从知识库删除文档子类实自己逻辑 | |||
| """ | |||
| pass | |||
| @abstractmethod | |||
| def do_clear_vs(self): | |||
| """ | |||
| 从知识库删除全部向量子类实自己逻辑 | |||
| """ | |||
| pass | |||
| class KBServiceFactory: | |||
| @staticmethod | |||
| def get_service( | |||
| kb_name: str, | |||
| vector_store_type: Union[str, SupportedVSType], | |||
| embed_model: str = DEFAULT_EMBEDDING_MODEL, | |||
| kb_info: str = None, | |||
| ) -> KBService: | |||
| if isinstance(vector_store_type, str): | |||
| vector_store_type = getattr(SupportedVSType, vector_store_type.upper()) | |||
| params = { | |||
| "knowledge_base_name": kb_name, | |||
| "embed_model": embed_model, | |||
| "kb_info": kb_info, | |||
| } | |||
| if SupportedVSType.FAISS == vector_store_type: | |||
| from ..kb_service.faiss_kb_service import ( | |||
| FaissKBService, | |||
| ) | |||
| return FaissKBService(**params) | |||
| elif SupportedVSType.MILVUS == vector_store_type: | |||
| from ..kb_service.milvus_kb_service import ( | |||
| MilvusKBService, | |||
| ) | |||
| return MilvusKBService(**params) | |||
| elif SupportedVSType.DEFAULT == vector_store_type: | |||
| from ..kb_service.milvus_kb_service import ( | |||
| MilvusKBService, | |||
| ) | |||
| return MilvusKBService( | |||
| **params | |||
| ) # other milvus parameters are set in model_config.kbs_config | |||
| elif SupportedVSType.ES == vector_store_type: | |||
| from ..kb_service.es_kb_service import ( | |||
| ESKBService, | |||
| ) | |||
| return ESKBService(**params) | |||
| elif ( | |||
| SupportedVSType.DEFAULT == vector_store_type | |||
| ): | |||
| from ..kb_service.default_kb_service import ( | |||
| DefaultKBService, | |||
| ) | |||
| return DefaultKBService(kb_name) | |||
| @staticmethod | |||
| def get_service_by_name(kb_name: str) -> KBService: | |||
| _, vs_type, embed_model = load_kb_from_db(kb_name) | |||
| if _ is None: # kb not in db, just return None | |||
| return None | |||
| return KBServiceFactory.get_service(kb_name, vs_type, embed_model) | |||
| @staticmethod | |||
| def get_default(): | |||
| return KBServiceFactory.get_service("default", SupportedVSType.DEFAULT) | |||
| def get_kb_details() -> List[Dict]: | |||
| kbs_in_folder = list_kbs_from_folder() | |||
| kbs_in_db: List[KnowledgeBaseSchema] = KBService.list_kbs() | |||
| result = {} | |||
| for kb in kbs_in_folder: | |||
| result[kb] = { | |||
| "kb_name": kb, | |||
| "vs_type": "", | |||
| "kb_info": "", | |||
| "embed_model": "", | |||
| "file_count": 0, | |||
| "create_time": None, | |||
| "in_folder": True, | |||
| "in_db": False, | |||
| } | |||
| for kb_detail in kbs_in_db: | |||
| kb_detail = kb_detail.model_dump() | |||
| kb_name = kb_detail["kb_name"] | |||
| kb_detail["in_db"] = True | |||
| if kb_name in result: | |||
| result[kb_name].update(kb_detail) | |||
| else: | |||
| kb_detail["in_folder"] = False | |||
| result[kb_name] = kb_detail | |||
| data = [] | |||
| for i, v in enumerate(result.values()): | |||
| v["No"] = i + 1 | |||
| data.append(v) | |||
| return data | |||
| def get_kb_file_details(kb_name: str) -> List[Dict]: | |||
| kb = KBServiceFactory.get_service_by_name(kb_name) | |||
| if kb is None: | |||
| return [] | |||
| files_in_folder = list_files_from_folder(kb_name) | |||
| files_in_db = kb.list_files() | |||
| result = {} | |||
| for doc in files_in_folder: | |||
| result[doc] = { | |||
| "kb_name": kb_name, | |||
| "file_name": doc, | |||
| "file_ext": os.path.splitext(doc)[-1], | |||
| "file_version": 0, | |||
| "document_loader": "", | |||
| "docs_count": 0, | |||
| "text_splitter": "", | |||
| "create_time": None, | |||
| "in_folder": True, | |||
| "in_db": False, | |||
| } | |||
| lower_names = {x.lower(): x for x in result} | |||
| for doc in files_in_db: | |||
| doc_detail = get_file_detail(kb_name, doc) | |||
| if doc_detail: | |||
| doc_detail["in_db"] = True | |||
| if doc.lower() in lower_names: | |||
| result[lower_names[doc.lower()]].update(doc_detail) | |||
| else: | |||
| doc_detail["in_folder"] = False | |||
| result[doc] = doc_detail | |||
| data = [] | |||
| for i, v in enumerate(result.values()): | |||
| v["No"] = i + 1 | |||
| data.append(v) | |||
| return data | |||
| def score_threshold_process(score_threshold, k, docs): | |||
| if score_threshold is not None: | |||
| cmp = operator.le | |||
| docs = [ | |||
| (doc, similarity) | |||
| for doc, similarity in docs | |||
| if cmp(similarity, score_threshold) | |||
| ] | |||
| return docs[:k] | |||
| @@ -0,0 +1,38 @@ | |||
| from typing import List | |||
| from langchain.embeddings.base import Embeddings | |||
| from langchain.schema import Document | |||
| from ..kb_service.base import KBService | |||
| class DefaultKBService(KBService): | |||
| def do_create_kb(self): | |||
| pass | |||
| def do_drop_kb(self): | |||
| pass | |||
| def do_add_doc(self, docs: List[Document]): | |||
| pass | |||
| def do_clear_vs(self): | |||
| pass | |||
| def vs_type(self) -> str: | |||
| return "default" | |||
| def do_init(self): | |||
| pass | |||
| def do_search(self): | |||
| pass | |||
| def do_insert_multi_knowledge(self): | |||
| pass | |||
| def do_insert_one_knowledge(self): | |||
| pass | |||
| def do_delete_doc(self): | |||
| pass | |||
| @@ -0,0 +1,226 @@ | |||
| import logging | |||
| import os | |||
| import shutil | |||
| from typing import List | |||
| from elasticsearch import BadRequestError, Elasticsearch | |||
| from langchain.schema import Document | |||
| from langchain_community.vectorstores.elasticsearch import ( | |||
| ApproxRetrievalStrategy, | |||
| ElasticsearchStore, | |||
| ) | |||
| from src.mindpilot.app.utils.system_utils import get_Embeddings | |||
| from ...configs import KB_ROOT_PATH, kbs_config | |||
| from ..file_rag.utils import get_Retriever | |||
| from ..kb_service.base import KBService, SupportedVSType | |||
| from ..utils import KnowledgeFile | |||
| logger = logging.getLogger() | |||
| class ESKBService(KBService): | |||
| def do_init(self): | |||
| self.kb_path = self.get_kb_path(self.kb_name) | |||
| self.index_name = os.path.split(self.kb_path)[-1] | |||
| self.IP = kbs_config[self.vs_type()]["host"] | |||
| self.PORT = kbs_config[self.vs_type()]["port"] | |||
| self.user = kbs_config[self.vs_type()].get("user", "") | |||
| self.password = kbs_config[self.vs_type()].get("password", "") | |||
| self.dims_length = kbs_config[self.vs_type()].get("dims_length", None) | |||
| self.embeddings_model = get_Embeddings(self.embed_model) | |||
| try: | |||
| # ES python客户端连接(仅连接) | |||
| if self.user != "" and self.password != "": | |||
| self.es_client_python = Elasticsearch( | |||
| f"http://{self.IP}:{self.PORT}", | |||
| basic_auth=(self.user, self.password), | |||
| ) | |||
| else: | |||
| logger.warning("ES未配置用户名和密码") | |||
| self.es_client_python = Elasticsearch(f"http://{self.IP}:{self.PORT}") | |||
| except ConnectionError: | |||
| logger.error("连接到 Elasticsearch 失败!") | |||
| raise ConnectionError | |||
| except Exception as e: | |||
| logger.error(f"Error 发生 : {e}") | |||
| raise e | |||
| try: | |||
| # 首先尝试通过es_client_python创建 | |||
| mappings = { | |||
| "properties": { | |||
| "dense_vector": { | |||
| "type": "dense_vector", | |||
| "dims": self.dims_length, | |||
| "index": True, | |||
| } | |||
| } | |||
| } | |||
| self.es_client_python.indices.create( | |||
| index=self.index_name, mappings=mappings | |||
| ) | |||
| except BadRequestError as e: | |||
| logger.error("创建索引失败,重新") | |||
| logger.error(e) | |||
| try: | |||
| # langchain ES 连接、创建索引 | |||
| params = dict( | |||
| es_url=f"http://{self.IP}:{self.PORT}", | |||
| index_name=self.index_name, | |||
| query_field="context", | |||
| vector_query_field="dense_vector", | |||
| embedding=self.embeddings_model, | |||
| strategy=ApproxRetrievalStrategy(), | |||
| es_params={ | |||
| "timeout": 60, | |||
| }, | |||
| ) | |||
| if self.user != "" and self.password != "": | |||
| params.update(es_user=self.user, es_password=self.password) | |||
| self.db = ElasticsearchStore(**params) | |||
| except ConnectionError: | |||
| logger.error("### 初始化 Elasticsearch 失败!") | |||
| raise ConnectionError | |||
| except Exception as e: | |||
| logger.error(f"Error 发生 : {e}") | |||
| raise e | |||
| try: | |||
| # 尝试通过db_init创建索引 | |||
| self.db._create_index_if_not_exists( | |||
| index_name=self.index_name, dims_length=self.dims_length | |||
| ) | |||
| except Exception as e: | |||
| logger.error("创建索引失败...") | |||
| logger.error(e) | |||
| # raise e | |||
| @staticmethod | |||
| def get_kb_path(knowledge_base_name: str): | |||
| return os.path.join(KB_ROOT_PATH, knowledge_base_name) | |||
| @staticmethod | |||
| def get_vs_path(knowledge_base_name: str): | |||
| return os.path.join( | |||
| ESKBService.get_kb_path(knowledge_base_name), "vector_store" | |||
| ) | |||
| def do_create_kb(self): ... | |||
| def vs_type(self) -> str: | |||
| return SupportedVSType.ES | |||
| def do_search(self, query: str, top_k: int, score_threshold: float): | |||
| # 文本相似性检索 | |||
| retriever = get_Retriever("vectorstore").from_vectorstore( | |||
| self.db, | |||
| top_k=top_k, | |||
| score_threshold=score_threshold, | |||
| ) | |||
| docs = retriever.get_relevant_documents(query) | |||
| return docs | |||
| def get_doc_by_ids(self, ids: List[str]) -> List[Document]: | |||
| results = [] | |||
| for doc_id in ids: | |||
| try: | |||
| response = self.es_client_python.get(index=self.index_name, id=doc_id) | |||
| source = response["_source"] | |||
| # Assuming your document has "text" and "metadata" fields | |||
| text = source.get("context", "") | |||
| metadata = source.get("metadata", {}) | |||
| results.append(Document(page_content=text, metadata=metadata)) | |||
| except Exception as e: | |||
| logger.error(f"Error retrieving document from Elasticsearch! {e}") | |||
| return results | |||
| def del_doc_by_ids(self, ids: List[str]) -> bool: | |||
| for doc_id in ids: | |||
| try: | |||
| self.es_client_python.delete( | |||
| index=self.index_name, id=doc_id, refresh=True | |||
| ) | |||
| except Exception as e: | |||
| logger.error(f"ES Docs Delete Error! {e}") | |||
| def do_delete_doc(self, kb_file, **kwargs): | |||
| if self.es_client_python.indices.exists(index=self.index_name): | |||
| # 从向量数据库中删除索引(文档名称是Keyword) | |||
| query = { | |||
| "query": { | |||
| "term": { | |||
| "metadata.source.keyword": self.get_relative_source_path( | |||
| kb_file.filepath | |||
| ) | |||
| } | |||
| }, | |||
| "track_total_hits": True, | |||
| } | |||
| # 注意设置size,默认返回10个,es检索设置track_total_hits为True返回数据库中真实的size。 | |||
| size = self.es_client_python.search(body=query)["hits"]["total"]["value"] | |||
| search_results = self.es_client_python.search(body=query, size=size) | |||
| delete_list = [hit["_id"] for hit in search_results["hits"]["hits"]] | |||
| if len(delete_list) == 0: | |||
| return None | |||
| else: | |||
| for doc_id in delete_list: | |||
| try: | |||
| self.es_client_python.delete( | |||
| index=self.index_name, id=doc_id, refresh=True | |||
| ) | |||
| except Exception as e: | |||
| logger.error(f"ES Docs Delete Error! {e}") | |||
| # self.db.delete(ids=delete_list) | |||
| # self.es_client_python.indices.refresh(index=self.index_name) | |||
| def do_add_doc(self, docs: List[Document], **kwargs): | |||
| """向知识库添加文件""" | |||
| print( | |||
| f"server.knowledge_base.kb_service.es_kb_service.do_add_doc 输入的docs参数长度为:{len(docs)}" | |||
| ) | |||
| print("*" * 100) | |||
| self.db.add_documents(documents=docs) | |||
| # 获取 id 和 source , 格式:[{"id": str, "metadata": dict}, ...] | |||
| print("写入数据成功.") | |||
| print("*" * 100) | |||
| if self.es_client_python.indices.exists(index=self.index_name): | |||
| file_path = docs[0].metadata.get("source") | |||
| query = { | |||
| "query": { | |||
| "term": {"metadata.source.keyword": file_path}, | |||
| "term": {"_index": self.index_name}, | |||
| } | |||
| } | |||
| # 注意设置size,默认返回10个。 | |||
| search_results = self.es_client_python.search(body=query, size=50) | |||
| if len(search_results["hits"]["hits"]) == 0: | |||
| raise ValueError("召回元素个数为0") | |||
| info_docs = [ | |||
| {"id": hit["_id"], "metadata": hit["_source"]["metadata"]} | |||
| for hit in search_results["hits"]["hits"] | |||
| ] | |||
| return info_docs | |||
| def do_clear_vs(self): | |||
| """从知识库删除全部向量""" | |||
| if self.es_client_python.indices.exists(index=self.kb_name): | |||
| self.es_client_python.indices.delete(index=self.kb_name) | |||
| def do_drop_kb(self): | |||
| """删除知识库""" | |||
| # self.kb_file: 知识库路径 | |||
| if os.path.exists(self.kb_path): | |||
| shutil.rmtree(self.kb_path) | |||
| if __name__ == "__main__": | |||
| esKBService = ESKBService("test") | |||
| # esKBService.clear_vs() | |||
| # esKBService.create_kb() | |||
| esKBService.add_doc(KnowledgeFile(filename="README.md", knowledge_base_name="test")) | |||
| print(esKBService.search_docs("如何启动api服务")) | |||
| @@ -0,0 +1,136 @@ | |||
| import os | |||
| import shutil | |||
| from typing import Dict, List, Tuple | |||
| from langchain.docstore.document import Document | |||
| from ...configs import SCORE_THRESHOLD | |||
| from ..file_rag.utils import get_Retriever | |||
| from ..kb_cache.faiss_cache import ( | |||
| ThreadSafeFaiss, | |||
| kb_faiss_pool, | |||
| ) | |||
| from ..kb_service.base import KBService, SupportedVSType | |||
| from ..utils import KnowledgeFile, get_kb_path, get_vs_path | |||
| class FaissKBService(KBService): | |||
| vs_path: str | |||
| kb_path: str | |||
| vector_name: str = None | |||
| def vs_type(self) -> str: | |||
| return SupportedVSType.FAISS | |||
| def get_vs_path(self): | |||
| return get_vs_path(self.kb_name, self.vector_name) | |||
| def get_kb_path(self): | |||
| return get_kb_path(self.kb_name) | |||
| def load_vector_store(self) -> ThreadSafeFaiss: | |||
| return kb_faiss_pool.load_vector_store( | |||
| kb_name=self.kb_name, | |||
| vector_name=self.vector_name, | |||
| embed_model=self.embed_model, | |||
| ) | |||
| def save_vector_store(self): | |||
| self.load_vector_store().save(self.vs_path) | |||
| def get_doc_by_ids(self, ids: List[str]) -> List[Document]: | |||
| with self.load_vector_store().acquire() as vs: | |||
| return [vs.docstore._dict.get(id) for id in ids] | |||
| def del_doc_by_ids(self, ids: List[str]) -> bool: | |||
| with self.load_vector_store().acquire() as vs: | |||
| vs.delete(ids) | |||
| def do_init(self): | |||
| self.vector_name = self.vector_name or self.embed_model | |||
| self.kb_path = self.get_kb_path() | |||
| self.vs_path = self.get_vs_path() | |||
| def do_create_kb(self): | |||
| if not os.path.exists(self.vs_path): | |||
| os.makedirs(self.vs_path) | |||
| self.load_vector_store() | |||
| def do_drop_kb(self): | |||
| self.clear_vs() | |||
| try: | |||
| shutil.rmtree(self.kb_path) | |||
| except Exception: | |||
| pass | |||
| def do_search( | |||
| self, | |||
| query: str, | |||
| top_k: int, | |||
| score_threshold: float = SCORE_THRESHOLD, | |||
| ) -> List[Tuple[Document, float]]: | |||
| with self.load_vector_store().acquire() as vs: | |||
| retriever = get_Retriever("ensemble").from_vectorstore( | |||
| vs, | |||
| top_k=top_k, | |||
| score_threshold=score_threshold, | |||
| ) | |||
| docs = retriever.get_relevant_documents(query) | |||
| return docs | |||
| def do_add_doc( | |||
| self, | |||
| docs: List[Document], | |||
| **kwargs, | |||
| ) -> List[Dict]: | |||
| texts = [x.page_content for x in docs] | |||
| metadatas = [x.metadata for x in docs] | |||
| with self.load_vector_store().acquire() as vs: | |||
| embeddings = vs.embeddings.embed_documents(texts) | |||
| ids = vs.add_embeddings( | |||
| text_embeddings=zip(texts, embeddings), metadatas=metadatas | |||
| ) | |||
| if not kwargs.get("not_refresh_vs_cache"): | |||
| vs.save_local(self.vs_path) | |||
| doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)] | |||
| return doc_infos | |||
| def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs): | |||
| with self.load_vector_store().acquire() as vs: | |||
| ids = [ | |||
| k | |||
| for k, v in vs.docstore._dict.items() | |||
| if v.metadata.get("source").lower() == kb_file.filename.lower() | |||
| ] | |||
| if len(ids) > 0: | |||
| vs.delete(ids) | |||
| if not kwargs.get("not_refresh_vs_cache"): | |||
| vs.save_local(self.vs_path) | |||
| return ids | |||
| def do_clear_vs(self): | |||
| with kb_faiss_pool.atomic: | |||
| kb_faiss_pool.pop((self.kb_name, self.vector_name)) | |||
| try: | |||
| shutil.rmtree(self.vs_path) | |||
| except Exception: | |||
| ... | |||
| os.makedirs(self.vs_path, exist_ok=True) | |||
| def exist_doc(self, file_name: str): | |||
| if super().exist_doc(file_name): | |||
| return "in_db" | |||
| content_path = os.path.join(self.kb_path, "content") | |||
| if os.path.isfile(os.path.join(content_path, file_name)): | |||
| return "in_folder" | |||
| else: | |||
| return False | |||
| if __name__ == "__main__": | |||
| faissService = FaissKBService("test") | |||
| faissService.add_doc(KnowledgeFile("README.md", "test")) | |||
| faissService.delete_doc(KnowledgeFile("README.md", "test")) | |||
| faissService.do_drop_kb() | |||
| print(faissService.search_docs("如何启动api服务")) | |||
| @@ -0,0 +1,125 @@ | |||
| import os | |||
| from typing import Dict, List, Optional | |||
| from langchain.schema import Document | |||
| from langchain.vectorstores.milvus import Milvus | |||
| from ...configs import kbs_config | |||
| from ..db.repository import list_file_num_docs_id_by_kb_name_and_file_name | |||
| from ...utils.system_utils import get_Embeddings | |||
| from ..file_rag.utils import get_Retriever | |||
| from ..kb_service.base import ( | |||
| KBService, | |||
| SupportedVSType, | |||
| score_threshold_process, | |||
| ) | |||
| from ..utils import KnowledgeFile | |||
| class MilvusKBService(KBService): | |||
| milvus: Milvus | |||
| @staticmethod | |||
| def get_collection(milvus_name): | |||
| from pymilvus import Collection | |||
| return Collection(milvus_name) | |||
| def get_doc_by_ids(self, ids: List[str]) -> List[Document]: | |||
| result = [] | |||
| if self.milvus.col: | |||
| # ids = [int(id) for id in ids] # for milvus if needed #pr 2725 | |||
| data_list = self.milvus.col.query( | |||
| expr=f"pk in {[int(_id) for _id in ids]}", output_fields=["*"] | |||
| ) | |||
| for data in data_list: | |||
| text = data.pop("text") | |||
| result.append(Document(page_content=text, metadata=data)) | |||
| return result | |||
| def del_doc_by_ids(self, ids: List[str]) -> bool: | |||
| self.milvus.col.delete(expr=f"pk in {ids}") | |||
| @staticmethod | |||
| def search(milvus_name, content, limit=3): | |||
| search_params = { | |||
| "metric_type": "L2", | |||
| "params": {"nprobe": 10}, | |||
| } | |||
| c = MilvusKBService.get_collection(milvus_name) | |||
| return c.search( | |||
| content, "embeddings", search_params, limit=limit, output_fields=["content"] | |||
| ) | |||
| def do_create_kb(self): | |||
| pass | |||
| def vs_type(self) -> str: | |||
| return SupportedVSType.MILVUS | |||
| def _load_milvus(self): | |||
| self.milvus = Milvus( | |||
| embedding_function=get_Embeddings(self.embed_model), | |||
| collection_name=self.kb_name, | |||
| connection_args=kbs_config.get("milvus"), | |||
| index_params=kbs_config.get("milvus_kwargs")["index_params"], | |||
| search_params=kbs_config.get("milvus_kwargs")["search_params"], | |||
| auto_id=True, | |||
| ) | |||
| def do_init(self): | |||
| self._load_milvus() | |||
| def do_drop_kb(self): | |||
| if self.milvus.col: | |||
| self.milvus.col.release() | |||
| self.milvus.col.drop() | |||
| def do_search(self, query: str, top_k: int, score_threshold: float): | |||
| self._load_milvus() | |||
| # embed_func = get_Embeddings(self.embed_model) | |||
| # embeddings = embed_func.embed_query(query) | |||
| # docs = self.milvus.similarity_search_with_score_by_vector(embeddings, top_k) | |||
| retriever = get_Retriever("vectorstore").from_vectorstore( | |||
| self.milvus, | |||
| top_k=top_k, | |||
| score_threshold=score_threshold, | |||
| ) | |||
| docs = retriever.get_relevant_documents(query) | |||
| return docs | |||
| def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: | |||
| for doc in docs: | |||
| for k, v in doc.metadata.items(): | |||
| doc.metadata[k] = str(v) | |||
| for field in self.milvus.fields: | |||
| doc.metadata.setdefault(field, "") | |||
| doc.metadata.pop(self.milvus._text_field, None) | |||
| doc.metadata.pop(self.milvus._vector_field, None) | |||
| ids = self.milvus.add_documents(docs) | |||
| doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)] | |||
| return doc_infos | |||
| def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs): | |||
| id_list = list_file_num_docs_id_by_kb_name_and_file_name( | |||
| kb_file.kb_name, kb_file.filename | |||
| ) | |||
| if self.milvus.col: | |||
| self.milvus.col.delete(expr=f"pk in {id_list}") | |||
| # Issue 2846, for windows | |||
| # if self.milvus.col: | |||
| # file_path = kb_file.filepath.replace("\\", "\\\\") | |||
| # file_name = os.path.basename(file_path) | |||
| # id_list = [item.get("pk") for item in | |||
| # self.milvus.col.query(expr=f'source == "{file_name}"', output_fields=["pk"])] | |||
| # self.milvus.col.delete(expr=f'pk in {id_list}') | |||
| def do_clear_vs(self): | |||
| if self.milvus.col: | |||
| self.do_drop_kb() | |||
| self.do_init() | |||
| @@ -0,0 +1,234 @@ | |||
| import os | |||
| from datetime import datetime | |||
| from typing import List, Literal | |||
| from dateutil.parser import parse | |||
| from ..configs import ( | |||
| CHUNK_SIZE, | |||
| DEFAULT_EMBEDDING_MODEL, | |||
| DEFAULT_VS_TYPE, | |||
| OVERLAP_SIZE, | |||
| ZH_TITLE_ENHANCE, | |||
| ) | |||
| from .db.base import Base, engine | |||
| from .db.repository.knowledge_file_repository import ( | |||
| add_file_to_db, | |||
| ) | |||
| from .db.session import session_scope | |||
| from .kb_service.base import ( | |||
| KBServiceFactory, | |||
| SupportedVSType, | |||
| ) | |||
| from .utils import ( | |||
| KnowledgeFile, | |||
| files2docs_in_thread, | |||
| get_file_path, | |||
| list_files_from_folder, | |||
| list_kbs_from_folder, | |||
| ) | |||
| def create_tables(): | |||
| Base.metadata.create_all(bind=engine) | |||
| def reset_tables(): | |||
| Base.metadata.drop_all(bind=engine) | |||
| create_tables() | |||
| def import_from_db( | |||
| sqlite_path: str = None, | |||
| # csv_path: str = None, | |||
| ) -> bool: | |||
| """ | |||
| 在知识库与向量库无变化的情况下,从备份数据库中导入数据到 info.db。 | |||
| 适用于版本升级时,info.db 结构变化,但无需重新向量化的情况。 | |||
| 请确保两边数据库表名一致,需要导入的字段名一致 | |||
| 当前仅支持 sqlite | |||
| """ | |||
| import sqlite3 as sql | |||
| from pprint import pprint | |||
| models = list(Base.registry.mappers) | |||
| try: | |||
| con = sql.connect(sqlite_path) | |||
| con.row_factory = sql.Row | |||
| cur = con.cursor() | |||
| tables = [ | |||
| x["name"] | |||
| for x in cur.execute( | |||
| "select name from sqlite_master where type='table'" | |||
| ).fetchall() | |||
| ] | |||
| for model in models: | |||
| table = model.local_table.fullname | |||
| if table not in tables: | |||
| continue | |||
| print(f"processing table: {table}") | |||
| with session_scope() as session: | |||
| for row in cur.execute(f"select * from {table}").fetchall(): | |||
| data = {k: row[k] for k in row.keys() if k in model.columns} | |||
| if "create_time" in data: | |||
| data["create_time"] = parse(data["create_time"]) | |||
| pprint(data) | |||
| session.add(model.class_(**data)) | |||
| con.close() | |||
| return True | |||
| except Exception as e: | |||
| print(f"无法读取备份数据库:{sqlite_path}。错误信息:{e}") | |||
| return False | |||
| def file_to_kbfile(kb_name: str, files: List[str]) -> List[KnowledgeFile]: | |||
| kb_files = [] | |||
| for file in files: | |||
| try: | |||
| kb_file = KnowledgeFile(filename=file, knowledge_base_name=kb_name) | |||
| kb_files.append(kb_file) | |||
| except Exception as e: | |||
| msg = f"{e},已跳过" | |||
| return kb_files | |||
| def folder2db( | |||
| kb_names: List[str], | |||
| mode: Literal["recreate_vs", "update_in_db", "increment"], | |||
| vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE, | |||
| embed_model: str = DEFAULT_EMBEDDING_MODEL, | |||
| chunk_size: int = CHUNK_SIZE, | |||
| chunk_overlap: int = OVERLAP_SIZE, | |||
| zh_title_enhance: bool = ZH_TITLE_ENHANCE, | |||
| ): | |||
| """ | |||
| use existed files in local folder to populate database and/or vector store. | |||
| set parameter `mode` to: | |||
| recreate_vs: recreate all vector store and fill info to database using existed files in local folder | |||
| fill_info_only(disabled): do not create vector store, fill info to db using existed files only | |||
| update_in_db: update vector store and database info using local files that existed in database only | |||
| increment: create vector store and database info for local files that not existed in database only | |||
| """ | |||
| def files2vs(kb_name: str, kb_files: List[KnowledgeFile]) -> List: | |||
| result = [] | |||
| for success, res in files2docs_in_thread( | |||
| kb_files, | |||
| chunk_size=chunk_size, | |||
| chunk_overlap=chunk_overlap, | |||
| zh_title_enhance=zh_title_enhance, | |||
| ): | |||
| if success: | |||
| _, filename, docs = res | |||
| print( | |||
| f"正在将 {kb_name}/{filename} 添加到向量库,共包含{len(docs)}条文档" | |||
| ) | |||
| kb_file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name) | |||
| kb_file.splited_docs = docs | |||
| kb.add_doc(kb_file=kb_file, not_refresh_vs_cache=True) | |||
| result.append({"kb_name": kb_name, "file": filename, "docs": docs}) | |||
| else: | |||
| print(res) | |||
| return result | |||
| kb_names = kb_names or list_kbs_from_folder() | |||
| for kb_name in kb_names: | |||
| start = datetime.now() | |||
| kb = KBServiceFactory.get_service(kb_name, vs_type, embed_model) | |||
| if not kb.exists(): | |||
| kb.create_kb() | |||
| # 清除向量库,从本地文件重建 | |||
| if mode == "recreate_vs": | |||
| kb.clear_vs() | |||
| kb.create_kb() | |||
| kb_files = file_to_kbfile(kb_name, list_files_from_folder(kb_name)) | |||
| result = files2vs(kb_name, kb_files) | |||
| kb.save_vector_store() | |||
| # # 不做文件内容的向量化,仅将文件元信息存到数据库 | |||
| # # 由于现在数据库存了很多与文本切分相关的信息,单纯存储文件信息意义不大,该功能取消。 | |||
| # elif mode == "fill_info_only": | |||
| # files = list_files_from_folder(kb_name) | |||
| # kb_files = file_to_kbfile(kb_name, files) | |||
| # for kb_file in kb_files: | |||
| # add_file_to_db(kb_file) | |||
| # print(f"已将 {kb_name}/{kb_file.filename} 添加到数据库") | |||
| # 以数据库中文件列表为基准,利用本地文件更新向量库 | |||
| elif mode == "update_in_db": | |||
| files = kb.list_files() | |||
| kb_files = file_to_kbfile(kb_name, files) | |||
| result = files2vs(kb_name, kb_files) | |||
| kb.save_vector_store() | |||
| # 对比本地目录与数据库中的文件列表,进行增量向量化 | |||
| elif mode == "increment": | |||
| db_files = kb.list_files() | |||
| folder_files = list_files_from_folder(kb_name) | |||
| files = list(set(folder_files) - set(db_files)) | |||
| kb_files = file_to_kbfile(kb_name, files) | |||
| result = files2vs(kb_name, kb_files) | |||
| kb.save_vector_store() | |||
| else: | |||
| print(f"unsupported migrate mode: {mode}") | |||
| end = datetime.now() | |||
| kb_path = ( | |||
| f"知识库路径\t:{kb.kb_path}\n" | |||
| if kb.vs_type() == SupportedVSType.FAISS | |||
| else "" | |||
| ) | |||
| file_count = len(kb_files) | |||
| success_count = len(result) | |||
| docs_count = sum([len(x["docs"]) for x in result]) | |||
| print("\n" + "-" * 100) | |||
| print( | |||
| ( | |||
| f"知识库名称\t:{kb_name}\n" | |||
| f"知识库类型\t:{kb.vs_type()}\n" | |||
| f"向量模型:\t:{kb.embed_model}\n" | |||
| ) | |||
| + kb_path | |||
| + ( | |||
| f"文件总数量\t:{file_count}\n" | |||
| f"入库文件数\t:{success_count}\n" | |||
| f"知识条目数\t:{docs_count}\n" | |||
| f"用时\t\t:{end-start}" | |||
| ) | |||
| ) | |||
| print("-" * 100 + "\n") | |||
| return result | |||
| def prune_db_docs(kb_names: List[str]): | |||
| """ | |||
| delete docs in database that not existed in local folder. | |||
| it is used to delete database docs after user deleted some doc files in file browser | |||
| """ | |||
| for kb_name in kb_names: | |||
| kb = KBServiceFactory.get_service_by_name(kb_name) | |||
| if kb is not None: | |||
| files_in_db = kb.list_files() | |||
| files_in_folder = list_files_from_folder(kb_name) | |||
| files = list(set(files_in_db) - set(files_in_folder)) | |||
| kb_files = file_to_kbfile(kb_name, files) | |||
| for kb_file in kb_files: | |||
| kb.delete_doc(kb_file, not_refresh_vs_cache=True) | |||
| print(f"success to delete docs for file: {kb_name}/{kb_file.filename}") | |||
| kb.save_vector_store() | |||
| def prune_folder_files(kb_names: List[str]): | |||
| """ | |||
| delete doc files in local folder that not existed in database. | |||
| it is used to free local disk space by delete unused doc files. | |||
| """ | |||
| for kb_name in kb_names: | |||
| kb = KBServiceFactory.get_service_by_name(kb_name) | |||
| if kb is not None: | |||
| files_in_db = kb.list_files() | |||
| files_in_folder = list_files_from_folder(kb_name) | |||
| files = list(set(files_in_folder) - set(files_in_db)) | |||
| for file in files: | |||
| os.remove(get_file_path(kb_name, file)) | |||
| print(f"success to delete file: {kb_name}/{file}") | |||
| @@ -0,0 +1,10 @@ | |||
| from langchain.docstore.document import Document | |||
| class DocumentWithVSId(Document): | |||
| """ | |||
| 矢量化后的文档 | |||
| """ | |||
| id: str = None | |||
| score: float = 3.0 | |||
| @@ -0,0 +1,483 @@ | |||
| import importlib | |||
| import json | |||
| import logging | |||
| import os | |||
| from functools import lru_cache | |||
| from pathlib import Path | |||
| from typing import Dict, Generator, List, Tuple, Union | |||
| import chardet | |||
| import langchain_community.document_loaders | |||
| from langchain.docstore.document import Document | |||
| from langchain.text_splitter import MarkdownHeaderTextSplitter, TextSplitter | |||
| from langchain_community.document_loaders import JSONLoader, TextLoader | |||
| from ..configs import ( | |||
| CHUNK_SIZE, | |||
| KB_ROOT_PATH, | |||
| OVERLAP_SIZE, | |||
| TEXT_SPLITTER_NAME, | |||
| ZH_TITLE_ENHANCE, | |||
| text_splitter_dict, | |||
| ) | |||
| from .file_rag.text_splitter import ( | |||
| zh_title_enhance as func_zh_title_enhance, | |||
| ) | |||
| from ..utils.system_utils import run_in_thread_pool | |||
| logger = logging.getLogger() | |||
| def validate_kb_name(knowledge_base_id: str) -> bool: | |||
| # 检查是否包含预期外的字符或路径攻击关键字 | |||
| if "../" in knowledge_base_id: | |||
| return False | |||
| return True | |||
| def get_kb_path(knowledge_base_name: str): | |||
| return os.path.join(KB_ROOT_PATH, knowledge_base_name) | |||
| def get_doc_path(knowledge_base_name: str): | |||
| return os.path.join(get_kb_path(knowledge_base_name), "content") | |||
| def get_vs_path(knowledge_base_name: str, vector_name: str): | |||
| return os.path.join(get_kb_path(knowledge_base_name), "vector_store", vector_name) | |||
| def get_file_path(knowledge_base_name: str, doc_name: str): | |||
| doc_path = Path(get_doc_path(knowledge_base_name)).resolve() | |||
| file_path = (doc_path / doc_name).resolve() | |||
| if str(file_path).startswith(str(doc_path)): | |||
| return str(file_path) | |||
| def list_kbs_from_folder(): | |||
| return [ | |||
| f | |||
| for f in os.listdir(KB_ROOT_PATH) | |||
| if os.path.isdir(os.path.join(KB_ROOT_PATH, f)) | |||
| ] | |||
| def list_files_from_folder(kb_name: str): | |||
| doc_path = get_doc_path(kb_name) | |||
| result = [] | |||
| def is_skiped_path(path: str): | |||
| tail = os.path.basename(path).lower() | |||
| for x in ["temp", "tmp", ".", "~$"]: | |||
| if tail.startswith(x): | |||
| return True | |||
| return False | |||
| def process_entry(entry): | |||
| if is_skiped_path(entry.path): | |||
| return | |||
| if entry.is_symlink(): | |||
| target_path = os.path.realpath(entry.path) | |||
| with os.scandir(target_path) as target_it: | |||
| for target_entry in target_it: | |||
| process_entry(target_entry) | |||
| elif entry.is_file(): | |||
| file_path = Path( | |||
| os.path.relpath(entry.path, doc_path) | |||
| ).as_posix() # 路径统一为 posix 格式 | |||
| result.append(file_path) | |||
| elif entry.is_dir(): | |||
| with os.scandir(entry.path) as it: | |||
| for sub_entry in it: | |||
| process_entry(sub_entry) | |||
| with os.scandir(doc_path) as it: | |||
| for entry in it: | |||
| process_entry(entry) | |||
| return result | |||
| LOADER_DICT = { | |||
| "UnstructuredHTMLLoader": [".html", ".htm"], | |||
| "MHTMLLoader": [".mhtml"], | |||
| "TextLoader": [".md"], | |||
| "UnstructuredMarkdownLoader": [".md"], | |||
| "JSONLoader": [".json"], | |||
| "JSONLinesLoader": [".jsonl"], | |||
| "CSVLoader": [".csv"], | |||
| # "FilteredCSVLoader": [".csv"], 如果使用自定义分割csv | |||
| "RapidOCRPDFLoader": [".pdf"], | |||
| "RapidOCRDocLoader": [".docx", ".doc"], | |||
| "RapidOCRPPTLoader": [ | |||
| ".ppt", | |||
| ".pptx", | |||
| ], | |||
| "RapidOCRLoader": [".png", ".jpg", ".jpeg", ".bmp"], | |||
| "UnstructuredFileLoader": [ | |||
| ".eml", | |||
| ".msg", | |||
| ".rst", | |||
| ".rtf", | |||
| ".txt", | |||
| ".xml", | |||
| ".epub", | |||
| ".odt", | |||
| ".tsv", | |||
| ], | |||
| "UnstructuredEmailLoader": [".eml", ".msg"], | |||
| "UnstructuredEPubLoader": [".epub"], | |||
| "UnstructuredExcelLoader": [".xlsx", ".xls", ".xlsd"], | |||
| "NotebookLoader": [".ipynb"], | |||
| "UnstructuredODTLoader": [".odt"], | |||
| "PythonLoader": [".py"], | |||
| "UnstructuredRSTLoader": [".rst"], | |||
| "UnstructuredRTFLoader": [".rtf"], | |||
| "SRTLoader": [".srt"], | |||
| "TomlLoader": [".toml"], | |||
| "UnstructuredTSVLoader": [".tsv"], | |||
| "UnstructuredWordDocumentLoader": [".docx", ".doc"], | |||
| "UnstructuredXMLLoader": [".xml"], | |||
| "UnstructuredPowerPointLoader": [".ppt", ".pptx"], | |||
| "EverNoteLoader": [".enex"], | |||
| } | |||
| SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist] | |||
| # patch json.dumps to disable ensure_ascii | |||
| def _new_json_dumps(obj, **kwargs): | |||
| kwargs["ensure_ascii"] = False | |||
| return _origin_json_dumps(obj, **kwargs) | |||
| if json.dumps is not _new_json_dumps: | |||
| _origin_json_dumps = json.dumps | |||
| json.dumps = _new_json_dumps | |||
| class JSONLinesLoader(JSONLoader): | |||
| def __init__(self, *args, **kwargs): | |||
| super().__init__(*args, **kwargs) | |||
| self._json_lines = True | |||
| langchain_community.document_loaders.JSONLinesLoader = JSONLinesLoader | |||
| def get_LoaderClass(file_extension): | |||
| for LoaderClass, extensions in LOADER_DICT.items(): | |||
| if file_extension in extensions: | |||
| return LoaderClass | |||
| def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None): | |||
| """ | |||
| 根据loader_name和文件路径或内容返回文档加载器。 | |||
| """ | |||
| loader_kwargs = loader_kwargs or {} | |||
| try: | |||
| if loader_name in [ | |||
| "RapidOCRPDFLoader", | |||
| "RapidOCRLoader", | |||
| "FilteredCSVLoader", | |||
| "RapidOCRDocLoader", | |||
| "RapidOCRPPTLoader", | |||
| ]: | |||
| document_loaders_module = importlib.import_module( | |||
| "chatchat.server.file_rag.document_loaders" | |||
| ) | |||
| else: | |||
| document_loaders_module = importlib.import_module( | |||
| "langchain_community.document_loaders" | |||
| ) | |||
| DocumentLoader = getattr(document_loaders_module, loader_name) | |||
| except Exception as e: | |||
| msg = f"为文件{file_path}查找加载器{loader_name}时出错:{e}" | |||
| logger.error( | |||
| f"{e.__class__.__name__}: {msg}", exc_info=e | |||
| ) | |||
| document_loaders_module = importlib.import_module( | |||
| "langchain_community.document_loaders" | |||
| ) | |||
| DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader") | |||
| if loader_name == "UnstructuredFileLoader": | |||
| loader_kwargs.setdefault("autodetect_encoding", True) | |||
| elif loader_name == "CSVLoader": | |||
| if not loader_kwargs.get("encoding"): | |||
| # 如果未指定 encoding,自动识别文件编码类型,避免langchain loader 加载文件报编码错误 | |||
| with open(file_path, "rb") as struct_file: | |||
| encode_detect = chardet.detect(struct_file.read()) | |||
| if encode_detect is None: | |||
| encode_detect = {"encoding": "utf-8"} | |||
| loader_kwargs["encoding"] = encode_detect["encoding"] | |||
| elif loader_name == "JSONLoader": | |||
| loader_kwargs.setdefault("jq_schema", ".") | |||
| loader_kwargs.setdefault("text_content", False) | |||
| elif loader_name == "JSONLinesLoader": | |||
| loader_kwargs.setdefault("jq_schema", ".") | |||
| loader_kwargs.setdefault("text_content", False) | |||
| loader = DocumentLoader(file_path, **loader_kwargs) | |||
| return loader | |||
| @lru_cache() | |||
| def make_text_splitter(splitter_name, chunk_size, chunk_overlap): | |||
| """ | |||
| 根据参数获取特定的分词器 | |||
| """ | |||
| splitter_name = splitter_name or "SpacyTextSplitter" | |||
| try: | |||
| if ( | |||
| splitter_name == "MarkdownHeaderTextSplitter" | |||
| ): # MarkdownHeaderTextSplitter特殊判定 | |||
| headers_to_split_on = text_splitter_dict[splitter_name][ | |||
| "headers_to_split_on" | |||
| ] | |||
| text_splitter = MarkdownHeaderTextSplitter( | |||
| headers_to_split_on=headers_to_split_on, strip_headers=False | |||
| ) | |||
| else: | |||
| try: # 优先使用用户自定义的text_splitter | |||
| text_splitter_module = importlib.import_module("chatchat.server.file_rag.text_splitter") | |||
| TextSplitter = getattr(text_splitter_module, splitter_name) | |||
| except: # 否则使用langchain的text_splitter | |||
| text_splitter_module = importlib.import_module( | |||
| "langchain.text_splitter" | |||
| ) | |||
| TextSplitter = getattr(text_splitter_module, splitter_name) | |||
| if ( | |||
| text_splitter_dict[splitter_name]["source"] == "tiktoken" | |||
| ): # 从tiktoken加载 | |||
| try: | |||
| text_splitter = TextSplitter.from_tiktoken_encoder( | |||
| encoding_name=text_splitter_dict[splitter_name][ | |||
| "tokenizer_name_or_path" | |||
| ], | |||
| pipeline="zh_core_web_sm", | |||
| chunk_size=chunk_size, | |||
| chunk_overlap=chunk_overlap, | |||
| ) | |||
| except: | |||
| text_splitter = TextSplitter.from_tiktoken_encoder( | |||
| encoding_name=text_splitter_dict[splitter_name][ | |||
| "tokenizer_name_or_path" | |||
| ], | |||
| chunk_size=chunk_size, | |||
| chunk_overlap=chunk_overlap, | |||
| ) | |||
| elif ( | |||
| text_splitter_dict[splitter_name]["source"] == "huggingface" | |||
| ): # 从huggingface加载 | |||
| if ( | |||
| text_splitter_dict[splitter_name]["tokenizer_name_or_path"] | |||
| == "gpt2" | |||
| ): | |||
| from langchain.text_splitter import CharacterTextSplitter | |||
| from mindnlp.transformers import GPT2TokenizerFast | |||
| tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") | |||
| else: # 字符长度加载 | |||
| from mindnlp.transformers import AutoTokenizer | |||
| tokenizer = AutoTokenizer.from_pretrained( | |||
| text_splitter_dict[splitter_name]["tokenizer_name_or_path"], | |||
| trust_remote_code=True, | |||
| ) | |||
| text_splitter = TextSplitter.from_huggingface_tokenizer( | |||
| tokenizer=tokenizer, | |||
| chunk_size=chunk_size, | |||
| chunk_overlap=chunk_overlap, | |||
| ) | |||
| else: | |||
| try: | |||
| text_splitter = TextSplitter( | |||
| pipeline="zh_core_web_sm", | |||
| chunk_size=chunk_size, | |||
| chunk_overlap=chunk_overlap, | |||
| ) | |||
| except: | |||
| text_splitter = TextSplitter( | |||
| chunk_size=chunk_size, chunk_overlap=chunk_overlap | |||
| ) | |||
| except Exception as e: | |||
| print(e) | |||
| text_splitter_module = importlib.import_module("langchain.text_splitter") | |||
| TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter") | |||
| text_splitter = TextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) | |||
| # If you use SpacyTextSplitter you can use GPU to do split likes Issue #1287 | |||
| # text_splitter._tokenizer.max_length = 37016792 | |||
| # text_splitter._tokenizer.prefer_gpu() | |||
| return text_splitter | |||
| class KnowledgeFile: | |||
| def __init__( | |||
| self, | |||
| filename: str, | |||
| knowledge_base_name: str, | |||
| loader_kwargs: Dict = {}, | |||
| ): | |||
| """ | |||
| 对应知识库目录中的文件,必须是磁盘上存在的才能进行向量化等操作。 | |||
| """ | |||
| self.kb_name = knowledge_base_name | |||
| self.filename = str(Path(filename).as_posix()) | |||
| self.ext = os.path.splitext(filename)[-1].lower() | |||
| if self.ext not in SUPPORTED_EXTS: | |||
| raise ValueError(f"暂未支持的文件格式 {self.filename}") | |||
| self.loader_kwargs = loader_kwargs | |||
| self.filepath = get_file_path(knowledge_base_name, filename) | |||
| self.docs = None | |||
| self.splited_docs = None | |||
| self.document_loader_name = get_LoaderClass(self.ext) | |||
| self.text_splitter_name = TEXT_SPLITTER_NAME | |||
| def file2docs(self, refresh: bool = False): | |||
| if self.docs is None or refresh: | |||
| logger.info(f"{self.document_loader_name} used for {self.filepath}") | |||
| loader = get_loader( | |||
| loader_name=self.document_loader_name, | |||
| file_path=self.filepath, | |||
| loader_kwargs=self.loader_kwargs, | |||
| ) | |||
| if isinstance(loader, TextLoader): | |||
| loader.encoding = "utf8" | |||
| self.docs = loader.load() | |||
| else: | |||
| self.docs = loader.load() | |||
| return self.docs | |||
| def docs2texts( | |||
| self, | |||
| docs: List[Document] = None, | |||
| zh_title_enhance: bool = ZH_TITLE_ENHANCE, | |||
| refresh: bool = False, | |||
| chunk_size: int = CHUNK_SIZE, | |||
| chunk_overlap: int = OVERLAP_SIZE, | |||
| text_splitter: TextSplitter = None, | |||
| ): | |||
| docs = docs or self.file2docs(refresh=refresh) | |||
| if not docs: | |||
| return [] | |||
| if self.ext not in [".csv"]: | |||
| if text_splitter is None: | |||
| text_splitter = make_text_splitter( | |||
| splitter_name=self.text_splitter_name, | |||
| chunk_size=chunk_size, | |||
| chunk_overlap=chunk_overlap, | |||
| ) | |||
| if self.text_splitter_name == "MarkdownHeaderTextSplitter": | |||
| docs = text_splitter.split_text(docs[0].page_content) | |||
| else: | |||
| docs = text_splitter.split_documents(docs) | |||
| if not docs: | |||
| return [] | |||
| print(f"文档切分示例:{docs[0]}") | |||
| if zh_title_enhance: | |||
| docs = func_zh_title_enhance(docs) | |||
| self.splited_docs = docs | |||
| return self.splited_docs | |||
| def file2text( | |||
| self, | |||
| zh_title_enhance: bool = ZH_TITLE_ENHANCE, | |||
| refresh: bool = False, | |||
| chunk_size: int = CHUNK_SIZE, | |||
| chunk_overlap: int = OVERLAP_SIZE, | |||
| text_splitter: TextSplitter = None, | |||
| ): | |||
| if self.splited_docs is None or refresh: | |||
| docs = self.file2docs() | |||
| self.splited_docs = self.docs2texts( | |||
| docs=docs, | |||
| zh_title_enhance=zh_title_enhance, | |||
| refresh=refresh, | |||
| chunk_size=chunk_size, | |||
| chunk_overlap=chunk_overlap, | |||
| text_splitter=text_splitter, | |||
| ) | |||
| return self.splited_docs | |||
| def file_exist(self): | |||
| return os.path.isfile(self.filepath) | |||
| def get_mtime(self): | |||
| return os.path.getmtime(self.filepath) | |||
| def get_size(self): | |||
| return os.path.getsize(self.filepath) | |||
| def files2docs_in_thread_file2docs( | |||
| *, file: KnowledgeFile, **kwargs | |||
| ) -> Tuple[bool, Tuple[str, str, List[Document]]]: | |||
| try: | |||
| return True, (file.kb_name, file.filename, file.file2text(**kwargs)) | |||
| except Exception as e: | |||
| msg = f"从文件 {file.kb_name}/{file.filename} 加载文档时出错:{e}" | |||
| logger.error( | |||
| f"{e.__class__.__name__}: {msg}", exc_info=e | |||
| ) | |||
| return False, (file.kb_name, file.filename, msg) | |||
| def files2docs_in_thread( | |||
| files: List[Union[KnowledgeFile, Tuple[str, str], Dict]], | |||
| chunk_size: int = CHUNK_SIZE, | |||
| chunk_overlap: int = OVERLAP_SIZE, | |||
| zh_title_enhance: bool = ZH_TITLE_ENHANCE, | |||
| ) -> Generator: | |||
| """ | |||
| 利用多线程批量将磁盘文件转化成langchain Document. | |||
| 如果传入参数是Tuple,形式为(filename, kb_name) | |||
| 生成器返回值为 status, (kb_name, file_name, docs | error) | |||
| """ | |||
| kwargs_list = [] | |||
| for i, file in enumerate(files): | |||
| kwargs = {} | |||
| try: | |||
| if isinstance(file, tuple) and len(file) >= 2: | |||
| filename = file[0] | |||
| kb_name = file[1] | |||
| file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name) | |||
| elif isinstance(file, dict): | |||
| filename = file.pop("filename") | |||
| kb_name = file.pop("kb_name") | |||
| kwargs.update(file) | |||
| file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name) | |||
| kwargs["file"] = file | |||
| kwargs["chunk_size"] = chunk_size | |||
| kwargs["chunk_overlap"] = chunk_overlap | |||
| kwargs["zh_title_enhance"] = zh_title_enhance | |||
| kwargs_list.append(kwargs) | |||
| except Exception as e: | |||
| yield False, (kb_name, filename, str(e)) | |||
| for result in run_in_thread_pool( | |||
| func=files2docs_in_thread_file2docs, params=kwargs_list | |||
| ): | |||
| yield result | |||
| if __name__ == "__main__": | |||
| from pprint import pprint | |||
| kb_file = KnowledgeFile( | |||
| filename="E:\\LLM\\Data\\Test.md", knowledge_base_name="samples" | |||
| ) | |||
| # kb_file.text_splitter_name = "RecursiveCharacterTextSplitter" | |||
| kb_file.text_splitter_name = "MarkdownHeaderTextSplitter" | |||
| docs = kb_file.file2docs() | |||
| # pprint(docs[-1]) | |||
| texts = kb_file.docs2texts(docs) | |||
| for text in texts: | |||
| print(text) | |||
| @@ -1,12 +1,7 @@ | |||
| import asyncio | |||
| import logging | |||
| import multiprocessing as mp | |||
| import os | |||
| import socket | |||
| import sqlite3 | |||
| import sys | |||
| from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed | |||
| from pathlib import Path | |||
| from concurrent.futures import ThreadPoolExecutor, as_completed | |||
| from typing import ( | |||
| Any, | |||
| Awaitable, | |||
| @@ -14,29 +9,14 @@ from typing import ( | |||
| Dict, | |||
| Generator, | |||
| List, | |||
| Literal, | |||
| Optional, | |||
| Tuple, | |||
| Union, | |||
| ) | |||
| import httpx | |||
| import openai | |||
| from fastapi import FastAPI | |||
| from langchain.tools import BaseTool | |||
| from langchain_core.embeddings import Embeddings | |||
| from langchain_openai.chat_models import ChatOpenAI | |||
| from langchain_openai.llms import OpenAI | |||
| # from chatchat.configs import ( | |||
| # DEFAULT_EMBEDDING_MODEL, | |||
| # DEFAULT_LLM_MODEL, | |||
| # HTTPX_DEFAULT_TIMEOUT, | |||
| # MODEL_PLATFORMS, | |||
| # TEMPERATURE, | |||
| # log_verbose, | |||
| # ) | |||
| from .pydantic_v2 import BaseModel, Field | |||
| from ..configs import DEFAULT_EMBEDDING_MODEL | |||
| from langchain_core.embeddings import Embeddings | |||
| logger = logging.getLogger() | |||
| @@ -209,7 +189,60 @@ class ListResponse(BaseResponse): | |||
| } | |||
| } | |||
| def get_mindpilot_db_connection(): | |||
| conn = sqlite3.connect('mindpilot.db') | |||
| conn.row_factory = sqlite3.Row | |||
| return conn | |||
| def run_in_thread_pool( | |||
| func: Callable, | |||
| params: List[Dict] = [], | |||
| ) -> Generator: | |||
| """ | |||
| 在线程池中批量运行任务,并将运行结果以生成器的形式返回。 | |||
| 请确保任务中的所有操作是线程安全的,任务函数请全部使用关键字参数。 | |||
| """ | |||
| tasks = [] | |||
| with ThreadPoolExecutor() as pool: | |||
| for kwargs in params: | |||
| tasks.append(pool.submit(func, **kwargs)) | |||
| for obj in as_completed(tasks): | |||
| try: | |||
| yield obj.result() | |||
| except Exception as e: | |||
| logger.error(f"error in sub thread: {e}", exc_info=True) | |||
| def get_Embeddings( | |||
| embed_model: str = DEFAULT_EMBEDDING_MODEL, | |||
| ) -> Embeddings: | |||
| from ..knowledge_base.embedding.localai_embeddings import ( | |||
| LocalAIEmbeddings, | |||
| ) | |||
| params = dict(model=embed_model) | |||
| try: | |||
| params.update( | |||
| openai_api_base=f"http://127.0.0.1:7890/v1", | |||
| openai_api_key="EMPTY", | |||
| ) | |||
| return LocalAIEmbeddings(**params) | |||
| except Exception as e: | |||
| logger.error( | |||
| f"failed to create Embeddings for model: {embed_model}.", exc_info=True | |||
| ) | |||
| def check_embed_model(embed_model: str = DEFAULT_EMBEDDING_MODEL) -> bool: | |||
| embeddings = get_Embeddings(embed_model=embed_model) | |||
| try: | |||
| embeddings.embed_query("this is a test") | |||
| return True | |||
| except Exception as e: | |||
| logger.error( | |||
| f"failed to access embed model '{embed_model}': {e}", exc_info=True | |||
| ) | |||
| return False | |||
| @@ -120,6 +120,10 @@ def main(): | |||
| print亮蓝(f"当前工作目录:{cwd}") | |||
| print亮蓝(f"OpenAPI 文档地址:http://{HOST}:{PORT}/docs") | |||
| from app.knowledge_base.migrate import create_tables | |||
| create_tables() | |||
| if sys.version_info < (3, 10): | |||
| loop = asyncio.get_event_loop() | |||
| else: | |||