|
|
|
@@ -12,6 +12,7 @@ DEFAULT_QUESTION = os.getenv( |
|
|
|
"DEFAULT_QUESTION", |
|
|
|
"Describe this image", |
|
|
|
) |
|
|
|
ADAPTER_PATH = os.getenv("ADAPTER_PATH", "") |
|
|
|
|
|
|
|
# Check if flash_attn is installed |
|
|
|
try: |
|
|
|
@@ -23,7 +24,7 @@ try: |
|
|
|
device_map="auto", |
|
|
|
attn_implementation="flash_attention_2", |
|
|
|
) |
|
|
|
except ImportError: |
|
|
|
except ImportError or ModuleNotFoundError: |
|
|
|
model = Qwen2VLForConditionalGeneration.from_pretrained( |
|
|
|
CUSTOM_MODEL_PATH, |
|
|
|
torch_dtype="auto", |
|
|
|
@@ -31,8 +32,12 @@ except ImportError: |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if ADAPTER_PATH != "": |
|
|
|
model.load_adapter(ADAPTER_PATH, "dora") |
|
|
|
|
|
|
|
|
|
|
|
# default processor |
|
|
|
processor = AutoProcessor.from_pretrained(DEFAULT_PATH) |
|
|
|
processor = AutoProcessor.from_pretrained(CUSTOM_MODEL_PATH) |
|
|
|
|
|
|
|
|
|
|
|
def generate(frames: dict, question): |
|
|
|
|