From a8b4d9288bc117fd540122da37d6161f2047fcee Mon Sep 17 00:00:00 2001 From: gjl <2802427218@qq.com> Date: Wed, 11 Sep 2024 20:48:58 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E6=9C=AC=E5=9C=B0?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/mindpilot/app/chat/chat.py | 24 ++- src/mindpilot/app/configs/__init__.py | 1 + src/mindpilot/app/configs/kb_config.py | 2 + .../app/conversation/conversation_api.py | 188 ++++++++++-------- 4 files changed, 135 insertions(+), 80 deletions(-) diff --git a/src/mindpilot/app/chat/chat.py b/src/mindpilot/app/chat/chat.py index 80f894a..dc3bfd8 100644 --- a/src/mindpilot/app/chat/chat.py +++ b/src/mindpilot/app/chat/chat.py @@ -16,10 +16,14 @@ from ..callback_handler.agent_callback_handler import ( AgentStatus, ) from ..chat.utils import History -from ..configs import MODEL_CONFIG, TOOL_CONFIG, OPENAI_PROMPT, PROMPT_TEMPLATES +from ..configs import MODEL_CONFIG, TOOL_CONFIG, OPENAI_PROMPT, PROMPT_TEMPLATES, CACHE_DIR from ..utils.system_utils import get_ChatOpenAI, get_tool, wrap_done, MsgType, get_mindpilot_db_connection from ..agent.utils import get_agent_from_id +from mindnlp.transformers import AutoModelForCausalLM, AutoTokenizer +import mindspore +import time + def create_models_from_config(configs, callbacks, stream): configs = configs @@ -502,3 +506,21 @@ async def debug_chat_online( ret.append(data) return ret + + +async def chat_outline( + content: str, + history: List[History], + chat_model_config: dict, +): + model_name = next(iter(chat_model_config["llm_model"])) + temperature = chat_model_config["llm_model"][model_name]["temperature"] + max_tokens = chat_model_config["llm_model"][model_name]["max_tokens"] + + path = 'openbmb/MiniCPM-2B-dpo-bf16' + tokenizer = AutoTokenizer.from_pretrained(path, cache_dir=CACHE_DIR) + model = AutoModelForCausalLM.from_pretrained(path, ms_dtype=mindspore.float16, cache_dir=CACHE_DIR) + + response, history = model.chat(tokenizer, content, history=history, temperature=temperature, top_p=0.9, + repetition_penalty=1.02) + return response diff --git a/src/mindpilot/app/configs/__init__.py b/src/mindpilot/app/configs/__init__.py index a5e2369..242e6dd 100644 --- a/src/mindpilot/app/configs/__init__.py +++ b/src/mindpilot/app/configs/__init__.py @@ -33,4 +33,5 @@ __all__ = [ "TEXT_SPLITTER_NAME", "EMBEDDING_KEYWORD_FILE", "DEFAULT_EMBEDDING_MODEL", + "CACHE_DIR", ] \ No newline at end of file diff --git a/src/mindpilot/app/configs/kb_config.py b/src/mindpilot/app/configs/kb_config.py index 9c7da75..bd1d652 100644 --- a/src/mindpilot/app/configs/kb_config.py +++ b/src/mindpilot/app/configs/kb_config.py @@ -45,6 +45,8 @@ KB_INFO = { } CHATCHAT_ROOT = str(Path(__file__).absolute().parent.parent.parent) +CACHE_DIR = str(Path(__file__).absolute().parent.parent.parent.parent.parent) +CACHE_DIR = os.path.join(CACHE_DIR, "cache") KB_ROOT_PATH = os.path.join(CHATCHAT_ROOT, "knowledge_base") diff --git a/src/mindpilot/app/conversation/conversation_api.py b/src/mindpilot/app/conversation/conversation_api.py index f65798c..1f9a08e 100644 --- a/src/mindpilot/app/conversation/conversation_api.py +++ b/src/mindpilot/app/conversation/conversation_api.py @@ -12,7 +12,7 @@ from ..utils.system_utils import BaseResponse, ListResponse, get_mindpilot_db_co from .message import init_messages_table, insert_message, split_message_content from ..model_configs.utils import get_config_from_id from ..agent.utils import get_agent_from_id -from ..chat.chat import chat_online, debug_chat_online +from ..chat.chat import chat_online, debug_chat_online, chat_outline def init_conversations_table(): @@ -209,8 +209,6 @@ async def send_messages( "content": row['content'] }) - # print(history) - # 存放用户输入 _, timestamp_user = insert_message(agent_status=0, role=role, content=text, files=json.dumps(files), conversation_id=conversation_id) @@ -224,77 +222,84 @@ async def send_messages( # 获取模型配置 chat_model_config = get_config_from_id(config_id=config_id) - model_key = next(iter(chat_model_config["llm_model"])) - chat_model_config["llm_model"][model_key]["temperature"] = temperature - chat_model_config["llm_model"][model_key]["max_tokens"] = max_tokens - - is_summery = False - - if agent_id == -1: - cursor.execute('''SELECT agent_id FROM conversations WHERE conversation_id = ?''', (conversation_id,)) - temp_agent_id = cursor.fetchone()[0] - if temp_agent_id != -1: - temp_agent = get_agent_from_id(temp_agent_id) - temp_agent_name = temp_agent["agent_name"] - temp_agent_abstract = temp_agent["agent_abstract"] - temp_agent_info = temp_agent["agent_info"] - agent_prompt = "Your name is " + temp_agent_name + "." + temp_agent_abstract + ". Below is your detailed information:" + temp_agent_info + "." - history.append({"role": "user", "content": agent_prompt}) - is_summery = True - - if len(history) == 0 or is_summery == True: - if len(history) == 0: - summery_prompt = "下面是用户的问题,请总结为不超过八个字的标题。\n" + "用户:" + text + '输出格式为:{"title":"总结的标题"},除了这个json,不允许输出其他内容。' - if is_summery == True: - summery_prompt = "下面是用户的问题,请总结为不超过八个字的标题。\n" + "用户:" + agent_prompt + text + '输出格式为:{"title":"总结的标题"},除了这个json,不允许输出其他内容。' - summery = await chat_online(content=summery_prompt, history=[], chat_model_config=chat_model_config, - agent_id=-1, tool_config=tool_config, conversation_id=conversation_id) - summery_content = summery[0]['choices'][0]['delta']['content'] - try: - summery_content = json.loads(summery_content)["title"] - cursor.execute(''' - UPDATE conversations - SET is_summarized = ?, title = ? - WHERE conversation_id = ? - ''', (True, summery_content, conversation_id)) - conn.commit() - except Exception as e: - print(e) - - # 获取模型输出 - ret = await chat_online(content=text, history=history, chat_model_config=chat_model_config, - tool_config=tool_config, agent_id=agent_id, conversation_id=conversation_id) - - response_messages = [] - for message in ret: - if message['status'] == 7: - message_role = message['choices'][0]['role'] - message_content = "Observation:\n" + message['choices'][0]['delta']['tool_calls'][0]['tool_output'] - message_id, timestamp_message = insert_message(agent_status=7, role=message_role, content=message_content, - files=json.dumps({}), conversation_id=conversation_id) - - cursor.execute(''' - UPDATE conversations - SET updated_at = ? - WHERE conversation_id = ? - ''', (timestamp_message, conversation_id)) - conn.commit() + if chat_model_config["platform"] == "LOCAL": + model_key = next(iter(chat_model_config["llm_model"])) + chat_model_config["llm_model"][model_key]["temperature"] = temperature + chat_model_config["llm_model"][model_key]["max_tokens"] = max_tokens + + ret = await chat_outline(content=text, history=history, chat_model_config=chat_model_config) + + message_id, timestamp_message = insert_message(agent_status=3, role="assistant", content=ret, + files=json.dumps({}), + conversation_id=conversation_id) + + cursor.execute('''UPDATE conversations + SET updated_at = ? + WHERE conversation_id = ?''', (timestamp_message, conversation_id)) + conn.commit() + + message_dict = { + "message_id": message_id, + "agent_status": 3, + "text": ret, + "files": [], + "timestamp": timestamp_message + } + response_messages = [message_dict] - message_dict = { - "message_id": message_id, - "agent_status": 7, - "text": message_content, - "files": [], - "timestamp": timestamp_message - } - response_messages.append(message_dict) + conn.close() - if message['status'] == 3: - message_role = message['choices'][0]['role'] - message_content = message['choices'][0]['delta']['content'] - message_list = split_message_content(message_content) - for m in message_list: - message_id, timestamp_message = insert_message(agent_status=3, role=message_role, content=m, + return BaseResponse(code=200, msg="success", data=response_messages) + + else: + model_key = next(iter(chat_model_config["llm_model"])) + chat_model_config["llm_model"][model_key]["temperature"] = temperature + chat_model_config["llm_model"][model_key]["max_tokens"] = max_tokens + + is_summery = False + + if agent_id == -1: + cursor.execute('''SELECT agent_id FROM conversations WHERE conversation_id = ?''', (conversation_id,)) + temp_agent_id = cursor.fetchone()[0] + if temp_agent_id != -1: + temp_agent = get_agent_from_id(temp_agent_id) + temp_agent_name = temp_agent["agent_name"] + temp_agent_abstract = temp_agent["agent_abstract"] + temp_agent_info = temp_agent["agent_info"] + agent_prompt = "Your name is " + temp_agent_name + "." + temp_agent_abstract + ". Below is your detailed information:" + temp_agent_info + "." + history.append({"role": "user", "content": agent_prompt}) + is_summery = True + + if len(history) == 0 or is_summery == True: + if len(history) == 0: + summery_prompt = "下面是用户的问题,请总结为不超过八个字的标题。\n" + "用户:" + text + '输出格式为:{"title":"总结的标题"},除了这个json,不允许输出其他内容。' + if is_summery == True: + summery_prompt = "下面是用户的问题,请总结为不超过八个字的标题。\n" + "用户:" + agent_prompt + text + '输出格式为:{"title":"总结的标题"},除了这个json,不允许输出其他内容。' + summery = await chat_online(content=summery_prompt, history=[], chat_model_config=chat_model_config, + agent_id=-1, tool_config=tool_config, conversation_id=conversation_id) + summery_content = summery[0]['choices'][0]['delta']['content'] + try: + summery_content = json.loads(summery_content)["title"] + cursor.execute(''' + UPDATE conversations + SET is_summarized = ?, title = ? + WHERE conversation_id = ? + ''', (True, summery_content, conversation_id)) + conn.commit() + except Exception as e: + print(e) + + # 获取模型输出 + ret = await chat_online(content=text, history=history, chat_model_config=chat_model_config, + tool_config=tool_config, agent_id=agent_id, conversation_id=conversation_id) + + response_messages = [] + for message in ret: + if message['status'] == 7: + message_role = message['choices'][0]['role'] + message_content = "Observation:\n" + message['choices'][0]['delta']['tool_calls'][0]['tool_output'] + message_id, timestamp_message = insert_message(agent_status=7, role=message_role, + content=message_content, files=json.dumps({}), conversation_id=conversation_id) cursor.execute(''' @@ -306,19 +311,44 @@ async def send_messages( message_dict = { "message_id": message_id, - "agent_status": 3, - "text": m, + "agent_status": 7, + "text": message_content, "files": [], "timestamp": timestamp_message } - response_messages.append(message_dict) - # TODO 这里考虑处理一下message['status']是4但之前一个message['status']不是3的,即agent无法解析的内容 + if message['status'] == 3: + message_role = message['choices'][0]['role'] + message_content = message['choices'][0]['delta']['content'] + message_list = split_message_content(message_content) + for m in message_list: + message_id, timestamp_message = insert_message(agent_status=3, role=message_role, content=m, + files=json.dumps({}), + conversation_id=conversation_id) + + cursor.execute(''' + UPDATE conversations + SET updated_at = ? + WHERE conversation_id = ? + ''', (timestamp_message, conversation_id)) + conn.commit() + + message_dict = { + "message_id": message_id, + "agent_status": 3, + "text": m, + "files": [], + "timestamp": timestamp_message + } + + response_messages.append(message_dict) + + # TODO 这里考虑处理一下message['status']是4但之前一个message['status']不是3的,即agent无法解析的内容 - conn.close() + conn.close() - return BaseResponse(code=200, msg="success", data=response_messages) + return BaseResponse(code=200, msg="success", data=response_messages) async def debug_messages( @@ -398,4 +428,4 @@ async def debug_messages( # TODO 这里考虑处理一下message['status']是4但之前一个message['status']不是3的,即agent无法解析的内容 - return BaseResponse(code=200, msg="success", data=response_messages) \ No newline at end of file + return BaseResponse(code=200, msg="success", data=response_messages)