Browse Source

feat:知识库增删改查

main
gjl 1 year ago
parent
commit
41fde6390b
52 changed files with 4699 additions and 26 deletions
  1. +9
    -1
      requirements.txt
  2. +2
    -0
      src/mindpilot/app/api/api_server.py
  3. +92
    -0
      src/mindpilot/app/api/kb_routes.py
  4. +25
    -2
      src/mindpilot/app/configs/__init__.py
  5. +113
    -0
      src/mindpilot/app/configs/kb_config.py
  6. +0
    -0
      src/mindpilot/app/knowledge_base/__init__.py
  7. +0
    -0
      src/mindpilot/app/knowledge_base/db/__init__.py
  8. +17
    -0
      src/mindpilot/app/knowledge_base/db/base.py
  9. +0
    -0
      src/mindpilot/app/knowledge_base/db/models/__init__.py
  10. +17
    -0
      src/mindpilot/app/knowledge_base/db/models/base.py
  11. +39
    -0
      src/mindpilot/app/knowledge_base/db/models/knowledge_base_model.py
  12. +42
    -0
      src/mindpilot/app/knowledge_base/db/models/knowledge_file_model.py
  13. +31
    -0
      src/mindpilot/app/knowledge_base/db/models/knowledge_metadata_model.py
  14. +3
    -0
      src/mindpilot/app/knowledge_base/db/repository/__init__.py
  15. +93
    -0
      src/mindpilot/app/knowledge_base/db/repository/knowledge_base_repository.py
  16. +245
    -0
      src/mindpilot/app/knowledge_base/db/repository/knowledge_file_repository.py
  17. +77
    -0
      src/mindpilot/app/knowledge_base/db/repository/knowledge_metadata_repository.py
  18. +48
    -0
      src/mindpilot/app/knowledge_base/db/session.py
  19. +380
    -0
      src/mindpilot/app/knowledge_base/embedding/localai_embeddings.py
  20. +0
    -0
      src/mindpilot/app/knowledge_base/file_rag/__init__.py
  21. +87
    -0
      src/mindpilot/app/knowledge_base/file_rag/document_loaders/FilteredCSVloader.py
  22. +4
    -0
      src/mindpilot/app/knowledge_base/file_rag/document_loaders/__init__.py
  23. +79
    -0
      src/mindpilot/app/knowledge_base/file_rag/document_loaders/mydocloader.py
  24. +28
    -0
      src/mindpilot/app/knowledge_base/file_rag/document_loaders/myimgloader.py
  25. +102
    -0
      src/mindpilot/app/knowledge_base/file_rag/document_loaders/mypdfloader.py
  26. +66
    -0
      src/mindpilot/app/knowledge_base/file_rag/document_loaders/mypptloader.py
  27. +21
    -0
      src/mindpilot/app/knowledge_base/file_rag/document_loaders/ocr.py
  28. +3
    -0
      src/mindpilot/app/knowledge_base/file_rag/retrievers/__init__.py
  29. +24
    -0
      src/mindpilot/app/knowledge_base/file_rag/retrievers/base.py
  30. +46
    -0
      src/mindpilot/app/knowledge_base/file_rag/retrievers/ensemble.py
  31. +30
    -0
      src/mindpilot/app/knowledge_base/file_rag/retrievers/vectorstore.py
  32. +4
    -0
      src/mindpilot/app/knowledge_base/file_rag/text_splitter/__init__.py
  33. +35
    -0
      src/mindpilot/app/knowledge_base/file_rag/text_splitter/ali_text_splitter.py
  34. +106
    -0
      src/mindpilot/app/knowledge_base/file_rag/text_splitter/chinese_recursive_text_splitter.py
  35. +77
    -0
      src/mindpilot/app/knowledge_base/file_rag/text_splitter/chinese_text_splitter.py
  36. +100
    -0
      src/mindpilot/app/knowledge_base/file_rag/text_splitter/zh_title_enhance.py
  37. +14
    -0
      src/mindpilot/app/knowledge_base/file_rag/utils.py
  38. +72
    -0
      src/mindpilot/app/knowledge_base/kb_api.py
  39. +94
    -0
      src/mindpilot/app/knowledge_base/kb_cache/base.py
  40. +211
    -0
      src/mindpilot/app/knowledge_base/kb_cache/faiss_cache.py
  41. +451
    -0
      src/mindpilot/app/knowledge_base/kb_doc_api.py
  42. +0
    -0
      src/mindpilot/app/knowledge_base/kb_service/__init__.py
  43. +500
    -0
      src/mindpilot/app/knowledge_base/kb_service/base.py
  44. +38
    -0
      src/mindpilot/app/knowledge_base/kb_service/default_kb_service.py
  45. +226
    -0
      src/mindpilot/app/knowledge_base/kb_service/es_kb_service.py
  46. +136
    -0
      src/mindpilot/app/knowledge_base/kb_service/faiss_kb_service.py
  47. +125
    -0
      src/mindpilot/app/knowledge_base/kb_service/milvus_kb_service.py
  48. +234
    -0
      src/mindpilot/app/knowledge_base/migrate.py
  49. +10
    -0
      src/mindpilot/app/knowledge_base/model/kb_document_model.py
  50. +483
    -0
      src/mindpilot/app/knowledge_base/utils.py
  51. +56
    -23
      src/mindpilot/app/utils/system_utils.py
  52. +4
    -0
      src/mindpilot/main.py

+ 9
- 1
requirements.txt View File

@@ -21,4 +21,12 @@ metaphor_python==0.1.23
langchainhub==0.1.20
langchain-experimental==0.0.62
arxiv==2.1.3
spark_ai_python~=0.4.4
spark_ai_python~=0.4.4
colorama~=0.4.6
tqdm~=4.66.4
numpy~=1.26.4
chardet~=5.2.0
elasticsearch~=8.15.0
tenacity~=8.5.0
pymilvus~=2.4.5
python-dateutil~=2.9.0post0

+ 2
- 0
src/mindpilot/app/api/api_server.py View File

@@ -7,6 +7,7 @@ from .tool_routes import tool_router
from .agent_routes import agent_router
from .config_routes import config_router
from .conversation_routes import conversation_router
from .kb_routes import kb_router


def create_app(run_mode: str = None):
@@ -28,5 +29,6 @@ def create_app(run_mode: str = None):
app.include_router(agent_router)
app.include_router(config_router)
app.include_router(conversation_router)
app.include_router(kb_router)

return app

+ 92
- 0
src/mindpilot/app/api/kb_routes.py View File

@@ -0,0 +1,92 @@
from __future__ import annotations
from typing import List
from fastapi import APIRouter, Request

# from chatchat.server.chat.file_chat import upload_temp_docs
from ..knowledge_base.kb_api import create_kb, delete_kb, list_kbs
from ..knowledge_base.kb_doc_api import (
delete_docs,
download_doc,
list_files,
recreate_vector_store,
search_docs,
update_docs,
update_info,
upload_docs,
)
# from chatchat.server.knowledge_base.kb_summary_api import (
# recreate_summary_vector_store,
# summary_doc_ids_to_vector_store,
# summary_file_to_vector_store,
# )
from ..utils.system_utils import BaseResponse, ListResponse

kb_router = APIRouter(prefix="/knowledge_base", tags=["知识库管理"])


kb_router.get(
"/list_knowledge_bases", response_model=ListResponse, summary="获取知识库列表"
)(list_kbs)

kb_router.post(
"/create_knowledge_base", response_model=BaseResponse, summary="创建知识库"
)(create_kb)

kb_router.post(
"/delete_knowledge_base", response_model=BaseResponse, summary="删除知识库"
)(delete_kb)

kb_router.get(
"/list_files", response_model=ListResponse, summary="获取知识库内的文件列表"
)(list_files)

kb_router.post(
"/search_docs", response_model=List[dict], summary="搜索知识库"
)(search_docs)

kb_router.post(
"/upload_docs",
response_model=BaseResponse,
summary="上传文件到知识库,并/或进行向量化",
)(upload_docs)

kb_router.post(
"/delete_docs", response_model=BaseResponse, summary="删除知识库内指定文件"
)(delete_docs)

kb_router.post(
"/update_info", response_model=BaseResponse, summary="更新知识库介绍"
)(update_info)

kb_router.post(
"/update_docs", response_model=BaseResponse, summary="更新现有文件到知识库"
)(update_docs)

kb_router.get(
"/download_doc", summary="下载对应的知识文件"
)(download_doc)

kb_router.post(
"/recreate_vector_store", summary="根据content中文档重建向量库,流式输出处理进度。"
)(recreate_vector_store)

# kb_router.post(
# "/upload_temp_docs", summary="上传文件到临时目录,用于文件对话。")(upload_temp_docs)
#
#
# summary_router = APIRouter(prefix="/kb_summary_api")
# summary_router.post(
# "/summary_file_to_vector_store", summary="单个知识库根据文件名称摘要"
# )(summary_file_to_vector_store)
#
# summary_router.post(
# "/summary_doc_ids_to_vector_store",
# summary="单个知识库根据doc_ids摘要",
# response_model=BaseResponse,
# )(summary_doc_ids_to_vector_store)
#
# summary_router.post(
# "/recreate_summary_vector_store", summary="重建单个知识库文件摘要"
# )(recreate_summary_vector_store)
#
# kb_router.include_router(summary_router)

+ 25
- 2
src/mindpilot/app/configs/__init__.py View File

@@ -2,6 +2,7 @@ from .system_config import HOST, PORT
from .model_config import MODEL_CONFIG
from .prompt_config import OPENAI_PROMPT, PROMPT_TEMPLATES
from .tool_config import TOOL_CONFIG
from .kb_config import *

__all__ = [
"HOST",
@@ -9,5 +10,27 @@ __all__ = [
"MODEL_CONFIG",
"OPENAI_PROMPT",
"TOOL_CONFIG",
"PROMPT_TEMPLATES"
]
"PROMPT_TEMPLATES",
"DEFAULT_KNOWLEDGE_BASE",
"DEFAULT_VS_TYPE",
"CACHED_VS_NUM",
"CACHED_MEMO_VS_NUM",
"CHUNK_SIZE",
"OVERLAP_SIZE",
"VECTOR_SEARCH_TOP_K",
"SCORE_THRESHOLD",
"DEFAULT_SEARCH_ENGINE",
"SEARCH_ENGINE_TOP_K",
"ZH_TITLE_ENHANCE",
# "PDF_OCR_THRESHOLD",
"KB_INFO",
"CHATCHAT_ROOT",
"KB_ROOT_PATH",
"DB_ROOT_PATH",
"SQLALCHEMY_DATABASE_URI",
"kbs_config",
"text_splitter_dict",
"TEXT_SPLITTER_NAME",
"EMBEDDING_KEYWORD_FILE",
"DEFAULT_EMBEDDING_MODEL",
]

+ 113
- 0
src/mindpilot/app/configs/kb_config.py View File

@@ -0,0 +1,113 @@
import os
from pathlib import Path

DEFAULT_KNOWLEDGE_BASE = "samples"

# 默认向量库/全文检索引擎类型。
DEFAULT_VS_TYPE = "faiss"

# 缓存向量库数量(针对FAISS)
CACHED_VS_NUM = 1

# 缓存临时向量库数量(针对FAISS),用于文件对话
CACHED_MEMO_VS_NUM = 10

# 知识库中单段文本长度(不适用MarkdownHeaderTextSplitter)
CHUNK_SIZE = 250

# 知识库中相邻文本重合长度(不适用MarkdownHeaderTextSplitter)
OVERLAP_SIZE = 50

# 知识库匹配向量数量
VECTOR_SEARCH_TOP_K = 3

# 知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右
SCORE_THRESHOLD = 1

# 默认搜索引擎。可选:bing, duckduckgo, metaphor
DEFAULT_SEARCH_ENGINE = "bing"

# 搜索引擎匹配结题数量
SEARCH_ENGINE_TOP_K = 3

# 是否开启中文标题加强,以及标题增强的相关配置
# 通过增加标题判断,判断哪些文本为标题,并在metadata中进行标记;
# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。
ZH_TITLE_ENHANCE = False

# PDF OCR 控制:只对宽高超过页面一定比例(图片宽/页面宽,图片高/页面高)的图片进行 OCR。
# 这样可以避免 PDF 中一些小图片的干扰,提高非扫描版 PDF 处理速度
PDF_OCR_THRESHOLD = (0.6, 0.6)

# 每个知识库的初始化介绍,用于在初始化知识库时显示和Agent调用,没写则没有介绍,不会被Agent调用。
KB_INFO = {
"samples": "关于本项目issue的解答",
}

CHATCHAT_ROOT = str(Path(__file__).absolute().parent.parent.parent)

KB_ROOT_PATH = os.path.join(CHATCHAT_ROOT, "knowledge_base")

# 数据库默认存储路径。

DB_ROOT_PATH = os.path.join(CHATCHAT_ROOT, "mindpilot.db")
SQLALCHEMY_DATABASE_URI = f"sqlite:///{DB_ROOT_PATH}"

# 可选向量库类型及对应配置
kbs_config = {
"faiss": {},
"milvus": {
"host": "127.0.0.1",
"port": "19530",
"user": "",
"password": "",
"secure": False,
},
"es": {
"host": "127.0.0.1",
"port": "9200",
"index_name": "test_index",
"user": "",
"password": "",
},
"milvus_kwargs": {
"search_params": {"metric_type": "L2"}, # 在此处增加search_params
"index_params": {
"metric_type": "L2",
"index_type": "HNSW",
"params": {"M": 8, "efConstruction": 64},
}, # 在此处增加index_params
},
}

# TextSplitter配置项,如果你不明白其中的含义,就不要修改。
text_splitter_dict = {
"ChineseRecursiveTextSplitter": {
"source": "", # 选择tiktoken则使用openai的方法 "huggingface"
"tokenizer_name_or_path": "",
},
"SpacyTextSplitter": {
"source": "huggingface",
"tokenizer_name_or_path": "gpt2",
},
"RecursiveCharacterTextSplitter": {
"source": "tiktoken",
"tokenizer_name_or_path": "cl100k_base",
},
"MarkdownHeaderTextSplitter": {
"headers_to_split_on": [
("#", "head1"),
("##", "head2"),
("###", "head3"),
("####", "head4"),
]
},
}

# TEXT_SPLITTER 名称
TEXT_SPLITTER_NAME = "ChineseRecursiveTextSplitter"

# Embedding模型定制词语的词表文件
EMBEDDING_KEYWORD_FILE = "embedding_keywords.txt"

DEFAULT_EMBEDDING_MODEL = "bce-embedding-base_v1"

+ 0
- 0
src/mindpilot/app/knowledge_base/__init__.py View File


+ 0
- 0
src/mindpilot/app/knowledge_base/db/__init__.py View File


+ 17
- 0
src/mindpilot/app/knowledge_base/db/base.py View File

@@ -0,0 +1,17 @@
import json

from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base
from sqlalchemy.orm import sessionmaker

from ...configs.kb_config import SQLALCHEMY_DATABASE_URI


engine = create_engine(
SQLALCHEMY_DATABASE_URI,
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
)

SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

Base: DeclarativeMeta = declarative_base()

+ 0
- 0
src/mindpilot/app/knowledge_base/db/models/__init__.py View File


+ 17
- 0
src/mindpilot/app/knowledge_base/db/models/base.py View File

@@ -0,0 +1,17 @@
from datetime import datetime

from sqlalchemy import Column, DateTime, Integer, String


class BaseModel:
"""
基础模型
"""

id = Column(Integer, primary_key=True, index=True, comment="主键ID")
create_time = Column(DateTime, default=datetime.utcnow, comment="创建时间")
update_time = Column(
DateTime, default=None, onupdate=datetime.utcnow, comment="更新时间"
)
create_by = Column(String, default=None, comment="创建者")
update_by = Column(String, default=None, comment="更新者")

+ 39
- 0
src/mindpilot/app/knowledge_base/db/models/knowledge_base_model.py View File

@@ -0,0 +1,39 @@
from datetime import datetime
from typing import Optional

from pydantic import BaseModel
from sqlalchemy import Column, DateTime, Integer, String, func

from ..base import Base


class KnowledgeBaseModel(Base):
"""
知识库模型
"""

__tablename__ = "knowledge_base"
id = Column(Integer, primary_key=True, autoincrement=True, comment="知识库ID")
kb_name = Column(String(50), comment="知识库名称")
kb_info = Column(String(200), comment="知识库简介(用于Agent)")
vs_type = Column(String(50), comment="向量库类型")
embed_model = Column(String(50), comment="嵌入模型名称")
file_count = Column(Integer, default=0, comment="文件数量")
create_time = Column(DateTime, default=func.now(), comment="创建时间")

def __repr__(self):
return f"<KnowledgeBase(id='{self.id}', kb_name='{self.kb_name}',kb_intro='{self.kb_info} vs_type='{self.vs_type}', embed_model='{self.embed_model}', file_count='{self.file_count}', create_time='{self.create_time}')>"


# 创建一个对应的 Pydantic 模型
class KnowledgeBaseSchema(BaseModel):
id: int
kb_name: str
kb_info: Optional[str]
vs_type: Optional[str]
embed_model: Optional[str]
file_count: Optional[int]
create_time: Optional[datetime]

class Config:
from_attributes = True # 确保可以从 ORM 实例进行验证

+ 42
- 0
src/mindpilot/app/knowledge_base/db/models/knowledge_file_model.py View File

@@ -0,0 +1,42 @@
from sqlalchemy import JSON, Boolean, Column, DateTime, Float, Integer, String, func

from ..base import Base


class KnowledgeFileModel(Base):
"""
知识文件模型
"""

__tablename__ = "knowledge_file"
id = Column(Integer, primary_key=True, autoincrement=True, comment="知识文件ID")
file_name = Column(String(255), comment="文件名")
file_ext = Column(String(10), comment="文件扩展名")
kb_name = Column(String(50), comment="所属知识库名称")
document_loader_name = Column(String(50), comment="文档加载器名称")
text_splitter_name = Column(String(50), comment="文本分割器名称")
file_version = Column(Integer, default=1, comment="文件版本")
file_mtime = Column(Float, default=0.0, comment="文件修改时间")
file_size = Column(Integer, default=0, comment="文件大小")
custom_docs = Column(Boolean, default=False, comment="是否自定义docs")
docs_count = Column(Integer, default=0, comment="切分文档数量")
create_time = Column(DateTime, default=func.now(), comment="创建时间")

