Browse Source

Add support for Qwenvl2 (#646)

Add support for qwenvl2 :)
tags/v0.3.7rc0
Haixuan Xavier Tao GitHub 1 year ago
parent
commit
701d0e1b75
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
20 changed files with 716 additions and 3 deletions
  1. +1
    -0
      examples/camera/.gitignore
  2. +9
    -0
      examples/camera/README.md
  3. +20
    -0
      examples/camera/dataflow.yml
  4. +98
    -0
      examples/camera/run.rs
  5. +1
    -0
      examples/vlm/.gitignore
  6. +1
    -0
      examples/vlm/README.md
  7. +35
    -0
      examples/vlm/dataflow.yml
  8. +98
    -0
      examples/vlm/run.rs
  9. +2
    -2
      node-hub/dora-keyboard/pyproject.toml
  10. +3
    -0
      node-hub/dora-qwenvl/README.md
  11. +11
    -0
      node-hub/dora-qwenvl/dora_qwenvl/__init__.py
  12. +158
    -0
      node-hub/dora-qwenvl/dora_qwenvl/main.py
  13. +30
    -0
      node-hub/dora-qwenvl/pyproject.toml
  14. +9
    -0
      node-hub/dora-qwenvl/tests/test_dora_qwenvl.py
  15. +1
    -1
      node-hub/dora-rerun/src/main.rs
  16. +3
    -0
      node-hub/llama-factory-recorder/README.md
  17. +11
    -0
      node-hub/llama-factory-recorder/llama_factory_recorder/__init__.py
  18. +193
    -0
      node-hub/llama-factory-recorder/llama_factory_recorder/main.py
  19. +23
    -0
      node-hub/llama-factory-recorder/pyproject.toml
  20. +9
    -0
      node-hub/llama-factory-recorder/tests/test_llama_factory_recorder.py

+ 1
- 0
examples/camera/.gitignore View File

@@ -0,0 +1 @@
*.pt

+ 9
- 0
examples/camera/README.md View File

@@ -0,0 +1,9 @@
# Quick example on how to use a camera

Make sure to have, dora and pip installed.

```bash
dora up
dora build dataflow.yml
dora start dataflow.yml
```

+ 20
- 0
examples/camera/dataflow.yml View File

@@ -0,0 +1,20 @@
nodes:
- id: camera
build: pip install ../../node-hub/opencv-video-capture
path: opencv-video-capture
inputs:
tick: dora/timer/millis/20
outputs:
- image
env:
CAPTURE_PATH: 0
IMAGE_WIDTH: 640
IMAGE_HEIGHT: 480

- id: plot
build: pip install ../../node-hub/opencv-plot
path: opencv-plot
inputs:
image:
source: camera/image
queue_size: 1

+ 98
- 0
examples/camera/run.rs View File

@@ -0,0 +1,98 @@
use dora_core::{get_pip_path, get_python_path, run};
use dora_tracing::set_up_tracing;
use eyre::{bail, ContextCompat, WrapErr};
use std::path::Path;

#[tokio::main]
async fn main() -> eyre::Result<()> {
set_up_tracing("python-dataflow-runner")?;

let root = Path::new(env!("CARGO_MANIFEST_DIR"));
std::env::set_current_dir(root.join(file!()).parent().unwrap())
.wrap_err("failed to set working dir")?;

run(
get_python_path().context("Could not get python binary")?,
&["-m", "venv", "../.env"],
None,
)
.await
.context("failed to create venv")?;
let venv = &root.join("examples").join(".env");
std::env::set_var(
"VIRTUAL_ENV",
venv.to_str().context("venv path not valid unicode")?,
);
let orig_path = std::env::var("PATH")?;
// bin folder is named Scripts on windows.
// 🤦‍♂️ See: https://github.com/pypa/virtualenv/commit/993ba1316a83b760370f5a3872b3f5ef4dd904c1
let venv_bin = if cfg!(windows) {
venv.join("Scripts")
} else {
venv.join("bin")
};

if cfg!(windows) {
std::env::set_var(
"PATH",
format!(
"{};{orig_path}",
venv_bin.to_str().context("venv path not valid unicode")?
),
);
} else {
std::env::set_var(
"PATH",
format!(
"{}:{orig_path}",
venv_bin.to_str().context("venv path not valid unicode")?
),
);
}

run(
get_pip_path().context("Could not get pip binary")?,
&["install", "maturin"],
Some(venv),
)
.await
.context("pip install maturin failed")?;

run(
"maturin",
&["develop"],
Some(&root.join("apis").join("python").join("node")),
)
.await
.context("maturin develop failed")?;

let dataflow = Path::new("dataflow.yml");
run_dataflow(dataflow).await?;

Ok(())
}

