diff --git a/nodes/audio.py b/nodes/audio.py index cb5c637..6fb7bf5 100644 --- a/nodes/audio.py +++ b/nodes/audio.py @@ -277,14 +277,14 @@ def transcribe( f"Processing chunk {chunk_offset:.1f}s - {chunk_end / sample_rate:.1f}s" ) - max_length = model.config.max_length or 448 + max_length = getattr(model.config, "max_length", None) or 448 attention_mask = torch.ones((1, max_length)) input_features = processor( chunk_waveform, sampling_rate=sample_rate, return_tensors="pt", - ).input_features.to(device) + ).input_features.to(device=device, dtype=model.dtype) with torch.no_grad(): predicted_ids = model.generate(