Browse Source

feat: updated generate func to use the same approach of app.py for parsing trajectories and also updated the input and output handling even better than before

tags/v0.3.11-rc1
Munish Mummadi 10 months ago
parent
commit
b6d4ea2152
1 changed files with 31 additions and 4 deletions
  1. +31
    -4
      node-hub/dora-magma/dora_magma/main.py

+ 31
- 4
node-hub/dora-magma/dora_magma/main.py View File

@@ -1,6 +1,7 @@
"""TODO: Add docstring."""

import os
import ast
from pathlib import Path
import cv2
import numpy as np
@@ -46,7 +47,7 @@ def load_magma_models():
model, processor, MODEL_NAME_OR_PATH = load_magma_models()

def generate(image, task_description, template=None, num_marks=10, speed=8, steps=8):
"""TODO: Add docstring."""
"""TODO: Add docstring."""
if template is None:
template = (
"<image>\nThe image is split into 256x256 grids and is labeled with numeric marks {}.\n"
@@ -87,10 +88,27 @@ def generate(image, task_description, template=None, num_marks=10, speed=8, step
use_cache=True,
)
response = processor.batch_decode(output_ids, skip_special_tokens=True)[0]
return response
# Parse trajectories from response
trajectories = {}
try:
if "and their future positions are:" in response:
_, traces_str = response.split("and their future positions are:\n")
else:
_, traces_str = None, response
# Parse the trajectories using the same approach as in `https://github.com/microsoft/Magma/blob/main/agents/robot_traj/app.py`
traces_dict = ast.literal_eval('{' + traces_str.strip().replace('\n\n',',') + '}')
for mark_id, trace in traces_dict.items():
trajectories[mark_id] = ast.literal_eval(trace)
except Exception as e:
logger.warning(f"Failed to parse trajectories: {e}")
return response, trajectories
except Exception as e:
logger.error(f"Error in generate: {e}")
return f"Error: {e}"
return f"Error: {e}", {}

def main():
"""TODO: Add docstring."""
@@ -145,12 +163,21 @@ def main():
continue
image = frames[image_id]
response = generate(image, task_description)
response, trajectories = generate(image, task_description)
node.send_output(
"text",
pa.array([response]),
{"image_id": image_id}
)
# Send trajectory data if available
if trajectories:
import json
node.send_output(
"trajectories",
pa.array([json.dumps(trajectories)]),
{"image_id": image_id}
)
else:
continue


Loading…
Cancel
Save