Browse Source

fixed again

tags/v0.3.11-rc1
ShashwatPatil 10 months ago
parent
commit
b2d49e4253
1 changed files with 5 additions and 3 deletions
  1. +5
    -3
      node-hub/dora-transformer/dora_transformer/main.py

+ 5
- 3
node-hub/dora-transformer/dora_transformer/main.py View File

@@ -65,9 +65,11 @@ def load_model():

def generate_response(model, tokenizer, text: str, history, max_retries: int = 3) -> tuple[str, list]:
"""Generate text using the transformer model with safe fallback mechanisms."""
global MAX_TOKENS # Declare global at the start of function
retry_count = 0
current_device = DEVICE
original_history = history.copy() # Keep original history safe
original_history = history.copy()
current_max_tokens = MAX_TOKENS # Local copy for this generation attempt
while retry_count < max_retries:
try:
@@ -84,7 +86,7 @@ def generate_response(model, tokenizer, text: str, history, max_retries: int = 3
with torch.inference_mode():
generated_ids = model.generate(
**model_inputs,
max_new_tokens=MAX_TOKENS,
max_new_tokens=current_max_tokens, # Use local copy
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
temperature=0.7,
@@ -129,7 +131,7 @@ def generate_response(model, tokenizer, text: str, history, max_retries: int = 3
else:
# Final retry: Reduce token count
logging.info("Reducing token count for final attempt")
MAX_TOKENS = 24
current_max_tokens = max(24, current_max_tokens // 2) # Reduce tokens but keep minimum
continue
else:
# For non-CUDA OOM errors, raise immediately


Loading…
Cancel
Save