Browse Source

Fix magma by using repo version instead of using transformers version

tags/v0.3.11-rc1
haixuanTao 10 months ago
parent
commit
3eb3abd005
2 changed files with 27 additions and 29 deletions
  1. +14
    -11
      node-hub/dora-magma/dora_magma/main.py
  2. +13
    -18
      node-hub/dora-magma/pyproject.toml

+ 14
- 11
node-hub/dora-magma/dora_magma/main.py View File

@@ -11,7 +11,10 @@ import pyarrow as pa
import torch
from dora import Node
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor

# from transformers import AutoModelForCausalLM, AutoProcessor
from dora_magma.Magma.magma.modeling_magma import MagmaForCausalLM
from dora_magma.Magma.magma.processing_magma import MagmaProcessor

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@@ -22,25 +25,25 @@ magma_dir = current_dir.parent / "Magma" / "magma"

def load_magma_models():
"""TODO: Add docstring."""
default_path = str(magma_dir.parent / "checkpoints" / "Magma-8B")
if not os.path.exists(default_path):
default_path = str(magma_dir.parent)
if not os.path.exists(default_path):
logger.warning(
"Warning: Magma submodule not found, falling back to HuggingFace version",
)
default_path = "microsoft/Magma-8B"
# default_path = str(magma_dir.parent / "checkpoints" / "Magma-8B")
# if not os.path.exists(default_path):
# default_path = str(magma_dir.parent)
# if not os.path.exists(default_path):
# logger.warning(
# "Warning: Magma submodule not found, falling back to HuggingFace version",
# )
default_path = "microsoft/Magma-8B"

model_name_or_path = os.getenv("MODEL_NAME_OR_PATH", default_path)
logger.info(f"Loading Magma model from: {model_name_or_path}")

model = AutoModelForCausalLM.from_pretrained(
model = MagmaForCausalLM.from_pretrained(
model_name_or_path,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
device_map="auto",
)
processor = AutoProcessor.from_pretrained(
processor = MagmaProcessor.from_pretrained(
model_name_or_path,
trust_remote_code=True,
torch_dtype=torch.bfloat16,


+ 13
- 18
node-hub/dora-magma/pyproject.toml View File

@@ -7,29 +7,24 @@ name = "dora-magma"
version = "0.1.0"
description = "Dora node for Microsoft Magma model"
requires-python = ">=3.10"
license = {text = "MIT"}
license = { text = "MIT" }
readme = "README.md"
authors = [
{name = "Munish Mummadi", email = "moneymindedmunish1@gmail.com"}
]
authors = [{ name = "Munish Mummadi", email = "moneymindedmunish1@gmail.com" }]
dependencies = [
"dora-rs >= 0.3.9",
"numpy >= 2.2.3",
"torch >= 2.4.0",
"torchvision >= 0.19",
"transformers >= 4.45",
"opencv-python >= 4.1.1",
"accelerate>=1.5.1",
"psutil>=7.0.0",
"open-clip-torch>=2.31.0",
"dora-rs >= 0.3.9",
"numpy < 2",
"torch >= 2.4.0",
"torchvision >= 0.19",
"transformers >= 4.45",
"opencv-python >= 4.1.1",
"accelerate>=1.5.1",
"psutil>=7.0.0",
"open-clip-torch>=2.31.0",
"wandb",
]

[dependency-groups]
dev = [
"pytest>=8.1.1",
"ruff>=0.9.1",
"pytest-cov>=4.0.0",
]
dev = ["pytest>=8.1.1", "ruff>=0.9.1", "pytest-cov>=4.0.0"]

[project.scripts]
dora-magma = "dora_magma.main:main"


Loading…
Cancel
Save