def __repr__(self):
return f"<KnowledgeFile(id='{self.id}', file_name='{self.file_name}', file_ext='{self.file_ext}', kb_name='{self.kb_name}', document_loader_name='{self.document_loader_name}', text_splitter_name='{self.text_splitter_name}', file_version='{self.file_version}', create_time='{self.create_time}')>"


class FileDocModel(Base):
"""
文件-向量库文档模型
"""

__tablename__ = "file_doc"
id = Column(Integer, primary_key=True, autoincrement=True, comment="ID")
kb_name = Column(String(50), comment="知识库名称")
file_name = Column(String(255), comment="文件名称")
doc_id = Column(String(50), comment="向量库文档ID")
meta_data = Column(JSON, default={})

def __repr__(self):
return f"<FileDoc(id='{self.id}', kb_name='{self.kb_name}', file_name='{self.file_name}', doc_id='{self.doc_id}', metadata='{self.meta_data}')>"

+ 31
- 0
src/mindpilot/app/knowledge_base/db/models/knowledge_metadata_model.py View File

@@ -0,0 +1,31 @@
from sqlalchemy import JSON, Boolean, Column, DateTime, Float, Integer, String, func

from ..base import Base


class SummaryChunkModel(Base):
"""
chunk summary模型,用于存储file_doc中每个doc_id的chunk 片段,
数据来源:
用户输入: 用户上传文件,可填写文件的描述,生成的file_doc中的doc_id,存入summary_chunk中
程序自动切分 对file_doc表meta_data字段信息中存储的页码信息,按每页的页码切分,自定义prompt生成总结文本,将对应页码关联的doc_id存入summary_chunk中
后续任务:
矢量库构建: 对数据库表summary_chunk中summary_context创建索引,构建矢量库,meta_data为矢量库的元数据(doc_ids)
语义关联: 通过用户输入的描述,自动切分的总结文本,计算
语义相似度

"""

__tablename__ = "summary_chunk"
id = Column(Integer, primary_key=True, autoincrement=True, comment="ID")
kb_name = Column(String(50), comment="知识库名称")
summary_context = Column(String(255), comment="总结文本")
summary_id = Column(String(255), comment="总结矢量id")
doc_ids = Column(String(1024), comment="向量库id关联列表")
meta_data = Column(JSON, default={})

def __repr__(self):
return (
f"<SummaryChunk(id='{self.id}', kb_name='{self.kb_name}', summary_context='{self.summary_context}',"
f" doc_ids='{self.doc_ids}', metadata='{self.metadata}')>"
)

+ 3
- 0
src/mindpilot/app/knowledge_base/db/repository/__init__.py View File

@@ -0,0 +1,3 @@
from .knowledge_base_repository import *
from .knowledge_file_repository import *


+ 93
- 0
src/mindpilot/app/knowledge_base/db/repository/knowledge_base_repository.py View File

@@ -0,0 +1,93 @@
from ...db.models.knowledge_base_model import (
KnowledgeBaseModel,
KnowledgeBaseSchema,
)
from ...db.session import with_session


@with_session
def add_kb_to_db(session, kb_name, kb_info, vs_type, embed_model):
# 创建知识库实例
kb = (
session.query(KnowledgeBaseModel)
.filter(KnowledgeBaseModel.kb_name.ilike(kb_name))
.first()
)
if not kb:
kb = KnowledgeBaseModel(
kb_name=kb_name, kb_info=kb_info, vs_type=vs_type, embed_model=embed_model
)
session.add(kb)
else: # update kb with new vs_type and embed_model
kb.kb_info = kb_info
kb.vs_type = vs_type
kb.embed_model = embed_model
return True


@with_session
def list_kbs_from_db(session, min_file_count: int = -1):
kbs = (
session.query(KnowledgeBaseModel)
.filter(KnowledgeBaseModel.file_count > min_file_count)
.all()
)
kbs = [KnowledgeBaseSchema.model_validate(kb) for kb in kbs]
return kbs


@with_session
def kb_exists(session, kb_name):
kb = (
session.query(KnowledgeBaseModel)
.filter(KnowledgeBaseModel.kb_name.ilike(kb_name))
.first()
)
status = True if kb else False
return status


@with_session
def load_kb_from_db(session, kb_name):
kb = (
session.query(KnowledgeBaseModel)
.filter(KnowledgeBaseModel.kb_name.ilike(kb_name))
.first()
)
if kb:
kb_name, vs_type, embed_model = kb.kb_name, kb.vs_type, kb.embed_model
else:
kb_name, vs_type, embed_model = None, None, None
return kb_name, vs_type, embed_model


@with_session
def delete_kb_from_db(session, kb_name):
kb = (
session.query(KnowledgeBaseModel)
.filter(KnowledgeBaseModel.kb_name.ilike(kb_name))
.first()
)
if kb:
session.delete(kb)
return True


@with_session
def get_kb_detail(session, kb_name: str) -> dict:
kb: KnowledgeBaseModel = (
session.query(KnowledgeBaseModel)
.filter(KnowledgeBaseModel.kb_name.ilike(kb_name))
.first()
)
if kb:
return {
"kb_name": kb.kb_name,
"kb_info": kb.kb_info,
"vs_type": kb.vs_type,
"embed_model": kb.embed_model,
"file_count": kb.file_count,
"create_time": kb.create_time,
}
else:
return {}

+ 245
- 0
src/mindpilot/app/knowledge_base/db/repository/knowledge_file_repository.py View File

@@ -0,0 +1,245 @@
from typing import Dict, List

from ...db.models.knowledge_base_model import KnowledgeBaseModel
from ...db.models.knowledge_file_model import (
FileDocModel,
KnowledgeFileModel,
)
from ...db.session import with_session
from ...utils import KnowledgeFile


@with_session
def list_file_num_docs_id_by_kb_name_and_file_name(
session,
kb_name: str,
file_name: str,
) -> List[int]:
"""
列出某知识库某文件对应的所有Document的id。
返回形式:[str, ...]
"""
doc_ids = (
session.query(FileDocModel.doc_id)
.filter_by(kb_name=kb_name, file_name=file_name)
.all()
)
return [int(_id[0]) for _id in doc_ids]


@with_session
def list_docs_from_db(
session,
kb_name: str,
file_name: str = None,
metadata: Dict = {},
) -> List[Dict]:
"""
列出某知识库某文件对应的所有Document。
返回形式:[{"id": str, "metadata": dict}, ...]
"""
docs = session.query(FileDocModel).filter(FileDocModel.kb_name.ilike(kb_name))
if file_name:
docs = docs.filter(FileDocModel.file_name.ilike(file_name))
for k, v in metadata.items():
docs = docs.filter(FileDocModel.meta_data[k].as_string() == str(v))

return [{"id": x.doc_id, "metadata": x.metadata} for x in docs.all()]


@with_session
def delete_docs_from_db(
session,
kb_name: str,
file_name: str = None,
) -> List[Dict]:
"""
删除某知识库某文件对应的所有Document,并返回被删除的Document。
返回形式:[{"id": str, "metadata": dict}, ...]
"""
docs = list_docs_from_db(kb_name=kb_name, file_name=file_name)
query = session.query(FileDocModel).filter(FileDocModel.kb_name.ilike(kb_name))
if file_name:
query = query.filter(FileDocModel.file_name.ilike(file_name))
query.delete(synchronize_session=False)
session.commit()
return docs


@with_session
def add_docs_to_db(session, kb_name: str, file_name: str, doc_infos: List[Dict]):
"""
将某知识库某文件对应的所有Document信息添加到数据库。
doc_infos形式:[{"id": str, "metadata": dict}, ...]
"""
# ! 这里会出现doc_infos为None的情况,需要进一步排查
if doc_infos is None:
print(
"输入的server.db.repository.knowledge_file_repository.add_docs_to_db的doc_infos参数为None"
)
return False
for d in doc_infos:
obj = FileDocModel(
kb_name=kb_name,
file_name=file_name,
doc_id=d["id"],
meta_data=d["metadata"],
)
session.add(obj)
return True


@with_session
def count_files_from_db(session, kb_name: str) -> int:
return (
session.query(KnowledgeFileModel)
.filter(KnowledgeFileModel.kb_name.ilike(kb_name))
.count()
)


@with_session
def list_files_from_db(session, kb_name):
files = (
session.query(KnowledgeFileModel)
.filter(KnowledgeFileModel.kb_name.ilike(kb_name))
.all()
)
docs = [f.file_name for f in files]
return docs


@with_session
def add_file_to_db(
session,
kb_file: KnowledgeFile,
docs_count: int = 0,
custom_docs: bool = False,
doc_infos: List[Dict] = [], # 形式:[{"id": str, "metadata": dict}, ...]
):
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first()
if kb:
# 如果已经存在该文件,则更新文件信息与版本号
existing_file: KnowledgeFileModel = (
session.query(KnowledgeFileModel)
.filter(
KnowledgeFileModel.kb_name.ilike(kb_file.kb_name),
KnowledgeFileModel.file_name.ilike(kb_file.filename),
)
.first()
)
mtime = kb_file.get_mtime()
size = kb_file.get_size()

if existing_file:
existing_file.file_mtime = mtime
existing_file.file_size = size
existing_file.docs_count = docs_count
existing_file.custom_docs = custom_docs
existing_file.file_version += 1
# 否则,添加新文件
else:
new_file = KnowledgeFileModel(
file_name=kb_file.filename,
file_ext=kb_file.ext,
kb_name=kb_file.kb_name,
document_loader_name=kb_file.document_loader_name,
text_splitter_name=kb_file.text_splitter_name or "SpacyTextSplitter",
file_mtime=mtime,
file_size=size,
docs_count=docs_count,
custom_docs=custom_docs,
)
kb.file_count += 1
session.add(new_file)
add_docs_to_db(
kb_name=kb_file.kb_name, file_name=kb_file.filename, doc_infos=doc_infos
)
return True


@with_session
def delete_file_from_db(session, kb_file: KnowledgeFile):
existing_file = (
session.query(KnowledgeFileModel)
.filter(
KnowledgeFileModel.file_name.ilike(kb_file.filename),
KnowledgeFileModel.kb_name.ilike(kb_file.kb_name),
)
.first()
)
if existing_file:
session.delete(existing_file)
delete_docs_from_db(kb_name=kb_file.kb_name, file_name=kb_file.filename)
session.commit()

kb = (
session.query(KnowledgeBaseModel)
.filter(KnowledgeBaseModel.kb_name.ilike(kb_file.kb_name))
.first()
)
if kb:
kb.file_count -= 1
session.commit()
return True


@with_session
def delete_files_from_db(session, knowledge_base_name: str):
session.query(KnowledgeFileModel).filter(
KnowledgeFileModel.kb_name.ilike(knowledge_base_name)
).delete(synchronize_session=False)
session.query(FileDocModel).filter(
FileDocModel.kb_name.ilike(knowledge_base_name)
).delete(synchronize_session=False)
kb = (
session.query(KnowledgeBaseModel)
.filter(KnowledgeBaseModel.kb_name.ilike(knowledge_base_name))
.first()
)
if kb:
kb.file_count = 0

session.commit()
return True


@with_session
def file_exists_in_db(session, kb_file: KnowledgeFile):
existing_file = (
session.query(KnowledgeFileModel)
.filter(
KnowledgeFileModel.file_name.ilike(kb_file.filename),
KnowledgeFileModel.kb_name.ilike(kb_file.kb_name),
)
.first()
)
return True if existing_file else False


@with_session
def get_file_detail(session, kb_name: str, filename: str) -> dict:
file: KnowledgeFileModel = (
session.query(KnowledgeFileModel)
.filter(
KnowledgeFileModel.file_name.ilike(filename),
KnowledgeFileModel.kb_name.ilike(kb_name),
)
.first()
)
if file:
return {
"kb_name": file.kb_name,
"file_name": file.file_name,
"file_ext": file.file_ext,
"file_version": file.file_version,
"document_loader": file.document_loader_name,
"text_splitter": file.text_splitter_name,
"create_time": file.create_time,
"file_mtime": file.file_mtime,
"file_size": file.file_size,
"custom_docs": file.custom_docs,
"docs_count": file.docs_count,
}
else:
return {}

+ 77
- 0
src/mindpilot/app/knowledge_base/db/repository/knowledge_metadata_repository.py View File

@@ -0,0 +1,77 @@
from typing import Dict, List

from ...db.models.knowledge_metadata_model import SummaryChunkModel
from ...db.session import with_session


@with_session
def list_summary_from_db(
session,
kb_name: str,
metadata: Dict = {},
) -> List[Dict]:
"""
列出某知识库chunk summary。
返回形式:[{"id": str, "summary_context": str, "doc_ids": str}, ...]
"""
docs = session.query(SummaryChunkModel).filter(
SummaryChunkModel.kb_name.ilike(kb_name)
)

for k, v in metadata.items():
docs = docs.filter(SummaryChunkModel.meta_data[k].as_string() == str(v))

return [
{
"id": x.id,
"summary_context": x.summary_context,
"summary_id": x.summary_id,
"doc_ids": x.doc_ids,
"metadata": x.metadata,
}
for x in docs.all()
]


@with_session
def delete_summary_from_db(session, kb_name: str) -> List[Dict]:
"""
删除知识库chunk summary,并返回被删除的Dchunk summary。
返回形式:[{"id": str, "summary_context": str, "doc_ids": str}, ...]
"""
docs = list_summary_from_db(kb_name=kb_name)
query = session.query(SummaryChunkModel).filter(
SummaryChunkModel.kb_name.ilike(kb_name)
)
query.delete(synchronize_session=False)
session.commit()
return docs


@with_session
def add_summary_to_db(session, kb_name: str, summary_infos: List[Dict]):
"""
将总结信息添加到数据库。
summary_infos形式:[{"summary_context": str, "doc_ids": str}, ...]
"""
for summary in summary_infos:
obj = SummaryChunkModel(
kb_name=kb_name,
summary_context=summary["summary_context"],
summary_id=summary["summary_id"],
doc_ids=summary["doc_ids"],
meta_data=summary["metadata"],
)
session.add(obj)

session.commit()
return True


@with_session
def count_summary_from_db(session, kb_name: str) -> int:
return (
session.query(SummaryChunkModel)
.filter(SummaryChunkModel.kb_name.ilike(kb_name))
.count()
)

+ 48
- 0
src/mindpilot/app/knowledge_base/db/session.py View File

@@ -0,0 +1,48 @@
from contextlib import contextmanager
from functools import wraps

from sqlalchemy.orm import Session

from .base import SessionLocal


@contextmanager
def session_scope() -> Session:
"""上下文管理器用于自动获取 Session, 避免错误"""
session = SessionLocal()
try:
yield session
session.commit()
except:
session.rollback()
raise
finally:
session.close()


def with_session(f):
@wraps(f)
def wrapper(*args, **kwargs):
with session_scope() as session:
try:
result = f(session, *args, **kwargs)
session.commit()
return result
except:
session.rollback()
raise

return wrapper


def get_db() -> SessionLocal:
db = SessionLocal()
try:
yield db
finally:
db.close()


def get_db0() -> SessionLocal:
db = SessionLocal()
return db

+ 380
- 0
src/mindpilot/app/knowledge_base/embedding/localai_embeddings.py View File

@@ -0,0 +1,380 @@
from __future__ import annotations

import logging
import warnings
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Set,
Tuple,
Union,
)

from langchain_community.utils.openai import is_openai_v1
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.utils import get_from_dict_or_env, get_pydantic_field_names
from tenacity import (
AsyncRetrying,
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)

from ...utils.system_utils import run_in_thread_pool

logger = logging.getLogger(__name__)


def _create_retry_decorator(embeddings: LocalAIEmbeddings) -> Callable[[Any], Any]:
import openai

