From b2d49e4253dcb697ba8f05c231fe1aa7ec11446f Mon Sep 17 00:00:00 2001 From: ShashwatPatil Date: Sun, 16 Mar 2025 15:43:39 +0530 Subject: [PATCH] fixed again --- node-hub/dora-transformer/dora_transformer/main.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/node-hub/dora-transformer/dora_transformer/main.py b/node-hub/dora-transformer/dora_transformer/main.py index 4c29b859..db2247ee 100644 --- a/node-hub/dora-transformer/dora_transformer/main.py +++ b/node-hub/dora-transformer/dora_transformer/main.py @@ -65,9 +65,11 @@ def load_model(): def generate_response(model, tokenizer, text: str, history, max_retries: int = 3) -> tuple[str, list]: """Generate text using the transformer model with safe fallback mechanisms.""" + global MAX_TOKENS # Declare global at the start of function retry_count = 0 current_device = DEVICE - original_history = history.copy() # Keep original history safe + original_history = history.copy() + current_max_tokens = MAX_TOKENS # Local copy for this generation attempt while retry_count < max_retries: try: @@ -84,7 +86,7 @@ def generate_response(model, tokenizer, text: str, history, max_retries: int = 3 with torch.inference_mode(): generated_ids = model.generate( **model_inputs, - max_new_tokens=MAX_TOKENS, + max_new_tokens=current_max_tokens, # Use local copy pad_token_id=tokenizer.pad_token_id, do_sample=True, temperature=0.7, @@ -129,7 +131,7 @@ def generate_response(model, tokenizer, text: str, history, max_retries: int = 3 else: # Final retry: Reduce token count logging.info("Reducing token count for final attempt") - MAX_TOKENS = 24 + current_max_tokens = max(24, current_max_tokens // 2) # Reduce tokens but keep minimum continue else: # For non-CUDA OOM errors, raise immediately