Browse Source

Merge 067f338d29 into 77c277910b

pull/922/merge
Sumitb09 GitHub 5 months ago
parent
commit
45a75d69dc
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
7 changed files with 181 additions and 0 deletions
  1. +9
    -0
      node-hub/dora-mlx-lm/README.md
  2. +15
    -0
      node-hub/dora-mlx-lm/dora_mlx_lm/__init__.py
  3. +6
    -0
      node-hub/dora-mlx-lm/dora_mlx_lm/__main__.py
  4. +66
    -0
      node-hub/dora-mlx-lm/dora_mlx_lm/_main_.py
  5. +33
    -0
      node-hub/dora-mlx-lm/pyproject.toml
  6. +38
    -0
      node-hub/dora-mlx-lm/test.yml
  7. +14
    -0
      node-hub/dora-mlx-lm/tests/test_dora_mlx_lm.py

+ 9
- 0
node-hub/dora-mlx-lm/README.md View File

@@ -0,0 +1,9 @@
# dora-mlx

A **Dora node** for running an **MLX-based language model**. This node enables text generation using **MLX**, an Apple Silicon-optimized deep learning framework.

## 🛠 Features
- Uses **MLX-LM** for efficient text generation on Apple Silicon (M1, M2, M3).
- Seamlessly integrates with **Dora**, a real-time event-based framework.
- Optimized for **macOS**.


+ 15
- 0
node-hub/dora-mlx-lm/dora_mlx_lm/__init__.py View File

@@ -0,0 +1,15 @@
# mlx_dora_lm/__init__.py
"""TODO: Add docstring."""

import os

# Define the path to the README file relative to the package directory
readme_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "README.md")

# Read the content of the README file
try:
with open(readme_path, encoding="utf-8") as f:
__doc__ = f.read()
except FileNotFoundError:
__doc__ = "README file not found."


+ 6
- 0
node-hub/dora-mlx-lm/dora_mlx_lm/__main__.py View File

@@ -0,0 +1,6 @@
# mlx_dora_lm/__main__.py
from .main import main

if __name__ == "__main__":
main()


+ 66
- 0
node-hub/dora-mlx-lm/dora_mlx_lm/_main_.py View File

@@ -0,0 +1,66 @@
"""Dora Node: MLX-based """

import logging
import os

import mlx.core as mx
import mlx.nn as nn
from dora import Node
from mlx_lm import load, generate
import pyarrow as pa

# Configure logging
logging.basicConfig(level=logging.INFO)

# Environment variables for model configuration
SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", "")
MODEL_NAME = os.getenv("MODEL_NAME", "mlx-community/Qwen1.5-0.5B")
MAX_TOKENS = int(os.getenv("MAX_TOKENS", "512"))
HISTORY_ENABLED = os.getenv("HISTORY", "False").lower() == "true"

# Words that trigger the model to respond
ACTIVATION_WORDS = os.getenv("ACTIVATION_WORDS", "").split()


def load_model():
"""Load the MLX transformer model and tokenizer."""
logging.info(f"Loading MLX model: {MODEL_NAME}")
model, tokenizer = load(MODEL_NAME, dtype=mx.float16)
logging.info("Model loaded successfully")
return model, tokenizer


def generate_response(model, tokenizer, text: str, history) -> tuple[str, list]:
"""Generate response using the MLX model."""
history.append({"role": "user", "content": text})

prompt = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True)

response = generate(model, tokenizer, prompt, max_tokens=MAX_TOKENS)
history.append({"role": "assistant", "content": response})

return response, history


def main():
"""Run Dora Node for MLX chatbot."""
# Load the model
model, tokenizer = load_model()
# Initialize chat history
history = [{"role": "system", "content": SYSTEM_PROMPT}] if SYSTEM_PROMPT else []
node = Node()

for event in node:
if event["type"] == "INPUT":
text = event["value"][0].as_py()
words = text.lower().split()

if not ACTIVATION_WORDS or any(word in ACTIVATION_WORDS for word in words):
response, tmp_history = generate_response(model, tokenizer, text, history)
history = tmp_history if HISTORY_ENABLED else history
node.send_output(output_id="text", data=pa.array([response]), metadata={})


if __name__ == "__main__":
main()

+ 33
- 0
node-hub/dora-mlx-lm/pyproject.toml View File

@@ -0,0 +1,33 @@
[project]
name = "dora-mlx"
version = "1.0.0"
authors = [{ name = "Sumit bharti", email = "email@email.com" }]
description = "Dora MLX "
license = { text = "MIT" }
readme = "README.md"
requires-python = ">=3.9"

dependencies = [
"dora-rs >= 0.3.9",
"mlx >= 0.8.0",
"mlx-lm >= 0.1.0",
"pyarrow >= 14.0.1"
]

[dependency-groups]
dev = ["pytest >=8.1.1", "ruff >=0.9.1"]

[project.scripts]
dora-mlx = "dora_mlx.main:main"

[tool.ruff.lint]
extend-select = [
"D", # pydocstyle
"UP", # Ruff's UP rule
"PERF", # Ruff's PERF rule
"RET", # Ruff's RET rule
"RSE", # Ruff's RSE rule
"NPY", # Ruff's NPY rule
"N", # Ruff's N rule
"I", # Ruff's I rule
]

+ 38
- 0
node-hub/dora-mlx-lm/test.yml View File

@@ -0,0 +1,38 @@
name: MLX Dora Model CI

on:
push:
branches:
- main
pull_request:
branches:
- main


jobs:
test:
runs-on: macos-latest

steps:

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.10"

- name: Install Homebrew dependencies
run: |
brew update
brew install cmake # MLX may require this for certain builds

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .

- name: Install development dependencies
run: |
pip install .[dev]

- name: Run tests
run: pytest tests

+ 14
- 0
node-hub/dora-mlx-lm/tests/test_dora_mlx_lm.py View File

@@ -0,0 +1,14 @@
"""Test for Dora MLX """

import pytest


def test_import_main():
"""Test that the MLX Dora chatbot script can be imported and run."""
from dora_mlx.main import main # Adjust the import path based on your directory structure

# Expect RuntimeError since we're not running inside a Dora dataflow
with pytest.raises(RuntimeError):
main()



Loading…
Cancel
Save