min_seconds = 4
max_seconds = 10
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
return retry(
reraise=True,
stop=stop_after_attempt(embeddings.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(openai.Timeout)
| retry_if_exception_type(openai.APIError)
| retry_if_exception_type(openai.APIConnectionError)
| retry_if_exception_type(openai.RateLimitError)
| retry_if_exception_type(openai.InternalServerError)
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)


def _async_retry_decorator(embeddings: LocalAIEmbeddings) -> Any:
import openai

min_seconds = 4
max_seconds = 10
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
async_retrying = AsyncRetrying(
reraise=True,
stop=stop_after_attempt(embeddings.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(openai.Timeout)
| retry_if_exception_type(openai.APIError)
| retry_if_exception_type(openai.APIConnectionError)
| retry_if_exception_type(openai.RateLimitError)
| retry_if_exception_type(openai.InternalServerError)
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)

def wrap(func: Callable) -> Callable:
async def wrapped_f(*args: Any, **kwargs: Any) -> Callable:
async for _ in async_retrying:
return await func(*args, **kwargs)
raise AssertionError("this is unreachable")

return wrapped_f

return wrap


# https://stackoverflow.com/questions/76469415/getting-embeddings-of-length-1-from-langchain-openaiembeddings
def _check_response(response: dict) -> dict:
if any([len(d.embedding) == 1 for d in response.data]):
import openai

raise openai.APIError("LocalAI API returned an empty embedding")
return response


def embed_with_retry(embeddings: LocalAIEmbeddings, **kwargs: Any) -> Any:
"""Use tenacity to retry the embedding call."""
retry_decorator = _create_retry_decorator(embeddings)

@retry_decorator
def _embed_with_retry(**kwargs: Any) -> Any:
response = embeddings.client.create(**kwargs)
return _check_response(response)

return _embed_with_retry(**kwargs)


async def async_embed_with_retry(embeddings: LocalAIEmbeddings, **kwargs: Any) -> Any:
"""Use tenacity to retry the embedding call."""

@_async_retry_decorator(embeddings)
async def _async_embed_with_retry(**kwargs: Any) -> Any:
response = await embeddings.async_client.create(**kwargs)
return _check_response(response)

return await _async_embed_with_retry(**kwargs)


class LocalAIEmbeddings(BaseModel, Embeddings):
"""LocalAI embedding models.

Since LocalAI and OpenAI have 1:1 compatibility between APIs, this class
uses the ``openai`` Python package's ``openai.Embedding`` as its client.
Thus, you should have the ``openai`` python package installed, and defeat
the environment variable ``OPENAI_API_KEY`` by setting to a random string.
You also need to specify ``OPENAI_API_BASE`` to point to your LocalAI
service endpoint.

Example:
.. code-block:: python

from langchain_community.embeddings import LocalAIEmbeddings
openai = LocalAIEmbeddings(
openai_api_key="random-string",
openai_api_base="http://localhost:8080"
)

"""

client: Any = Field(default=None, exclude=True) #: :meta private:
async_client: Any = Field(default=None, exclude=True) #: :meta private:
model: str = "text-embedding-ada-002"
deployment: str = model
openai_api_version: Optional[str] = None
openai_api_base: Optional[str] = Field(default=None, alias="base_url")
# to support explicit proxy for LocalAI
openai_proxy: Optional[str] = None
embedding_ctx_length: int = 8191
"""The maximum number of tokens to embed at once."""
openai_api_key: Optional[str] = Field(default=None, alias="api_key")
openai_organization: Optional[str] = Field(default=None, alias="organization")
allowed_special: Union[Literal["all"], Set[str]] = set()
disallowed_special: Union[Literal["all"], Set[str], Sequence[str]] = "all"
chunk_size: int = 1000
"""Maximum number of texts to embed in each batch"""
max_retries: int = 6
"""Maximum number of retries to make when generating."""
request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
default=None, alias="timeout"
)
"""Timeout in seconds for the LocalAI request."""
headers: Any = None
show_progress_bar: bool = False
"""Whether to show a progress bar when embedding."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""

class Config:
"""Configuration for this pydantic object."""

allow_population_by_field_name = True

@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
if field_name not in all_required_field_names:
warnings.warn(
f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
)
extra[field_name] = values.pop(field_name)

invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
if invalid_model_kwargs:
raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
f"Instead they were passed in as part of `model_kwargs` parameter."
)

values["model_kwargs"] = extra
return values

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["openai_api_key"] = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY"
)
values["openai_api_base"] = get_from_dict_or_env(
values,
"openai_api_base",
"OPENAI_API_BASE",
default="",
)
values["openai_proxy"] = get_from_dict_or_env(
values,
"openai_proxy",
"OPENAI_PROXY",
default="",
)

default_api_version = ""
values["openai_api_version"] = get_from_dict_or_env(
values,
"openai_api_version",
"OPENAI_API_VERSION",
default=default_api_version,
)
values["openai_organization"] = get_from_dict_or_env(
values,
"openai_organization",
"OPENAI_ORGANIZATION",
default="",
)
try:
import openai

if is_openai_v1():
client_params = {
"api_key": values["openai_api_key"],
"organization": values["openai_organization"],
"base_url": values["openai_api_base"],
"timeout": values["request_timeout"],
"max_retries": values["max_retries"],
}

if not values.get("client"):
values["client"] = openai.OpenAI(**client_params).embeddings
if not values.get("async_client"):
values["async_client"] = openai.AsyncOpenAI(
**client_params
).embeddings
elif not values.get("client"):
values["client"] = openai.Embedding
else:
pass
except ImportError:
raise ImportError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
return values

@property
def _invocation_params(self) -> Dict:
openai_args = {
"model": self.model,
"timeout": self.request_timeout,
"extra_headers": self.headers,
**self.model_kwargs,
}
if self.openai_proxy:
import openai

openai.proxy = {
"http": self.openai_proxy,
"https": self.openai_proxy,
} # type: ignore[assignment] # noqa: E501
return openai_args

def _embedding_func(self, text: str, *, engine: str) -> List[float]:
"""Call out to LocalAI's embedding endpoint."""
# handle large input text
if self.model.endswith("001"):
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
return (
embed_with_retry(
self,
input=[text],
**self._invocation_params,
)
.data[0]
.embedding
)

async def _aembedding_func(self, text: str, *, engine: str) -> List[float]:
"""Call out to LocalAI's embedding endpoint."""
# handle large input text
if self.model.endswith("001"):
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
return (
(
await async_embed_with_retry(
self,
input=[text],
**self._invocation_params,
)
)
.data[0]
.embedding
)

def embed_documents(
self, texts: List[str], chunk_size: Optional[int] = 0
) -> List[List[float]]:
"""Call out to LocalAI's embedding endpoint for embedding search docs.

Args:
texts: The list of texts to embed.
chunk_size: The chunk size of embeddings. If None, will use the chunk size
specified by the class.

Returns:
List of embeddings, one for each text.
"""

# call _embedding_func for each text with multithreads
def task(seq, text):
return (seq, self._embedding_func(text, engine=self.deployment))

params = [{"seq": i, "text": text} for i, text in enumerate(texts)]
result = list(run_in_thread_pool(func=task, params=params))
result = sorted(result, key=lambda x: x[0])
return [x[1] for x in result]

async def aembed_documents(
self, texts: List[str], chunk_size: Optional[int] = 0
) -> List[List[float]]:
"""Call out to LocalAI's embedding endpoint async for embedding search docs.

Args:
texts: The list of texts to embed.
chunk_size: The chunk size of embeddings. If None, will use the chunk size
specified by the class.

Returns:
List of embeddings, one for each text.
"""
embeddings = []
for text in texts:
response = await self._aembedding_func(text, engine=self.deployment)
embeddings.append(response)
return embeddings

def embed_query(self, text: str) -> List[float]:
"""Call out to LocalAI's embedding endpoint for embedding query text.

Args:
text: The text to embed.

Returns:
Embedding for the text.
"""
embedding = self._embedding_func(text, engine=self.deployment)
return embedding

async def aembed_query(self, text: str) -> List[float]:
"""Call out to LocalAI's embedding endpoint async for embedding query text.

Args:
text: The text to embed.

Returns:
Embedding for the text.
"""
embedding = await self._aembedding_func(text, engine=self.deployment)
return embedding

+ 0
- 0
src/mindpilot/app/knowledge_base/file_rag/__init__.py View File


+ 87
- 0
src/mindpilot/app/knowledge_base/file_rag/document_loaders/FilteredCSVloader.py View File

@@ -0,0 +1,87 @@
## 指定制定列的csv文件加载器

import csv
from io import TextIOWrapper
from typing import Dict, List, Optional

from langchain.docstore.document import Document
from langchain_community.document_loaders import CSVLoader
from langchain_community.document_loaders.helpers import detect_file_encodings


class FilteredCSVLoader(CSVLoader):
def __init__(
self,
file_path: str,
columns_to_read: List[str],
source_column: Optional[str] = None,
metadata_columns: List[str] = [],
csv_args: Optional[Dict] = None,
encoding: Optional[str] = None,
autodetect_encoding: bool = False,
):
super().__init__(
file_path=file_path,
source_column=source_column,
metadata_columns=metadata_columns,
csv_args=csv_args,
encoding=encoding,
autodetect_encoding=autodetect_encoding,
)
self.columns_to_read = columns_to_read

def load(self) -> List[Document]:
"""Load data into document objects."""

docs = []
try:
with open(self.file_path, newline="", encoding=self.encoding) as csvfile:
docs = self.__read_file(csvfile)
except UnicodeDecodeError as e:
if self.autodetect_encoding:
detected_encodings = detect_file_encodings(self.file_path)
for encoding in detected_encodings:
try:
with open(
self.file_path, newline="", encoding=encoding.encoding
) as csvfile:
docs = self.__read_file(csvfile)
break
except UnicodeDecodeError:
continue
else:
raise RuntimeError(f"Error loading {self.file_path}") from e
except Exception as e:
raise RuntimeError(f"Error loading {self.file_path}") from e

return docs

def __read_file(self, csvfile: TextIOWrapper) -> List[Document]:
docs = []
csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
for i, row in enumerate(csv_reader):
content = []
for col in self.columns_to_read:
if col in row:
content.append(f"{col}:{str(row[col])}")
else:
raise ValueError(
f"Column '{self.columns_to_read[0]}' not found in CSV file."
)
content = "\n".join(content)
# Extract the source if available
source = (
row.get(self.source_column, None)
if self.source_column is not None
else self.file_path
)
metadata = {"source": source, "row": i}

for col in self.metadata_columns:
if col in row:
metadata[col] = row[col]

doc = Document(page_content=content, metadata=metadata)
docs.append(doc)

return docs

+ 4
- 0
src/mindpilot/app/knowledge_base/file_rag/document_loaders/__init__.py View File

@@ -0,0 +1,4 @@
from .mydocloader import RapidOCRDocLoader
from .myimgloader import RapidOCRLoader
from .mypdfloader import RapidOCRPDFLoader
from .mypptloader import RapidOCRPPTLoader

+ 79
- 0
src/mindpilot/app/knowledge_base/file_rag/document_loaders/mydocloader.py View File

@@ -0,0 +1,79 @@
from typing import List

import tqdm
from langchain_community.document_loaders.unstructured import UnstructuredFileLoader


class RapidOCRDocLoader(UnstructuredFileLoader):
def _get_elements(self) -> List:
def doc2text(filepath):
from io import BytesIO

import numpy as np
from docx import Document, ImagePart
from docx.oxml.table import CT_Tbl
from docx.oxml.text.paragraph import CT_P
from docx.table import Table, _Cell
from docx.text.paragraph import Paragraph
from PIL import Image
from rapidocr_onnxruntime import RapidOCR

ocr = RapidOCR()
doc = Document(filepath)
resp = ""

def iter_block_items(parent):
from docx.document import Document

if isinstance(parent, Document):
parent_elm = parent.element.body
elif isinstance(parent, _Cell):
parent_elm = parent._tc
else:
raise ValueError("RapidOCRDocLoader parse fail")

for child in parent_elm.iterchildren():
if isinstance(child, CT_P):
yield Paragraph(child, parent)
elif isinstance(child, CT_Tbl):
yield Table(child, parent)

b_unit = tqdm.tqdm(
total=len(doc.paragraphs) + len(doc.tables),
desc="RapidOCRDocLoader block index: 0",
)
for i, block in enumerate(iter_block_items(doc)):
b_unit.set_description("RapidOCRDocLoader block index: {}".format(i))
b_unit.refresh()
if isinstance(block, Paragraph):
resp += block.text.strip() + "\n"
images = block._element.xpath(".//pic:pic") # 获取所有图片
for image in images:
for img_id in image.xpath(".//a:blip/@r:embed"): # 获取图片id
part = doc.part.related_parts[
img_id
] # 根据图片id获取对应的图片
if isinstance(part, ImagePart):
image = Image.open(BytesIO(part._blob))
result, _ = ocr(np.array(image))
if result:
ocr_result = [line[1] for line in result]
resp += "\n".join(ocr_result)
elif isinstance(block, Table):
for row in block.rows:
for cell in row.cells:
for paragraph in cell.paragraphs:
resp += paragraph.text.strip() + "\n"
b_unit.update(1)
return resp

text = doc2text(self.file_path)
from unstructured.partition.text import partition_text

return partition_text(text=text, **self.unstructured_kwargs)


if __name__ == "__main__":
loader = RapidOCRDocLoader(file_path="../tests/samples/ocr_test.docx")
docs = loader.load()
print(docs)

+ 28
- 0
src/mindpilot/app/knowledge_base/file_rag/document_loaders/myimgloader.py View File

@@ -0,0 +1,28 @@
from typing import List

from langchain_community.document_loaders.unstructured import UnstructuredFileLoader

from chatchat.server.file_rag.document_loaders.ocr import get_ocr


class RapidOCRLoader(UnstructuredFileLoader):
def _get_elements(self) -> List:
def img2text(filepath):
resp = ""
ocr = get_ocr()
result, _ = ocr(filepath)
if result:
ocr_result = [line[1] for line in result]
resp += "\n".join(ocr_result)
return resp

text = img2text(self.file_path)
from unstructured.partition.text import partition_text

return partition_text(text=text, **self.unstructured_kwargs)


if __name__ == "__main__":
loader = RapidOCRLoader(file_path="../tests/samples/ocr_test.jpg")
docs = loader.load()
print(docs)

+ 102
- 0
src/mindpilot/app/knowledge_base/file_rag/document_loaders/mypdfloader.py View File

@@ -0,0 +1,102 @@
from typing import List

import cv2
import numpy as np
import tqdm
from langchain_community.document_loaders.unstructured import UnstructuredFileLoader
from PIL import Image

from chatchat.configs import PDF_OCR_THRESHOLD
from chatchat.server.file_rag.document_loaders.ocr import get_ocr


class RapidOCRPDFLoader(UnstructuredFileLoader):
def _get_elements(self) -> List:
def rotate_img(img, angle):
"""
img --image
angle --rotation angle
return--rotated img
"""

h, w = img.shape[:2]
rotate_center = (w / 2, h / 2)
# 获取旋转矩阵
# 参数1为旋转中心点;
# 参数2为旋转角度,正值-逆时针旋转;负值-顺时针旋转
# 参数3为各向同性的比例因子,1.0原图,2.0变成原来的2倍,0.5变成原来的0.5倍
M = cv2.getRotationMatrix2D(rotate_center, angle, 1.0)
# 计算图像新边界
new_w = int(h * np.abs(M[0, 1]) + w * np.abs(M[0, 0]))
new_h = int(h * np.abs(M[0, 0]) + w * np.abs(M[0, 1]))
# 调整旋转矩阵以考虑平移
M[0, 2] += (new_w - w) / 2
M[1, 2] += (new_h - h) / 2

rotated_img = cv2.warpAffine(img, M, (new_w, new_h))
return rotated_img

def pdf2text(filepath):
import fitz # pyMuPDF里面的fitz包,不要与pip install fitz混淆
import numpy as np

ocr = get_ocr()
doc = fitz.open(filepath)
resp = ""

b_unit = tqdm.tqdm(
total=doc.page_count, desc="RapidOCRPDFLoader context page index: 0"
)
for i, page in enumerate(doc):
b_unit.set_description(
"RapidOCRPDFLoader context page index: {}".format(i)
)
b_unit.refresh()
text = page.get_text("")
resp += text + "\n"

img_list = page.get_image_info(xrefs=True)
for img in img_list:
if xref := img.get("xref"):
bbox = img["bbox"]
# 检查图片尺寸是否超过设定的阈值
if (bbox[2] - bbox[0]) / (page.rect.width) < PDF_OCR_THRESHOLD[
0
] or (bbox[3] - bbox[1]) / (
page.rect.height
) < PDF_OCR_THRESHOLD[1]:
continue
pix = fitz.Pixmap(doc, xref)
samples = pix.samples
if int(page.rotation) != 0: # 如果Page有旋转角度,则旋转图片
img_array = np.frombuffer(
pix.samples, dtype=np.uint8
).reshape(pix.height, pix.width, -1)
tmp_img = Image.fromarray(img_array)
ori_img = cv2.cvtColor(np.array(tmp_img), cv2.COLOR_RGB2BGR)
rot_img = rotate_img(img=ori_img, angle=360 - page.rotation)
img_array = cv2.cvtColor(rot_img, cv2.COLOR_RGB2BGR)
else:
img_array = np.frombuffer(
pix.samples, dtype=np.uint8
).reshape(pix.height, pix.width, -1)

result, _ = ocr(img_array)
if result:
ocr_result = [line[1] for line in result]
resp += "\n".join(ocr_result)

# 更新进度
b_unit.update(1)
return resp

text = pdf2text(self.file_path)
from unstructured.partition.text import partition_text

return partition_text(text=text, **self.unstructured_kwargs)


if __name__ == "__main__":
loader = RapidOCRPDFLoader(file_path="/Users/tonysong/Desktop/test.pdf")
docs = loader.load()
print(docs)

+ 66
- 0
src/mindpilot/app/knowledge_base/file_rag/document_loaders/mypptloader.py View File

@@ -0,0 +1,66 @@
from typing import List

import tqdm
from langchain_community.document_loaders.unstructured import UnstructuredFileLoader


class RapidOCRPPTLoader(UnstructuredFileLoader):
def _get_elements(self) -> List:
def ppt2text(filepath):
from io import BytesIO

import numpy as np
from PIL import Image
from pptx import Presentation
from rapidocr_onnxruntime import RapidOCR

ocr = RapidOCR()
prs = Presentation(filepath)
resp = ""

def extract_text(shape):
nonlocal resp
if shape.has_text_frame:
resp += shape.text.strip() + "\n"
if shape.has_table:
for row in shape.table.rows:
for cell in row.cells:
for paragraph in cell.text_frame.paragraphs:
resp += paragraph.text.strip() + "\n"
if shape.shape_type == 13: # 13 表示图片
image = Image.open(BytesIO(shape.image.blob))
result, _ = ocr(np.array(image))
if result:
ocr_result = [line[1] for line in result]
resp += "\n".join(ocr_result)
elif shape.shape_type == 6: # 6 表示组合
for child_shape in shape.shapes:
extract_text(child_shape)

b_unit = tqdm.tqdm(
total=len(prs.slides), desc="RapidOCRPPTLoader slide index: 1"
)
# 遍历所有幻灯片
for slide_number, slide in enumerate(prs.slides, start=1):
b_unit.set_description(
"RapidOCRPPTLoader slide index: {}".format(slide_number)
)
b_unit.refresh()
sorted_shapes = sorted(
slide.shapes, key=lambda x: (x.top, x.left)
) # 从上到下、从左到右遍历
for shape in sorted_shapes:
extract_text(shape)
b_unit.update(1)
return resp

text = ppt2text(self.file_path)
from unstructured.partition.text import partition_text

return partition_text(text=text, **self.unstructured_kwargs)


if __name__ == "__main__":
loader = RapidOCRPPTLoader(file_path="../tests/samples/ocr_test.pptx")
docs = loader.load()
print(docs)

+ 21
- 0
src/mindpilot/app/knowledge_base/file_rag/document_loaders/ocr.py View File

@@ -0,0 +1,21 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
try:
from rapidocr_paddle import RapidOCR
except ImportError:
from rapidocr_onnxruntime import RapidOCR


def get_ocr(use_cuda: bool = True) -> "RapidOCR":
try:
from rapidocr_paddle import RapidOCR

ocr = RapidOCR(
det_use_cuda=use_cuda, cls_use_cuda=use_cuda, rec_use_cuda=use_cuda
)
except ImportError:
from rapidocr_onnxruntime import RapidOCR

ocr = RapidOCR()
return ocr

+ 3
- 0
src/mindpilot/app/knowledge_base/file_rag/retrievers/__init__.py View File

@@ -0,0 +1,3 @@
from ..retrievers.base import BaseRetrieverService
from ..retrievers.ensemble import EnsembleRetrieverService
from ..retrievers.vectorstore import VectorstoreRetrieverService

+ 24
- 0
src/mindpilot/app/knowledge_base/file_rag/retrievers/base.py View File

@@ -0,0 +1,24 @@
from abc import ABCMeta, abstractmethod

from langchain.vectorstores import VectorStore


class BaseRetrieverService(metaclass=ABCMeta):
def __init__(self, **kwargs):
self.do_init(**kwargs)

@abstractmethod
def do_init(self, **kwargs):
pass

@abstractmethod
def from_vectorstore(
vectorstore: VectorStore,
top_k: int,
score_threshold: int or float,
):
pass

@abstractmethod
def get_relevant_documents(self, query: str):
pass

+ 46
- 0
src/mindpilot/app/knowledge_base/file_rag/retrievers/ensemble.py View File

@@ -0,0 +1,46 @@
from langchain.retrievers import EnsembleRetriever
from langchain.vectorstores import VectorStore
from langchain_community.retrievers import BM25Retriever
from langchain_core.retrievers import BaseRetriever

from .base import BaseRetrieverService


class EnsembleRetrieverService(BaseRetrieverService):
def do_init(
self,
retriever: BaseRetriever = None,
top_k: int = 5,
):
self.vs = None
self.top_k = top_k
self.retriever = retriever

@staticmethod
def from_vectorstore(
vectorstore: VectorStore,
top_k: int,
score_threshold: int or float,
):
faiss_retriever = vectorstore.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"score_threshold": score_threshold, "k": top_k},
)
# TODO: 换个不用torch的实现方式
# from cutword.cutword import Cutter
import jieba

