| @@ -1,6 +1,7 @@ | |||
| """TODO: Add docstring.""" | |||
| import os | |||
| import ast | |||
| from pathlib import Path | |||
| import cv2 | |||
| import numpy as np | |||
| @@ -46,7 +47,7 @@ def load_magma_models(): | |||
| model, processor, MODEL_NAME_OR_PATH = load_magma_models() | |||
| def generate(image, task_description, template=None, num_marks=10, speed=8, steps=8): | |||
| """TODO: Add docstring.""" | |||
| """TODO: Add docstring.""" | |||
| if template is None: | |||
| template = ( | |||
| "<image>\nThe image is split into 256x256 grids and is labeled with numeric marks {}.\n" | |||
| @@ -87,10 +88,27 @@ def generate(image, task_description, template=None, num_marks=10, speed=8, step | |||
| use_cache=True, | |||
| ) | |||
| response = processor.batch_decode(output_ids, skip_special_tokens=True)[0] | |||
| return response | |||
| # Parse trajectories from response | |||
| trajectories = {} | |||
| try: | |||
| if "and their future positions are:" in response: | |||
| _, traces_str = response.split("and their future positions are:\n") | |||
| else: | |||
| _, traces_str = None, response | |||
| # Parse the trajectories using the same approach as in `https://github.com/microsoft/Magma/blob/main/agents/robot_traj/app.py` | |||
| traces_dict = ast.literal_eval('{' + traces_str.strip().replace('\n\n',',') + '}') | |||
| for mark_id, trace in traces_dict.items(): | |||
| trajectories[mark_id] = ast.literal_eval(trace) | |||
| except Exception as e: | |||
| logger.warning(f"Failed to parse trajectories: {e}") | |||
| return response, trajectories | |||
| except Exception as e: | |||
| logger.error(f"Error in generate: {e}") | |||
| return f"Error: {e}" | |||
| return f"Error: {e}", {} | |||
| def main(): | |||
| """TODO: Add docstring.""" | |||
| @@ -145,12 +163,21 @@ def main(): | |||
| continue | |||
| image = frames[image_id] | |||
| response = generate(image, task_description) | |||
| response, trajectories = generate(image, task_description) | |||
| node.send_output( | |||
| "text", | |||
| pa.array([response]), | |||
| {"image_id": image_id} | |||
| ) | |||
| # Send trajectory data if available | |||
| if trajectories: | |||
| import json | |||
| node.send_output( | |||
| "trajectories", | |||
| pa.array([json.dumps(trajectories)]), | |||
| {"image_id": image_id} | |||
| ) | |||
| else: | |||
| continue | |||