You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

llm_op.py 10 kB

10 months ago
10 months ago
10 months ago
1 year ago
1 year ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. """TODO: Add docstring."""
  2. import json
  3. import os
  4. import re
  5. import time
  6. import pyarrow as pa
  7. import pylcs
  8. from dora import DoraStatus
  9. from transformers import AutoModelForCausalLM, AutoTokenizer
  10. MODEL_NAME_OR_PATH = "TheBloke/deepseek-coder-6.7B-instruct-GPTQ"
  11. # MODEL_NAME_OR_PATH = "hanspeterlyngsoeraaschoujensen/deepseek-math-7b-instruct-GPTQ"
  12. CODE_MODIFIER_TEMPLATE = """
  13. ### Instruction
  14. Respond with the small modified code only. No explanation.
  15. ```python
  16. {code}
  17. ```
  18. {user_message}
  19. ### Response:
  20. """
  21. MESSAGE_SENDER_TEMPLATE = """
  22. ### Instruction
  23. You're a json expert. Format your response as a json with a topic and a data field in a ```json block. No explanation needed. No code needed.
  24. The schema for those json are:
  25. - line: Int[4]
  26. The response should look like this:
  27. ```json
  28. {{ "topic": "line", "data": [10, 10, 90, 10] }}
  29. ```
  30. {user_message}
  31. ### Response:
  32. """
  33. ASSISTANT_TEMPLATE = """
  34. ### Instruction
  35. You're a helpuf assistant named dora.
  36. Reply with a short message. No code needed.
  37. User {user_message}
  38. ### Response:
  39. """
  40. model = AutoModelForCausalLM.from_pretrained(
  41. MODEL_NAME_OR_PATH,
  42. device_map="auto",
  43. trust_remote_code=True,
  44. revision="main",
  45. )
  46. tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
  47. def extract_python_code_blocks(text):
  48. """Extracts Python code blocks from the given text that are enclosed in triple backticks with a python language identifier.
  49. Parameters
  50. ----------
  51. - text: A string that may contain one or more Python code blocks.
  52. Returns
  53. -------
  54. - A list of strings, where each string is a block of Python code extracted from the text.
  55. """
  56. pattern = r"```python\n(.*?)\n```"
  57. matches = re.findall(pattern, text, re.DOTALL)
  58. if len(matches) == 0:
  59. pattern = r"```python\n(.*?)(?:\n```|$)"
  60. matches = re.findall(pattern, text, re.DOTALL)
  61. if len(matches) == 0:
  62. return [text]
  63. matches = [remove_last_line(matches[0])]
  64. return matches
  65. def extract_json_code_blocks(text):
  66. """Extracts json code blocks from the given text that are enclosed in triple backticks with a json language identifier.
  67. Parameters
  68. ----------
  69. - text: A string that may contain one or more json code blocks.
  70. Returns
  71. -------
  72. - A list of strings, where each string is a block of json code extracted from the text.
  73. """
  74. pattern = r"```json\n(.*?)\n```"
  75. matches = re.findall(pattern, text, re.DOTALL)
  76. if len(matches) == 0:
  77. pattern = r"```json\n(.*?)(?:\n```|$)"
  78. matches = re.findall(pattern, text, re.DOTALL)
  79. if len(matches) == 0:
  80. return [text]
  81. return matches
  82. def remove_last_line(python_code):
  83. """Removes the last line from a given string of Python code.
  84. Parameters
  85. ----------
  86. - python_code: A string representing Python source code.
  87. Returns
  88. -------
  89. - A string with the last line removed.
  90. """
  91. lines = python_code.split("\n") # Split the string into lines
  92. if lines: # Check if there are any lines to remove
  93. lines.pop() # Remove the last line
  94. return "\n".join(lines) # Join the remaining lines back into a string
  95. def calculate_similarity(source, target):
  96. """Calculate a similarity score between the source and target strings.
  97. This uses the edit distance relative to the length of the strings.
  98. """
  99. edit_distance = pylcs.edit_distance(source, target)
  100. max_length = max(len(source), len(target))
  101. # Normalize the score by the maximum possible edit distance (the length of the longer string)
  102. similarity = 1 - (edit_distance / max_length)
  103. return similarity
  104. def find_best_match_location(source_code, target_block):
  105. """Find the best match for the target_block within the source_code by searching line by line,
  106. considering blocks of varying lengths.
  107. """
  108. source_lines = source_code.split("\n")
  109. target_lines = target_block.split("\n")
  110. best_similarity = 0
  111. best_start_index = 0
  112. best_end_index = -1
  113. # Iterate over the source lines to find the best matching range for all lines in target_block
  114. for start_index in range(len(source_lines) - len(target_lines) + 1):
  115. for end_index in range(start_index + len(target_lines), len(source_lines) + 1):
  116. current_window = "\n".join(source_lines[start_index:end_index])
  117. current_similarity = calculate_similarity(current_window, target_block)
  118. if current_similarity > best_similarity:
  119. best_similarity = current_similarity
  120. best_start_index = start_index
  121. best_end_index = end_index
  122. # Convert line indices back to character indices for replacement
  123. char_start_index = len("\n".join(source_lines[:best_start_index])) + (
  124. 1 if best_start_index > 0 else 0
  125. )
  126. char_end_index = len("\n".join(source_lines[:best_end_index]))
  127. return char_start_index, char_end_index
  128. def replace_code_in_source(source_code, replacement_block: str):
  129. """Replace the best matching block in the source_code with the replacement_block, considering variable block lengths."""
  130. replacement_block = extract_python_code_blocks(replacement_block)[0]
  131. start_index, end_index = find_best_match_location(source_code, replacement_block)
  132. if start_index != -1 and end_index != -1:
  133. # Replace the best matching part with the replacement block
  134. new_source = (
  135. source_code[:start_index] + replacement_block + source_code[end_index:]
  136. )
  137. return new_source
  138. return source_code
  139. class Operator:
  140. """TODO: Add docstring."""
  141. def on_event(
  142. self,
  143. dora_event,
  144. send_output,
  145. ) -> DoraStatus:
  146. """TODO: Add docstring."""
  147. if dora_event["type"] == "INPUT" and dora_event["id"] == "code_modifier":
  148. input = dora_event["value"][0].as_py()
  149. with open(input["path"], encoding="utf8") as f:
  150. code = f.read()
  151. user_message = input["user_message"]
  152. start_llm = time.time()
  153. output = self.ask_llm(
  154. CODE_MODIFIER_TEMPLATE.format(code=code, user_message=user_message),
  155. )
  156. source_code = replace_code_in_source(code, output)
  157. print("response time:", time.time() - start_llm, flush=True)
  158. send_output(
  159. "modified_file",
  160. pa.array(
  161. [
  162. {
  163. "raw": source_code,
  164. "path": input["path"],
  165. "response": output,
  166. "prompt": input["user_message"],
  167. },
  168. ],
  169. ),
  170. dora_event["metadata"],
  171. )
  172. print("response: ", output, flush=True)
  173. send_output(
  174. "assistant_message",
  175. pa.array([output]),
  176. dora_event["metadata"],
  177. )
  178. elif dora_event["type"] == "INPUT" and dora_event["id"] == "message_sender":
  179. user_message = dora_event["value"][0].as_py()
  180. output = self.ask_llm(
  181. MESSAGE_SENDER_TEMPLATE.format(user_message=user_message),
  182. )
  183. outputs = extract_json_code_blocks(output)[0]
  184. try:
  185. output = json.loads(outputs)
  186. if not isinstance(output["data"], list):
  187. output["data"] = [output["data"]]
  188. if output["topic"] in [
  189. "line",
  190. ]:
  191. send_output(
  192. output["topic"],
  193. pa.array(output["data"]),
  194. dora_event["metadata"],
  195. )
  196. else:
  197. print("Could not find the topic: {}".format(output["topic"]))
  198. except:
  199. print("Could not parse json")
  200. # if data is not iterable, put data in a list
  201. elif dora_event["type"] == "INPUT" and dora_event["id"] == "assistant":
  202. user_message = dora_event["value"][0].as_py()
  203. output = self.ask_llm(ASSISTANT_TEMPLATE.format(user_message=user_message))
  204. send_output(
  205. "assistant_message",
  206. pa.array([output]),
  207. dora_event["metadata"],
  208. )
  209. return DoraStatus.CONTINUE
  210. def ask_llm(self, prompt):
  211. # Generate output
  212. # prompt = PROMPT_TEMPLATE.format(system_message=system_message, prompt=prompt))
  213. """TODO: Add docstring."""
  214. input = tokenizer(prompt, return_tensors="pt")
  215. input_ids = input.input_ids.cuda()
  216. # add attention mask here
  217. attention_mask = input["attention_mask"].cuda()
  218. output = model.generate(
  219. inputs=input_ids,
  220. temperature=0.7,
  221. do_sample=True,
  222. top_p=0.95,
  223. top_k=40,
  224. max_new_tokens=512,
  225. attention_mask=attention_mask,
  226. eos_token_id=tokenizer.eos_token_id,
  227. )
  228. # Get the tokens from the output, decode them, print them
  229. # Get text between im_start and im_end
  230. return tokenizer.decode(output[0], skip_special_tokens=True)[len(prompt) :]
  231. if __name__ == "__main__":
  232. op = Operator()
  233. # Path to the current file
  234. current_file_path = __file__
  235. # Directory of the current file
  236. current_directory = os.path.dirname(current_file_path)
  237. path = current_directory + "object_detection.py"
  238. with open(path, encoding="utf8") as f:
  239. raw = f.read()
  240. op.on_event(
  241. {
  242. "type": "INPUT",
  243. "id": "message_sender",
  244. "value": pa.array(
  245. [
  246. {
  247. "path": path,
  248. "user_message": "send a star ",
  249. },
  250. ],
  251. ),
  252. "metadata": [],
  253. },
  254. print,
  255. )