|
|
|
@@ -11,7 +11,10 @@ import pyarrow as pa |
|
|
|
import torch |
|
|
|
from dora import Node |
|
|
|
from PIL import Image |
|
|
|
from transformers import AutoModelForCausalLM, AutoProcessor |
|
|
|
|
|
|
|
# from transformers import AutoModelForCausalLM, AutoProcessor |
|
|
|
from dora_magma.Magma.magma.modeling_magma import MagmaForCausalLM |
|
|
|
from dora_magma.Magma.magma.processing_magma import MagmaProcessor |
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
@@ -22,25 +25,25 @@ magma_dir = current_dir.parent / "Magma" / "magma" |
|
|
|
|
|
|
|
def load_magma_models(): |
|
|
|
"""TODO: Add docstring.""" |
|
|
|
default_path = str(magma_dir.parent / "checkpoints" / "Magma-8B") |
|
|
|
if not os.path.exists(default_path): |
|
|
|
default_path = str(magma_dir.parent) |
|
|
|
if not os.path.exists(default_path): |
|
|
|
logger.warning( |
|
|
|
"Warning: Magma submodule not found, falling back to HuggingFace version", |
|
|
|
) |
|
|
|
default_path = "microsoft/Magma-8B" |
|
|
|
# default_path = str(magma_dir.parent / "checkpoints" / "Magma-8B") |
|
|
|
# if not os.path.exists(default_path): |
|
|
|
# default_path = str(magma_dir.parent) |
|
|
|
# if not os.path.exists(default_path): |
|
|
|
# logger.warning( |
|
|
|
# "Warning: Magma submodule not found, falling back to HuggingFace version", |
|
|
|
# ) |
|
|
|
default_path = "microsoft/Magma-8B" |
|
|
|
|
|
|
|
model_name_or_path = os.getenv("MODEL_NAME_OR_PATH", default_path) |
|
|
|
logger.info(f"Loading Magma model from: {model_name_or_path}") |
|
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
|
model = MagmaForCausalLM.from_pretrained( |
|
|
|
model_name_or_path, |
|
|
|
trust_remote_code=True, |
|
|
|
torch_dtype=torch.bfloat16, |
|
|
|
device_map="auto", |
|
|
|
) |
|
|
|
processor = AutoProcessor.from_pretrained( |
|
|
|
processor = MagmaProcessor.from_pretrained( |
|
|
|
model_name_or_path, |
|
|
|
trust_remote_code=True, |
|
|
|
torch_dtype=torch.bfloat16, |
|
|
|
|