Browse Source

Add dora-qwen2-5-vl

tags/v0.3.9-rc1
haixuantao 11 months ago
parent
commit
65135134c7
7 changed files with 1599 additions and 0 deletions
  1. +28
    -0
      examples/vlm/qwen2-5-vl-vision-only.yml
  2. +32
    -0
      node-hub/dora-qwen2-5-vl/README.md
  3. +11
    -0
      node-hub/dora-qwen2-5-vl/dora_qwenvl/__init__.py
  4. +220
    -0
      node-hub/dora-qwen2-5-vl/dora_qwenvl/main.py
  5. +37
    -0
      node-hub/dora-qwen2-5-vl/pyproject.toml
  6. +9
    -0
      node-hub/dora-qwen2-5-vl/tests/test_dora_qwenvl.py
  7. +1262
    -0
      node-hub/dora-qwen2-5-vl/uv.lock

+ 28
- 0
examples/vlm/qwen2-5-vl-vision-only.yml View File

@@ -0,0 +1,28 @@
nodes:
- id: camera
build: pip install opencv-video-capture
path: opencv-video-capture
inputs:
tick: dora/timer/millis/100
outputs:
- image
env:
CAPTURE_PATH: 1

- id: dora-qwenvl
build: pip install -e ../../node-hub/dora-qwen2-5-vl
path: dora-qwenvl
inputs:
image: camera/image
text: dora/timer/millis/1000
outputs:
- text
env:
DEFAULT_QUESTION: Describe the image in three words.

- id: plot
build: pip install dora-rerun
path: dora-rerun
inputs:
image: camera/image
text_qwenvl: dora-qwenvl/text

+ 32
- 0
node-hub/dora-qwen2-5-vl/README.md View File

@@ -0,0 +1,32 @@
# Dora QwenVL2 node

Experimental node for using a VLM within dora.

## YAML Specification

This node is supposed to be used as follows:

```yaml
- id: dora-qwenvl
build: pip install dora-qwenvl
path: dora-qwenvl
inputs:
image:
source: camera/image
queue_size: 1
text: dora-distil-whisper/text
outputs:
- text
env:
DEFAULT_QUESTION: Describe the image in a very short sentence.
```

## Additional documentation

- Qwenvl: https://github.com/QwenLM/Qwen-VL

## Examples

- Vision Language Model
- Github: https://github.com/dora-rs/dora/blob/main/examples/vlm
- Website: https://dora-rs.ai/docs/examples/vlm

+ 11
- 0
node-hub/dora-qwen2-5-vl/dora_qwenvl/__init__.py View File

@@ -0,0 +1,11 @@
import os

# Define the path to the README file relative to the package directory
readme_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "README.md")

# Read the content of the README file
try:
with open(readme_path, encoding="utf-8") as f:
__doc__ = f.read()
except FileNotFoundError:
__doc__ = "README file not found."

+ 220
- 0
node-hub/dora-qwen2-5-vl/dora_qwenvl/main.py View File

@@ -0,0 +1,220 @@
import os
from pathlib import Path

import cv2
import numpy as np
import pyarrow as pa
from dora import Node
from PIL import Image
from qwen_vl_utils import process_vision_info
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration

DEFAULT_PATH = "Qwen/Qwen2.5-VL-3B-Instruct"

MODEL_NAME_OR_PATH = os.getenv("MODEL_NAME_OR_PATH", DEFAULT_PATH)

if bool(os.getenv("USE_MODELSCOPE_HUB") in ["True", "true"]):
from modelscope import snapshot_download

if not Path(MODEL_NAME_OR_PATH).exists():
MODEL_NAME_OR_PATH = snapshot_download(MODEL_NAME_OR_PATH)

SYSTEM_PROMPT = os.getenv(
"SYSTEM_PROMPT",
"You're a very succinct AI assistant, that describes image with a very short sentence.",
)
DEFAULT_QUESTION = os.getenv(
"DEFAULT_QUESTION",
"Describe this image",
)
IMAGE_WIDTH = int(
os.getenv(
"IMAGE_WIDTH",
"320",
)
)
IMAGE_HEIGHT = int(
os.getenv(
"IMAGE_HEIGHT",
"225",
)
)
HISTORY = os.getenv("HISTORY", "False") in ["True", "true"]
ADAPTER_PATH = os.getenv("ADAPTER_PATH", "")

# Check if flash_attn is installed
try:
import flash_attn as _ # noqa

model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_NAME_OR_PATH,
torch_dtype="auto",
device_map="auto",
attn_implementation="flash_attention_2",
)
except (ImportError, ModuleNotFoundError):
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_NAME_OR_PATH, torch_dtype="auto", device_map="auto"
)


if ADAPTER_PATH != "":
model.load_adapter(ADAPTER_PATH, "dora")


