|
|
@@ -0,0 +1,161 @@ |
|
|
|
|
|
"""TODO: Add docstring.""" |
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
|
|
|
from pathlib import Path |
|
|
|
|
|
import cv2 |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
import pyarrow as pa |
|
|
|
|
|
import torch |
|
|
|
|
|
from dora import Node |
|
|
|
|
|
from PIL import Image |
|
|
|
|
|
from transformers import AutoModelForCausalLM, AutoProcessor |
|
|
|
|
|
import logging |
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
current_dir = Path(__file__).parent.absolute() |
|
|
|
|
|
magma_dir = current_dir.parent / "Magma" / "magma" |
|
|
|
|
|
|
|
|
|
|
|
def load_magma_models(): |
|
|
|
|
|
"""TODO: Add docstring.""" |
|
|
|
|
|
DEFAULT_PATH = str(magma_dir.parent / "checkpoints" / "Magma-8B") |
|
|
|
|
|
if not os.path.exists(DEFAULT_PATH): |
|
|
|
|
|
DEFAULT_PATH = str(magma_dir.parent) |
|
|
|
|
|
if not os.path.exists(DEFAULT_PATH): |
|
|
|
|
|
logger.warning("Warning: Magma submodule not found, falling back to HuggingFace version") |
|
|
|
|
|
DEFAULT_PATH = "microsoft/Magma-8B" |
|
|
|
|
|
|
|
|
|
|
|
MODEL_NAME_OR_PATH = os.getenv("MODEL_NAME_OR_PATH", DEFAULT_PATH) |
|
|
|
|
|
logger.info(f"Loading Magma model from: {MODEL_NAME_OR_PATH}") |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
|
|
|
MODEL_NAME_OR_PATH, |
|
|
|
|
|
trust_remote_code=True, |
|
|
|
|
|
torch_dtype=torch.bfloat16, |
|
|
|
|
|
device_map="auto" |
|
|
|
|
|
) |
|
|
|
|
|
processor = AutoProcessor.from_pretrained(MODEL_NAME_OR_PATH, trust_remote_code=True) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
logger.error(f"Failed to load model: {e}") |
|
|
|
|
|
raise |
|
|
|
|
|
|
|
|
|
|
|
return model, processor, MODEL_NAME_OR_PATH |
|
|
|
|
|
|
|
|
|
|
|
model, processor, MODEL_NAME_OR_PATH = load_magma_models() |
|
|
|
|
|
|
|
|
|
|
|
def generate(image, task_description, template=None, num_marks=10, speed=8, steps=8): |
|
|
|
|
|
"""TODO: Add docstring.""" |
|
|
|
|
|
if template is None: |
|
|
|
|
|
template = ( |
|
|
|
|
|
"<image>\nThe image is split into 256x256 grids and is labeled with numeric marks {}.\n" |
|
|
|
|
|
"The robot is doing: {}. To finish the task, how to move the numerical marks in the image " |
|
|
|
|
|
"with speed {} for the next {} steps?\n" |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
mark_ids = [i + 1 for i in range(num_marks)] |
|
|
|
|
|
conv_user = template.format(mark_ids, task_description, speed, steps) |
|
|
|
|
|
|
|
|
|
|
|
if hasattr(model.config, 'mm_use_image_start_end') and model.config.mm_use_image_start_end: |
|
|
|
|
|
conv_user = conv_user.replace("<image>", "<image_start><image><image_end>") |
|
|
|
|
|
|
|
|
|
|
|
convs = [ |
|
|
|
|
|
{"role": "system", "content": "You are an agent that can see, talk, and act."}, |
|
|
|
|
|
{"role": "user", "content": conv_user}, |
|
|
|
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
prompt = processor.tokenizer.apply_chat_template( |
|
|
|
|
|
convs, |
|
|
|
|
|
tokenize=False, |
|
|
|
|
|
add_generation_prompt=True |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
inputs = processor(images=image, texts=prompt, return_tensors="pt") |
|
|
|
|
|
inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0) |
|
|
|
|
|
inputs['image_sizes'] = inputs['image_sizes'].unsqueeze(0) |
|
|
|
|
|
inputs = inputs.to(model.device) |
|
|
|
|
|
|
|
|
|
|
|
with torch.inference_mode(): |
|
|
|
|
|
output_ids = model.generate( |
|
|
|
|
|
**inputs, |
|
|
|
|
|
temperature=0.3, |
|
|
|
|
|
do_sample=True, |
|
|
|
|
|
num_beams=1, |
|
|
|
|
|
max_new_tokens=1024, |
|
|
|
|
|
use_cache=True, |
|
|
|
|
|
) |
|
|
|
|
|
response = processor.batch_decode(output_ids, skip_special_tokens=True)[0] |
|
|
|
|
|
return response |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
logger.error(f"Error in generate: {e}") |
|
|
|
|
|
return f"Error: {e}" |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
"""TODO: Add docstring.""" |
|
|
|
|
|
node = Node() |
|
|
|
|
|
frames = {} |
|
|
|
|
|
|
|
|
|
|
|
for event in node: |
|
|
|
|
|
event_type = event["type"] |
|
|
|
|
|
|
|
|
|
|
|
if event_type == "INPUT": |
|
|
|
|
|
event_id = event["id"] |
|
|
|
|
|
|
|
|
|
|
|
if "image" in event_id: |
|
|
|
|
|
storage = event["value"] |
|
|
|
|
|
metadata = event["metadata"] |
|
|
|
|
|
encoding = metadata["encoding"] |
|
|
|
|
|
width = metadata["width"] |
|
|
|
|
|
height = metadata["height"] |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
if encoding == "bgr8": |
|
|
|
|
|
frame = storage.to_numpy().astype(np.uint8).reshape((height, width, 3)) |
|
|
|
|
|
frame = frame[:, :, ::-1] # Convert BGR to RGB |
|
|
|
|
|
elif encoding == "rgb8": |
|
|
|
|
|
frame = storage.to_numpy().astype(np.uint8).reshape((height, width, 3)) |
|
|
|
|
|
elif encoding in ["jpeg", "jpg", "jpe", "bmp", "webp", "png"]: |
|
|
|
|
|
storage = storage.to_numpy() |
|
|
|
|
|
frame = cv2.imdecode(storage, cv2.IMREAD_COLOR) |
|
|
|
|
|
if frame is None: |
|
|
|
|
|
raise ValueError(f"Failed to decode image with encoding {encoding}") |
|
|
|
|
|
frame = frame[:, :, ::-1] # Convert BGR to RGB |
|
|
|
|
|
else: |
|
|
|
|
|
raise ValueError(f"Unsupported image encoding: {encoding}") |
|
|
|
|
|
|
|
|
|
|
|
image = Image.fromarray(frame) |
|
|
|
|
|
frames[event_id] = image |
|
|
|
|
|
|
|
|
|
|
|
# Cleanup old frames |
|
|
|
|
|
if len(frames) > 10: |
|
|
|
|
|
frames.popitem(last=False) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
logger.error(f"Error processing image {event_id}: {e}") |
|
|
|
|
|
|
|
|
|
|
|
# Handle text inputs |
|
|
|
|
|
elif "text" in event_id: |
|
|
|
|
|
if len(event["value"]) > 0: |
|
|
|
|
|
task_description = event["value"][0].as_py() |
|
|
|
|
|
image_id = event["metadata"].get("image_id", None) |
|
|
|
|
|
|
|
|
|
|
|
if image_id is None or image_id not in frames: |
|
|
|
|
|
logger.error(f"Image ID {image_id} not found in frames") |
|
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
image = frames[image_id] |
|
|
|
|
|
response = generate(image, task_description) |
|
|
|
|
|
node.send_output( |
|
|
|
|
|
"text", |
|
|
|
|
|
pa.array([response]), |
|
|
|
|
|
{"image_id": image_id} |
|
|
|
|
|
) |
|
|
|
|
|
else: |
|
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
elif event_type == "ERROR": |
|
|
|
|
|
logger.error(f"Event Error: {event['error']}") |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
main() |