async fn run_dataflow(dataflow: &Path) -> eyre::Result<()> {
let cargo = std::env::var("CARGO").unwrap();

// First build the dataflow (install requirements)
let mut cmd = tokio::process::Command::new(&cargo);
cmd.arg("run");
cmd.arg("--package").arg("dora-cli");
cmd.arg("--").arg("build").arg(dataflow);
if !cmd.status().await?.success() {
bail!("failed to run dataflow");
};

let mut cmd = tokio::process::Command::new(&cargo);
cmd.arg("run");
cmd.arg("--package").arg("dora-cli");
cmd.arg("--")
.arg("daemon")
.arg("--run-dataflow")
.arg(dataflow);
if !cmd.status().await?.success() {
bail!("failed to run dataflow");
};
Ok(())
}

+ 1
- 0
examples/vlm/.gitignore View File

@@ -0,0 +1 @@
*.pt

+ 1
- 0
examples/vlm/README.md View File

@@ -0,0 +1 @@
# Quick example on using a VLM with dora-rs

+ 35
- 0
examples/vlm/dataflow.yml View File

@@ -0,0 +1,35 @@
nodes:
- id: camera
build: pip install ../../node-hub/opencv-video-capture
path: opencv-video-capture
inputs:
tick: dora/timer/millis/20
outputs:
- image
env:
CAPTURE_PATH: 0
IMAGE_WIDTH: 640
IMAGE_HEIGHT: 480

- id: dora-qwenvl
build: pip install -e ../../node-hub/dora-qwenvl
path: dora-qwenvl
inputs:
image:
source: camera/image
queue_size: 1
tick: dora/timer/millis/300
outputs:
- text
- tick
env:
DEFAULT_QUESTION: Describe the image.

- id: plot
build: pip install ../../node-hub/opencv-plot
path: opencv-plot
inputs:
image:
source: camera/image
queue_size: 1
text: dora-qwenvl/tick

+ 98
- 0
examples/vlm/run.rs View File

@@ -0,0 +1,98 @@
use dora_core::{get_pip_path, get_python_path, run};
use dora_tracing::set_up_tracing;
use eyre::{bail, ContextCompat, WrapErr};
use std::path::Path;

#[tokio::main]
async fn main() -> eyre::Result<()> {
set_up_tracing("python-dataflow-runner")?;

let root = Path::new(env!("CARGO_MANIFEST_DIR"));
std::env::set_current_dir(root.join(file!()).parent().unwrap())
.wrap_err("failed to set working dir")?;

run(
get_python_path().context("Could not get python binary")?,
&["-m", "venv", "../.env"],
None,
)
.await
.context("failed to create venv")?;
let venv = &root.join("examples").join(".env");
std::env::set_var(
"VIRTUAL_ENV",
venv.to_str().context("venv path not valid unicode")?,
);
let orig_path = std::env::var("PATH")?;
// bin folder is named Scripts on windows.
// 🤦‍♂️ See: https://github.com/pypa/virtualenv/commit/993ba1316a83b760370f5a3872b3f5ef4dd904c1
let venv_bin = if cfg!(windows) {
venv.join("Scripts")
} else {
venv.join("bin")
};

if cfg!(windows) {
std::env::set_var(
"PATH",
format!(
"{};{orig_path}",
venv_bin.to_str().context("venv path not valid unicode")?
),
);
} else {
std::env::set_var(
"PATH",
format!(
"{}:{orig_path}",
venv_bin.to_str().context("venv path not valid unicode")?
),
);
}

run(
get_pip_path().context("Could not get pip binary")?,
&["install", "maturin"],
Some(venv),
)
.await
.context("pip install maturin failed")?;

run(
"maturin",
&["develop"],
Some(&root.join("apis").join("python").join("node")),
)
.await
.context("maturin develop failed")?;

let dataflow = Path::new("dataflow.yml");
run_dataflow(dataflow).await?;

Ok(())
}

async fn run_dataflow(dataflow: &Path) -> eyre::Result<()> {
let cargo = std::env::var("CARGO").unwrap();

// First build the dataflow (install requirements)
let mut cmd = tokio::process::Command::new(&cargo);
cmd.arg("run");
cmd.arg("--package").arg("dora-cli");
cmd.arg("--").arg("build").arg(dataflow);
if !cmd.status().await?.success() {
bail!("failed to run dataflow");
};

let mut cmd = tokio::process::Command::new(&cargo);
cmd.arg("run");
cmd.arg("--package").arg("dora-cli");
cmd.arg("--")
.arg("daemon")
.arg("--run-dataflow")
.arg(dataflow);
if !cmd.status().await?.success() {
bail!("failed to run dataflow");
};
Ok(())
}

