You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

llm_op.py 9.9 kB

10 months ago
10 months ago
1 year ago
1 year ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. import json
  2. import os
  3. import re
  4. import time
  5. import pyarrow as pa
  6. import pylcs
  7. from dora import DoraStatus
  8. from transformers import AutoModelForCausalLM, AutoTokenizer
  9. MODEL_NAME_OR_PATH = "TheBloke/deepseek-coder-6.7B-instruct-GPTQ"
  10. # MODEL_NAME_OR_PATH = "hanspeterlyngsoeraaschoujensen/deepseek-math-7b-instruct-GPTQ"
  11. CODE_MODIFIER_TEMPLATE = """
  12. ### Instruction
  13. Respond with the small modified code only. No explanation.
  14. ```python
  15. {code}
  16. ```
  17. {user_message}
  18. ### Response:
  19. """
  20. MESSAGE_SENDER_TEMPLATE = """
  21. ### Instruction
  22. You're a json expert. Format your response as a json with a topic and a data field in a ```json block. No explanation needed. No code needed.
  23. The schema for those json are:
  24. - line: Int[4]
  25. The response should look like this:
  26. ```json
  27. {{ "topic": "line", "data": [10, 10, 90, 10] }}
  28. ```
  29. {user_message}
  30. ### Response:
  31. """
  32. ASSISTANT_TEMPLATE = """
  33. ### Instruction
  34. You're a helpuf assistant named dora.
  35. Reply with a short message. No code needed.
  36. User {user_message}
  37. ### Response:
  38. """
  39. model = AutoModelForCausalLM.from_pretrained(
  40. MODEL_NAME_OR_PATH,
  41. device_map="auto",
  42. trust_remote_code=True,
  43. revision="main",
  44. )
  45. tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
  46. def extract_python_code_blocks(text):
  47. """Extracts Python code blocks from the given text that are enclosed in triple backticks with a python language identifier.
  48. Parameters
  49. ----------
  50. - text: A string that may contain one or more Python code blocks.
  51. Returns
  52. -------
  53. - A list of strings, where each string is a block of Python code extracted from the text.
  54. """
  55. pattern = r"```python\n(.*?)\n```"
  56. matches = re.findall(pattern, text, re.DOTALL)
  57. if len(matches) == 0:
  58. pattern = r"```python\n(.*?)(?:\n```|$)"
  59. matches = re.findall(pattern, text, re.DOTALL)
  60. if len(matches) == 0:
  61. return [text]
  62. matches = [remove_last_line(matches[0])]
  63. return matches
  64. def extract_json_code_blocks(text):
  65. """Extracts json code blocks from the given text that are enclosed in triple backticks with a json language identifier.
  66. Parameters
  67. ----------
  68. - text: A string that may contain one or more json code blocks.
  69. Returns
  70. -------
  71. - A list of strings, where each string is a block of json code extracted from the text.
  72. """
  73. pattern = r"```json\n(.*?)\n```"
  74. matches = re.findall(pattern, text, re.DOTALL)
  75. if len(matches) == 0:
  76. pattern = r"```json\n(.*?)(?:\n```|$)"
  77. matches = re.findall(pattern, text, re.DOTALL)
  78. if len(matches) == 0:
  79. return [text]
  80. return matches
  81. def remove_last_line(python_code):
  82. """Removes the last line from a given string of Python code.
  83. Parameters
  84. ----------
  85. - python_code: A string representing Python source code.
  86. Returns
  87. -------
  88. - A string with the last line removed.
  89. """
  90. lines = python_code.split("\n") # Split the string into lines
  91. if lines: # Check if there are any lines to remove
  92. lines.pop() # Remove the last line
  93. return "\n".join(lines) # Join the remaining lines back into a string
  94. def calculate_similarity(source, target):
  95. """Calculate a similarity score between the source and target strings.
  96. This uses the edit distance relative to the length of the strings.
  97. """
  98. edit_distance = pylcs.edit_distance(source, target)
  99. max_length = max(len(source), len(target))
  100. # Normalize the score by the maximum possible edit distance (the length of the longer string)
  101. similarity = 1 - (edit_distance / max_length)
  102. return similarity
  103. def find_best_match_location(source_code, target_block):
  104. """Find the best match for the target_block within the source_code by searching line by line,
  105. considering blocks of varying lengths.
  106. """
  107. source_lines = source_code.split("\n")
  108. target_lines = target_block.split("\n")
  109. best_similarity = 0
  110. best_start_index = 0
  111. best_end_index = -1
  112. # Iterate over the source lines to find the best matching range for all lines in target_block
  113. for start_index in range(len(source_lines) - len(target_lines) + 1):
  114. for end_index in range(start_index + len(target_lines), len(source_lines) + 1):
  115. current_window = "\n".join(source_lines[start_index:end_index])
  116. current_similarity = calculate_similarity(current_window, target_block)
  117. if current_similarity > best_similarity:
  118. best_similarity = current_similarity
  119. best_start_index = start_index
  120. best_end_index = end_index
  121. # Convert line indices back to character indices for replacement
  122. char_start_index = len("\n".join(source_lines[:best_start_index])) + (
  123. 1 if best_start_index > 0 else 0
  124. )
  125. char_end_index = len("\n".join(source_lines[:best_end_index]))
  126. return char_start_index, char_end_index
  127. def replace_code_in_source(source_code, replacement_block: str):
  128. """Replace the best matching block in the source_code with the replacement_block, considering variable block lengths.
  129. """
  130. replacement_block = extract_python_code_blocks(replacement_block)[0]
  131. start_index, end_index = find_best_match_location(source_code, replacement_block)
  132. if start_index != -1 and end_index != -1:
  133. # Replace the best matching part with the replacement block
  134. new_source = (
  135. source_code[:start_index] + replacement_block + source_code[end_index:]
  136. )
  137. return new_source
  138. return source_code
  139. class Operator:
  140. def on_event(
  141. self,
  142. dora_event,
  143. send_output,
  144. ) -> DoraStatus:
  145. if dora_event["type"] == "INPUT" and dora_event["id"] == "code_modifier":
  146. input = dora_event["value"][0].as_py()
  147. with open(input["path"], encoding="utf8") as f:
  148. code = f.read()
  149. user_message = input["user_message"]
  150. start_llm = time.time()
  151. output = self.ask_llm(
  152. CODE_MODIFIER_TEMPLATE.format(code=code, user_message=user_message),
  153. )
  154. source_code = replace_code_in_source(code, output)
  155. print("response time:", time.time() - start_llm, flush=True)
  156. send_output(
  157. "modified_file",
  158. pa.array(
  159. [
  160. {
  161. "raw": source_code,
  162. "path": input["path"],
  163. "response": output,
  164. "prompt": input["user_message"],
  165. },
  166. ],
  167. ),
  168. dora_event["metadata"],
  169. )
  170. print("response: ", output, flush=True)
  171. send_output(
  172. "assistant_message",
  173. pa.array([output]),
  174. dora_event["metadata"],
  175. )
  176. elif dora_event["type"] == "INPUT" and dora_event["id"] == "message_sender":
  177. user_message = dora_event["value"][0].as_py()
  178. output = self.ask_llm(
  179. MESSAGE_SENDER_TEMPLATE.format(user_message=user_message),
  180. )
  181. outputs = extract_json_code_blocks(output)[0]
  182. try:
  183. output = json.loads(outputs)
  184. if not isinstance(output["data"], list):
  185. output["data"] = [output["data"]]
  186. if output["topic"] in [
  187. "line",
  188. ]:
  189. send_output(
  190. output["topic"],
  191. pa.array(output["data"]),
  192. dora_event["metadata"],
  193. )
  194. else:
  195. print("Could not find the topic: {}".format(output["topic"]))
  196. except:
  197. print("Could not parse json")
  198. # if data is not iterable, put data in a list
  199. elif dora_event["type"] == "INPUT" and dora_event["id"] == "assistant":
  200. user_message = dora_event["value"][0].as_py()
  201. output = self.ask_llm(ASSISTANT_TEMPLATE.format(user_message=user_message))
  202. send_output(
  203. "assistant_message",
  204. pa.array([output]),
  205. dora_event["metadata"],
  206. )
  207. return DoraStatus.CONTINUE
  208. def ask_llm(self, prompt):
  209. # Generate output
  210. # prompt = PROMPT_TEMPLATE.format(system_message=system_message, prompt=prompt))
  211. input = tokenizer(prompt, return_tensors="pt")
  212. input_ids = input.input_ids.cuda()
  213. # add attention mask here
  214. attention_mask = input["attention_mask"].cuda()
  215. output = model.generate(
  216. inputs=input_ids,
  217. temperature=0.7,
  218. do_sample=True,
  219. top_p=0.95,
  220. top_k=40,
  221. max_new_tokens=512,
  222. attention_mask=attention_mask,
  223. eos_token_id=tokenizer.eos_token_id,
  224. )
  225. # Get the tokens from the output, decode them, print them
  226. # Get text between im_start and im_end
  227. return tokenizer.decode(output[0], skip_special_tokens=True)[len(prompt) :]
  228. if __name__ == "__main__":
  229. op = Operator()
  230. # Path to the current file
  231. current_file_path = __file__
  232. # Directory of the current file
  233. current_directory = os.path.dirname(current_file_path)
  234. path = current_directory + "object_detection.py"
  235. with open(path, encoding="utf8") as f:
  236. raw = f.read()
  237. op.on_event(
  238. {
  239. "type": "INPUT",
  240. "id": "message_sender",
  241. "value": pa.array(
  242. [
  243. {
  244. "path": path,
  245. "user_message": "send a star ",
  246. },
  247. ],
  248. ),
  249. "metadata": [],
  250. },
  251. print,
  252. )