This commit is contained in:
Awni Hannun 2024-11-02 13:51:38 -07:00 committed by GitHub
parent 0f799947d0
commit 29c954f4cb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 24 additions and 25 deletions

View File

@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.
__version__ = "0.4.0"
__version__ = "0.4.1"

View File

@ -589,35 +589,34 @@ class DecodingTask:
)
return tokens, completed, sum_logprobs, pre_logits
try:
tokens, completed, sum_logprobs, pre_logits = _step(
tokens, audio_features, tokens, sum_logprobs
tokens, completed, sum_logprobs, pre_logits = _step(
tokens, audio_features, tokens, sum_logprobs
)
if self.tokenizer.no_speech is not None: # compute no_speech_probs
probs_at_sot = mx.softmax(pre_logits[:, self.sot_index], axis=-1)
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech]
else:
no_speech_probs = mx.full(n_batch, mx.nan)
mx.async_eval(completed, tokens, sum_logprobs, no_speech_probs)
for i in range(1, self.sample_len):
inputs = tokens[:, -1:]
if tokens.shape[-1] > self.n_ctx:
break
next_tokens, next_completed, next_sum_logprobs, _ = _step(
inputs, audio_features, tokens, sum_logprobs
)
if self.tokenizer.no_speech is not None: # compute no_speech_probs
probs_at_sot = mx.softmax(pre_logits[:, self.sot_index], axis=-1)
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech]
else:
no_speech_probs = mx.full(n_batch, mx.nan)
mx.async_eval(completed, tokens, sum_logprobs, no_speech_probs)
for i in range(1, self.sample_len):
inputs = tokens[:, -1:]
next_tokens, next_completed, next_sum_logprobs, _ = _step(
inputs, audio_features, tokens, sum_logprobs
)
mx.async_eval(next_completed, next_tokens, next_sum_logprobs)
if completed or tokens.shape[-1] > self.n_ctx:
break
tokens = next_tokens
completed = next_completed
sum_logprobs = next_sum_logprobs
finally:
self.inference.reset()
mx.async_eval(next_completed, next_tokens, next_sum_logprobs)
if completed:
break
tokens = next_tokens
completed = next_completed
sum_logprobs = next_sum_logprobs
return tokens, sum_logprobs, no_speech_probs
def run(self, mel: mx.array) -> List[DecodingResult]:
self.inference.reset()
self.decoder.reset()
tokenizer: Tokenizer = self.tokenizer
n_audio: int = mel.shape[0]