+ 2
- 2
node-hub/dora-keyboard/pyproject.toml View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "dora-keyboard"
version = "0.3.5"
version = "0.3.6"
authors = [
"Haixuan Xavier Tao <tao.xavier@outlook.com>",
"Enzo Le Van <dev@enzo-le-van.fr>",
@@ -13,7 +13,7 @@ readme = "README.md"
packages = [{ include = "dora_keyboard" }]

[tool.poetry.dependencies]
dora-rs = "0.3.5"
dora-rs = "^0.3.6"
numpy = "< 2.0.0"
pyarrow = ">= 5.0.0"
pynput = "^1.7.6"


+ 3
- 0
node-hub/dora-qwenvl/README.md View File

@@ -0,0 +1,3 @@
# Dora QwenVL2 node

Experimental node for using a VLM within dora.

+ 11
- 0
node-hub/dora-qwenvl/dora_qwenvl/__init__.py View File

@@ -0,0 +1,11 @@
import os

# Define the path to the README file relative to the package directory
readme_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "README.md")

# Read the content of the README file
try:
with open(readme_path, "r", encoding="utf-8") as f:
__doc__ = f.read()
except FileNotFoundError:
__doc__ = "README file not found."

+ 158
- 0
node-hub/dora-qwenvl/dora_qwenvl/main.py View File

@@ -0,0 +1,158 @@
import os
from dora import Node
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import numpy as np
import pyarrow as pa
from PIL import Image

DEFAULT_PATH = "Qwen/Qwen2-VL-2B-Instruct"
CUSTOM_MODEL_PATH = os.getenv("CUSTOM_MODEL_PATH", DEFAULT_PATH)
DEFAULT_QUESTION = os.getenv(
"DEFAULT_QUESTION",
"Describe this image",
)

# Check if flash_attn is installed
try:
import flash_attn as _

model = Qwen2VLForConditionalGeneration.from_pretrained(
CUSTOM_MODEL_PATH,
torch_dtype="auto",
device_map="auto",
attn_implementation="flash_attention_2",
)
except ImportError:
model = Qwen2VLForConditionalGeneration.from_pretrained(
CUSTOM_MODEL_PATH,
torch_dtype="auto",
device_map="auto",
)


# default processor
processor = AutoProcessor.from_pretrained(DEFAULT_PATH)


def generate(frames: dict, question):
"""
Generate the response to the question given the image using Qwen2 model.
"""

messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": image,
}
for image in frames.values()
]
+ [
{"type": "text", "text": question},
],
}
]

# Preparation for inference
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda")

# Inference: Generation of the output
generated_ids = model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
return output_text[0]


def main():
pa.array([]) # initialize pyarrow array
node = Node()

question = DEFAULT_QUESTION
frames = {}

for event in node:
event_type = event["type"]

if event_type == "INPUT":
event_id = event["id"]

if "image" in event_id:
storage = event["value"]
metadata = event["metadata"]
encoding = metadata["encoding"]
width = metadata["width"]
height = metadata["height"]

if encoding == "bgr8":
channels = 3
storage_type = np.uint8
elif encoding == "rgb8":
channels = 3
storage_type = np.uint8
else:
raise RuntimeError(f"Unsupported image encoding: {encoding}")

frame = (
storage.to_numpy()
.astype(storage_type)
.reshape((height, width, channels))
)
if encoding == "bgr8":
frame = frame[:, :, ::-1] # OpenCV image (BGR to RGB)
elif encoding == "rgb8":
pass
else:
raise RuntimeError(f"Unsupported image encoding: {encoding}")
frames[event_id] = Image.fromarray(frame)

elif event_id == "tick":
if len(frames.keys()) == 0:
continue
response = generate(frames, question)
node.send_output(
"tick",
pa.array([response]),
{},
)

elif event_id == "text":
text = event["value"][0].as_py()
if text != "":
question = text
if len(frames.keys()) == 0:
continue
# set the max number of tiles in `max_num`
response = generate(frames, question)
node.send_output(
"text",
pa.array([response]),
{},
)

elif event_type == "ERROR":
raise RuntimeError(event["error"])


if __name__ == "__main__":
main()

