Browse Source

resolved Recursive fallback of dora-transformer node

tags/v0.3.11-rc1
ShashwatPatil 10 months ago
parent
commit
bbd8518760
2 changed files with 78 additions and 41 deletions
  1. +2
    -2
      node-hub/dora-transformer/README.md
  2. +76
    -39
      node-hub/dora-transformer/dora_transformer/main.py

+ 2
- 2
node-hub/dora-transformer/README.md View File

@@ -127,8 +127,8 @@ nodes:
### Running the Example

```bash
dora build example.yml
dora run example.yml
dora build test.yml
dora run test.yml
```

### Troubleshooting


+ 76
- 39
node-hub/dora-transformer/dora_transformer/main.py View File

@@ -63,47 +63,84 @@ def load_model():
logging.error(f"Error loading model: {e}")
raise

def generate_response(model, tokenizer, text: str, history) -> tuple[str, list]:
"""Generate text using the transformer model."""
try:
history += [{"role": "user", "content": text}]
prompt = tokenizer.apply_chat_template(
history, tokenize=False, add_generation_prompt=True
)
model_inputs = tokenizer([prompt], return_tensors="pt").to(DEVICE)
with torch.inference_mode():
generated_ids = model.generate(
**model_inputs,
max_new_tokens=MAX_TOKENS,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.2, # Reduce repetition
length_penalty=0.5,
def generate_response(model, tokenizer, text: str, history, max_retries: int = 3) -> tuple[str, list]:
"""Generate text using the transformer model with safe fallback mechanisms."""
retry_count = 0
current_device = DEVICE
original_history = history.copy() # Keep original history safe
while retry_count < max_retries:
try:
# Reset history to original state on retries
history = original_history.copy()
history += [{"role": "user", "content": text}]
prompt = tokenizer.apply_chat_template(
history, tokenize=False, add_generation_prompt=True
)
generated_ids = [
output_ids[len(input_ids):]
for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
history += [{"role": "assistant", "content": response}]
# Clear CUDA cache after generation if enabled
if ENABLE_MEMORY_EFFICIENT and DEVICE == "cuda":
torch.cuda.empty_cache()
return response, history
except RuntimeError as e:
if "CUDA out of memory" in str(e):
logging.error("CUDA out of memory during generation. Falling back to CPU")
model.to("cpu")
return generate_response(model, tokenizer, text, history)
raise
model_inputs = tokenizer([prompt], return_tensors="pt").to(current_device)
with torch.inference_mode():
generated_ids = model.generate(
**model_inputs,
max_new_tokens=MAX_TOKENS,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.2,
length_penalty=0.5,
)
generated_ids = [
output_ids[len(input_ids):]
for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
history += [{"role": "assistant", "content": response}]
# Clear CUDA cache after successful generation
if ENABLE_MEMORY_EFFICIENT and current_device == "cuda":
torch.cuda.empty_cache()
return response, history
except RuntimeError as e:
if "CUDA out of memory" in str(e):
retry_count += 1
logging.warning(f"CUDA OOM error (attempt {retry_count}/{max_retries})")
# Clear CUDA cache
if current_device == "cuda":
torch.cuda.empty_cache()
# Strategy for each retry
if retry_count == 1:
# First retry: Clear cache and try again on CUDA
continue
elif retry_count == 2:
# Second retry: Move model to CPU
logging.info("Moving model to CPU for fallback")
current_device = "cpu"
model = model.to("cpu")
continue
else:
# Final retry: Reduce token count
logging.info("Reducing token count for final attempt")
global MAX_TOKENS
MAX_TOKENS = max(32, MAX_TOKENS // 2)
continue
else:
# For non-CUDA OOM errors, raise immediately
raise
# If we've exhausted all retries
raise RuntimeError(
"Failed to generate response after multiple attempts. "
"Try reducing model size or using CPU inference."
)

def main():
# Initialize model and conversation history


Loading…
Cancel
Save