|
|
|
@@ -63,47 +63,84 @@ def load_model(): |
|
|
|
logging.error(f"Error loading model: {e}") |
|
|
|
raise |
|
|
|
|
|
|
|
def generate_response(model, tokenizer, text: str, history) -> tuple[str, list]: |
|
|
|
"""Generate text using the transformer model.""" |
|
|
|
try: |
|
|
|
history += [{"role": "user", "content": text}] |
|
|
|
prompt = tokenizer.apply_chat_template( |
|
|
|
history, tokenize=False, add_generation_prompt=True |
|
|
|
) |
|
|
|
|
|
|
|
model_inputs = tokenizer([prompt], return_tensors="pt").to(DEVICE) |
|
|
|
|
|
|
|
with torch.inference_mode(): |
|
|
|
generated_ids = model.generate( |
|
|
|
**model_inputs, |
|
|
|
max_new_tokens=MAX_TOKENS, |
|
|
|
pad_token_id=tokenizer.pad_token_id, |
|
|
|
do_sample=True, |
|
|
|
temperature=0.7, |
|
|
|
top_p=0.9, |
|
|
|
repetition_penalty=1.2, # Reduce repetition |
|
|
|
length_penalty=0.5, |
|
|
|
def generate_response(model, tokenizer, text: str, history, max_retries: int = 3) -> tuple[str, list]: |
|
|
|
"""Generate text using the transformer model with safe fallback mechanisms.""" |
|
|
|
retry_count = 0 |
|
|
|
current_device = DEVICE |
|
|
|
original_history = history.copy() # Keep original history safe |
|
|
|
|
|
|
|
while retry_count < max_retries: |
|
|
|
try: |
|
|
|
# Reset history to original state on retries |
|
|
|
history = original_history.copy() |
|
|
|
history += [{"role": "user", "content": text}] |
|
|
|
|
|
|
|
prompt = tokenizer.apply_chat_template( |
|
|
|
history, tokenize=False, add_generation_prompt=True |
|
|
|
) |
|
|
|
|
|
|
|
generated_ids = [ |
|
|
|
output_ids[len(input_ids):] |
|
|
|
for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) |
|
|
|
] |
|
|
|
|
|
|
|
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
|
|
history += [{"role": "assistant", "content": response}] |
|
|
|
|
|
|
|
# Clear CUDA cache after generation if enabled |
|
|
|
if ENABLE_MEMORY_EFFICIENT and DEVICE == "cuda": |
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
return response, history |
|
|
|
except RuntimeError as e: |
|
|
|
if "CUDA out of memory" in str(e): |
|
|
|
logging.error("CUDA out of memory during generation. Falling back to CPU") |
|
|
|
model.to("cpu") |
|
|
|
return generate_response(model, tokenizer, text, history) |
|
|
|
raise |
|
|
|
model_inputs = tokenizer([prompt], return_tensors="pt").to(current_device) |
|
|
|
|
|
|
|
with torch.inference_mode(): |
|
|
|
generated_ids = model.generate( |
|
|
|
**model_inputs, |
|
|
|
max_new_tokens=MAX_TOKENS, |
|
|
|
pad_token_id=tokenizer.pad_token_id, |
|
|
|
do_sample=True, |
|
|
|
temperature=0.7, |
|
|
|
top_p=0.9, |
|
|
|
repetition_penalty=1.2, |
|
|
|
length_penalty=0.5, |
|
|
|
) |
|
|
|
|
|
|
|
generated_ids = [ |
|
|
|
output_ids[len(input_ids):] |
|
|
|
for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) |
|
|
|
] |
|
|
|
|
|
|
|
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
|
|
history += [{"role": "assistant", "content": response}] |
|
|
|
|
|
|
|
# Clear CUDA cache after successful generation |
|
|
|
if ENABLE_MEMORY_EFFICIENT and current_device == "cuda": |
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
return response, history |
|
|
|
|
|
|
|
except RuntimeError as e: |
|
|
|
if "CUDA out of memory" in str(e): |
|
|
|
retry_count += 1 |
|
|
|
logging.warning(f"CUDA OOM error (attempt {retry_count}/{max_retries})") |
|
|
|
|
|
|
|
# Clear CUDA cache |
|
|
|
if current_device == "cuda": |
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
# Strategy for each retry |
|
|
|
if retry_count == 1: |
|
|
|
# First retry: Clear cache and try again on CUDA |
|
|
|
continue |
|
|
|
elif retry_count == 2: |
|
|
|
# Second retry: Move model to CPU |
|
|
|
logging.info("Moving model to CPU for fallback") |
|
|
|
current_device = "cpu" |
|
|
|
model = model.to("cpu") |
|
|
|
continue |
|
|
|
else: |
|
|
|
# Final retry: Reduce token count |
|
|
|
logging.info("Reducing token count for final attempt") |
|
|
|
global MAX_TOKENS |
|
|
|
MAX_TOKENS = max(32, MAX_TOKENS // 2) |
|
|
|
continue |
|
|
|
else: |
|
|
|
# For non-CUDA OOM errors, raise immediately |
|
|
|
raise |
|
|
|
|
|
|
|
# If we've exhausted all retries |
|
|
|
raise RuntimeError( |
|
|
|
"Failed to generate response after multiple attempts. " |
|
|
|
"Try reducing model size or using CPU inference." |
|
|
|
) |
|
|
|
|
|
|
|
def main(): |
|
|
|
# Initialize model and conversation history |
|
|
|
|