# cutter = Cutter()
docs = list(vectorstore.docstore._dict.values())
bm25_retriever = BM25Retriever.from_documents(
docs,
preprocess_func=jieba.lcut_for_search,
)
bm25_retriever.k = top_k
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5]
)
return EnsembleRetrieverService(retriever=ensemble_retriever)

def get_relevant_documents(self, query: str):
return self.retriever.get_relevant_documents(query)[: self.top_k]

+ 30
- 0
src/mindpilot/app/knowledge_base/file_rag/retrievers/vectorstore.py View File

@@ -0,0 +1,30 @@
from langchain.vectorstores import VectorStore
from langchain_core.retrievers import BaseRetriever

from .base import BaseRetrieverService


class VectorstoreRetrieverService(BaseRetrieverService):
def do_init(
self,
retriever: BaseRetriever = None,
top_k: int = 5,
):
self.vs = None
self.top_k = top_k
self.retriever = retriever

@staticmethod
def from_vectorstore(
vectorstore: VectorStore,
top_k: int,
score_threshold: int or float,
):
retriever = vectorstore.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"score_threshold": score_threshold, "k": top_k},
)
return VectorstoreRetrieverService(retriever=retriever)

def get_relevant_documents(self, query: str):
return self.retriever.get_relevant_documents(query)[: self.top_k]

+ 4
- 0
src/mindpilot/app/knowledge_base/file_rag/text_splitter/__init__.py View File

@@ -0,0 +1,4 @@
from .ali_text_splitter import AliTextSplitter
from .chinese_recursive_text_splitter import ChineseRecursiveTextSplitter
from .chinese_text_splitter import ChineseTextSplitter
from .zh_title_enhance import zh_title_enhance

+ 35
- 0
src/mindpilot/app/knowledge_base/file_rag/text_splitter/ali_text_splitter.py View File

@@ -0,0 +1,35 @@
import re
from typing import List

from langchain.text_splitter import CharacterTextSplitter


class AliTextSplitter(CharacterTextSplitter):
def __init__(self, pdf: bool = False, **kwargs):
super().__init__(**kwargs)
self.pdf = pdf

def split_text(self, text: str) -> List[str]:
# use_document_segmentation参数指定是否用语义切分文档,此处采取的文档语义分割模型为达摩院开源的nlp_bert_document-segmentation_chinese-base,论文见https://arxiv.org/abs/2107.09278
# 如果使用模型进行文档语义切分,那么需要安装modelscope[nlp]:pip install "modelscope[nlp]" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
# 考虑到使用了三个模型,可能对于低配置gpu不太友好,因此这里将模型load进cpu计算,有需要的话可以替换device为自己的显卡id
if self.pdf:
text = re.sub(r"\n{3,}", r"\n", text)
text = re.sub("\s", " ", text)
text = re.sub("\n\n", "", text)
try:
from modelscope.pipelines import pipeline
except ImportError:
raise ImportError(
"Could not import modelscope python package. "
"Please install modelscope with `pip install modelscope`. "
)

p = pipeline(
task="document-segmentation",
model="damo/nlp_bert_document-segmentation_chinese-base",
device="cpu",
)
result = p(documents=text)
sent_list = [i for i in result["text"].split("\n\t") if i]
return sent_list

+ 106
- 0
src/mindpilot/app/knowledge_base/file_rag/text_splitter/chinese_recursive_text_splitter.py View File

@@ -0,0 +1,106 @@
import logging
import re
from typing import Any, List, Optional

from langchain.text_splitter import RecursiveCharacterTextSplitter

logger = logging.getLogger(__name__)


def _split_text_with_regex_from_end(
text: str, separator: str, keep_separator: bool
) -> List[str]:
# Now that we have the separator, split the text
if separator:
if keep_separator:
# The parentheses in the pattern keep the delimiters in the result.
_splits = re.split(f"({separator})", text)
splits = ["".join(i) for i in zip(_splits[0::2], _splits[1::2])]
if len(_splits) % 2 == 1:
splits += _splits[-1:]
# splits = [_splits[0]] + splits
else:
splits = re.split(separator, text)
else:
splits = list(text)
return [s for s in splits if s != ""]


class ChineseRecursiveTextSplitter(RecursiveCharacterTextSplitter):
def __init__(
self,
separators: Optional[List[str]] = None,
keep_separator: bool = True,
is_separator_regex: bool = True,
**kwargs: Any,
) -> None:
"""Create a new TextSplitter."""
super().__init__(keep_separator=keep_separator, **kwargs)
self._separators = separators or [
"\n\n",
"\n",
"。|!|?",
"\.\s|\!\s|\?\s",
";|;\s",
",|,\s",
]
self._is_separator_regex = is_separator_regex

def _split_text(self, text: str, separators: List[str]) -> List[str]:
"""Split incoming text and return chunks."""
final_chunks = []
# Get appropriate separator to use
separator = separators[-1]
new_separators = []
for i, _s in enumerate(separators):
_separator = _s if self._is_separator_regex else re.escape(_s)
if _s == "":
separator = _s
break
if re.search(_separator, text):
separator = _s
new_separators = separators[i + 1 :]
break

_separator = separator if self._is_separator_regex else re.escape(separator)
splits = _split_text_with_regex_from_end(text, _separator, self._keep_separator)

# Now go merging things, recursively splitting longer texts.
_good_splits = []
_separator = "" if self._keep_separator else separator
for s in splits:
if self._length_function(s) < self._chunk_size:
_good_splits.append(s)
else:
if _good_splits:
merged_text = self._merge_splits(_good_splits, _separator)
final_chunks.extend(merged_text)
_good_splits = []
if not new_separators:
final_chunks.append(s)
else:
other_info = self._split_text(s, new_separators)
final_chunks.extend(other_info)
if _good_splits:
merged_text = self._merge_splits(_good_splits, _separator)
final_chunks.extend(merged_text)
return [
re.sub(r"\n{2,}", "\n", chunk.strip())
for chunk in final_chunks
if chunk.strip() != ""
]


if __name__ == "__main__":
text_splitter = ChineseRecursiveTextSplitter(
keep_separator=True, is_separator_regex=True, chunk_size=50, chunk_overlap=0
)
ls = [
"""中国对外贸易形势报告(75页)。前 10 个月,一般贸易进出口 19.5 万亿元,增长 25.1%, 比整体进出口增速高出 2.9 个百分点,占进出口总额的 61.7%,较去年同期提升 1.6 个百分点。其中,一般贸易出口 10.6 万亿元,增长 25.3%,占出口总额的 60.9%,提升 1.5 个百分点;进口8.9万亿元,增长24.9%,占进口总额的62.7%, 提升 1.8 个百分点。加工贸易进出口 6.8 万亿元,增长 11.8%, 占进出口总额的 21.5%,减少 2.0 个百分点。其中,出口增 长 10.4%,占出口总额的 24.3%,减少 2.6 个百分点;进口增 长 14.2%,占进口总额的 18.0%,减少 1.2 个百分点。此外, 以保税物流方式进出口 3.96 万亿元,增长 27.9%。其中,出 口 1.47 万亿元,增长 38.9%;进口 2.49 万亿元,增长 22.2%。前三季度,中国服务贸易继续保持快速增长态势。服务 进出口总额 37834.3 亿元,增长 11.6%;其中服务出口 17820.9 亿元,增长 27.3%;进口 20013.4 亿元,增长 0.5%,进口增 速实现了疫情以来的首次转正。服务出口增幅大于进口 26.8 个百分点,带动服务贸易逆差下降 62.9%至 2192.5 亿元。服 务贸易结构持续优化,知识密集型服务进出口 16917.7 亿元, 增长 13.3%,占服务进出口总额的比重达到 44.7%,提升 0.7 个百分点。 二、中国对外贸易发展环境分析和展望 全球疫情起伏反复,经济复苏分化加剧,大宗商品价格 上涨、能源紧缺、运力紧张及发达经济体政策调整外溢等风 险交织叠加。同时也要看到,我国经济长期向好的趋势没有 改变,外贸企业韧性和活力不断增强,新业态新模式加快发 展,创新转型步伐提速。产业链供应链面临挑战。美欧等加快出台制造业回迁计 划,加速产业链供应链本土布局,跨国公司调整产业链供应 链,全球双链面临新一轮重构,区域化、近岸化、本土化、 短链化趋势凸显。疫苗供应不足,制造业“缺芯”、物流受限、 运价高企,全球产业链供应链面临压力。 全球通胀持续高位运行。能源价格上涨加大主要经济体 的通胀压力,增加全球经济复苏的不确定性。世界银行今年 10 月发布《大宗商品市场展望》指出,能源价格在 2021 年 大涨逾 80%,并且仍将在 2022 年小幅上涨。IMF 指出,全 球通胀上行风险加剧,通胀前景存在巨大不确定性。""",
]
# text = """"""
for inum, text in enumerate(ls):
print(inum)
chunks = text_splitter.split_text(text)
for chunk in chunks:
print(chunk)

+ 77
- 0
src/mindpilot/app/knowledge_base/file_rag/text_splitter/chinese_text_splitter.py View File

@@ -0,0 +1,77 @@
import re
from typing import List

from langchain.text_splitter import CharacterTextSplitter


class ChineseTextSplitter(CharacterTextSplitter):
def __init__(self, pdf: bool = False, sentence_size: int = 250, **kwargs):
super().__init__(**kwargs)
self.pdf = pdf
self.sentence_size = sentence_size

def split_text1(self, text: str) -> List[str]:
if self.pdf:
text = re.sub(r"\n{3,}", "\n", text)
text = re.sub("\s", " ", text)
text = text.replace("\n\n", "")
sent_sep_pattern = re.compile(
'([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))'
) # del :;
sent_list = []
for ele in sent_sep_pattern.split(text):
if sent_sep_pattern.match(ele) and sent_list:
sent_list[-1] += ele
elif ele:
sent_list.append(ele)
return sent_list

def split_text(self, text: str) -> List[str]: ##此处需要进一步优化逻辑
if self.pdf:
text = re.sub(r"\n{3,}", r"\n", text)
text = re.sub("\s", " ", text)
text = re.sub("\n\n", "", text)

text = re.sub(r"([;;.!?。!?\?])([^”’])", r"\1\n\2", text) # 单字符断句符
text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text) # 英文省略号
text = re.sub(r'(\…{2})([^"’”」』])', r"\1\n\2", text) # 中文省略号
text = re.sub(
r'([;;!?。!?\?]["’”」』]{0,2})([^;;!?,。!?\?])', r"\1\n\2", text
)
# 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后,注意前面的几句都小心保留了双引号
text = text.rstrip() # 段尾如果有多余的\n就去掉它
# 很多规则中会考虑分号;,但是这里我把它忽略不计,破折号、英文双引号等同样忽略,需要的再做些简单调整即可。
ls = [i for i in text.split("\n") if i]
for ele in ls:
if len(ele) > self.sentence_size:
ele1 = re.sub(r'([,,.]["’”」』]{0,2})([^,,.])', r"\1\n\2", ele)
ele1_ls = ele1.split("\n")
for ele_ele1 in ele1_ls:
if len(ele_ele1) > self.sentence_size:
ele_ele2 = re.sub(
r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])',
r"\1\n\2",
ele_ele1,
)
ele2_ls = ele_ele2.split("\n")
for ele_ele2 in ele2_ls:
if len(ele_ele2) > self.sentence_size:
ele_ele3 = re.sub(
'( ["’”」』]{0,2})([^ ])', r"\1\n\2", ele_ele2
)
ele2_id = ele2_ls.index(ele_ele2)
ele2_ls = (
ele2_ls[:ele2_id]
+ [i for i in ele_ele3.split("\n") if i]
+ ele2_ls[ele2_id + 1 :]
)
ele_id = ele1_ls.index(ele_ele1)
ele1_ls = (
ele1_ls[:ele_id]
+ [i for i in ele2_ls if i]
+ ele1_ls[ele_id + 1 :]
)

id = ls.index(ele)
ls = ls[:id] + [i for i in ele1_ls if i] + ls[id + 1 :]
return ls

+ 100
- 0
src/mindpilot/app/knowledge_base/file_rag/text_splitter/zh_title_enhance.py View File

@@ -0,0 +1,100 @@
import re

from langchain.docstore.document import Document


def under_non_alpha_ratio(text: str, threshold: float = 0.5):
"""Checks if the proportion of non-alpha characters in the text snippet exceeds a given
threshold. This helps prevent text like "-----------BREAK---------" from being tagged
as a title or narrative text. The ratio does not count spaces.

Parameters
----------
text
The input string to test
threshold
If the proportion of non-alpha characters exceeds this threshold, the function
returns False
"""
if len(text) == 0:
return False

alpha_count = len([char for char in text if char.strip() and char.isalpha()])
total_count = len([char for char in text if char.strip()])
try:
ratio = alpha_count / total_count
return ratio < threshold
except:
return False


def is_possible_title(
text: str,
title_max_word_length: int = 20,
non_alpha_threshold: float = 0.5,
) -> bool:
"""Checks to see if the text passes all of the checks for a valid title.

Parameters
----------
text
The input text to check
title_max_word_length
The maximum number of words a title can contain
non_alpha_threshold
The minimum number of alpha characters the text needs to be considered a title
"""

# 文本长度为0的话,肯定不是title
if len(text) == 0:
print("Not a title. Text is empty.")
return False

# 文本中有标点符号,就不是title
ENDS_IN_PUNCT_PATTERN = r"[^\w\s]\Z"
ENDS_IN_PUNCT_RE = re.compile(ENDS_IN_PUNCT_PATTERN)
if ENDS_IN_PUNCT_RE.search(text) is not None:
return False

# 文本长度不能超过设定值,默认20
# NOTE(robinson) - splitting on spaces here instead of word tokenizing because it
# is less expensive and actual tokenization doesn't add much value for the length check
if len(text) > title_max_word_length:
return False

# 文本中数字的占比不能太高,否则不是title
if under_non_alpha_ratio(text, threshold=non_alpha_threshold):
return False

# NOTE(robinson) - Prevent flagging salutations like "To My Dearest Friends," as titles
if text.endswith((",", ".", ",", "。")):
return False

if text.isnumeric():
print(f"Not a title. Text is all numeric:\n\n{text}") # type: ignore
return False

# 开头的字符内应该有数字,默认5个字符内
if len(text) < 5:
text_5 = text
else:
text_5 = text[:5]
alpha_in_text_5 = sum(list(map(lambda x: x.isnumeric(), list(text_5))))
if not alpha_in_text_5:
return False

return True


def zh_title_enhance(docs: Document) -> Document:
title = None
if len(docs) > 0:
for doc in docs:
if is_possible_title(doc.page_content):
doc.metadata["category"] = "cn_Title"
title = doc.page_content
elif title:
doc.page_content = f"下文与({title})有关。{doc.page_content}"
return docs
else:
print("文件不存在")

+ 14
- 0
src/mindpilot/app/knowledge_base/file_rag/utils.py View File

@@ -0,0 +1,14 @@
from .retrievers import (
BaseRetrieverService,
EnsembleRetrieverService,
VectorstoreRetrieverService,
)

Retrivals = {
"vectorstore": VectorstoreRetrieverService,
"ensemble": EnsembleRetrieverService,
}


def get_Retriever(type: str = "vectorstore") -> BaseRetrieverService:
return Retrivals[type]

+ 72
- 0
src/mindpilot/app/knowledge_base/kb_api.py View File

@@ -0,0 +1,72 @@
import urllib

from fastapi import Body

from .db.repository.knowledge_base_repository import list_kbs_from_db
from .kb_service.base import KBServiceFactory
from .utils import validate_kb_name
from ..utils.system_utils import BaseResponse, ListResponse, logger


def list_kbs():
# Get List of Knowledge Base
return ListResponse(data=list_kbs_from_db())


def create_kb(
knowledge_base_name: str = Body(..., examples=["samples"]),
vector_store_type: str = Body("faiss"),
kb_info: str = Body("", description="知识库内容简介,用于Agent选择知识库。"),
embed_model: str = Body(..., examples=["bce-embedding-base_v1", "bge-large-zh-v1.5"]),
) -> BaseResponse:
# Create selected knowledge base
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
if knowledge_base_name is None or knowledge_base_name.strip() == "":
return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称")

kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is not None:
return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}")

kb = KBServiceFactory.get_service(
knowledge_base_name, vector_store_type, embed_model, kb_info=kb_info
)
try:
kb.create_kb()
except Exception as e:
msg = f"创建知识库出错: {e}"
logger.error(
f"{e.__class__.__name__}: {msg}", exc_info=e
)
return BaseResponse(code=500, msg=msg)

return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")


def delete_kb(
knowledge_base_name: str = Body(..., examples=["samples"]),
) -> BaseResponse:
# Delete selected knowledge base
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)

kb = KBServiceFactory.get_service_by_name(knowledge_base_name)

if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")

try:
status = kb.clear_vs()
status = kb.drop_kb()
if status:
return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}")
except Exception as e:
msg = f"删除知识库时出现意外: {e}"
logger.error(
f"{e.__class__.__name__}: {msg}", exc_info=e
)
return BaseResponse(code=500, msg=msg)

return BaseResponse(code=500, msg=f"删除知识库失败 {knowledge_base_name}")

+ 94
- 0
src/mindpilot/app/knowledge_base/kb_cache/base.py View File

@@ -0,0 +1,94 @@
import threading
from collections import OrderedDict
from contextlib import contextmanager
from typing import Any, List, Tuple, Union

from langchain_community.vectorstores import FAISS


class ThreadSafeObject:
def __init__(
self, key: Union[str, Tuple], obj: Any = None, pool: "CachePool" = None
):
self._obj = obj
self._key = key
self._pool = pool
self._lock = threading.RLock()
self._loaded = threading.Event()

def __repr__(self) -> str:
cls = type(self).__name__
return f"<{cls}: key: {self.key}, obj: {self._obj}>"

@property
def key(self):
return self._key

@contextmanager
def acquire(self, owner: str = "", msg: str = "") -> FAISS:
owner = owner or f"thread {threading.get_native_id()}"
try:
self._lock.acquire()
if self._pool is not None:
self._pool._cache.move_to_end(self.key)
yield self._obj
finally:
self._lock.release()

def start_loading(self):
self._loaded.clear()

def finish_loading(self):
self._loaded.set()

def wait_for_loading(self):
self._loaded.wait()

@property
def obj(self):
return self._obj

@obj.setter
def obj(self, val: Any):
self._obj = val


class CachePool:
def __init__(self, cache_num: int = -1):
self._cache_num = cache_num
self._cache = OrderedDict()
self.atomic = threading.RLock()

def keys(self) -> List[str]:
return list(self._cache.keys())

def _check_count(self):
if isinstance(self._cache_num, int) and self._cache_num > 0:
while len(self._cache) > self._cache_num:
self._cache.popitem(last=False)

def get(self, key: str) -> ThreadSafeObject:
if cache := self._cache.get(key):
cache.wait_for_loading()
return cache

def set(self, key: str, obj: ThreadSafeObject) -> ThreadSafeObject:
self._cache[key] = obj
self._check_count()
return obj

def pop(self, key: str = None) -> ThreadSafeObject:
if key is None:
return self._cache.popitem(last=False)
else:
return self._cache.pop(key, None)

def acquire(self, key: Union[str, Tuple], owner: str = "", msg: str = ""):
cache = self.get(key)
if cache is None:
raise RuntimeError(f"请求的资源 {key} 不存在")
elif isinstance(cache, ThreadSafeObject):
self._cache.move_to_end(key)
return cache.acquire(owner=owner, msg=msg)
else:
return cache

+ 211
- 0
src/mindpilot/app/knowledge_base/kb_cache/faiss_cache.py View File

@@ -0,0 +1,211 @@
import logging
import os

from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain.schema import Document

from ...configs import CACHED_MEMO_VS_NUM, CACHED_VS_NUM, DEFAULT_EMBEDDING_MODEL
from .base import *
from ..utils import get_vs_path
from ...utils.system_utils import get_Embeddings

logger = logging.getLogger()

# patch FAISS to include doc id in Document.metadata
def _new_ds_search(self, search: str) -> Union[str, Document]:
if search not in self._dict:
return f"ID {search} not found."
else:
doc = self._dict[search]
if isinstance(doc, Document):
doc.metadata["id"] = search
return doc


InMemoryDocstore.search = _new_ds_search


class ThreadSafeFaiss(ThreadSafeObject):
def __repr__(self) -> str:
cls = type(self).__name__
return f"<{cls}: key: {self.key}, obj: {self._obj}, docs_count: {self.docs_count()}>"

def docs_count(self) -> int:
return len(self._obj.docstore._dict)

def save(self, path: str, create_path: bool = True):
with self.acquire():
if not os.path.isdir(path) and create_path:
os.makedirs(path)
ret = self._obj.save_local(path)
logger.info(f"已将向量库 {self.key} 保存到磁盘")
return ret

def clear(self):
ret = []
with self.acquire():
ids = list(self._obj.docstore._dict.keys())
if ids:
ret = self._obj.delete(ids)
assert len(self._obj.docstore._dict) == 0
logger.info(f"已将向量库 {self.key} 清空")
return ret


class _FaissPool(CachePool):
def new_vector_store(
self,
kb_name: str,
embed_model: str = DEFAULT_EMBEDDING_MODEL,
) -> FAISS:
# create an empty vector store
embeddings = get_Embeddings(embed_model=embed_model)
doc = Document(page_content="init", metadata={})
vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True)
ids = list(vector_store.docstore._dict.keys())
vector_store.delete(ids)
return vector_store

def new_temp_vector_store(
self,
embed_model: str = DEFAULT_EMBEDDING_MODEL,
) -> FAISS:
# create an empty vector store
embeddings = get_Embeddings(embed_model=embed_model)
doc = Document(page_content="init", metadata={})
vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True)
ids = list(vector_store.docstore._dict.keys())
vector_store.delete(ids)
return vector_store

def save_vector_store(self, kb_name: str, path: str = None):
if cache := self.get(kb_name):
return cache.save(path)

def unload_vector_store(self, kb_name: str):
if cache := self.get(kb_name):
self.pop(kb_name)
logger.info(f"成功释放向量库:{kb_name}")


class KBFaissPool(_FaissPool):
def load_vector_store(
self,
kb_name: str,
vector_name: str = None,
create: bool = True,
embed_model: str = DEFAULT_EMBEDDING_MODEL,
) -> ThreadSafeFaiss:
self.atomic.acquire()
locked = True
vector_name = vector_name or embed_model
cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些
try:
if cache is None:
item = ThreadSafeFaiss((kb_name, vector_name), pool=self)
self.set((kb_name, vector_name), item)
with item.acquire(msg="初始化"):
self.atomic.release()
locked = False
logger.info(
f"loading vector store in '{kb_name}/vector_store/{vector_name}' from disk."
)
vs_path = get_vs_path(kb_name, vector_name)

if os.path.isfile(os.path.join(vs_path, "index.faiss")):
embeddings = get_Embeddings(embed_model=embed_model)
vector_store = FAISS.load_local(
vs_path,
embeddings,
normalize_L2=True,
allow_dangerous_deserialization=True,
)
elif create:
# create an empty vector store
if not os.path.exists(vs_path):
os.makedirs(vs_path)
vector_store = self.new_vector_store(
kb_name=kb_name, embed_model=embed_model
)
vector_store.save_local(vs_path)
else:
raise RuntimeError(f"knowledge base {kb_name} not exist.")
item.obj = vector_store
item.finish_loading()
else:
self.atomic.release()
locked = False
except Exception as e:
if locked: # we don't know exception raised before or after atomic.release
self.atomic.release()
logger.error(e, exc_info=True)
raise RuntimeError(f"向量库 {kb_name} 加载失败。")
return self.get((kb_name, vector_name))


class MemoFaissPool(_FaissPool):
r"""
临时向量库的缓存池
"""

def load_vector_store(
self,
kb_name: str,
embed_model: str = DEFAULT_EMBEDDING_MODEL,
) -> ThreadSafeFaiss:
self.atomic.acquire()
cache = self.get(kb_name)
if cache is None:
item = ThreadSafeFaiss(kb_name, pool=self)
self.set(kb_name, item)
with item.acquire(msg="初始化"):
self.atomic.release()
logger.info(f"loading vector store in '{kb_name}' to memory.")
# create an empty vector store
vector_store = self.new_temp_vector_store(embed_model=embed_model)
item.obj = vector_store
item.finish_loading()
else:
self.atomic.release()
return self.get(kb_name)


kb_faiss_pool = KBFaissPool(cache_num=CACHED_VS_NUM)
memo_faiss_pool = MemoFaissPool(cache_num=CACHED_MEMO_VS_NUM)
#
#
# if __name__ == "__main__":
# import time, random
# from pprint import pprint
#
# kb_names = ["vs1", "vs2", "vs3"]
# # for name in kb_names:
# # memo_faiss_pool.load_vector_store(name)
#
# def worker(vs_name: str, name: str):
# vs_name = "samples"
# time.sleep(random.randint(1, 5))
# embeddings = load_local_embeddings()
# r = random.randint(1, 3)
#
# with kb_faiss_pool.load_vector_store(vs_name).acquire(name) as vs:
# if r == 1: # add docs
# ids = vs.add_texts([f"text added by {name}"], embeddings=embeddings)
# pprint(ids)
# elif r == 2: # search docs
# docs = vs.similarity_search_with_score(f"{name}", k=3, score_threshold=1.0)
# pprint(docs)
# if r == 3: # delete docs
# logger.warning(f"清除 {vs_name} by {name}")
# kb_faiss_pool.get(vs_name).clear()
#
# threads = []
# for n in range(1, 30):
# t = threading.Thread(target=worker,
# kwargs={"vs_name": random.choice(kb_names), "name": f"worker {n}"},
# daemon=True)
# t.start()
# threads.append(t)
#
# for t in threads:
# t.join()

+ 451
- 0
src/mindpilot/app/knowledge_base/kb_doc_api.py View File

@@ -0,0 +1,451 @@
import json
import os
import urllib
from typing import Dict, List

from fastapi import Body, File, Form, Query, UploadFile
from fastapi.responses import FileResponse
from langchain.docstore.document import Document
from sse_starlette import EventSourceResponse

from ..configs import (
CHUNK_SIZE,
DEFAULT_VS_TYPE,
OVERLAP_SIZE,
SCORE_THRESHOLD,
VECTOR_SEARCH_TOP_K,
ZH_TITLE_ENHANCE,
)
from .db.repository.knowledge_file_repository import get_file_detail
from .kb_service.base import (
KBServiceFactory,
get_kb_file_details,
)
from .model.kb_document_model import DocumentWithVSId
from .utils import (
KnowledgeFile,
files2docs_in_thread,
get_file_path,
list_files_from_folder,
validate_kb_name,
)
from ..utils.system_utils import (
BaseResponse,
ListResponse,
run_in_thread_pool,
)


def search_docs(
query: str = Body("", description="用户输入", examples=["你好"]),
knowledge_base_name: str = Body(
..., description="知识库名称", examples=["samples"]
),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
score_threshold: float = Body(
SCORE_THRESHOLD,
description="知识库匹配相关度阈值,取值范围在0-1之间,"
"SCORE越小,相关度越高,"
"取到1相当于不筛选,建议设置在0.5左右",
ge=0.0,
le=1.0,
),
file_name: str = Body("", description="文件名称,支持 sql 通配符"),
metadata: dict = Body({}, description="根据 metadata 进行过滤,仅支持一级键"),
) -> List[Dict]:
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
data = []
if kb is not None:
if query:
docs = kb.search_docs(query, top_k, score_threshold)
# data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs]
data = [DocumentWithVSId(**x.dict(), id=x.metadata.get("id")) for x in docs]
elif file_name or metadata:
data = kb.list_docs(file_name=file_name, metadata=metadata)
for d in data:
if "vector" in d.metadata:
del d.metadata["vector"]
return [x.dict() for x in data]


def list_files(knowledge_base_name: str) -> ListResponse:
if not validate_kb_name(knowledge_base_name):
return ListResponse(code=403, msg="Don't attack me", data=[])

knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return ListResponse(
code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[]
)
else:
all_docs = get_kb_file_details(knowledge_base_name)
return ListResponse(data=all_docs)


def _save_files_in_thread(
files: List[UploadFile], knowledge_base_name: str, override: bool
):
"""
通过多线程将上传的文件保存到对应知识库目录内。
生成器返回保存结果:{"code":200, "msg": "xxx", "data": {"knowledge_base_name":"xxx", "file_name": "xxx"}}
"""

def save_file(file: UploadFile, knowledge_base_name: str, override: bool) -> dict:
"""
保存单个文件。
"""
try:
filename = file.filename
file_path = get_file_path(
knowledge_base_name=knowledge_base_name, doc_name=filename
)
data = {"knowledge_base_name": knowledge_base_name, "file_name": filename}

file_content = file.file.read() # 读取上传文件的内容
if (
os.path.isfile(file_path)
and not override
and os.path.getsize(file_path) == len(file_content)
):
file_status = f"文件 {filename} 已存在。"
return dict(code=404, msg=file_status, data=data)

if not os.path.isdir(os.path.dirname(file_path)):
os.makedirs(os.path.dirname(file_path))
with open(file_path, "wb") as f:
f.write(file_content)
return dict(code=200, msg=f"成功上传文件 {filename}", data=data)
except Exception as e:
msg = f"{filename} 文件上传失败,报错信息为: {e}"
return dict(code=500, msg=msg, data=data)

params = [
{"file": file, "knowledge_base_name": knowledge_base_name, "override": override}
for file in files
]
for result in run_in_thread_pool(save_file, params=params):
yield result


# def files2docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
# knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
# override: bool = Form(False, description="覆盖已有文件"),
# save: bool = Form(True, description="是否将文件保存到知识库目录")):
# def save_files(files, knowledge_base_name, override):
# for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override):
# yield json.dumps(result, ensure_ascii=False)

# def files_to_docs(files):
# for result in files2docs_in_thread(files):
# yield json.dumps(result, ensure_ascii=False)


def upload_docs(
files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
knowledge_base_name: str = Form(
..., description="知识库名称", examples=["samples"]
),
override: bool = Form(False, description="覆盖已有文件"),
to_vector_store: bool = Form(True, description="上传文件后是否进行向量化"),
chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"),
chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
docs: str = Form("", description="自定义的docs,需要转为json字符串"),
not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"),
) -> BaseResponse:
"""
API接口:上传文件,并/或向量化
"""
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")

kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")

docs = json.loads(docs) if docs else {}
failed_files = {}
file_names = list(docs.keys())

# 先将上传的文件保存到磁盘
for result in _save_files_in_thread(
files, knowledge_base_name=knowledge_base_name, override=override
):
filename = result["data"]["file_name"]
if result["code"] != 200:
failed_files[filename] = result["msg"]

if filename not in file_names:
file_names.append(filename)

# 对保存的文件进行向量化
if to_vector_store:
result = update_docs(
knowledge_base_name=knowledge_base_name,
file_names=file_names,
override_custom_docs=True,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
zh_title_enhance=zh_title_enhance,
docs=docs,
not_refresh_vs_cache=True,
)
failed_files.update(result.data["failed_files"])
if not not_refresh_vs_cache:
kb.save_vector_store()

return BaseResponse(
code=200, msg="文件上传与向量化完成", data={"failed_files": failed_files}
)


def delete_docs(
knowledge_base_name: str = Body(..., examples=["samples"]),
file_names: List[str] = Body(..., examples=[["file_name.md", "test.txt"]]),
delete_content: bool = Body(False),
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
) -> BaseResponse:
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")

knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")

failed_files = {}
for file_name in file_names:
if not kb.exist_doc(file_name):
failed_files[file_name] = f"未找到文件 {file_name}"

try:
kb_file = KnowledgeFile(
filename=file_name, knowledge_base_name=knowledge_base_name
)
kb.delete_doc(kb_file, delete_content, not_refresh_vs_cache=True)
except Exception as e:
msg = f"{file_name} 文件删除失败,错误信息:{e}"
failed_files[file_name] = msg

if not not_refresh_vs_cache:
kb.save_vector_store()

return BaseResponse(
code=200, msg=f"文件删除完成", data={"failed_files": failed_files}
)


def update_info(
knowledge_base_name: str = Body(
..., description="知识库名称", examples=["samples"]
),
kb_info: str = Body(..., description="知识库介绍", examples=["这是一个知识库"]),
):
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")

kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
kb.update_info(kb_info)

return BaseResponse(code=200, msg=f"知识库介绍修改完成", data={"kb_info": kb_info})


def update_docs(
knowledge_base_name: str = Body(
..., description="知识库名称", examples=["samples"]
),
file_names: List[str] = Body(
..., description="文件名称,支持多文件", examples=[["file_name1", "text.txt"]]
),
chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"),
chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"),
docs: str = Body("", description="自定义的docs,需要转为json字符串"),
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
) -> BaseResponse:
"""
更新知识库文档
"""
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")

kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")

failed_files = {}
kb_files = []
docs = json.loads(docs) if docs else {}

# 生成需要加载docs的文件列表
for file_name in file_names:
file_detail = get_file_detail(kb_name=knowledge_base_name, filename=file_name)
# 如果该文件之前使用了自定义docs,则根据参数决定略过或覆盖
if file_detail.get("custom_docs") and not override_custom_docs:
continue
if file_name not in docs:
try:
kb_files.append(
KnowledgeFile(
filename=file_name, knowledge_base_name=knowledge_base_name
)
)
except Exception as e:
msg = f"加载文档 {file_name} 时出错:{e}"
failed_files[file_name] = msg

# 从文件生成docs,并进行向量化。
# 这里利用了KnowledgeFile的缓存功能,在多线程中加载Document,然后传给KnowledgeFile
for status, result in files2docs_in_thread(
kb_files,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
zh_title_enhance=zh_title_enhance,
):
if status:
kb_name, file_name, new_docs = result
kb_file = KnowledgeFile(
filename=file_name, knowledge_base_name=knowledge_base_name
)
kb_file.splited_docs = new_docs
kb.update_doc(kb_file, not_refresh_vs_cache=True)
else:
kb_name, file_name, error = result
failed_files[file_name] = error

