You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 5.6 kB

10 months ago
1 year ago
1 year ago
10 months ago
10 months ago
1 year ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. """TODO: Add docstring."""
  2. import os
  3. from pathlib import Path
  4. import cv2
  5. import numpy as np
  6. import pyarrow as pa
  7. import torch
  8. from dora import Node
  9. from PIL import Image
  10. from qwen_vl_utils import process_vision_info
  11. from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
  12. DEFAULT_PATH = "Qwen/Qwen2-VL-2B-Instruct"
  13. MODEL_NAME_OR_PATH = os.getenv("MODEL_NAME_OR_PATH", DEFAULT_PATH)
  14. if bool(os.getenv("USE_MODELSCOPE_HUB") in ["True", "true"]):
  15. from modelscope import snapshot_download
  16. if not Path(MODEL_NAME_OR_PATH).exists():
  17. MODEL_NAME_OR_PATH = snapshot_download(MODEL_NAME_OR_PATH)
  18. DEFAULT_QUESTION = os.getenv(
  19. "DEFAULT_QUESTION",
  20. "Describe this image",
  21. )
  22. ADAPTER_PATH = os.getenv("ADAPTER_PATH", "")
  23. # Check if flash_attn is installed
  24. try:
  25. import flash_attn as _ # noqa
  26. model = Qwen2VLForConditionalGeneration.from_pretrained(
  27. MODEL_NAME_OR_PATH,
  28. torch_dtype="auto",
  29. device_map="auto",
  30. attn_implementation="flash_attention_2",
  31. )
  32. except (ImportError, ModuleNotFoundError):
  33. model = Qwen2VLForConditionalGeneration.from_pretrained(
  34. MODEL_NAME_OR_PATH,
  35. torch_dtype="auto",
  36. device_map="auto",
  37. )
  38. if ADAPTER_PATH != "":
  39. model.load_adapter(ADAPTER_PATH, "dora")
  40. # default processor
  41. processor = AutoProcessor.from_pretrained(MODEL_NAME_OR_PATH)
  42. def generate(frames: dict, question):
  43. """Generate the response to the question given the image using Qwen2 model."""
  44. messages = [
  45. {
  46. "role": "user",
  47. "content": [
  48. {
  49. "type": "image",
  50. "image": image,
  51. }
  52. for image in frames.values()
  53. ]
  54. + [
  55. {"type": "text", "text": question},
  56. ],
  57. },
  58. ]
  59. # Preparation for inference
  60. text = processor.apply_chat_template(
  61. messages, tokenize=False, add_generation_prompt=True,
  62. )
  63. image_inputs, video_inputs = process_vision_info(messages)
  64. inputs = processor(
  65. text=[text],
  66. images=image_inputs,
  67. videos=video_inputs,
  68. padding=True,
  69. return_tensors="pt",
  70. )
  71. if torch.backends.mps.is_available():
  72. device = torch.device("mps")
  73. elif torch.cuda.is_available():
  74. device = torch.device("cuda", 0)
  75. else:
  76. device = torch.device("cpu")
  77. inputs = inputs.to(device)
  78. # Inference: Generation of the output
  79. generated_ids = model.generate(**inputs, max_new_tokens=128)
  80. generated_ids_trimmed = [
  81. out_ids[len(in_ids) :]
  82. for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
  83. ]
  84. output_text = processor.batch_decode(
  85. generated_ids_trimmed,
  86. skip_special_tokens=True,
  87. clean_up_tokenization_spaces=False,
  88. )
  89. return output_text[0]
  90. def main():
  91. """TODO: Add docstring."""
  92. pa.array([]) # initialize pyarrow array
  93. node = Node()
  94. question = DEFAULT_QUESTION
  95. frames = {}
  96. for event in node:
  97. event_type = event["type"]
  98. if event_type == "INPUT":
  99. event_id = event["id"]
  100. if "image" in event_id:
  101. storage = event["value"]
  102. metadata = event["metadata"]
  103. encoding = metadata["encoding"]
  104. width = metadata["width"]
  105. height = metadata["height"]
  106. if encoding == "bgr8" or encoding == "rgb8" or encoding in ["jpeg", "jpg", "jpe", "bmp", "webp", "png"]:
  107. channels = 3
  108. storage_type = np.uint8
  109. else:
  110. raise RuntimeError(f"Unsupported image encoding: {encoding}")
  111. if encoding == "bgr8":
  112. frame = (
  113. storage.to_numpy()
  114. .astype(storage_type)
  115. .reshape((height, width, channels))
  116. )
  117. frame = frame[:, :, ::-1] # OpenCV image (BGR to RGB)
  118. elif encoding == "rgb8":
  119. frame = (
  120. storage.to_numpy()
  121. .astype(storage_type)
  122. .reshape((height, width, channels))
  123. )
  124. elif encoding in ["jpeg", "jpg", "jpe", "bmp", "webp", "png"]:
  125. storage = storage.to_numpy()
  126. frame = cv2.imdecode(storage, cv2.IMREAD_COLOR)
  127. frame = frame[:, :, ::-1] # OpenCV image (BGR to RGB)
  128. else:
  129. raise RuntimeError(f"Unsupported image encoding: {encoding}")
  130. frames[event_id] = Image.fromarray(frame)
  131. elif event_id == "tick":
  132. if len(frames.keys()) == 0:
  133. continue
  134. response = generate(frames, question)
  135. node.send_output(
  136. "tick",
  137. pa.array([response]),
  138. {},
  139. )
  140. elif event_id == "text":
  141. text = event["value"][0].as_py()
  142. if text != "":
  143. question = text
  144. if len(frames.keys()) == 0:
  145. continue
  146. # set the max number of tiles in `max_num`
  147. response = generate(frames, question)
  148. node.send_output(
  149. "text",
  150. pa.array([response]),
  151. {},
  152. )
  153. elif event_type == "ERROR":
  154. print("Event Error:" + event["error"])
  155. if __name__ == "__main__":
  156. main()