You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

llm_op.py 9.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. from dora import DoraStatus
  2. import pylcs
  3. import os
  4. import pyarrow as pa
  5. from transformers import AutoModelForCausalLM, AutoTokenizer
  6. import json
  7. import re
  8. import time
  9. MODEL_NAME_OR_PATH = "TheBloke/deepseek-coder-6.7B-instruct-GPTQ"
  10. # MODEL_NAME_OR_PATH = "hanspeterlyngsoeraaschoujensen/deepseek-math-7b-instruct-GPTQ"
  11. CODE_MODIFIER_TEMPLATE = """
  12. ### Instruction
  13. Respond with the small modified code only. No explaination.
  14. ```python
  15. {code}
  16. ```
  17. {user_message}
  18. ### Response:
  19. """
  20. MESSAGE_SENDER_TEMPLATE = """
  21. ### Instruction
  22. You're a json expert. Format your response as a json with a topic and a data field in a ```json block. No explaination needed. No code needed.
  23. The schema for those json are:
  24. - line: Int[4]
  25. The response should look like this:
  26. ```json
  27. {{ "topic": "line", "data": [10, 10, 90, 10] }}
  28. ```
  29. {user_message}
  30. ### Response:
  31. """
  32. ASSISTANT_TEMPLATE = """
  33. ### Instruction
  34. You're a helpuf assistant named dora.
  35. Reply with a short message. No code needed.
  36. User {user_message}
  37. ### Response:
  38. """
  39. model = AutoModelForCausalLM.from_pretrained(
  40. MODEL_NAME_OR_PATH,
  41. device_map="auto",
  42. trust_remote_code=True,
  43. revision="main",
  44. )
  45. tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
  46. def extract_python_code_blocks(text):
  47. """
  48. Extracts Python code blocks from the given text that are enclosed in triple backticks with a python language identifier.
  49. Parameters:
  50. - text: A string that may contain one or more Python code blocks.
  51. Returns:
  52. - A list of strings, where each string is a block of Python code extracted from the text.
  53. """
  54. pattern = r"```python\n(.*?)\n```"
  55. matches = re.findall(pattern, text, re.DOTALL)
  56. if len(matches) == 0:
  57. pattern = r"```python\n(.*?)(?:\n```|$)"
  58. matches = re.findall(pattern, text, re.DOTALL)
  59. if len(matches) == 0:
  60. return [text]
  61. else:
  62. matches = [remove_last_line(matches[0])]
  63. return matches
  64. def extract_json_code_blocks(text):
  65. """
  66. Extracts json code blocks from the given text that are enclosed in triple backticks with a json language identifier.
  67. Parameters:
  68. - text: A string that may contain one or more json code blocks.
  69. Returns:
  70. - A list of strings, where each string is a block of json code extracted from the text.
  71. """
  72. pattern = r"```json\n(.*?)\n```"
  73. matches = re.findall(pattern, text, re.DOTALL)
  74. if len(matches) == 0:
  75. pattern = r"```json\n(.*?)(?:\n```|$)"
  76. matches = re.findall(pattern, text, re.DOTALL)
  77. if len(matches) == 0:
  78. return [text]
  79. return matches
  80. def remove_last_line(python_code):
  81. """
  82. Removes the last line from a given string of Python code.
  83. Parameters:
  84. - python_code: A string representing Python source code.
  85. Returns:
  86. - A string with the last line removed.
  87. """
  88. lines = python_code.split("\n") # Split the string into lines
  89. if lines: # Check if there are any lines to remove
  90. lines.pop() # Remove the last line
  91. return "\n".join(lines) # Join the remaining lines back into a string
  92. def calculate_similarity(source, target):
  93. """
  94. Calculate a similarity score between the source and target strings.
  95. This uses the edit distance relative to the length of the strings.
  96. """
  97. edit_distance = pylcs.edit_distance(source, target)
  98. max_length = max(len(source), len(target))
  99. # Normalize the score by the maximum possible edit distance (the length of the longer string)
  100. similarity = 1 - (edit_distance / max_length)
  101. return similarity
  102. def find_best_match_location(source_code, target_block):
  103. """
  104. Find the best match for the target_block within the source_code by searching line by line,
  105. considering blocks of varying lengths.
  106. """
  107. source_lines = source_code.split("\n")
  108. target_lines = target_block.split("\n")
  109. best_similarity = 0
  110. best_start_index = 0
  111. best_end_index = -1
  112. # Iterate over the source lines to find the best matching range for all lines in target_block
  113. for start_index in range(len(source_lines) - len(target_lines) + 1):
  114. for end_index in range(start_index + len(target_lines), len(source_lines) + 1):
  115. current_window = "\n".join(source_lines[start_index:end_index])
  116. current_similarity = calculate_similarity(current_window, target_block)
  117. if current_similarity > best_similarity:
  118. best_similarity = current_similarity
  119. best_start_index = start_index
  120. best_end_index = end_index
  121. # Convert line indices back to character indices for replacement
  122. char_start_index = len("\n".join(source_lines[:best_start_index])) + (
  123. 1 if best_start_index > 0 else 0
  124. )
  125. char_end_index = len("\n".join(source_lines[:best_end_index]))
  126. return char_start_index, char_end_index
  127. def replace_code_in_source(source_code, replacement_block: str):
  128. """
  129. Replace the best matching block in the source_code with the replacement_block, considering variable block lengths.
  130. """
  131. replacement_block = extract_python_code_blocks(replacement_block)[0]
  132. start_index, end_index = find_best_match_location(source_code, replacement_block)
  133. if start_index != -1 and end_index != -1:
  134. # Replace the best matching part with the replacement block
  135. new_source = (
  136. source_code[:start_index] + replacement_block + source_code[end_index:]
  137. )
  138. return new_source
  139. else:
  140. return source_code
  141. class Operator:
  142. def on_event(
  143. self,
  144. dora_event,
  145. send_output,
  146. ) -> DoraStatus:
  147. if dora_event["type"] == "INPUT" and dora_event["id"] == "code_modifier":
  148. input = dora_event["value"][0].as_py()
  149. with open(input["path"], "r", encoding="utf8") as f:
  150. code = f.read()
  151. user_message = input["user_message"]
  152. start_llm = time.time()
  153. output = self.ask_llm(
  154. CODE_MODIFIER_TEMPLATE.format(code=code, user_message=user_message)
  155. )
  156. source_code = replace_code_in_source(code, output)
  157. print("response time:", time.time() - start_llm, flush=True)
  158. send_output(
  159. "modified_file",
  160. pa.array(
  161. [
  162. {
  163. "raw": source_code,
  164. "path": input["path"],
  165. "response": output,
  166. "prompt": input["user_message"],
  167. }
  168. ]
  169. ),
  170. dora_event["metadata"],
  171. )
  172. print("response: ", output, flush=True)
  173. send_output(
  174. "assistant_message",
  175. pa.array([output]),
  176. dora_event["metadata"],
  177. )
  178. elif dora_event["type"] == "INPUT" and dora_event["id"] == "message_sender":
  179. user_message = dora_event["value"][0].as_py()
  180. output = self.ask_llm(
  181. MESSAGE_SENDER_TEMPLATE.format(user_message=user_message)
  182. )
  183. outputs = extract_json_code_blocks(output)[0]
  184. try:
  185. output = json.loads(outputs)
  186. if not isinstance(output["data"], list):
  187. output["data"] = [output["data"]]
  188. if output["topic"] in [
  189. "line",
  190. ]:
  191. send_output(
  192. output["topic"],
  193. pa.array(output["data"]),
  194. dora_event["metadata"],
  195. )
  196. else:
  197. print("Could not find the topic: {}".format(output["topic"]))
  198. except:
  199. print("Could not parse json")
  200. # if data is not iterable, put data in a list
  201. elif dora_event["type"] == "INPUT" and dora_event["id"] == "assistant":
  202. user_message = dora_event["value"][0].as_py()
  203. output = self.ask_llm(ASSISTANT_TEMPLATE.format(user_message=user_message))
  204. send_output(
  205. "assistant_message",
  206. pa.array([output]),
  207. dora_event["metadata"],
  208. )
  209. return DoraStatus.CONTINUE
  210. def ask_llm(self, prompt):
  211. # Generate output
  212. # prompt = PROMPT_TEMPLATE.format(system_message=system_message, prompt=prompt))
  213. input = tokenizer(prompt, return_tensors="pt")
  214. input_ids = input.input_ids.cuda()
  215. # add attention mask here
  216. attention_mask = input["attention_mask"].cuda()
  217. output = model.generate(
  218. inputs=input_ids,
  219. temperature=0.7,
  220. do_sample=True,
  221. top_p=0.95,
  222. top_k=40,
  223. max_new_tokens=512,
  224. attention_mask=attention_mask,
  225. eos_token_id=tokenizer.eos_token_id,
  226. )
  227. # Get the tokens from the output, decode them, print them
  228. # Get text between im_start and im_end
  229. return tokenizer.decode(output[0], skip_special_tokens=True)[len(prompt) :]
  230. if __name__ == "__main__":
  231. op = Operator()
  232. # Path to the current file
  233. current_file_path = __file__
  234. # Directory of the current file
  235. current_directory = os.path.dirname(current_file_path)
  236. path = current_directory + "object_detection.py"
  237. with open(path, "r", encoding="utf8") as f:
  238. raw = f.read()
  239. op.on_event(
  240. {
  241. "type": "INPUT",
  242. "id": "message_sender",
  243. "value": pa.array(
  244. [
  245. {
  246. "path": path,
  247. "user_message": "send a star ",
  248. },
  249. ]
  250. ),
  251. "metadata": [],
  252. },
  253. print,
  254. )