# 将自定义的docs进行向量化
for file_name, v in docs.items():
try:
v = [x if isinstance(x, Document) else Document(**x) for x in v]
kb_file = KnowledgeFile(
filename=file_name, knowledge_base_name=knowledge_base_name
)
kb.update_doc(kb_file, docs=v, not_refresh_vs_cache=True)
except Exception as e:
msg = f"为 {file_name} 添加自定义docs时出错:{e}"
failed_files[file_name] = msg

if not not_refresh_vs_cache:
kb.save_vector_store()

return BaseResponse(
code=200, msg=f"更新文档完成", data={"failed_files": failed_files}
)


def download_doc(
knowledge_base_name: str = Query(
..., description="知识库名称", examples=["samples"]
),
file_name: str = Query(..., description="文件名称", examples=["test.txt"]),
preview: bool = Query(False, description="是:浏览器内预览;否:下载"),
):
"""
下载知识库文档
"""
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")

kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")

if preview:
content_disposition_type = "inline"
else:
content_disposition_type = None

try:
kb_file = KnowledgeFile(
filename=file_name, knowledge_base_name=knowledge_base_name
)

if os.path.exists(kb_file.filepath):
return FileResponse(
path=kb_file.filepath,
filename=kb_file.filename,
media_type="multipart/form-data",
content_disposition_type=content_disposition_type,
)
except Exception as e:
msg = f"{kb_file.filename} 读取文件失败,错误信息是:{e}"
return BaseResponse(code=500, msg=msg)

return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败")


def recreate_vector_store(
knowledge_base_name: str = Body(..., examples=["samples"]),
allow_empty_kb: bool = Body(True),
vs_type: str = Body(DEFAULT_VS_TYPE),
embed_model: str = Body(...),
chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"),
chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
):
"""
recreate vector store from the content.
this is usefull when user can copy files to content folder directly instead of upload through network.
by default, get_service_by_name only return knowledge base in the info.db and having document files in it.
set allow_empty_kb to True make it applied on empty knowledge base which it not in the info.db or having no documents.
"""

def output():
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
if not kb.exists() and not allow_empty_kb:
yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"}
else:
error_msg = (
f"could not recreate vector store because failed to access embed model."
)
if not kb.check_embed_model(error_msg):
yield {"code": 404, "msg": error_msg}
else:
if kb.exists():
kb.clear_vs()
kb.create_kb()
files = list_files_from_folder(knowledge_base_name)
kb_files = [(file, knowledge_base_name) for file in files]
i = 0
for status, result in files2docs_in_thread(
kb_files,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
zh_title_enhance=zh_title_enhance,
):
if status:
kb_name, file_name, docs = result
kb_file = KnowledgeFile(
filename=file_name, knowledge_base_name=kb_name
)
kb_file.splited_docs = docs
yield json.dumps(
{
"code": 200,
"msg": f"({i + 1} / {len(files)}): {file_name}",
"total": len(files),
"finished": i + 1,
"doc": file_name,
},
ensure_ascii=False,
)
kb.add_doc(kb_file, not_refresh_vs_cache=True)
else:
kb_name, file_name, error = result
msg = f"添加文件‘{file_name}’到知识库‘{knowledge_base_name}’时出错:{error}。已跳过。"
yield json.dumps(
{
"code": 500,
"msg": msg,
}
)
i += 1
if not not_refresh_vs_cache:
kb.save_vector_store()

return EventSourceResponse(output())

+ 0
- 0
src/mindpilot/app/knowledge_base/kb_service/__init__.py View File


+ 500
- 0
src/mindpilot/app/knowledge_base/kb_service/base.py View File

@@ -0,0 +1,500 @@
import operator
import os
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union

from langchain.docstore.document import Document

from ...configs import (
KB_INFO,
SCORE_THRESHOLD,
VECTOR_SEARCH_TOP_K,
kbs_config,
DEFAULT_EMBEDDING_MODEL
)
from ..db.models.knowledge_base_model import KnowledgeBaseSchema
from ..db.repository.knowledge_base_repository import (
add_kb_to_db,
delete_kb_from_db,
kb_exists,
list_kbs_from_db,
load_kb_from_db,
)
from ..db.repository.knowledge_file_repository import (
add_file_to_db,
count_files_from_db,
delete_file_from_db,
delete_files_from_db,
file_exists_in_db,
get_file_detail,
list_docs_from_db,
list_files_from_db,
)
from ..model.kb_document_model import DocumentWithVSId
from ..utils import (
KnowledgeFile,
get_doc_path,
get_kb_path,
list_files_from_folder,
list_kbs_from_folder,
)
from ...utils.system_utils import check_embed_model as _check_embed_model

class SupportedVSType:
FAISS = "faiss"
MILVUS = "milvus"
DEFAULT = "default"
ES = "es"


class KBService(ABC):
def __init__(
self,
knowledge_base_name: str,
kb_info: str = None,
embed_model: str = DEFAULT_EMBEDDING_MODEL,
):
self.kb_name = knowledge_base_name
self.kb_info = kb_info or KB_INFO.get(
knowledge_base_name, f"关于{knowledge_base_name}的知识库"
)
self.embed_model = embed_model
self.kb_path = get_kb_path(self.kb_name)
self.doc_path = get_doc_path(self.kb_name)
self.do_init()

def __repr__(self) -> str:
return f"{self.kb_name} @ {self.embed_model}"

def save_vector_store(self):
"""
保存向量库:FAISS保存到磁盘,milvus保存到数据库。PGVector暂未支持
"""
pass

def check_embed_model(self, error_msg: str) -> bool:
if not _check_embed_model(self.embed_model):
return False
else:
return True

def create_kb(self):
"""
创建知识库
"""
if not os.path.exists(self.doc_path):
os.makedirs(self.doc_path)

status = add_kb_to_db(
self.kb_name, self.kb_info, self.vs_type(), self.embed_model
)

if status:
self.do_create_kb()
return status

def clear_vs(self):
"""
删除向量库中所有内容
"""
self.do_clear_vs()
status = delete_files_from_db(self.kb_name)
return status

def drop_kb(self):
"""
删除知识库
"""
self.do_drop_kb()
status = delete_kb_from_db(self.kb_name)
return status

def add_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):
"""
向知识库添加文件
如果指定了docs,则不再将文本向量化,并将数据库对应条目标为custom_docs=True
"""
if not self.check_embed_model(
f"could not add docs because failed to access embed model."
):
return False

if docs:
custom_docs = True
else:
docs = kb_file.file2text()
custom_docs = False

if docs:
# 将 metadata["source"] 改为相对路径
for doc in docs:
try:
doc.metadata.setdefault("source", kb_file.filename)
source = doc.metadata.get("source", "")
if os.path.isabs(source):
rel_path = Path(source).relative_to(self.doc_path)
doc.metadata["source"] = str(rel_path.as_posix().strip("/"))
except Exception as e:
print(
f"cannot convert absolute path ({source}) to relative path. error is : {e}"
)
self.delete_doc(kb_file)
doc_infos = self.do_add_doc(docs, **kwargs)
status = add_file_to_db(
kb_file,
custom_docs=custom_docs,
docs_count=len(docs),
doc_infos=doc_infos,
)
else:
status = False
return status

def delete_doc(
self, kb_file: KnowledgeFile, delete_content: bool = False, **kwargs
):
"""
从知识库删除文件
"""
self.do_delete_doc(kb_file, **kwargs)
status = delete_file_from_db(kb_file)
if delete_content and os.path.exists(kb_file.filepath):
os.remove(kb_file.filepath)
return status

def update_info(self, kb_info: str):
"""
更新知识库介绍
"""
self.kb_info = kb_info
status = add_kb_to_db(
self.kb_name, self.kb_info, self.vs_type(), self.embed_model
)
return status

def update_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):
"""
使用content中的文件更新向量库
如果指定了docs,则使用自定义docs,并将数据库对应条目标为custom_docs=True
"""
if not self.check_embed_model(
f"could not update docs because failed to access embed model."
):
return False

if os.path.exists(kb_file.filepath):
self.delete_doc(kb_file, **kwargs)
return self.add_doc(kb_file, docs=docs, **kwargs)

def exist_doc(self, file_name: str):
return file_exists_in_db(
KnowledgeFile(knowledge_base_name=self.kb_name, filename=file_name)
)

def list_files(self):
return list_files_from_db(self.kb_name)

def count_files(self):
return count_files_from_db(self.kb_name)

def search_docs(
self,
query: str,
top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: float = SCORE_THRESHOLD,
) -> List[Document]:
if not self.check_embed_model(
f"could not search docs because failed to access embed model."
):
return []
docs = self.do_search(query, top_k, score_threshold)
return docs

def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
return []

def del_doc_by_ids(self, ids: List[str]) -> bool:
raise NotImplementedError

def update_doc_by_ids(self, docs: Dict[str, Document]) -> bool:
"""
传入参数为: {doc_id: Document, ...}
如果对应 doc_id 的值为 None,或其 page_content 为空,则删除该文档
"""
if not self.check_embed_model(
f"could not update docs because failed to access embed model."
):
return False

self.del_doc_by_ids(list(docs.keys()))
pending_docs = []
ids = []
for _id, doc in docs.items():
if not doc or not doc.page_content.strip():
continue
ids.append(_id)
pending_docs.append(doc)
self.do_add_doc(docs=pending_docs, ids=ids)
return True

def list_docs(
self, file_name: str = None, metadata: Dict = {}
) -> List[DocumentWithVSId]:
"""
通过file_name或metadata检索Document
"""
doc_infos = list_docs_from_db(
kb_name=self.kb_name, file_name=file_name, metadata=metadata
)
docs = []
for x in doc_infos:
doc_info = self.get_doc_by_ids([x["id"]])[0]
if doc_info is not None:
# 处理非空的情况
doc_with_id = DocumentWithVSId(**doc_info.dict(), id=x["id"])
docs.append(doc_with_id)
else:
# 处理空的情况
# 可以选择跳过当前循环迭代或执行其他操作
pass
return docs

def get_relative_source_path(self, filepath: str):
"""
将文件路径转化为相对路径,保证查询时一致
"""
relative_path = filepath
if os.path.isabs(relative_path):
try:
relative_path = Path(filepath).relative_to(self.doc_path)
except Exception as e:
print(
f"cannot convert absolute path ({relative_path}) to relative path. error is : {e}"
)

relative_path = str(relative_path.as_posix().strip("/"))
return relative_path

@abstractmethod
def do_create_kb(self):
"""
创建知识库子类实自己逻辑
"""
pass

@staticmethod
def list_kbs_type():
return list(kbs_config.keys())

@classmethod
def list_kbs(cls):
return list_kbs_from_db()

def exists(self, kb_name: str = None):
kb_name = kb_name or self.kb_name
return kb_exists(kb_name)

@abstractmethod
def vs_type(self) -> str:
pass

@abstractmethod
def do_init(self):
pass

@abstractmethod
def do_drop_kb(self):
"""
删除知识库子类实自己逻辑
"""
pass

@abstractmethod
def do_search(
self,
query: str,
top_k: int,
score_threshold: float,
) -> List[Tuple[Document, float]]:
"""
搜索知识库子类实自己逻辑
"""
pass

@abstractmethod
def do_add_doc(
self,
docs: List[Document],
**kwargs,
) -> List[Dict]:
"""
向知识库添加文档子类实自己逻辑
"""
pass

@abstractmethod
def do_delete_doc(self, kb_file: KnowledgeFile):
"""
从知识库删除文档子类实自己逻辑
"""
pass

@abstractmethod
def do_clear_vs(self):
"""
从知识库删除全部向量子类实自己逻辑
"""
pass


class KBServiceFactory:
@staticmethod
def get_service(
kb_name: str,
vector_store_type: Union[str, SupportedVSType],
embed_model: str = DEFAULT_EMBEDDING_MODEL,
kb_info: str = None,
) -> KBService:
if isinstance(vector_store_type, str):
vector_store_type = getattr(SupportedVSType, vector_store_type.upper())
params = {
"knowledge_base_name": kb_name,
"embed_model": embed_model,
"kb_info": kb_info,
}
if SupportedVSType.FAISS == vector_store_type:
from ..kb_service.faiss_kb_service import (
FaissKBService,
)

return FaissKBService(**params)

elif SupportedVSType.MILVUS == vector_store_type:
from ..kb_service.milvus_kb_service import (
MilvusKBService,
)

return MilvusKBService(**params)

elif SupportedVSType.DEFAULT == vector_store_type:
from ..kb_service.milvus_kb_service import (
MilvusKBService,
)

return MilvusKBService(
**params
) # other milvus parameters are set in model_config.kbs_config
elif SupportedVSType.ES == vector_store_type:
from ..kb_service.es_kb_service import (
ESKBService,
)

return ESKBService(**params)

elif (
SupportedVSType.DEFAULT == vector_store_type
):
from ..kb_service.default_kb_service import (
DefaultKBService,
)

return DefaultKBService(kb_name)

@staticmethod
def get_service_by_name(kb_name: str) -> KBService:
_, vs_type, embed_model = load_kb_from_db(kb_name)
if _ is None: # kb not in db, just return None
return None
return KBServiceFactory.get_service(kb_name, vs_type, embed_model)

@staticmethod
def get_default():
return KBServiceFactory.get_service("default", SupportedVSType.DEFAULT)


def get_kb_details() -> List[Dict]:
kbs_in_folder = list_kbs_from_folder()
kbs_in_db: List[KnowledgeBaseSchema] = KBService.list_kbs()
result = {}

for kb in kbs_in_folder:
result[kb] = {
"kb_name": kb,
"vs_type": "",
"kb_info": "",
"embed_model": "",
"file_count": 0,
"create_time": None,
"in_folder": True,
"in_db": False,
}

for kb_detail in kbs_in_db:
kb_detail = kb_detail.model_dump()
kb_name = kb_detail["kb_name"]
kb_detail["in_db"] = True
if kb_name in result:
result[kb_name].update(kb_detail)
else:
kb_detail["in_folder"] = False
result[kb_name] = kb_detail

data = []
for i, v in enumerate(result.values()):
v["No"] = i + 1
data.append(v)

return data


def get_kb_file_details(kb_name: str) -> List[Dict]:
kb = KBServiceFactory.get_service_by_name(kb_name)
if kb is None:
return []

files_in_folder = list_files_from_folder(kb_name)
files_in_db = kb.list_files()
result = {}

for doc in files_in_folder:
result[doc] = {
"kb_name": kb_name,
"file_name": doc,
"file_ext": os.path.splitext(doc)[-1],
"file_version": 0,
"document_loader": "",
"docs_count": 0,
"text_splitter": "",
"create_time": None,
"in_folder": True,
"in_db": False,
}
lower_names = {x.lower(): x for x in result}
for doc in files_in_db:
doc_detail = get_file_detail(kb_name, doc)
if doc_detail:
doc_detail["in_db"] = True
if doc.lower() in lower_names:
result[lower_names[doc.lower()]].update(doc_detail)
else:
doc_detail["in_folder"] = False
result[doc] = doc_detail

data = []
for i, v in enumerate(result.values()):
v["No"] = i + 1
data.append(v)

return data


def score_threshold_process(score_threshold, k, docs):
if score_threshold is not None:
cmp = operator.le
docs = [
(doc, similarity)
for doc, similarity in docs
if cmp(similarity, score_threshold)
]
return docs[:k]

+ 38
- 0
src/mindpilot/app/knowledge_base/kb_service/default_kb_service.py View File

@@ -0,0 +1,38 @@
from typing import List

from langchain.embeddings.base import Embeddings
from langchain.schema import Document

from ..kb_service.base import KBService


class DefaultKBService(KBService):
def do_create_kb(self):
pass

def do_drop_kb(self):
pass

def do_add_doc(self, docs: List[Document]):
pass

def do_clear_vs(self):
pass

def vs_type(self) -> str:
return "default"

def do_init(self):
pass

def do_search(self):
pass

def do_insert_multi_knowledge(self):
pass

def do_insert_one_knowledge(self):
pass

def do_delete_doc(self):
pass

+ 226
- 0
src/mindpilot/app/knowledge_base/kb_service/es_kb_service.py View File

@@ -0,0 +1,226 @@
import logging
import os
import shutil
from typing import List

from elasticsearch import BadRequestError, Elasticsearch
from langchain.schema import Document
from langchain_community.vectorstores.elasticsearch import (
ApproxRetrievalStrategy,
ElasticsearchStore,
)

from src.mindpilot.app.utils.system_utils import get_Embeddings
from ...configs import KB_ROOT_PATH, kbs_config
from ..file_rag.utils import get_Retriever
from ..kb_service.base import KBService, SupportedVSType
from ..utils import KnowledgeFile


logger = logging.getLogger()


class ESKBService(KBService):
def do_init(self):
self.kb_path = self.get_kb_path(self.kb_name)
self.index_name = os.path.split(self.kb_path)[-1]
self.IP = kbs_config[self.vs_type()]["host"]
self.PORT = kbs_config[self.vs_type()]["port"]
self.user = kbs_config[self.vs_type()].get("user", "")
self.password = kbs_config[self.vs_type()].get("password", "")
self.dims_length = kbs_config[self.vs_type()].get("dims_length", None)
self.embeddings_model = get_Embeddings(self.embed_model)
try:
# ES python客户端连接(仅连接)
if self.user != "" and self.password != "":
self.es_client_python = Elasticsearch(
f"http://{self.IP}:{self.PORT}",
basic_auth=(self.user, self.password),
)
else:
logger.warning("ES未配置用户名和密码")
self.es_client_python = Elasticsearch(f"http://{self.IP}:{self.PORT}")
except ConnectionError:
logger.error("连接到 Elasticsearch 失败!")
raise ConnectionError
except Exception as e:
logger.error(f"Error 发生 : {e}")
raise e
try:
# 首先尝试通过es_client_python创建
mappings = {
"properties": {
"dense_vector": {
"type": "dense_vector",
"dims": self.dims_length,
"index": True,
}
}
}
self.es_client_python.indices.create(
index=self.index_name, mappings=mappings
)
except BadRequestError as e:
logger.error("创建索引失败,重新")
logger.error(e)

