You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

llm_op.py 10 kB

10 months ago
10 months ago
10 months ago
1 year ago
1 year ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. """TODO: Add docstring."""
  2. import json
  3. import os
  4. import re
  5. import time
  6. import pyarrow as pa
  7. import pylcs
  8. from dora import DoraStatus
  9. from transformers import AutoModelForCausalLM, AutoTokenizer
  10. MODEL_NAME_OR_PATH = "TheBloke/deepseek-coder-6.7B-instruct-GPTQ"
  11. # MODEL_NAME_OR_PATH = "hanspeterlyngsoeraaschoujensen/deepseek-math-7b-instruct-GPTQ"
  12. CODE_MODIFIER_TEMPLATE = """
  13. ### Instruction
  14. Respond with the small modified code only. No explanation.
  15. ```python
  16. {code}
  17. ```
  18. {user_message}
  19. ### Response:
  20. """
  21. MESSAGE_SENDER_TEMPLATE = """
  22. ### Instruction
  23. You're a json expert. Format your response as a json with a topic and a data field in a ```json block. No explanation needed. No code needed.
  24. The schema for those json are:
  25. - line: Int[4]
  26. The response should look like this:
  27. ```json
  28. {{ "topic": "line", "data": [10, 10, 90, 10] }}
  29. ```
  30. {user_message}
  31. ### Response:
  32. """
  33. ASSISTANT_TEMPLATE = """
  34. ### Instruction
  35. You're a helpuf assistant named dora.
  36. Reply with a short message. No code needed.
  37. User {user_message}
  38. ### Response:
  39. """
  40. model = AutoModelForCausalLM.from_pretrained(
  41. MODEL_NAME_OR_PATH,
  42. device_map="auto",
  43. trust_remote_code=True,
  44. revision="main",
  45. )
  46. tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
  47. def extract_python_code_blocks(text):
  48. """Extract Python code blocks from the given text that are enclosed in triple backticks with a python language identifier.
  49. Parameters
  50. ----------
  51. text : str
  52. A string that may contain one or more Python code blocks.
  53. Returns
  54. -------
  55. - A list of strings, where each string is a block of Python code extracted from the text.
  56. """
  57. pattern = r"```python\n(.*?)\n```"
  58. matches = re.findall(pattern, text, re.DOTALL)
  59. if len(matches) == 0:
  60. pattern = r"```python\n(.*?)(?:\n```|$)"
  61. matches = re.findall(pattern, text, re.DOTALL)
  62. if len(matches) == 0:
  63. return [text]
  64. matches = [remove_last_line(matches[0])]
  65. return matches
  66. def extract_json_code_blocks(text):
  67. """Extract json code blocks from the given text that are enclosed in triple backticks with a json language identifier.
  68. Parameters
  69. ----------
  70. text : str
  71. A string that may contain one or more json code blocks.
  72. Returns
  73. -------
  74. list of str
  75. A list of strings, where each string is a block of json code extracted from the text.
  76. """
  77. pattern = r"```json\n(.*?)\n```"
  78. matches = re.findall(pattern, text, re.DOTALL)
  79. if len(matches) == 0:
  80. pattern = r"```json\n(.*?)(?:\n```|$)"
  81. matches = re.findall(pattern, text, re.DOTALL)
  82. if len(matches) == 0:
  83. return [text]
  84. return matches
  85. def remove_last_line(python_code):
  86. """Remove the last line from a given string of Python code.
  87. Parameters
  88. ----------
  89. python_code : str
  90. A string representing Python source code.
  91. Returns
  92. -------
  93. str
  94. A string with the last line removed.
  95. """
  96. lines = python_code.split("\n") # Split the string into lines
  97. if lines: # Check if there are any lines to remove
  98. lines.pop() # Remove the last line
  99. return "\n".join(lines) # Join the remaining lines back into a string
  100. def calculate_similarity(source, target):
  101. """Calculate a similarity score between the source and target strings.
  102. This uses the edit distance relative to the length of the strings.
  103. """
  104. edit_distance = pylcs.edit_distance(source, target)
  105. max_length = max(len(source), len(target))
  106. # Normalize the score by the maximum possible edit distance (the length of the longer string)
  107. return 1 - (edit_distance / max_length)
  108. def find_best_match_location(source_code, target_block):
  109. """Find the best match for the target_block within the source_code by searching line by line, considering blocks of varying lengths."""
  110. source_lines = source_code.split("\n")
  111. target_lines = target_block.split("\n")
  112. best_similarity = 0
  113. best_start_index = 0
  114. best_end_index = -1
  115. # Iterate over the source lines to find the best matching range for all lines in target_block
  116. for start_index in range(len(source_lines) - len(target_lines) + 1):
  117. for end_index in range(start_index + len(target_lines), len(source_lines) + 1):
  118. current_window = "\n".join(source_lines[start_index:end_index])
  119. current_similarity = calculate_similarity(current_window, target_block)
  120. if current_similarity > best_similarity:
  121. best_similarity = current_similarity
  122. best_start_index = start_index
  123. best_end_index = end_index
  124. # Convert line indices back to character indices for replacement
  125. char_start_index = len("\n".join(source_lines[:best_start_index])) + (
  126. 1 if best_start_index > 0 else 0
  127. )
  128. char_end_index = len("\n".join(source_lines[:best_end_index]))
  129. return char_start_index, char_end_index
  130. def replace_code_in_source(source_code, replacement_block: str):
  131. """Replace the best matching block in the source_code with the replacement_block, considering variable block lengths."""
  132. replacement_block = extract_python_code_blocks(replacement_block)[0]
  133. start_index, end_index = find_best_match_location(source_code, replacement_block)
  134. if start_index != -1 and end_index != -1:
  135. # Replace the best matching part with the replacement block
  136. return (
  137. source_code[:start_index] + replacement_block + source_code[end_index:]
  138. )
  139. return source_code
  140. class Operator:
  141. """TODO: Add docstring."""
  142. def on_event(
  143. self,
  144. dora_event,
  145. send_output,
  146. ) -> DoraStatus:
  147. """TODO: Add docstring."""
  148. if dora_event["type"] == "INPUT" and dora_event["id"] == "code_modifier":
  149. input = dora_event["value"][0].as_py()
  150. with open(input["path"], encoding="utf8") as f:
  151. code = f.read()
  152. user_message = input["user_message"]
  153. start_llm = time.time()
  154. output = self.ask_llm(
  155. CODE_MODIFIER_TEMPLATE.format(code=code, user_message=user_message),
  156. )
  157. source_code = replace_code_in_source(code, output)
  158. print("response time:", time.time() - start_llm, flush=True)
  159. send_output(
  160. "modified_file",
  161. pa.array(
  162. [
  163. {
  164. "raw": source_code,
  165. "path": input["path"],
  166. "response": output,
  167. "prompt": input["user_message"],
  168. },
  169. ],
  170. ),
  171. dora_event["metadata"],
  172. )
  173. print("response: ", output, flush=True)
  174. send_output(
  175. "assistant_message",
  176. pa.array([output]),
  177. dora_event["metadata"],
  178. )
  179. elif dora_event["type"] == "INPUT" and dora_event["id"] == "message_sender":
  180. user_message = dora_event["value"][0].as_py()
  181. output = self.ask_llm(
  182. MESSAGE_SENDER_TEMPLATE.format(user_message=user_message),
  183. )
  184. outputs = extract_json_code_blocks(output)[0]
  185. try:
  186. output = json.loads(outputs)
  187. if not isinstance(output["data"], list):
  188. output["data"] = [output["data"]]
  189. if output["topic"] in [
  190. "line",
  191. ]:
  192. send_output(
  193. output["topic"],
  194. pa.array(output["data"]),
  195. dora_event["metadata"],
  196. )
  197. else:
  198. print("Could not find the topic: {}".format(output["topic"]))
  199. except:
  200. print("Could not parse json")
  201. # if data is not iterable, put data in a list
  202. elif dora_event["type"] == "INPUT" and dora_event["id"] == "assistant":
  203. user_message = dora_event["value"][0].as_py()
  204. output = self.ask_llm(ASSISTANT_TEMPLATE.format(user_message=user_message))
  205. send_output(
  206. "assistant_message",
  207. pa.array([output]),
  208. dora_event["metadata"],
  209. )
  210. return DoraStatus.CONTINUE
  211. def ask_llm(self, prompt):
  212. # Generate output
  213. # prompt = PROMPT_TEMPLATE.format(system_message=system_message, prompt=prompt))
  214. """TODO: Add docstring."""
  215. input = tokenizer(prompt, return_tensors="pt")
  216. input_ids = input.input_ids.cuda()
  217. # add attention mask here
  218. attention_mask = input["attention_mask"].cuda()
  219. output = model.generate(
  220. inputs=input_ids,
  221. temperature=0.7,
  222. do_sample=True,
  223. top_p=0.95,
  224. top_k=40,
  225. max_new_tokens=512,
  226. attention_mask=attention_mask,
  227. eos_token_id=tokenizer.eos_token_id,
  228. )
  229. # Get the tokens from the output, decode them, print them
  230. # Get text between im_start and im_end
  231. return tokenizer.decode(output[0], skip_special_tokens=True)[len(prompt) :]
  232. if __name__ == "__main__":
  233. op = Operator()
  234. # Path to the current file
  235. current_file_path = __file__
  236. # Directory of the current file
  237. current_directory = os.path.dirname(current_file_path)
  238. path = current_directory + "object_detection.py"
  239. with open(path, encoding="utf8") as f:
  240. raw = f.read()
  241. op.on_event(
  242. {
  243. "type": "INPUT",
  244. "id": "message_sender",
  245. "value": pa.array(
  246. [
  247. {
  248. "path": path,
  249. "user_message": "send a star ",
  250. },
  251. ],
  252. ),
  253. "metadata": [],
  254. },
  255. print,
  256. )