From ac17d926e3ea2ad09848dc4cf77acebd962ef429 Mon Sep 17 00:00:00 2001 From: 7SOMAY Date: Tue, 11 Mar 2025 17:44:34 +0530 Subject: [PATCH] Added device map for for auto mapping - cpu & gpu --- node-hub/dora-phi4/dora_phi4/main.py | 59 +++++++++++++++++----------- node-hub/dora-phi4/pyproject.toml | 3 +- 2 files changed, 38 insertions(+), 24 deletions(-) diff --git a/node-hub/dora-phi4/dora_phi4/main.py b/node-hub/dora-phi4/dora_phi4/main.py index 00f87e25..0944a08a 100644 --- a/node-hub/dora-phi4/dora_phi4/main.py +++ b/node-hub/dora-phi4/dora_phi4/main.py @@ -5,6 +5,7 @@ import pyarrow as pa import requests import soundfile as sf import torch +from accelerate import infer_auto_device_map from dora import Node from PIL import Image from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig @@ -25,25 +26,37 @@ else: MODEL_PATH = "microsoft/Phi-4-multimodal-instruct" processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True) - -try: - model = AutoModelForCausalLM.from_pretrained( - MODEL_PATH, - torch_dtype=torch_dtype, - trust_remote_code=True, - _attn_implementation="eager", - low_cpu_mem_usage=True, # Reduce memory usage - ).to(device) -except RuntimeError: - print(f"⚠️ {device.upper()} ran out of memory! Switching to CPU.") - device = "cpu" - model = AutoModelForCausalLM.from_pretrained( - MODEL_PATH, - torch_dtype=torch.float32, # Use float32 for CPU - trust_remote_code=True, - _attn_implementation="eager", - low_cpu_mem_usage=True, - ).to("cpu") +# bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16) + +model = AutoModelForCausalLM.from_pretrained( + MODEL_PATH, + # quantization_config=bnb_config, + torch_dtype=torch.float16 + if device == "cuda" + else torch.bfloat16, # Use bfloat16 for CPU + trust_remote_code=True, + _attn_implementation="flash_attention_2" + if device == "cuda" and torch.cuda.get_device_properties(0).total_memory > 16e9 + else "eager", + low_cpu_mem_usage=True, +) + +# Infer and apply the device map before moving model +device_map = infer_auto_device_map(model) + +model = AutoModelForCausalLM.from_pretrained( + MODEL_PATH, + # quantization_config=bnb_config, + torch_dtype=torch.float16 + if device == "cuda" + else torch.bfloat16, # Use bfloat16 for CPU + trust_remote_code=True, + _attn_implementation="flash_attention_2" + if device == "cuda" and torch.cuda.get_device_properties(0).total_memory > 16e9 + else "eager", + low_cpu_mem_usage=True, + device_map=device_map, +) generation_config = GenerationConfig.from_pretrained(MODEL_PATH) @@ -61,12 +74,12 @@ def process_image(image_url): image = Image.open(requests.get(image_url, stream=True).raw) # Process input - inputs = processor(text=prompt, images=image, return_tensors="pt").to(device) + inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device) # Generate response with torch.no_grad(): generate_ids = model.generate( - **inputs, max_new_tokens=1000, generation_config=generation_config + **inputs, max_new_tokens=512, generation_config=generation_config ) generate_ids = generate_ids[:, inputs["input_ids"].shape[1] :] @@ -87,12 +100,12 @@ def process_audio(audio_url): # Process input inputs = processor( text=prompt, audios=[(audio, samplerate)], return_tensors="pt" - ).to(device) + ).to(model.device) # Generate response with torch.no_grad(): generate_ids = model.generate( - **inputs, max_new_tokens=1000, generation_config=generation_config + **inputs, max_new_tokens=512, generation_config=generation_config ) generate_ids = generate_ids[:, inputs["input_ids"].shape[1] :] diff --git a/node-hub/dora-phi4/pyproject.toml b/node-hub/dora-phi4/pyproject.toml index c84bbc63..b1c05cfb 100644 --- a/node-hub/dora-phi4/pyproject.toml +++ b/node-hub/dora-phi4/pyproject.toml @@ -8,7 +8,7 @@ readme = "README.md" requires-python = ">=3.10" dependencies = [ - "dora-rs >=0.3.9", + "dora-rs>=0.3.9", "torch==2.6.0", "torchvision==0.21.0", "transformers==4.48.2", @@ -18,6 +18,7 @@ dependencies = [ "scipy==1.15.2", "backoff==2.2.1", "peft==0.13.2", + "bitsandbytes>=0.42.0", "requests" ]