try:
# langchain ES 连接、创建索引
params = dict(
es_url=f"http://{self.IP}:{self.PORT}",
index_name=self.index_name,
query_field="context",
vector_query_field="dense_vector",
embedding=self.embeddings_model,
strategy=ApproxRetrievalStrategy(),
es_params={
"timeout": 60,
},
)
if self.user != "" and self.password != "":
params.update(es_user=self.user, es_password=self.password)
self.db = ElasticsearchStore(**params)
except ConnectionError:
logger.error("### 初始化 Elasticsearch 失败!")
raise ConnectionError
except Exception as e:
logger.error(f"Error 发生 : {e}")
raise e
try:
# 尝试通过db_init创建索引
self.db._create_index_if_not_exists(
index_name=self.index_name, dims_length=self.dims_length
)
except Exception as e:
logger.error("创建索引失败...")
logger.error(e)
# raise e

@staticmethod
def get_kb_path(knowledge_base_name: str):
return os.path.join(KB_ROOT_PATH, knowledge_base_name)

@staticmethod
def get_vs_path(knowledge_base_name: str):
return os.path.join(
ESKBService.get_kb_path(knowledge_base_name), "vector_store"
)

def do_create_kb(self): ...

def vs_type(self) -> str:
return SupportedVSType.ES

def do_search(self, query: str, top_k: int, score_threshold: float):
# 文本相似性检索
retriever = get_Retriever("vectorstore").from_vectorstore(
self.db,
top_k=top_k,
score_threshold=score_threshold,
)
docs = retriever.get_relevant_documents(query)
return docs

def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
results = []
for doc_id in ids:
try:
response = self.es_client_python.get(index=self.index_name, id=doc_id)
source = response["_source"]
# Assuming your document has "text" and "metadata" fields
text = source.get("context", "")
metadata = source.get("metadata", {})
results.append(Document(page_content=text, metadata=metadata))
except Exception as e:
logger.error(f"Error retrieving document from Elasticsearch! {e}")
return results

def del_doc_by_ids(self, ids: List[str]) -> bool:
for doc_id in ids:
try:
self.es_client_python.delete(
index=self.index_name, id=doc_id, refresh=True
)
except Exception as e:
logger.error(f"ES Docs Delete Error! {e}")

def do_delete_doc(self, kb_file, **kwargs):
if self.es_client_python.indices.exists(index=self.index_name):
# 从向量数据库中删除索引(文档名称是Keyword)
query = {
"query": {
"term": {
"metadata.source.keyword": self.get_relative_source_path(
kb_file.filepath
)
}
},
"track_total_hits": True,
}
# 注意设置size,默认返回10个,es检索设置track_total_hits为True返回数据库中真实的size。
size = self.es_client_python.search(body=query)["hits"]["total"]["value"]
search_results = self.es_client_python.search(body=query, size=size)
delete_list = [hit["_id"] for hit in search_results["hits"]["hits"]]
if len(delete_list) == 0:
return None
else:
for doc_id in delete_list:
try:
self.es_client_python.delete(
index=self.index_name, id=doc_id, refresh=True
)
except Exception as e:
logger.error(f"ES Docs Delete Error! {e}")

# self.db.delete(ids=delete_list)
# self.es_client_python.indices.refresh(index=self.index_name)

def do_add_doc(self, docs: List[Document], **kwargs):
"""向知识库添加文件"""

print(
f"server.knowledge_base.kb_service.es_kb_service.do_add_doc 输入的docs参数长度为:{len(docs)}"
)
print("*" * 100)

self.db.add_documents(documents=docs)
# 获取 id 和 source , 格式:[{"id": str, "metadata": dict}, ...]
print("写入数据成功.")
print("*" * 100)

if self.es_client_python.indices.exists(index=self.index_name):
file_path = docs[0].metadata.get("source")
query = {
"query": {
"term": {"metadata.source.keyword": file_path},
"term": {"_index": self.index_name},
}
}
# 注意设置size,默认返回10个。
search_results = self.es_client_python.search(body=query, size=50)
if len(search_results["hits"]["hits"]) == 0:
raise ValueError("召回元素个数为0")
info_docs = [
{"id": hit["_id"], "metadata": hit["_source"]["metadata"]}
for hit in search_results["hits"]["hits"]
]
return info_docs

def do_clear_vs(self):
"""从知识库删除全部向量"""
if self.es_client_python.indices.exists(index=self.kb_name):
self.es_client_python.indices.delete(index=self.kb_name)

def do_drop_kb(self):
"""删除知识库"""
# self.kb_file: 知识库路径
if os.path.exists(self.kb_path):
shutil.rmtree(self.kb_path)


if __name__ == "__main__":
esKBService = ESKBService("test")
# esKBService.clear_vs()
# esKBService.create_kb()
esKBService.add_doc(KnowledgeFile(filename="README.md", knowledge_base_name="test"))
print(esKBService.search_docs("如何启动api服务"))

+ 136
- 0
src/mindpilot/app/knowledge_base/kb_service/faiss_kb_service.py View File

@@ -0,0 +1,136 @@
import os
import shutil
from typing import Dict, List, Tuple

from langchain.docstore.document import Document

from ...configs import SCORE_THRESHOLD
from ..file_rag.utils import get_Retriever
from ..kb_cache.faiss_cache import (
ThreadSafeFaiss,
kb_faiss_pool,
)
from ..kb_service.base import KBService, SupportedVSType
from ..utils import KnowledgeFile, get_kb_path, get_vs_path


class FaissKBService(KBService):
vs_path: str
kb_path: str
vector_name: str = None

def vs_type(self) -> str:
return SupportedVSType.FAISS

def get_vs_path(self):
return get_vs_path(self.kb_name, self.vector_name)

def get_kb_path(self):
return get_kb_path(self.kb_name)

def load_vector_store(self) -> ThreadSafeFaiss:
return kb_faiss_pool.load_vector_store(
kb_name=self.kb_name,
vector_name=self.vector_name,
embed_model=self.embed_model,
)

def save_vector_store(self):
self.load_vector_store().save(self.vs_path)

def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
with self.load_vector_store().acquire() as vs:
return [vs.docstore._dict.get(id) for id in ids]

def del_doc_by_ids(self, ids: List[str]) -> bool:
with self.load_vector_store().acquire() as vs:
vs.delete(ids)

def do_init(self):
self.vector_name = self.vector_name or self.embed_model
self.kb_path = self.get_kb_path()
self.vs_path = self.get_vs_path()

def do_create_kb(self):
if not os.path.exists(self.vs_path):
os.makedirs(self.vs_path)
self.load_vector_store()

def do_drop_kb(self):
self.clear_vs()
try:
shutil.rmtree(self.kb_path)
except Exception:
pass

def do_search(
self,
query: str,
top_k: int,
score_threshold: float = SCORE_THRESHOLD,
) -> List[Tuple[Document, float]]:
with self.load_vector_store().acquire() as vs:
retriever = get_Retriever("ensemble").from_vectorstore(
vs,
top_k=top_k,
score_threshold=score_threshold,
)
docs = retriever.get_relevant_documents(query)
return docs

def do_add_doc(
self,
docs: List[Document],
**kwargs,
) -> List[Dict]:
texts = [x.page_content for x in docs]
metadatas = [x.metadata for x in docs]
with self.load_vector_store().acquire() as vs:
embeddings = vs.embeddings.embed_documents(texts)
ids = vs.add_embeddings(
text_embeddings=zip(texts, embeddings), metadatas=metadatas
)
if not kwargs.get("not_refresh_vs_cache"):
vs.save_local(self.vs_path)
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
return doc_infos

def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
with self.load_vector_store().acquire() as vs:
ids = [
k
for k, v in vs.docstore._dict.items()
if v.metadata.get("source").lower() == kb_file.filename.lower()
]
if len(ids) > 0:
vs.delete(ids)
if not kwargs.get("not_refresh_vs_cache"):
vs.save_local(self.vs_path)
return ids

def do_clear_vs(self):
with kb_faiss_pool.atomic:
kb_faiss_pool.pop((self.kb_name, self.vector_name))
try:
shutil.rmtree(self.vs_path)
except Exception:
...
os.makedirs(self.vs_path, exist_ok=True)

def exist_doc(self, file_name: str):
if super().exist_doc(file_name):
return "in_db"

content_path = os.path.join(self.kb_path, "content")
if os.path.isfile(os.path.join(content_path, file_name)):
return "in_folder"
else:
return False


if __name__ == "__main__":
faissService = FaissKBService("test")
faissService.add_doc(KnowledgeFile("README.md", "test"))
faissService.delete_doc(KnowledgeFile("README.md", "test"))
faissService.do_drop_kb()
print(faissService.search_docs("如何启动api服务"))

+ 125
- 0
src/mindpilot/app/knowledge_base/kb_service/milvus_kb_service.py View File

@@ -0,0 +1,125 @@
import os
from typing import Dict, List, Optional

from langchain.schema import Document
from langchain.vectorstores.milvus import Milvus

from ...configs import kbs_config
from ..db.repository import list_file_num_docs_id_by_kb_name_and_file_name
from ...utils.system_utils import get_Embeddings
from ..file_rag.utils import get_Retriever
from ..kb_service.base import (
KBService,
SupportedVSType,
score_threshold_process,
)
from ..utils import KnowledgeFile


class MilvusKBService(KBService):
milvus: Milvus

@staticmethod
def get_collection(milvus_name):
from pymilvus import Collection

return Collection(milvus_name)

def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
result = []
if self.milvus.col:
# ids = [int(id) for id in ids] # for milvus if needed #pr 2725
data_list = self.milvus.col.query(
expr=f"pk in {[int(_id) for _id in ids]}", output_fields=["*"]
)
for data in data_list:
text = data.pop("text")
result.append(Document(page_content=text, metadata=data))
return result

def del_doc_by_ids(self, ids: List[str]) -> bool:
self.milvus.col.delete(expr=f"pk in {ids}")

@staticmethod
def search(milvus_name, content, limit=3):
search_params = {
"metric_type": "L2",
"params": {"nprobe": 10},
}
c = MilvusKBService.get_collection(milvus_name)
return c.search(
content, "embeddings", search_params, limit=limit, output_fields=["content"]
)

def do_create_kb(self):
pass

def vs_type(self) -> str:
return SupportedVSType.MILVUS

def _load_milvus(self):
self.milvus = Milvus(
embedding_function=get_Embeddings(self.embed_model),
collection_name=self.kb_name,
connection_args=kbs_config.get("milvus"),
index_params=kbs_config.get("milvus_kwargs")["index_params"],
search_params=kbs_config.get("milvus_kwargs")["search_params"],
auto_id=True,
)

def do_init(self):
self._load_milvus()

def do_drop_kb(self):
if self.milvus.col:
self.milvus.col.release()
self.milvus.col.drop()

def do_search(self, query: str, top_k: int, score_threshold: float):
self._load_milvus()
# embed_func = get_Embeddings(self.embed_model)
# embeddings = embed_func.embed_query(query)
# docs = self.milvus.similarity_search_with_score_by_vector(embeddings, top_k)
retriever = get_Retriever("vectorstore").from_vectorstore(
self.milvus,
top_k=top_k,
score_threshold=score_threshold,
)
docs = retriever.get_relevant_documents(query)
return docs

def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
for doc in docs:
for k, v in doc.metadata.items():
doc.metadata[k] = str(v)
for field in self.milvus.fields:
doc.metadata.setdefault(field, "")
doc.metadata.pop(self.milvus._text_field, None)
doc.metadata.pop(self.milvus._vector_field, None)

ids = self.milvus.add_documents(docs)
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
return doc_infos

def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
id_list = list_file_num_docs_id_by_kb_name_and_file_name(
kb_file.kb_name, kb_file.filename
)
if self.milvus.col:
self.milvus.col.delete(expr=f"pk in {id_list}")

# Issue 2846, for windows
# if self.milvus.col:
# file_path = kb_file.filepath.replace("\\", "\\\\")
# file_name = os.path.basename(file_path)
# id_list = [item.get("pk") for item in
# self.milvus.col.query(expr=f'source == "{file_name}"', output_fields=["pk"])]
# self.milvus.col.delete(expr=f'pk in {id_list}')

def do_clear_vs(self):
if self.milvus.col:
self.do_drop_kb()
self.do_init()




+ 234
- 0
src/mindpilot/app/knowledge_base/migrate.py View File

@@ -0,0 +1,234 @@
import os
from datetime import datetime
from typing import List, Literal

from dateutil.parser import parse

from ..configs import (
CHUNK_SIZE,
DEFAULT_EMBEDDING_MODEL,
DEFAULT_VS_TYPE,
OVERLAP_SIZE,
ZH_TITLE_ENHANCE,
)

from .db.base import Base, engine
from .db.repository.knowledge_file_repository import (
add_file_to_db,
)

from .db.session import session_scope
from .kb_service.base import (
KBServiceFactory,
SupportedVSType,
)
from .utils import (
KnowledgeFile,
files2docs_in_thread,
get_file_path,
list_files_from_folder,
list_kbs_from_folder,
)


def create_tables():
Base.metadata.create_all(bind=engine)


def reset_tables():
Base.metadata.drop_all(bind=engine)
create_tables()


def import_from_db(
sqlite_path: str = None,
# csv_path: str = None,
) -> bool:
"""
在知识库与向量库无变化的情况下,从备份数据库中导入数据到 info.db。
适用于版本升级时,info.db 结构变化,但无需重新向量化的情况。
请确保两边数据库表名一致,需要导入的字段名一致
当前仅支持 sqlite
"""
import sqlite3 as sql
from pprint import pprint

models = list(Base.registry.mappers)

try:
con = sql.connect(sqlite_path)
con.row_factory = sql.Row
cur = con.cursor()
tables = [
x["name"]
for x in cur.execute(
"select name from sqlite_master where type='table'"
).fetchall()
]
for model in models:
table = model.local_table.fullname
if table not in tables:
continue
print(f"processing table: {table}")
with session_scope() as session:
for row in cur.execute(f"select * from {table}").fetchall():
data = {k: row[k] for k in row.keys() if k in model.columns}
if "create_time" in data:
data["create_time"] = parse(data["create_time"])
pprint(data)
session.add(model.class_(**data))
con.close()
return True
except Exception as e:
print(f"无法读取备份数据库:{sqlite_path}。错误信息:{e}")
return False


def file_to_kbfile(kb_name: str, files: List[str]) -> List[KnowledgeFile]:
kb_files = []
for file in files:
try:
kb_file = KnowledgeFile(filename=file, knowledge_base_name=kb_name)
kb_files.append(kb_file)
except Exception as e:
msg = f"{e},已跳过"
return kb_files


def folder2db(
kb_names: List[str],
mode: Literal["recreate_vs", "update_in_db", "increment"],
vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE,
embed_model: str = DEFAULT_EMBEDDING_MODEL,
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE,
zh_title_enhance: bool = ZH_TITLE_ENHANCE,
):
"""
use existed files in local folder to populate database and/or vector store.
set parameter `mode` to:
recreate_vs: recreate all vector store and fill info to database using existed files in local folder
fill_info_only(disabled): do not create vector store, fill info to db using existed files only
update_in_db: update vector store and database info using local files that existed in database only
increment: create vector store and database info for local files that not existed in database only
"""

def files2vs(kb_name: str, kb_files: List[KnowledgeFile]) -> List:
result = []
for success, res in files2docs_in_thread(
kb_files,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
zh_title_enhance=zh_title_enhance,
):
if success:
_, filename, docs = res
print(
f"正在将 {kb_name}/{filename} 添加到向量库,共包含{len(docs)}条文档"
)
kb_file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
kb_file.splited_docs = docs
kb.add_doc(kb_file=kb_file, not_refresh_vs_cache=True)
result.append({"kb_name": kb_name, "file": filename, "docs": docs})
else:
print(res)
return result

kb_names = kb_names or list_kbs_from_folder()
for kb_name in kb_names:
start = datetime.now()
kb = KBServiceFactory.get_service(kb_name, vs_type, embed_model)
if not kb.exists():
kb.create_kb()

# 清除向量库,从本地文件重建
if mode == "recreate_vs":
kb.clear_vs()
kb.create_kb()
kb_files = file_to_kbfile(kb_name, list_files_from_folder(kb_name))
result = files2vs(kb_name, kb_files)
kb.save_vector_store()
# # 不做文件内容的向量化,仅将文件元信息存到数据库
# # 由于现在数据库存了很多与文本切分相关的信息,单纯存储文件信息意义不大,该功能取消。
# elif mode == "fill_info_only":
# files = list_files_from_folder(kb_name)
# kb_files = file_to_kbfile(kb_name, files)
# for kb_file in kb_files:
# add_file_to_db(kb_file)
# print(f"已将 {kb_name}/{kb_file.filename} 添加到数据库")
# 以数据库中文件列表为基准,利用本地文件更新向量库
elif mode == "update_in_db":
files = kb.list_files()
kb_files = file_to_kbfile(kb_name, files)
result = files2vs(kb_name, kb_files)
kb.save_vector_store()
# 对比本地目录与数据库中的文件列表,进行增量向量化
elif mode == "increment":
db_files = kb.list_files()
folder_files = list_files_from_folder(kb_name)
files = list(set(folder_files) - set(db_files))
kb_files = file_to_kbfile(kb_name, files)
result = files2vs(kb_name, kb_files)
kb.save_vector_store()
else:
print(f"unsupported migrate mode: {mode}")
end = datetime.now()
kb_path = (
f"知识库路径\t:{kb.kb_path}\n"
if kb.vs_type() == SupportedVSType.FAISS
else ""
)
file_count = len(kb_files)
success_count = len(result)
docs_count = sum([len(x["docs"]) for x in result])
print("\n" + "-" * 100)
print(
(
f"知识库名称\t:{kb_name}\n"
f"知识库类型\t:{kb.vs_type()}\n"
f"向量模型:\t:{kb.embed_model}\n"
)
+ kb_path
+ (
f"文件总数量\t:{file_count}\n"
f"入库文件数\t:{success_count}\n"
f"知识条目数\t:{docs_count}\n"
f"用时\t\t:{end-start}"
)
)
print("-" * 100 + "\n")
return result


