You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 3.6 kB

10 months ago
11 months ago
11 months ago
11 months ago
10 months ago
11 months ago
10 months ago
10 months ago
11 months ago
10 months ago
11 months ago
11 months ago
11 months ago
10 months ago
10 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. """TODO: Add docstring."""
  2. import os
  3. import sys
  4. import pyarrow as pa
  5. from dora import Node
  6. from transformers import AutoModelForCausalLM, AutoTokenizer
  7. SYSTEM_PROMPT = os.getenv(
  8. "SYSTEM_PROMPT",
  9. "You're a very succinct AI assistant with short answers.",
  10. )
  11. def get_model_gguf():
  12. """TODO: Add docstring."""
  13. from llama_cpp import Llama
  14. llm = Llama.from_pretrained(
  15. repo_id="Qwen/Qwen2.5-0.5B-Instruct-GGUF",
  16. filename="*fp16.gguf",
  17. verbose=False,
  18. )
  19. return llm
  20. def get_model_darwin():
  21. """TODO: Add docstring."""
  22. from mlx_lm import load
  23. model, tokenizer = load("mlx-community/Qwen2.5-0.5B-Instruct-8bit")
  24. return model, tokenizer
  25. def get_model_huggingface():
  26. """TODO: Add docstring."""
  27. model_name = "Qwen/Qwen2.5-0.5B-Instruct"
  28. model = AutoModelForCausalLM.from_pretrained(
  29. model_name,
  30. torch_dtype="auto",
  31. device_map="auto",
  32. )
  33. tokenizer = AutoTokenizer.from_pretrained(model_name)
  34. return model, tokenizer
  35. ACTIVATION_WORDS = os.getenv("ACTIVATION_WORDS", "what how who where you").split()
  36. def generate_hf(model, tokenizer, prompt: str, history) -> str:
  37. """TODO: Add docstring."""
  38. history += [{"role": "user", "content": prompt}]
  39. text = tokenizer.apply_chat_template(
  40. history,
  41. tokenize=False,
  42. add_generation_prompt=True,
  43. )
  44. model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
  45. generated_ids = model.generate(**model_inputs, max_new_tokens=512)
  46. generated_ids = [
  47. output_ids[len(input_ids) :]
  48. for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
  49. ]
  50. response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
  51. history += [{"role": "assistant", "content": response}]
  52. return response, history
  53. def main():
  54. """TODO: Add docstring."""
  55. history = []
  56. # If OS is not Darwin, use Huggingface model
  57. if sys.platform == "darwin":
  58. model = get_model_gguf()
  59. elif sys.platform == "linux":
  60. model, tokenizer = get_model_huggingface()
  61. else:
  62. model, tokenizer = get_model_darwin()
  63. node = Node()
  64. for event in node:
  65. if event["type"] == "INPUT":
  66. # Warning: Make sure to add my_output_id and my_input_id within the dataflow.
  67. text = event["value"][0].as_py()
  68. words = text.lower().split()
  69. if any(word in ACTIVATION_WORDS for word in words):
  70. # On linux, Windows
  71. if sys.platform == "darwin":
  72. response = model(
  73. f"Q: {text} A: ", # Prompt
  74. max_tokens=24,
  75. stop=[
  76. "Q:",
  77. "\n",
  78. ], # Stop generating just before the model would generate a new question
  79. )["choices"][0]["text"]
  80. elif sys.platform == "linux":
  81. response, history = generate_hf(model, tokenizer, text, history)
  82. else:
  83. from mlx_lm import generate
  84. response = generate(
  85. model,
  86. tokenizer,
  87. prompt=text,
  88. verbose=False,
  89. max_tokens=50,
  90. )
  91. node.send_output(
  92. output_id="text",
  93. data=pa.array([response]),
  94. metadata={},
  95. )
  96. if __name__ == "__main__":
  97. main()