|
|
|
@@ -65,9 +65,11 @@ def load_model(): |
|
|
|
|
|
|
|
def generate_response(model, tokenizer, text: str, history, max_retries: int = 3) -> tuple[str, list]: |
|
|
|
"""Generate text using the transformer model with safe fallback mechanisms.""" |
|
|
|
global MAX_TOKENS # Declare global at the start of function |
|
|
|
retry_count = 0 |
|
|
|
current_device = DEVICE |
|
|
|
original_history = history.copy() # Keep original history safe |
|
|
|
original_history = history.copy() |
|
|
|
current_max_tokens = MAX_TOKENS # Local copy for this generation attempt |
|
|
|
|
|
|
|
while retry_count < max_retries: |
|
|
|
try: |
|
|
|
@@ -84,7 +86,7 @@ def generate_response(model, tokenizer, text: str, history, max_retries: int = 3 |
|
|
|
with torch.inference_mode(): |
|
|
|
generated_ids = model.generate( |
|
|
|
**model_inputs, |
|
|
|
max_new_tokens=MAX_TOKENS, |
|
|
|
max_new_tokens=current_max_tokens, # Use local copy |
|
|
|
pad_token_id=tokenizer.pad_token_id, |
|
|
|
do_sample=True, |
|
|
|
temperature=0.7, |
|
|
|
@@ -129,7 +131,7 @@ def generate_response(model, tokenizer, text: str, history, max_retries: int = 3 |
|
|
|
else: |
|
|
|
# Final retry: Reduce token count |
|
|
|
logging.info("Reducing token count for final attempt") |
|
|
|
MAX_TOKENS = 24 |
|
|
|
current_max_tokens = max(24, current_max_tokens // 2) # Reduce tokens but keep minimum |
|
|
|
continue |
|
|
|
else: |
|
|
|
# For non-CUDA OOM errors, raise immediately |
|
|
|
|