mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-05 08:11:13 +08:00
fix (#1082)
This commit is contained in:
parent
0f799947d0
commit
29c954f4cb
@ -1,3 +1,3 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
__version__ = "0.4.0"
|
__version__ = "0.4.1"
|
||||||
|
@ -589,7 +589,6 @@ class DecodingTask:
|
|||||||
)
|
)
|
||||||
return tokens, completed, sum_logprobs, pre_logits
|
return tokens, completed, sum_logprobs, pre_logits
|
||||||
|
|
||||||
try:
|
|
||||||
tokens, completed, sum_logprobs, pre_logits = _step(
|
tokens, completed, sum_logprobs, pre_logits = _step(
|
||||||
tokens, audio_features, tokens, sum_logprobs
|
tokens, audio_features, tokens, sum_logprobs
|
||||||
)
|
)
|
||||||
@ -602,22 +601,22 @@ class DecodingTask:
|
|||||||
|
|
||||||
for i in range(1, self.sample_len):
|
for i in range(1, self.sample_len):
|
||||||
inputs = tokens[:, -1:]
|
inputs = tokens[:, -1:]
|
||||||
|
if tokens.shape[-1] > self.n_ctx:
|
||||||
|
break
|
||||||
next_tokens, next_completed, next_sum_logprobs, _ = _step(
|
next_tokens, next_completed, next_sum_logprobs, _ = _step(
|
||||||
inputs, audio_features, tokens, sum_logprobs
|
inputs, audio_features, tokens, sum_logprobs
|
||||||
)
|
)
|
||||||
mx.async_eval(next_completed, next_tokens, next_sum_logprobs)
|
mx.async_eval(next_completed, next_tokens, next_sum_logprobs)
|
||||||
if completed or tokens.shape[-1] > self.n_ctx:
|
if completed:
|
||||||
break
|
break
|
||||||
tokens = next_tokens
|
tokens = next_tokens
|
||||||
completed = next_completed
|
completed = next_completed
|
||||||
sum_logprobs = next_sum_logprobs
|
sum_logprobs = next_sum_logprobs
|
||||||
|
|
||||||
finally:
|
|
||||||
self.inference.reset()
|
|
||||||
|
|
||||||
return tokens, sum_logprobs, no_speech_probs
|
return tokens, sum_logprobs, no_speech_probs
|
||||||
|
|
||||||
def run(self, mel: mx.array) -> List[DecodingResult]:
|
def run(self, mel: mx.array) -> List[DecodingResult]:
|
||||||
|
self.inference.reset()
|
||||||
self.decoder.reset()
|
self.decoder.reset()
|
||||||
tokenizer: Tokenizer = self.tokenizer
|
tokenizer: Tokenizer = self.tokenizer
|
||||||
n_audio: int = mel.shape[0]
|
n_audio: int = mel.shape[0]
|
||||||
|
Loading…
Reference in New Issue
Block a user