You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

agent_callback_handler.py 6.1 kB

1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. from __future__ import annotations
  2. import asyncio
  3. import json
  4. from typing import Any, Dict, List, Optional
  5. from uuid import UUID
  6. from langchain.callbacks import AsyncIteratorCallbackHandler
  7. from langchain.schema import AgentAction, AgentFinish
  8. from langchain_core.outputs import LLMResult
  9. def dumps(obj: Dict) -> str:
  10. return json.dumps(obj, ensure_ascii=False)
  11. class AgentStatus:
  12. llm_start: int = 1
  13. llm_new_token: int = 2
  14. llm_end: int = 3
  15. agent_action: int = 4
  16. agent_finish: int = 5
  17. tool_start: int = 6
  18. tool_end: int = 7
  19. error: int = 8
  20. class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
  21. def __init__(self):
  22. super().__init__()
  23. self.queue = asyncio.Queue()
  24. self.done = asyncio.Event()
  25. self.out = True
  26. async def on_llm_start(
  27. self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
  28. ) -> None:
  29. data = {
  30. "status": AgentStatus.llm_start,
  31. "text": "",
  32. }
  33. self.done.clear()
  34. self.queue.put_nowait(dumps(data))
  35. async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
  36. special_tokens = ["\n\nAction:", "\n\nObservation:", "<|observation|>", "\n\nThought:"]
  37. for stoken in special_tokens:
  38. if stoken in token:
  39. before_action = token.split(stoken)[0]
  40. data = {
  41. "status": AgentStatus.llm_new_token,
  42. "text": before_action + "\n",
  43. }
  44. self.queue.put_nowait(dumps(data))
  45. self.out = False
  46. break
  47. if token is not None and token != "" and self.out:
  48. data = {
  49. "status": AgentStatus.llm_new_token,
  50. "text": token,
  51. }
  52. self.queue.put_nowait(dumps(data))
  53. async def on_chat_model_start(
  54. self,
  55. serialized: Dict[str, Any],
  56. messages: List[List],
  57. *,
  58. run_id: UUID,
  59. parent_run_id: Optional[UUID] = None,
  60. tags: Optional[List[str]] = None,
  61. metadata: Optional[Dict[str, Any]] = None,
  62. **kwargs: Any,
  63. ) -> None:
  64. data = {
  65. "status": AgentStatus.llm_start,
  66. "text": "",
  67. }
  68. self.done.clear()
  69. self.queue.put_nowait(dumps(data))
  70. async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
  71. pass
  72. # data = {
  73. # "status": AgentStatus.llm_end,
  74. # "text": response.generations[0][0].message.content,
  75. # }
  76. # self.queue.put_nowait(dumps(data))
  77. async def on_llm_error(
  78. self, error: Exception | KeyboardInterrupt, **kwargs: Any
  79. ) -> None:
  80. data = {
  81. "status": AgentStatus.error,
  82. "text": str(error),
  83. }
  84. self.queue.put_nowait(dumps(data))
  85. async def on_tool_start(
  86. self,
  87. serialized: Dict[str, Any],
  88. input_str: str,
  89. *,
  90. run_id: UUID,
  91. parent_run_id: Optional[UUID] = None,
  92. tags: Optional[List[str]] = None,
  93. metadata: Optional[Dict[str, Any]] = None,
  94. **kwargs: Any,
  95. ) -> None:
  96. data = {
  97. "run_id": str(run_id),
  98. "status": AgentStatus.tool_start,
  99. "tool": serialized["name"],
  100. "tool_input": input_str,
  101. }
  102. self.queue.put_nowait(dumps(data))
  103. async def on_tool_end(
  104. self,
  105. output: str,
  106. *,
  107. run_id: UUID,
  108. parent_run_id: Optional[UUID] = None,
  109. tags: Optional[List[str]] = None,
  110. **kwargs: Any,
  111. ) -> None:
  112. """Run when tool ends running."""
  113. data = {
  114. "run_id": str(run_id),
  115. "status": AgentStatus.tool_end,
  116. "tool_output": output.to_serializable_data(),
  117. }
  118. # self.done.clear()
  119. self.queue.put_nowait(dumps(data))
  120. async def on_tool_error(
  121. self,
  122. error: BaseException,
  123. *,
  124. run_id: UUID,
  125. parent_run_id: Optional[UUID] = None,
  126. tags: Optional[List[str]] = None,
  127. **kwargs: Any,
  128. ) -> None:
  129. """Run when tool errors."""
  130. data = {
  131. "run_id": str(run_id),
  132. "status": AgentStatus.tool_end,
  133. "tool_output": str(error),
  134. "is_error": True,
  135. }
  136. # self.done.clear()
  137. self.queue.put_nowait(dumps(data))
  138. # async def on_agent_action(
  139. # self,
  140. # action: AgentAction,
  141. # *,
  142. # run_id: UUID,
  143. # parent_run_id: Optional[UUID] = None,
  144. # tags: Optional[List[str]] = None,
  145. # **kwargs: Any,
  146. # ) -> None:
  147. # data = {
  148. # "status": AgentStatus.agent_action,
  149. # "tool_name": action.tool,
  150. # "tool_input": action.tool_input,
  151. # "text": action.log,
  152. # }
  153. # self.queue.put_nowait(dumps(data))
  154. # async def on_agent_finish(
  155. # self,
  156. # finish: AgentFinish,
  157. # *,
  158. # run_id: UUID,
  159. # parent_run_id: Optional[UUID] = None,
  160. # tags: Optional[List[str]] = None,
  161. # **kwargs: Any,
  162. # ) -> None:
  163. # if "Thought:" in finish.return_values["output"]:
  164. # finish.return_values["output"] = finish.return_values["output"].replace(
  165. # "Thought:", ""
  166. # )
  167. #
  168. # data = {
  169. # "status": AgentStatus.agent_finish,
  170. # "text": finish.return_values["output"],
  171. # }
  172. # self.queue.put_nowait(dumps(data))
  173. async def on_chain_end(
  174. self,
  175. outputs: Dict[str, Any],
  176. *,
  177. run_id: UUID,
  178. parent_run_id: UUID | None = None,
  179. tags: List[str] | None = None,
  180. **kwargs: Any,
  181. ) -> None:
  182. self.done.set()
  183. self.out = True

MindPilot是一个跨平台的多功能智能Agent桌面助手,旨在为用户提供便捷、高效的智能解决方案。通过集成先进的大语言模型作为核心决策引擎,MindPilot能够对用户的任务进行精准分解、规划、执行、反思和总结,确保任务的高效完成。同时提供了高度自定义化的Agent,用户可以根据需求自定义不同身份的Agent,以应对多样化的任务场景,实现个性化的智能服务。在MindSpore和MindNLP的