+ 30
- 0
node-hub/dora-qwenvl/pyproject.toml View File

@@ -0,0 +1,30 @@
[tool.poetry]
name = "dora-qwenvl"
version = "0.3.6-rc0"
authors = [
"Haixuan Xavier Tao <tao.xavier@outlook.com>",
"Enzo Le Van <dev@enzo-le-van.fr>",
]
description = "Dora Node for VLM"
readme = "README.md"

packages = [{ include = "dora_qwenvl" }]

[tool.poetry.dependencies]
python = "^3.7"
dora-rs = "^0.3.6"
numpy = "< 2.0.0"
torch = "^2.4.0"
torchvision = "^0.19"
transformers = { git = "https://github.com/huggingface/transformers" }
qwen-vl-utils = "^0.0.2"
accelerate = "^0.33"
# flash_attn = "^2.6.1" # Install using: pip install -U flash-attn --no-build-isolation


[tool.poetry.scripts]
dora-qwenvl = "dora_qwenvl.main:main"

[build-system]
requires = ["poetry-core>=1.8.0"]
build-backend = "poetry.core.masonry.api"

+ 9
- 0
node-hub/dora-qwenvl/tests/test_dora_qwenvl.py View File

@@ -0,0 +1,9 @@
import pytest


def test_import_main():
from dora_qwenvl.main import main

# Check that everything is working, and catch dora Runtime Exception as we're not running in a dora dataflow.
with pytest.raises(RuntimeError):
main()

+ 1
- 1
node-hub/dora-rerun/src/main.rs View File