# default processor
processor = AutoProcessor.from_pretrained(MODEL_NAME_OR_PATH)


def generate(frames: dict, question, history):
"""Generate the response to the question given the image using Qwen2 model."""
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": image,
}
for image in frames.values()
]
+ [
{"type": "text", "text": question},
],
},
]
tmp_history = history + messages
# Preparation for inference
text = processor.apply_chat_template(
tmp_history,
tokenize=False,
add_generation_prompt=True,
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)

inputs = inputs.to(model.device)

# Inference: Generation of the output
generated_ids = model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
if HISTORY:
history += [
{
"role": "user",
"content": [
{"type": "text", "text": question},
],
},
{
"role": "assistant",
"content": [
{"type": "text", "text": output_text[0]},
],
},
]

return output_text[0], history


def main():
pa.array([]) # initialize pyarrow array
node = Node()

frames = {}
history = [
{
"role": "system",
"content": [
{"type": "text", "text": SYSTEM_PROMPT},
],
},
{
"role": "user",
"content": [
{"type": "text", "text": DEFAULT_QUESTION},
],
},
]

for event in node:
event_type = event["type"]

if event_type == "INPUT":
event_id = event["id"]

if "image" in event_id:
storage = event["value"]
metadata = event["metadata"]
encoding = metadata["encoding"]
width = metadata["width"]
height = metadata["height"]

if (
encoding == "bgr8"
or encoding == "rgb8"
or encoding in ["jpeg", "jpg", "jpe", "bmp", "webp", "png"]
):
channels = 3
storage_type = np.uint8
else:
raise RuntimeError(f"Unsupported image encoding: {encoding}")

if encoding == "bgr8":
frame = (
storage.to_numpy()
.astype(storage_type)
.reshape((height, width, channels))
)
frame = frame[:, :, ::-1] # OpenCV image (BGR to RGB)
elif encoding == "rgb8":
frame = (
storage.to_numpy()
.astype(storage_type)
.reshape((height, width, channels))
)
elif encoding in ["jpeg", "jpg", "jpe", "bmp", "webp", "png"]:
storage = storage.to_numpy()
frame = cv2.imdecode(storage, cv2.IMREAD_COLOR)
frame = frame[:, :, ::-1] # OpenCV image (BGR to RGB)
else:
raise RuntimeError(f"Unsupported image encoding: {encoding}")
image = Image.fromarray(frame)
frames[event_id] = image.resize((IMAGE_HEIGHT, IMAGE_WIDTH))

elif event_id == "text":
if len(event["value"]) > 0:
text = event["value"][0].as_py()
else:
text = ""

if len(frames.keys()) == 0:
continue
# set the max number of tiles in `max_num`
response, history = generate(frames, text, history)
node.send_output(
"text",
pa.array([response]),
{},
)

elif event_type == "ERROR":
print("Event Error:" + event["error"])


if __name__ == "__main__":
main()

+ 37
- 0
node-hub/dora-qwen2-5-vl/pyproject.toml View File

@@ -0,0 +1,37 @@
[project]
name = "dora-qwenvl"
version = "0.3.8"
authors = [
{ name = "Haixuan Xavier Tao", email = "tao.xavier@outlook.com" },
{ name = "Enzo Le Van", email = "dev@enzo-le-van.fr" },
]
description = "Dora Node for VLM"
license = { text = "MIT" }
readme = "README.md"
requires-python = ">=3.9"

dependencies = [
"dora-rs >= 0.3.6",
"numpy < 2.0.0",
"torch == 2.4.0",
"torchvision >= 0.19",
"torchaudio >= 2.1.0",
"qwen-vl-utils >= 0.0.5",
"opencv-python >= 4.1.1",
"modelscope >= 1.18.1",
"peft == 0.13.2",
"accelerate>=1.3.0",
"transformers",
]

# flash_attn = "^2.6.1" # Install using: pip install -U flash-attn --no-build-isolation


[dependency-groups]
dev = ["pytest >=8.1.1", "ruff >=0.9.1"]

[tool.uv.sources]
transformers = { git = "https://github.com/huggingface/transformers" }

[project.scripts]
dora-qwenvl = "dora_qwenvl.main:main"

+ 9
- 0
node-hub/dora-qwen2-5-vl/tests/test_dora_qwenvl.py View File

@@ -0,0 +1,9 @@
import pytest


def test_import_main():
from dora_qwenvl.main import main

# Check that everything is working, and catch dora Runtime Exception as we're not running in a dora dataflow.
with pytest.raises(RuntimeError):
main()

+ 1262
- 0
node-hub/dora-qwen2-5-vl/uv.lock
File diff suppressed because it is too large
View File


Loading…
Cancel
Save