Browse Source

feat: 添加本地模型

main
gjl 1 year ago
parent
commit
a8b4d9288b
4 changed files with 135 additions and 80 deletions
  1. +23
    -1
      src/mindpilot/app/chat/chat.py
  2. +1
    -0
      src/mindpilot/app/configs/__init__.py
  3. +2
    -0
      src/mindpilot/app/configs/kb_config.py
  4. +109
    -79
      src/mindpilot/app/conversation/conversation_api.py

+ 23
- 1
src/mindpilot/app/chat/chat.py View File

@@ -16,10 +16,14 @@ from ..callback_handler.agent_callback_handler import (
AgentStatus,
)
from ..chat.utils import History
from ..configs import MODEL_CONFIG, TOOL_CONFIG, OPENAI_PROMPT, PROMPT_TEMPLATES
from ..configs import MODEL_CONFIG, TOOL_CONFIG, OPENAI_PROMPT, PROMPT_TEMPLATES, CACHE_DIR
from ..utils.system_utils import get_ChatOpenAI, get_tool, wrap_done, MsgType, get_mindpilot_db_connection
from ..agent.utils import get_agent_from_id

from mindnlp.transformers import AutoModelForCausalLM, AutoTokenizer
import mindspore
import time


def create_models_from_config(configs, callbacks, stream):
configs = configs
@@ -502,3 +506,21 @@ async def debug_chat_online(
ret.append(data)

return ret


async def chat_outline(
content: str,
history: List[History],
chat_model_config: dict,
):
model_name = next(iter(chat_model_config["llm_model"]))
temperature = chat_model_config["llm_model"][model_name]["temperature"]
max_tokens = chat_model_config["llm_model"][model_name]["max_tokens"]

path = 'openbmb/MiniCPM-2B-dpo-bf16'
tokenizer = AutoTokenizer.from_pretrained(path, cache_dir=CACHE_DIR)
model = AutoModelForCausalLM.from_pretrained(path, ms_dtype=mindspore.float16, cache_dir=CACHE_DIR)

response, history = model.chat(tokenizer, content, history=history, temperature=temperature, top_p=0.9,
repetition_penalty=1.02)
return response

+ 1
- 0
src/mindpilot/app/configs/__init__.py View File

@@ -33,4 +33,5 @@ __all__ = [
"TEXT_SPLITTER_NAME",
"EMBEDDING_KEYWORD_FILE",
"DEFAULT_EMBEDDING_MODEL",
"CACHE_DIR",
]

+ 2
- 0
src/mindpilot/app/configs/kb_config.py View File

@@ -45,6 +45,8 @@ KB_INFO = {
}

CHATCHAT_ROOT = str(Path(__file__).absolute().parent.parent.parent)
CACHE_DIR = str(Path(__file__).absolute().parent.parent.parent.parent.parent)
CACHE_DIR = os.path.join(CACHE_DIR, "cache")

KB_ROOT_PATH = os.path.join(CHATCHAT_ROOT, "knowledge_base")



+ 109
- 79
src/mindpilot/app/conversation/conversation_api.py View File

@@ -12,7 +12,7 @@ from ..utils.system_utils import BaseResponse, ListResponse, get_mindpilot_db_co
from .message import init_messages_table, insert_message, split_message_content
from ..model_configs.utils import get_config_from_id
from ..agent.utils import get_agent_from_id
from ..chat.chat import chat_online, debug_chat_online
from ..chat.chat import chat_online, debug_chat_online, chat_outline


def init_conversations_table():
@@ -209,8 +209,6 @@ async def send_messages(
"content": row['content']
})

# print(history)