@@ -104,7 +104,7 @@ fn main() -> Result<()> {

rec.log(id.as_str(), &image)
.context("could not log image")?;
} else if id.as_str().contains("textlog") {
} else if id.as_str().contains("text") {
let buffer: StringArray = data.to_data().into();
buffer.iter().try_for_each(|string| -> Result<()> {
if let Some(str) = string {


+ 3
- 0
node-hub/llama-factory-recorder/README.md View File

@@ -0,0 +1,3 @@
# Dora Llama factory recorder

Experimental node for recording for training llama based model.

+ 11
- 0
node-hub/llama-factory-recorder/llama_factory_recorder/__init__.py View File

@@ -0,0 +1,11 @@
import os

# Define the path to the README file relative to the package directory
readme_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "README.md")

# Read the content of the README file
try:
with open(readme_path, "r", encoding="utf-8") as f:
__doc__ = f.read()
except FileNotFoundError:
__doc__ = "README file not found."

+ 193
- 0
node-hub/llama-factory-recorder/llama_factory_recorder/main.py View File

@@ -0,0 +1,193 @@
import os
import json
from dora import Node
import numpy as np
import pyarrow as pa
from PIL import Image
from pathlib import Path

DEFAULT_QUESTION = os.getenv(
"DEFAULT_QUESTION",
"Describe this image",
)


def write_dict_to_json(file_path, key: str, new_data):
"""
Writes a dictionary to a JSON file. If the file already contains a list of entries,
the new data will be appended to that list. Otherwise, it will create a new list.

Parameters:
- file_path: str, the path to the JSON file.
- new_data: dict, the dictionary to add to the JSON file.
"""
try:
# Open the JSON file and load its content
with open(file_path, "r+", encoding="utf-8") as file:
try:
data = json.load(file)
except json.JSONDecodeError:
data = {}

data[key] = new_data
# Write the updated data back to the file
file.seek(0)
json.dump(data, file, indent=4, ensure_ascii=False)
file.truncate()

except FileNotFoundError:
# If the file doesn't exist, create it and write the new data as a list
with open(file_path, "w", encoding="utf-8") as file:
json.dump({key: new_data}, file, indent=4, ensure_ascii=False)


def save_image_and_add_to_json(
image_array, root_path, llama_root_path, jsonl_file, messages
):
"""
Saves an image from a NumPy array and adds a new JSON object as a line to a JSONL file.
The function generates a sequential numeric image filename starting from 0 and
follows the provided template structure.

Parameters:
- image_array: numpy.ndarray, the image data as a NumPy array.
- root_path: str, the root directory where the image will be saved.
- jsonl_file: str, the path to the JSONL file.
- messages: list of dicts, each containing 'content' and 'role'.

The image is saved as a PNG file, and the JSONL entry includes the 'messages' and 'images' keys.
"""

# Create the root directory if it doesn't exist
os.makedirs(llama_root_path / root_path, exist_ok=True)

# Get the current image ID by counting existing files
image_id = len(
[
name
for name in os.listdir(llama_root_path / root_path)
if os.path.isfile(os.path.join(llama_root_path / root_path, name))
]
)

# Define the image filename
image_filename = f"{image_id}.png"
image_path = os.path.join(root_path, image_filename)

# Save the image
image = Image.fromarray(image_array)
image.save(llama_root_path / image_path)

# Create the JSON entry with 'messages' and 'images'
new_entry = {"messages": messages, "images": [image_path]}

# Add the entry to the JSONL file with UTF-8 encoding
with open(jsonl_file, "a", encoding="utf-8") as f:
json_line = json.dumps(new_entry)
f.write(json_line + "\n")


def main():
pa.array([]) # initialize pyarrow array
node = Node()

assert os.getenv(
"LLAMA_FACTORY_ROOT_PATH"
), "LLAMA_FACTORY_ROOT_PATH is not set, Either git clone the repo or set the environment variable"
llama_factory_root_path = Path(os.getenv("LLAMA_FACTORY_ROOT_PATH")) / "data"

entry_name = os.getenv("ENTRY_NAME", "dora_demo")
# If JSON already exists, append incremental suffix to avoid overwriting
if (llama_factory_root_path / entry_name).exists():
i = 1
while (llama_factory_root_path / f"{entry_name}_{i}.json").exists():
i += 1
entry_name = f"{entry_name}_{i}"

default_record_json_path = llama_factory_root_path / (entry_name + ".json")

write_dict_to_json(
llama_factory_root_path / "dataset_info.json",
entry_name,
{
"file_name": entry_name + ".json",
"formatting": "sharegpt",
"columns": {"messages": "messages", "images": "images"},
"tags": {
"role_tag": "role",
"content_tag": "content",
"user_tag": "user",
"assistant_tag": "assistant",
},
},
)

question = DEFAULT_QUESTION
frame = None

for event in node:
event_type = event["type"]

if event_type == "INPUT":
event_id = event["id"]

if event_id == "image":
storage = event["value"]
metadata = event["metadata"]
encoding = metadata["encoding"]
width = metadata["width"]
height = metadata["height"]

if encoding == "bgr8":
channels = 3
storage_type = np.uint8
elif encoding == "rgb8":
channels = 3
storage_type = np.uint8
else:
raise RuntimeError(f"Unsupported image encoding: {encoding}")

frame = (
storage.to_numpy()
.astype(storage_type)
.reshape((height, width, channels))
)
if encoding == "bgr8":
frame = frame[:, :, ::-1] # OpenCV image (BGR to RGB)
elif encoding == "rgb8":
pass
else:
raise RuntimeError(f"Unsupported image encoding: {encoding}")

elif event_id == "text":
text = event["value"][0].as_py()
if text != "":
question = text
elif event_id == "ground_truth":
if frame is None:
continue
ground_truth = event["value"][0].as_py()

messages = [
{"content": "<image>" + question, "role": "user"},
{
"content": ground_truth,
"role": "assistant",
},
]

save_image_and_add_to_json(
image_array=frame,
root_path=entry_name,
llama_root_path=llama_factory_root_path,
jsonl_file=default_record_json_path,
messages=messages,
)
node.send_output(
"text",
pa.array([ground_truth]),
metadata,
)

elif event_type == "ERROR":
raise RuntimeError(event["error"])

+ 23
- 0
node-hub/llama-factory-recorder/pyproject.toml View File

@@ -0,0 +1,23 @@
[tool.poetry]
name = "llama-factory-recorder"
version = "0.3.6-rc0"
authors = [
"Haixuan Xavier Tao <tao.xavier@outlook.com>",
"Enzo Le Van <dev@enzo-le-van.fr>",
]
description = "Dora Node for VLM"
readme = "README.md"

packages = [{ include = "llama_factory_recorder" }]

[tool.poetry.dependencies]
python = "^3.7"
dora-rs = "^0.3.6"
pillow = "^10.4.0"

[tool.poetry.scripts]
llama-factory-recorder = "llama_factory_recorder.main:main"

[build-system]
requires = ["poetry-core>=1.8.0"]
build-backend = "poetry.core.masonry.api"

+ 9
- 0
node-hub/llama-factory-recorder/tests/test_llama_factory_recorder.py View File

@@ -0,0 +1,9 @@
import pytest


def test_import_main():
from llama_factory_recorder.main import main

# Check that everything is working, and catch dora Runtime Exception as we're not running in a dora dataflow.
with pytest.raises(RuntimeError):
main()

Loading…
Cancel
Save