def prune_db_docs(kb_names: List[str]):
"""
delete docs in database that not existed in local folder.
it is used to delete database docs after user deleted some doc files in file browser
"""
for kb_name in kb_names:
kb = KBServiceFactory.get_service_by_name(kb_name)
if kb is not None:
files_in_db = kb.list_files()
files_in_folder = list_files_from_folder(kb_name)
files = list(set(files_in_db) - set(files_in_folder))
kb_files = file_to_kbfile(kb_name, files)
for kb_file in kb_files:
kb.delete_doc(kb_file, not_refresh_vs_cache=True)
print(f"success to delete docs for file: {kb_name}/{kb_file.filename}")
kb.save_vector_store()


def prune_folder_files(kb_names: List[str]):
"""
delete doc files in local folder that not existed in database.
it is used to free local disk space by delete unused doc files.
"""
for kb_name in kb_names:
kb = KBServiceFactory.get_service_by_name(kb_name)
if kb is not None:
files_in_db = kb.list_files()
files_in_folder = list_files_from_folder(kb_name)
files = list(set(files_in_folder) - set(files_in_db))
for file in files:
os.remove(get_file_path(kb_name, file))
print(f"success to delete file: {kb_name}/{file}")

+ 10
- 0
src/mindpilot/app/knowledge_base/model/kb_document_model.py View File

@@ -0,0 +1,10 @@
from langchain.docstore.document import Document


class DocumentWithVSId(Document):
"""
矢量化后的文档
"""

id: str = None
score: float = 3.0

+ 483
- 0
src/mindpilot/app/knowledge_base/utils.py View File

@@ -0,0 +1,483 @@
import importlib
import json
import logging
import os
from functools import lru_cache
from pathlib import Path
from typing import Dict, Generator, List, Tuple, Union

import chardet
import langchain_community.document_loaders
from langchain.docstore.document import Document
from langchain.text_splitter import MarkdownHeaderTextSplitter, TextSplitter
from langchain_community.document_loaders import JSONLoader, TextLoader

from ..configs import (
CHUNK_SIZE,
KB_ROOT_PATH,
OVERLAP_SIZE,
TEXT_SPLITTER_NAME,
ZH_TITLE_ENHANCE,
text_splitter_dict,
)
from .file_rag.text_splitter import (
zh_title_enhance as func_zh_title_enhance,
)
from ..utils.system_utils import run_in_thread_pool

logger = logging.getLogger()


def validate_kb_name(knowledge_base_id: str) -> bool:
# 检查是否包含预期外的字符或路径攻击关键字
if "../" in knowledge_base_id:
return False
return True


def get_kb_path(knowledge_base_name: str):
return os.path.join(KB_ROOT_PATH, knowledge_base_name)


def get_doc_path(knowledge_base_name: str):
return os.path.join(get_kb_path(knowledge_base_name), "content")


def get_vs_path(knowledge_base_name: str, vector_name: str):
return os.path.join(get_kb_path(knowledge_base_name), "vector_store", vector_name)


def get_file_path(knowledge_base_name: str, doc_name: str):
doc_path = Path(get_doc_path(knowledge_base_name)).resolve()
file_path = (doc_path / doc_name).resolve()
if str(file_path).startswith(str(doc_path)):
return str(file_path)


def list_kbs_from_folder():
return [
f
for f in os.listdir(KB_ROOT_PATH)
if os.path.isdir(os.path.join(KB_ROOT_PATH, f))
]


def list_files_from_folder(kb_name: str):
doc_path = get_doc_path(kb_name)
result = []

def is_skiped_path(path: str):
tail = os.path.basename(path).lower()
for x in ["temp", "tmp", ".", "~$"]:
if tail.startswith(x):
return True
return False

def process_entry(entry):
if is_skiped_path(entry.path):
return

if entry.is_symlink():
target_path = os.path.realpath(entry.path)
with os.scandir(target_path) as target_it:
for target_entry in target_it:
process_entry(target_entry)
elif entry.is_file():
file_path = Path(
os.path.relpath(entry.path, doc_path)
).as_posix() # 路径统一为 posix 格式
result.append(file_path)
elif entry.is_dir():
with os.scandir(entry.path) as it:
for sub_entry in it:
process_entry(sub_entry)

with os.scandir(doc_path) as it:
for entry in it:
process_entry(entry)

return result


LOADER_DICT = {
"UnstructuredHTMLLoader": [".html", ".htm"],
"MHTMLLoader": [".mhtml"],
"TextLoader": [".md"],
"UnstructuredMarkdownLoader": [".md"],
"JSONLoader": [".json"],
"JSONLinesLoader": [".jsonl"],
"CSVLoader": [".csv"],
# "FilteredCSVLoader": [".csv"], 如果使用自定义分割csv
"RapidOCRPDFLoader": [".pdf"],
"RapidOCRDocLoader": [".docx", ".doc"],
"RapidOCRPPTLoader": [
".ppt",
".pptx",
],
"RapidOCRLoader": [".png", ".jpg", ".jpeg", ".bmp"],
"UnstructuredFileLoader": [
".eml",
".msg",
".rst",
".rtf",
".txt",
".xml",
".epub",
".odt",
".tsv",
],
"UnstructuredEmailLoader": [".eml", ".msg"],
"UnstructuredEPubLoader": [".epub"],
"UnstructuredExcelLoader": [".xlsx", ".xls", ".xlsd"],
"NotebookLoader": [".ipynb"],
"UnstructuredODTLoader": [".odt"],
"PythonLoader": [".py"],
"UnstructuredRSTLoader": [".rst"],
"UnstructuredRTFLoader": [".rtf"],
"SRTLoader": [".srt"],
"TomlLoader": [".toml"],
"UnstructuredTSVLoader": [".tsv"],
"UnstructuredWordDocumentLoader": [".docx", ".doc"],
"UnstructuredXMLLoader": [".xml"],
"UnstructuredPowerPointLoader": [".ppt", ".pptx"],
"EverNoteLoader": [".enex"],
}
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]


# patch json.dumps to disable ensure_ascii
def _new_json_dumps(obj, **kwargs):
kwargs["ensure_ascii"] = False
return _origin_json_dumps(obj, **kwargs)


if json.dumps is not _new_json_dumps:
_origin_json_dumps = json.dumps
json.dumps = _new_json_dumps


class JSONLinesLoader(JSONLoader):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._json_lines = True


langchain_community.document_loaders.JSONLinesLoader = JSONLinesLoader


def get_LoaderClass(file_extension):
for LoaderClass, extensions in LOADER_DICT.items():
if file_extension in extensions:
return LoaderClass


def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None):
"""
根据loader_name和文件路径或内容返回文档加载器。
"""
loader_kwargs = loader_kwargs or {}
try:
if loader_name in [
"RapidOCRPDFLoader",
"RapidOCRLoader",
"FilteredCSVLoader",
"RapidOCRDocLoader",
"RapidOCRPPTLoader",
]:
document_loaders_module = importlib.import_module(
"chatchat.server.file_rag.document_loaders"
)
else:
document_loaders_module = importlib.import_module(
"langchain_community.document_loaders"
)
DocumentLoader = getattr(document_loaders_module, loader_name)
except Exception as e:
msg = f"为文件{file_path}查找加载器{loader_name}时出错:{e}"
logger.error(
f"{e.__class__.__name__}: {msg}", exc_info=e
)
document_loaders_module = importlib.import_module(
"langchain_community.document_loaders"
)
DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")

if loader_name == "UnstructuredFileLoader":
loader_kwargs.setdefault("autodetect_encoding", True)
elif loader_name == "CSVLoader":
if not loader_kwargs.get("encoding"):
# 如果未指定 encoding,自动识别文件编码类型,避免langchain loader 加载文件报编码错误
with open(file_path, "rb") as struct_file:
encode_detect = chardet.detect(struct_file.read())
if encode_detect is None:
encode_detect = {"encoding": "utf-8"}
loader_kwargs["encoding"] = encode_detect["encoding"]

elif loader_name == "JSONLoader":
loader_kwargs.setdefault("jq_schema", ".")
loader_kwargs.setdefault("text_content", False)
elif loader_name == "JSONLinesLoader":
loader_kwargs.setdefault("jq_schema", ".")
loader_kwargs.setdefault("text_content", False)

loader = DocumentLoader(file_path, **loader_kwargs)
return loader


@lru_cache()
def make_text_splitter(splitter_name, chunk_size, chunk_overlap):
"""
根据参数获取特定的分词器
"""
splitter_name = splitter_name or "SpacyTextSplitter"
try:
if (
splitter_name == "MarkdownHeaderTextSplitter"
): # MarkdownHeaderTextSplitter特殊判定
headers_to_split_on = text_splitter_dict[splitter_name][
"headers_to_split_on"
]
text_splitter = MarkdownHeaderTextSplitter(
headers_to_split_on=headers_to_split_on, strip_headers=False
)
else:
try: # 优先使用用户自定义的text_splitter
text_splitter_module = importlib.import_module("chatchat.server.file_rag.text_splitter")
TextSplitter = getattr(text_splitter_module, splitter_name)
except: # 否则使用langchain的text_splitter
text_splitter_module = importlib.import_module(
"langchain.text_splitter"
)
TextSplitter = getattr(text_splitter_module, splitter_name)

if (
text_splitter_dict[splitter_name]["source"] == "tiktoken"
): # 从tiktoken加载
try:
text_splitter = TextSplitter.from_tiktoken_encoder(
encoding_name=text_splitter_dict[splitter_name][
"tokenizer_name_or_path"
],
pipeline="zh_core_web_sm",
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
except:
text_splitter = TextSplitter.from_tiktoken_encoder(
encoding_name=text_splitter_dict[splitter_name][
"tokenizer_name_or_path"
],
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
elif (
text_splitter_dict[splitter_name]["source"] == "huggingface"
): # 从huggingface加载
if (
text_splitter_dict[splitter_name]["tokenizer_name_or_path"]
== "gpt2"
):
from langchain.text_splitter import CharacterTextSplitter
from mindnlp.transformers import GPT2TokenizerFast

tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
else: # 字符长度加载
from mindnlp.transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(
text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
trust_remote_code=True,
)
text_splitter = TextSplitter.from_huggingface_tokenizer(
tokenizer=tokenizer,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
else:
try:
text_splitter = TextSplitter(
pipeline="zh_core_web_sm",
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
except:
text_splitter = TextSplitter(
chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
except Exception as e:
print(e)
text_splitter_module = importlib.import_module("langchain.text_splitter")
TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter")
text_splitter = TextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)

# If you use SpacyTextSplitter you can use GPU to do split likes Issue #1287
# text_splitter._tokenizer.max_length = 37016792
# text_splitter._tokenizer.prefer_gpu()
return text_splitter


class KnowledgeFile:
def __init__(
self,
filename: str,
knowledge_base_name: str,
loader_kwargs: Dict = {},
):
"""
对应知识库目录中的文件,必须是磁盘上存在的才能进行向量化等操作。
"""
self.kb_name = knowledge_base_name
self.filename = str(Path(filename).as_posix())
self.ext = os.path.splitext(filename)[-1].lower()
if self.ext not in SUPPORTED_EXTS:
raise ValueError(f"暂未支持的文件格式 {self.filename}")
self.loader_kwargs = loader_kwargs
self.filepath = get_file_path(knowledge_base_name, filename)
self.docs = None
self.splited_docs = None
self.document_loader_name = get_LoaderClass(self.ext)
self.text_splitter_name = TEXT_SPLITTER_NAME

def file2docs(self, refresh: bool = False):
if self.docs is None or refresh:
logger.info(f"{self.document_loader_name} used for {self.filepath}")
loader = get_loader(
loader_name=self.document_loader_name,
file_path=self.filepath,
loader_kwargs=self.loader_kwargs,
)
if isinstance(loader, TextLoader):
loader.encoding = "utf8"
self.docs = loader.load()
else:
self.docs = loader.load()
return self.docs

def docs2texts(
self,
docs: List[Document] = None,
zh_title_enhance: bool = ZH_TITLE_ENHANCE,
refresh: bool = False,
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE,
text_splitter: TextSplitter = None,
):
docs = docs or self.file2docs(refresh=refresh)
if not docs:
return []
if self.ext not in [".csv"]:
if text_splitter is None:
text_splitter = make_text_splitter(
splitter_name=self.text_splitter_name,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
if self.text_splitter_name == "MarkdownHeaderTextSplitter":
docs = text_splitter.split_text(docs[0].page_content)
else:
docs = text_splitter.split_documents(docs)

if not docs:
return []

print(f"文档切分示例:{docs[0]}")
if zh_title_enhance:
docs = func_zh_title_enhance(docs)
self.splited_docs = docs
return self.splited_docs

def file2text(
self,
zh_title_enhance: bool = ZH_TITLE_ENHANCE,
refresh: bool = False,
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE,
text_splitter: TextSplitter = None,
):
if self.splited_docs is None or refresh:
docs = self.file2docs()
self.splited_docs = self.docs2texts(
docs=docs,
zh_title_enhance=zh_title_enhance,
refresh=refresh,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
text_splitter=text_splitter,
)
return self.splited_docs

def file_exist(self):
return os.path.isfile(self.filepath)

def get_mtime(self):
return os.path.getmtime(self.filepath)

def get_size(self):
return os.path.getsize(self.filepath)


def files2docs_in_thread_file2docs(
*, file: KnowledgeFile, **kwargs
) -> Tuple[bool, Tuple[str, str, List[Document]]]:
try:
return True, (file.kb_name, file.filename, file.file2text(**kwargs))
except Exception as e:
msg = f"从文件 {file.kb_name}/{file.filename} 加载文档时出错:{e}"
logger.error(
f"{e.__class__.__name__}: {msg}", exc_info=e
)
return False, (file.kb_name, file.filename, msg)


def files2docs_in_thread(
files: List[Union[KnowledgeFile, Tuple[str, str], Dict]],
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE,
zh_title_enhance: bool = ZH_TITLE_ENHANCE,
) -> Generator:
"""
利用多线程批量将磁盘文件转化成langchain Document.
如果传入参数是Tuple,形式为(filename, kb_name)
生成器返回值为 status, (kb_name, file_name, docs | error)
"""

kwargs_list = []
for i, file in enumerate(files):
kwargs = {}
try:
if isinstance(file, tuple) and len(file) >= 2:
filename = file[0]
kb_name = file[1]
file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
elif isinstance(file, dict):
filename = file.pop("filename")
kb_name = file.pop("kb_name")
kwargs.update(file)
file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
kwargs["file"] = file
kwargs["chunk_size"] = chunk_size
kwargs["chunk_overlap"] = chunk_overlap
kwargs["zh_title_enhance"] = zh_title_enhance
kwargs_list.append(kwargs)
except Exception as e:
yield False, (kb_name, filename, str(e))

for result in run_in_thread_pool(
func=files2docs_in_thread_file2docs, params=kwargs_list
):
yield result


if __name__ == "__main__":
from pprint import pprint

kb_file = KnowledgeFile(
filename="E:\\LLM\\Data\\Test.md", knowledge_base_name="samples"
)
# kb_file.text_splitter_name = "RecursiveCharacterTextSplitter"
kb_file.text_splitter_name = "MarkdownHeaderTextSplitter"
docs = kb_file.file2docs()
# pprint(docs[-1])
texts = kb_file.docs2texts(docs)
for text in texts:
print(text)

+ 56
- 23
src/mindpilot/app/utils/system_utils.py View File

@@ -1,12 +1,7 @@
import asyncio
import logging
import multiprocessing as mp
import os
import socket
import sqlite3
import sys
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import (
Any,
Awaitable,
@@ -14,29 +9,14 @@ from typing import (
Dict,
Generator,
List,
Literal,
Optional,
Tuple,
Union,
)

import httpx
import openai
from fastapi import FastAPI
from langchain.tools import BaseTool
from langchain_core.embeddings import Embeddings
from langchain_openai.chat_models import ChatOpenAI
from langchain_openai.llms import OpenAI

# from chatchat.configs import (
# DEFAULT_EMBEDDING_MODEL,
# DEFAULT_LLM_MODEL,
# HTTPX_DEFAULT_TIMEOUT,
# MODEL_PLATFORMS,
# TEMPERATURE,
# log_verbose,
# )
from .pydantic_v2 import BaseModel, Field
from ..configs import DEFAULT_EMBEDDING_MODEL
from langchain_core.embeddings import Embeddings

logger = logging.getLogger()

@@ -209,7 +189,60 @@ class ListResponse(BaseResponse):
}
}


def get_mindpilot_db_connection():
conn = sqlite3.connect('mindpilot.db')
conn.row_factory = sqlite3.Row
return conn


def run_in_thread_pool(
func: Callable,
params: List[Dict] = [],
) -> Generator:
"""
在线程池中批量运行任务,并将运行结果以生成器的形式返回。
请确保任务中的所有操作是线程安全的,任务函数请全部使用关键字参数。
"""
tasks = []
with ThreadPoolExecutor() as pool:
for kwargs in params:
tasks.append(pool.submit(func, **kwargs))

for obj in as_completed(tasks):
try:
yield obj.result()
except Exception as e:
logger.error(f"error in sub thread: {e}", exc_info=True)

def get_Embeddings(
embed_model: str = DEFAULT_EMBEDDING_MODEL,
) -> Embeddings:

from ..knowledge_base.embedding.localai_embeddings import (
LocalAIEmbeddings,
)

params = dict(model=embed_model)
try:
params.update(
openai_api_base=f"http://127.0.0.1:7890/v1",
openai_api_key="EMPTY",
)
return LocalAIEmbeddings(**params)
except Exception as e:
logger.error(
f"failed to create Embeddings for model: {embed_model}.", exc_info=True
)


def check_embed_model(embed_model: str = DEFAULT_EMBEDDING_MODEL) -> bool:
embeddings = get_Embeddings(embed_model=embed_model)
try:
embeddings.embed_query("this is a test")
return True
except Exception as e:
logger.error(
f"failed to access embed model '{embed_model}': {e}", exc_info=True
)
return False

+ 4
- 0
src/mindpilot/main.py View File

@@ -120,6 +120,10 @@ def main():
print亮蓝(f"当前工作目录:{cwd}")
print亮蓝(f"OpenAPI 文档地址:http://{HOST}:{PORT}/docs")

from app.knowledge_base.migrate import create_tables

create_tables()

if sys.version_info < (3, 10):
loop = asyncio.get_event_loop()
else:


Loading…
Cancel
Save