# 存放用户输入
_, timestamp_user = insert_message(agent_status=0, role=role, content=text, files=json.dumps(files),
conversation_id=conversation_id)
@@ -224,77 +222,84 @@ async def send_messages(

# 获取模型配置
chat_model_config = get_config_from_id(config_id=config_id)
model_key = next(iter(chat_model_config["llm_model"]))
chat_model_config["llm_model"][model_key]["temperature"] = temperature
chat_model_config["llm_model"][model_key]["max_tokens"] = max_tokens

is_summery = False

if agent_id == -1:
cursor.execute('''SELECT agent_id FROM conversations WHERE conversation_id = ?''', (conversation_id,))
temp_agent_id = cursor.fetchone()[0]
if temp_agent_id != -1:
temp_agent = get_agent_from_id(temp_agent_id)
temp_agent_name = temp_agent["agent_name"]
temp_agent_abstract = temp_agent["agent_abstract"]
temp_agent_info = temp_agent["agent_info"]
agent_prompt = "Your name is " + temp_agent_name + "." + temp_agent_abstract + ". Below is your detailed information:" + temp_agent_info + "."
history.append({"role": "user", "content": agent_prompt})
is_summery = True

if len(history) == 0 or is_summery == True:
if len(history) == 0:
summery_prompt = "下面是用户的问题,请总结为不超过八个字的标题。\n" + "用户:" + text + '输出格式为:{"title":"总结的标题"},除了这个json,不允许输出其他内容。'
if is_summery == True:
summery_prompt = "下面是用户的问题,请总结为不超过八个字的标题。\n" + "用户:" + agent_prompt + text + '输出格式为:{"title":"总结的标题"},除了这个json,不允许输出其他内容。'
summery = await chat_online(content=summery_prompt, history=[], chat_model_config=chat_model_config,
agent_id=-1, tool_config=tool_config, conversation_id=conversation_id)
summery_content = summery[0]['choices'][0]['delta']['content']
try:
summery_content = json.loads(summery_content)["title"]
cursor.execute('''
UPDATE conversations
SET is_summarized = ?, title = ?
WHERE conversation_id = ?
''', (True, summery_content, conversation_id))
conn.commit()
except Exception as e:
print(e)

# 获取模型输出
ret = await chat_online(content=text, history=history, chat_model_config=chat_model_config,
tool_config=tool_config, agent_id=agent_id, conversation_id=conversation_id)

response_messages = []
for message in ret:
if message['status'] == 7:
message_role = message['choices'][0]['role']
message_content = "Observation:\n" + message['choices'][0]['delta']['tool_calls'][0]['tool_output']
message_id, timestamp_message = insert_message(agent_status=7, role=message_role, content=message_content,
files=json.dumps({}), conversation_id=conversation_id)

cursor.execute('''
UPDATE conversations
SET updated_at = ?
WHERE conversation_id = ?
''', (timestamp_message, conversation_id))
conn.commit()
if chat_model_config["platform"] == "LOCAL":
model_key = next(iter(chat_model_config["llm_model"]))
chat_model_config["llm_model"][model_key]["temperature"] = temperature
chat_model_config["llm_model"][model_key]["max_tokens"] = max_tokens

ret = await chat_outline(content=text, history=history, chat_model_config=chat_model_config)

message_id, timestamp_message = insert_message(agent_status=3, role="assistant", content=ret,
files=json.dumps({}),
conversation_id=conversation_id)

cursor.execute('''UPDATE conversations
SET updated_at = ?
WHERE conversation_id = ?''', (timestamp_message, conversation_id))
conn.commit()

message_dict = {
"message_id": message_id,
"agent_status": 3,
"text": ret,
"files": [],
"timestamp": timestamp_message
}
response_messages = [message_dict]

message_dict = {
"message_id": message_id,
"agent_status": 7,
"text": message_content,
"files": [],
"timestamp": timestamp_message
}
response_messages.append(message_dict)
conn.close()

if message['status'] == 3:
message_role = message['choices'][0]['role']
message_content = message['choices'][0]['delta']['content']
message_list = split_message_content(message_content)
for m in message_list:
message_id, timestamp_message = insert_message(agent_status=3, role=message_role, content=m,
return BaseResponse(code=200, msg="success", data=response_messages)

else:
model_key = next(iter(chat_model_config["llm_model"]))
chat_model_config["llm_model"][model_key]["temperature"] = temperature
chat_model_config["llm_model"][model_key]["max_tokens"] = max_tokens

is_summery = False

if agent_id == -1:
cursor.execute('''SELECT agent_id FROM conversations WHERE conversation_id = ?''', (conversation_id,))
temp_agent_id = cursor.fetchone()[0]
if temp_agent_id != -1:
temp_agent = get_agent_from_id(temp_agent_id)
temp_agent_name = temp_agent["agent_name"]
temp_agent_abstract = temp_agent["agent_abstract"]
temp_agent_info = temp_agent["agent_info"]
agent_prompt = "Your name is " + temp_agent_name + "." + temp_agent_abstract + ". Below is your detailed information:" + temp_agent_info + "."
history.append({"role": "user", "content": agent_prompt})
is_summery = True

if len(history) == 0 or is_summery == True:
if len(history) == 0:
summery_prompt = "下面是用户的问题,请总结为不超过八个字的标题。\n" + "用户:" + text + '输出格式为:{"title":"总结的标题"},除了这个json,不允许输出其他内容。'
if is_summery == True:
summery_prompt = "下面是用户的问题,请总结为不超过八个字的标题。\n" + "用户:" + agent_prompt + text + '输出格式为:{"title":"总结的标题"},除了这个json,不允许输出其他内容。'
summery = await chat_online(content=summery_prompt, history=[], chat_model_config=chat_model_config,
agent_id=-1, tool_config=tool_config, conversation_id=conversation_id)
summery_content = summery[0]['choices'][0]['delta']['content']
try:
summery_content = json.loads(summery_content)["title"]
cursor.execute('''
UPDATE conversations
SET is_summarized = ?, title = ?
WHERE conversation_id = ?
''', (True, summery_content, conversation_id))
conn.commit()
except Exception as e:
print(e)

# 获取模型输出
ret = await chat_online(content=text, history=history, chat_model_config=chat_model_config,
tool_config=tool_config, agent_id=agent_id, conversation_id=conversation_id)

response_messages = []
for message in ret:
if message['status'] == 7:
message_role = message['choices'][0]['role']
message_content = "Observation:\n" + message['choices'][0]['delta']['tool_calls'][0]['tool_output']
message_id, timestamp_message = insert_message(agent_status=7, role=message_role,
content=message_content,
files=json.dumps({}), conversation_id=conversation_id)

cursor.execute('''
@@ -306,19 +311,44 @@ async def send_messages(

message_dict = {
"message_id": message_id,
"agent_status": 3,
"text": m,
"agent_status": 7,
"text": message_content,
"files": [],
"timestamp": timestamp_message
}

response_messages.append(message_dict)

# TODO 这里考虑处理一下message['status']是4但之前一个message['status']不是3的,即agent无法解析的内容
if message['status'] == 3:
message_role = message['choices'][0]['role']
message_content = message['choices'][0]['delta']['content']
message_list = split_message_content(message_content)
for m in message_list:
message_id, timestamp_message = insert_message(agent_status=3, role=message_role, content=m,
files=json.dumps({}),
conversation_id=conversation_id)

cursor.execute('''
UPDATE conversations
SET updated_at = ?
WHERE conversation_id = ?
''', (timestamp_message, conversation_id))
conn.commit()

message_dict = {
"message_id": message_id,
"agent_status": 3,
"text": m,
"files": [],
"timestamp": timestamp_message
}

response_messages.append(message_dict)

# TODO 这里考虑处理一下message['status']是4但之前一个message['status']不是3的,即agent无法解析的内容

conn.close()
conn.close()

return BaseResponse(code=200, msg="success", data=response_messages)
return BaseResponse(code=200, msg="success", data=response_messages)


async def debug_messages(
@@ -398,4 +428,4 @@ async def debug_messages(

# TODO 这里考虑处理一下message['status']是4但之前一个message['status']不是3的,即agent无法解析的内容

return BaseResponse(code=200, msg="success", data=response_messages)
return BaseResponse(code=200, msg="success", data=response_messages)

Loading…
Cancel
Save