You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483
  1. import importlib
  2. import json
  3. import logging
  4. import os
  5. from functools import lru_cache
  6. from pathlib import Path
  7. from typing import Dict, Generator, List, Tuple, Union
  8. import chardet
  9. import langchain_community.document_loaders
  10. from langchain.docstore.document import Document
  11. from langchain.text_splitter import MarkdownHeaderTextSplitter, TextSplitter
  12. from langchain_community.document_loaders import JSONLoader, TextLoader
  13. from ..configs import (
  14. CHUNK_SIZE,
  15. KB_ROOT_PATH,
  16. OVERLAP_SIZE,
  17. TEXT_SPLITTER_NAME,
  18. ZH_TITLE_ENHANCE,
  19. text_splitter_dict,
  20. )
  21. from .file_rag.text_splitter import (
  22. zh_title_enhance as func_zh_title_enhance,
  23. )
  24. from ..utils.system_utils import run_in_thread_pool
  25. logger = logging.getLogger()
  26. def validate_kb_name(knowledge_base_id: str) -> bool:
  27. # 检查是否包含预期外的字符或路径攻击关键字
  28. if "../" in knowledge_base_id:
  29. return False
  30. return True
  31. def get_kb_path(knowledge_base_name: str):
  32. return os.path.join(KB_ROOT_PATH, knowledge_base_name)
  33. def get_doc_path(knowledge_base_name: str):
  34. return os.path.join(get_kb_path(knowledge_base_name), "content")
  35. def get_vs_path(knowledge_base_name: str, vector_name: str):
  36. return os.path.join(get_kb_path(knowledge_base_name), "vector_store", vector_name)
  37. def get_file_path(knowledge_base_name: str, doc_name: str):
  38. doc_path = Path(get_doc_path(knowledge_base_name)).resolve()
  39. file_path = (doc_path / doc_name).resolve()
  40. if str(file_path).startswith(str(doc_path)):
  41. return str(file_path)
  42. def list_kbs_from_folder():
  43. return [
  44. f
  45. for f in os.listdir(KB_ROOT_PATH)
  46. if os.path.isdir(os.path.join(KB_ROOT_PATH, f))
  47. ]
  48. def list_files_from_folder(kb_name: str):
  49. doc_path = get_doc_path(kb_name)
  50. result = []
  51. def is_skiped_path(path: str):
  52. tail = os.path.basename(path).lower()
  53. for x in ["temp", "tmp", ".", "~$"]:
  54. if tail.startswith(x):
  55. return True
  56. return False
  57. def process_entry(entry):
  58. if is_skiped_path(entry.path):
  59. return
  60. if entry.is_symlink():
  61. target_path = os.path.realpath(entry.path)
  62. with os.scandir(target_path) as target_it:
  63. for target_entry in target_it:
  64. process_entry(target_entry)
  65. elif entry.is_file():
  66. file_path = Path(
  67. os.path.relpath(entry.path, doc_path)
  68. ).as_posix() # 路径统一为 posix 格式
  69. result.append(file_path)
  70. elif entry.is_dir():
  71. with os.scandir(entry.path) as it:
  72. for sub_entry in it:
  73. process_entry(sub_entry)
  74. with os.scandir(doc_path) as it:
  75. for entry in it:
  76. process_entry(entry)
  77. return result
  78. LOADER_DICT = {
  79. "UnstructuredHTMLLoader": [".html", ".htm"],
  80. "MHTMLLoader": [".mhtml"],
  81. "TextLoader": [".md"],
  82. "UnstructuredMarkdownLoader": [".md"],
  83. "JSONLoader": [".json"],
  84. "JSONLinesLoader": [".jsonl"],
  85. "CSVLoader": [".csv"],
  86. # "FilteredCSVLoader": [".csv"], 如果使用自定义分割csv
  87. "RapidOCRPDFLoader": [".pdf"],
  88. "RapidOCRDocLoader": [".docx", ".doc"],
  89. "RapidOCRPPTLoader": [
  90. ".ppt",
  91. ".pptx",
  92. ],
  93. "RapidOCRLoader": [".png", ".jpg", ".jpeg", ".bmp"],
  94. "UnstructuredFileLoader": [
  95. ".eml",
  96. ".msg",
  97. ".rst",
  98. ".rtf",
  99. ".txt",
  100. ".xml",
  101. ".epub",
  102. ".odt",
  103. ".tsv",
  104. ],
  105. "UnstructuredEmailLoader": [".eml", ".msg"],
  106. "UnstructuredEPubLoader": [".epub"],
  107. "UnstructuredExcelLoader": [".xlsx", ".xls", ".xlsd"],
  108. "NotebookLoader": [".ipynb"],
  109. "UnstructuredODTLoader": [".odt"],
  110. "PythonLoader": [".py"],
  111. "UnstructuredRSTLoader": [".rst"],
  112. "UnstructuredRTFLoader": [".rtf"],
  113. "SRTLoader": [".srt"],
  114. "TomlLoader": [".toml"],
  115. "UnstructuredTSVLoader": [".tsv"],
  116. "UnstructuredWordDocumentLoader": [".docx", ".doc"],
  117. "UnstructuredXMLLoader": [".xml"],
  118. "UnstructuredPowerPointLoader": [".ppt", ".pptx"],
  119. "EverNoteLoader": [".enex"],
  120. }
  121. SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
  122. # patch json.dumps to disable ensure_ascii
  123. def _new_json_dumps(obj, **kwargs):
  124. kwargs["ensure_ascii"] = False
  125. return _origin_json_dumps(obj, **kwargs)
  126. if json.dumps is not _new_json_dumps:
  127. _origin_json_dumps = json.dumps
  128. json.dumps = _new_json_dumps
  129. class JSONLinesLoader(JSONLoader):
  130. def __init__(self, *args, **kwargs):
  131. super().__init__(*args, **kwargs)
  132. self._json_lines = True
  133. langchain_community.document_loaders.JSONLinesLoader = JSONLinesLoader
  134. def get_LoaderClass(file_extension):
  135. for LoaderClass, extensions in LOADER_DICT.items():
  136. if file_extension in extensions:
  137. return LoaderClass
  138. def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None):
  139. """
  140. 根据loader_name和文件路径或内容返回文档加载器。
  141. """
  142. loader_kwargs = loader_kwargs or {}
  143. try:
  144. if loader_name in [
  145. "RapidOCRPDFLoader",
  146. "RapidOCRLoader",
  147. "FilteredCSVLoader",
  148. "RapidOCRDocLoader",
  149. "RapidOCRPPTLoader",
  150. ]:
  151. document_loaders_module = importlib.import_module(
  152. "chatchat.server.file_rag.document_loaders"
  153. )
  154. else:
  155. document_loaders_module = importlib.import_module(
  156. "langchain_community.document_loaders"
  157. )
  158. DocumentLoader = getattr(document_loaders_module, loader_name)
  159. except Exception as e:
  160. msg = f"为文件{file_path}查找加载器{loader_name}时出错:{e}"
  161. logger.error(
  162. f"{e.__class__.__name__}: {msg}", exc_info=e
  163. )
  164. document_loaders_module = importlib.import_module(
  165. "langchain_community.document_loaders"
  166. )
  167. DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")
  168. if loader_name == "UnstructuredFileLoader":
  169. loader_kwargs.setdefault("autodetect_encoding", True)
  170. elif loader_name == "CSVLoader":
  171. if not loader_kwargs.get("encoding"):
  172. # 如果未指定 encoding,自动识别文件编码类型,避免langchain loader 加载文件报编码错误
  173. with open(file_path, "rb") as struct_file:
  174. encode_detect = chardet.detect(struct_file.read())
  175. if encode_detect is None:
  176. encode_detect = {"encoding": "utf-8"}
  177. loader_kwargs["encoding"] = encode_detect["encoding"]
  178. elif loader_name == "JSONLoader":
  179. loader_kwargs.setdefault("jq_schema", ".")
  180. loader_kwargs.setdefault("text_content", False)
  181. elif loader_name == "JSONLinesLoader":
  182. loader_kwargs.setdefault("jq_schema", ".")
  183. loader_kwargs.setdefault("text_content", False)
  184. loader = DocumentLoader(file_path, **loader_kwargs)
  185. return loader
  186. @lru_cache()
  187. def make_text_splitter(splitter_name, chunk_size, chunk_overlap):
  188. """
  189. 根据参数获取特定的分词器
  190. """
  191. splitter_name = splitter_name or "SpacyTextSplitter"
  192. try:
  193. if (
  194. splitter_name == "MarkdownHeaderTextSplitter"
  195. ): # MarkdownHeaderTextSplitter特殊判定
  196. headers_to_split_on = text_splitter_dict[splitter_name][
  197. "headers_to_split_on"
  198. ]
  199. text_splitter = MarkdownHeaderTextSplitter(
  200. headers_to_split_on=headers_to_split_on, strip_headers=False
  201. )
  202. else:
  203. try: # 优先使用用户自定义的text_splitter
  204. text_splitter_module = importlib.import_module("chatchat.server.file_rag.text_splitter")
  205. TextSplitter = getattr(text_splitter_module, splitter_name)
  206. except: # 否则使用langchain的text_splitter
  207. text_splitter_module = importlib.import_module(
  208. "langchain.text_splitter"
  209. )
  210. TextSplitter = getattr(text_splitter_module, splitter_name)
  211. if (
  212. text_splitter_dict[splitter_name]["source"] == "tiktoken"
  213. ): # 从tiktoken加载
  214. try:
  215. text_splitter = TextSplitter.from_tiktoken_encoder(
  216. encoding_name=text_splitter_dict[splitter_name][
  217. "tokenizer_name_or_path"
  218. ],
  219. pipeline="zh_core_web_sm",
  220. chunk_size=chunk_size,
  221. chunk_overlap=chunk_overlap,
  222. )
  223. except:
  224. text_splitter = TextSplitter.from_tiktoken_encoder(
  225. encoding_name=text_splitter_dict[splitter_name][
  226. "tokenizer_name_or_path"
  227. ],
  228. chunk_size=chunk_size,
  229. chunk_overlap=chunk_overlap,
  230. )
  231. elif (
  232. text_splitter_dict[splitter_name]["source"] == "huggingface"
  233. ): # 从huggingface加载
  234. if (
  235. text_splitter_dict[splitter_name]["tokenizer_name_or_path"]
  236. == "gpt2"
  237. ):
  238. from langchain.text_splitter import CharacterTextSplitter
  239. from mindnlp.transformers import GPT2TokenizerFast
  240. tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
  241. else: # 字符长度加载
  242. from mindnlp.transformers import AutoTokenizer
  243. tokenizer = AutoTokenizer.from_pretrained(
  244. text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
  245. trust_remote_code=True,
  246. )
  247. text_splitter = TextSplitter.from_huggingface_tokenizer(
  248. tokenizer=tokenizer,
  249. chunk_size=chunk_size,
  250. chunk_overlap=chunk_overlap,
  251. )
  252. else:
  253. try:
  254. text_splitter = TextSplitter(
  255. pipeline="zh_core_web_sm",
  256. chunk_size=chunk_size,
  257. chunk_overlap=chunk_overlap,
  258. )
  259. except:
  260. text_splitter = TextSplitter(
  261. chunk_size=chunk_size, chunk_overlap=chunk_overlap
  262. )
  263. except Exception as e:
  264. print(e)
  265. text_splitter_module = importlib.import_module("langchain.text_splitter")
  266. TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter")
  267. text_splitter = TextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
  268. # If you use SpacyTextSplitter you can use GPU to do split likes Issue #1287
  269. # text_splitter._tokenizer.max_length = 37016792
  270. # text_splitter._tokenizer.prefer_gpu()
  271. return text_splitter
  272. class KnowledgeFile:
  273. def __init__(
  274. self,
  275. filename: str,
  276. knowledge_base_name: str,
  277. loader_kwargs: Dict = {},
  278. ):
  279. """
  280. 对应知识库目录中的文件,必须是磁盘上存在的才能进行向量化等操作。
  281. """
  282. self.kb_name = knowledge_base_name
  283. self.filename = str(Path(filename).as_posix())
  284. self.ext = os.path.splitext(filename)[-1].lower()
  285. if self.ext not in SUPPORTED_EXTS:
  286. raise ValueError(f"暂未支持的文件格式 {self.filename}")
  287. self.loader_kwargs = loader_kwargs
  288. self.filepath = get_file_path(knowledge_base_name, filename)
  289. self.docs = None
  290. self.splited_docs = None
  291. self.document_loader_name = get_LoaderClass(self.ext)
  292. self.text_splitter_name = TEXT_SPLITTER_NAME
  293. def file2docs(self, refresh: bool = False):
  294. if self.docs is None or refresh:
  295. logger.info(f"{self.document_loader_name} used for {self.filepath}")
  296. loader = get_loader(
  297. loader_name=self.document_loader_name,
  298. file_path=self.filepath,
  299. loader_kwargs=self.loader_kwargs,
  300. )
  301. if isinstance(loader, TextLoader):
  302. loader.encoding = "utf8"
  303. self.docs = loader.load()
  304. else:
  305. self.docs = loader.load()
  306. return self.docs
  307. def docs2texts(
  308. self,
  309. docs: List[Document] = None,
  310. zh_title_enhance: bool = ZH_TITLE_ENHANCE,
  311. refresh: bool = False,
  312. chunk_size: int = CHUNK_SIZE,
  313. chunk_overlap: int = OVERLAP_SIZE,
  314. text_splitter: TextSplitter = None,
  315. ):
  316. docs = docs or self.file2docs(refresh=refresh)
  317. if not docs:
  318. return []
  319. if self.ext not in [".csv"]:
  320. if text_splitter is None:
  321. text_splitter = make_text_splitter(
  322. splitter_name=self.text_splitter_name,
  323. chunk_size=chunk_size,
  324. chunk_overlap=chunk_overlap,
  325. )
  326. if self.text_splitter_name == "MarkdownHeaderTextSplitter":
  327. docs = text_splitter.split_text(docs[0].page_content)
  328. else:
  329. docs = text_splitter.split_documents(docs)
  330. if not docs:
  331. return []
  332. print(f"文档切分示例:{docs[0]}")
  333. if zh_title_enhance:
  334. docs = func_zh_title_enhance(docs)
  335. self.splited_docs = docs
  336. return self.splited_docs
  337. def file2text(
  338. self,
  339. zh_title_enhance: bool = ZH_TITLE_ENHANCE,
  340. refresh: bool = False,
  341. chunk_size: int = CHUNK_SIZE,
  342. chunk_overlap: int = OVERLAP_SIZE,
  343. text_splitter: TextSplitter = None,
  344. ):
  345. if self.splited_docs is None or refresh:
  346. docs = self.file2docs()
  347. self.splited_docs = self.docs2texts(
  348. docs=docs,
  349. zh_title_enhance=zh_title_enhance,
  350. refresh=refresh,
  351. chunk_size=chunk_size,
  352. chunk_overlap=chunk_overlap,
  353. text_splitter=text_splitter,
  354. )
  355. return self.splited_docs
  356. def file_exist(self):
  357. return os.path.isfile(self.filepath)
  358. def get_mtime(self):
  359. return os.path.getmtime(self.filepath)
  360. def get_size(self):
  361. return os.path.getsize(self.filepath)
  362. def files2docs_in_thread_file2docs(
  363. *, file: KnowledgeFile, **kwargs
  364. ) -> Tuple[bool, Tuple[str, str, List[Document]]]:
  365. try:
  366. return True, (file.kb_name, file.filename, file.file2text(**kwargs))
  367. except Exception as e:
  368. msg = f"从文件 {file.kb_name}/{file.filename} 加载文档时出错:{e}"
  369. logger.error(
  370. f"{e.__class__.__name__}: {msg}", exc_info=e
  371. )
  372. return False, (file.kb_name, file.filename, msg)
  373. def files2docs_in_thread(
  374. files: List[Union[KnowledgeFile, Tuple[str, str], Dict]],
  375. chunk_size: int = CHUNK_SIZE,
  376. chunk_overlap: int = OVERLAP_SIZE,
  377. zh_title_enhance: bool = ZH_TITLE_ENHANCE,
  378. ) -> Generator:
  379. """
  380. 利用多线程批量将磁盘文件转化成langchain Document.
  381. 如果传入参数是Tuple,形式为(filename, kb_name)
  382. 生成器返回值为 status, (kb_name, file_name, docs | error)
  383. """
  384. kwargs_list = []
  385. for i, file in enumerate(files):
  386. kwargs = {}
  387. try:
  388. if isinstance(file, tuple) and len(file) >= 2:
  389. filename = file[0]
  390. kb_name = file[1]
  391. file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
  392. elif isinstance(file, dict):
  393. filename = file.pop("filename")
  394. kb_name = file.pop("kb_name")
  395. kwargs.update(file)
  396. file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
  397. kwargs["file"] = file
  398. kwargs["chunk_size"] = chunk_size
  399. kwargs["chunk_overlap"] = chunk_overlap
  400. kwargs["zh_title_enhance"] = zh_title_enhance
  401. kwargs_list.append(kwargs)
  402. except Exception as e:
  403. yield False, (kb_name, filename, str(e))
  404. for result in run_in_thread_pool(
  405. func=files2docs_in_thread_file2docs, params=kwargs_list
  406. ):
  407. yield result
  408. if __name__ == "__main__":
  409. from pprint import pprint
  410. kb_file = KnowledgeFile(
  411. filename="E:\\LLM\\Data\\Test.md", knowledge_base_name="samples"
  412. )
  413. # kb_file.text_splitter_name = "RecursiveCharacterTextSplitter"
  414. kb_file.text_splitter_name = "MarkdownHeaderTextSplitter"
  415. docs = kb_file.file2docs()
  416. # pprint(docs[-1])
  417. texts = kb_file.docs2texts(docs)
  418. for text in texts:
  419. print(text)

MindPilot是一个跨平台的多功能智能Agent桌面助手,旨在为用户提供便捷、高效的智能解决方案。通过集成先进的大语言模型作为核心决策引擎,MindPilot能够对用户的任务进行精准分解、规划、执行、反思和总结,确保任务的高效完成。同时提供了高度自定义化的Agent,用户可以根据需求自定义不同身份的Agent,以应对多样化的任务场景,实现个性化的智能服务。在MindSpore和MindNLP的