|
|
|
@@ -58,11 +58,12 @@ def detect_language(
|
|
|
|
|
logits = model.logits(x, mel)[:, 0]
|
|
|
|
|
|
|
|
|
|
# collect detected languages; suppress all non-language tokens
|
|
|
|
|
mask = np.full(logits.shape[-1], -np.inf, dtype=np.float32)
|
|
|
|
|
mask = mx.full(logits.shape[-1], -mx.inf, dtype=mx.float32)
|
|
|
|
|
mask[list(tokenizer.all_language_tokens)] = 0.0
|
|
|
|
|
logits += mx.array(mask)
|
|
|
|
|
logits += mask
|
|
|
|
|
language_tokens = mx.argmax(logits, axis=-1)
|
|
|
|
|
language_token_probs = mx.softmax(logits, axis=-1)
|
|
|
|
|
language_token_probs = np.array(language_token_probs)
|
|
|
|
|
language_probs = [
|
|
|
|
|
{
|
|
|
|
|
c: language_token_probs[i, j].item()
|
|
|
|
@@ -129,17 +130,12 @@ class DecodingResult:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Inference:
|
|
|
|
|
def __init__(self, model: "Whisper", initial_token_length: int):
|
|
|
|
|
def __init__(self, model: "Whisper"):
|
|
|
|
|
self.model: "Whisper" = model
|
|
|
|
|
self.initial_token_length = initial_token_length
|
|
|
|
|
self.kv_cache = None
|
|
|
|
|
|
|
|
|
|
def logits(self, tokens: mx.array, audio_features: mx.array) -> mx.array:
|
|
|
|
|
"""Perform a forward pass on the decoder and return per-token logits"""
|
|
|
|
|
if tokens.shape[-1] > self.initial_token_length:
|
|
|
|
|
# only need to use the last token except in the first forward pass
|
|
|
|
|
tokens = tokens[:, -1:]
|
|
|
|
|
|
|
|
|
|
logits, self.kv_cache, _ = self.model.decoder(
|
|
|
|
|
tokens, audio_features, kv_cache=self.kv_cache
|
|
|
|
|
)
|
|
|
|
@@ -251,6 +247,11 @@ class TokenDecoder:
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@mx.compile
|
|
|
|
|
def categorical(logits, temp):
|
|
|
|
|
return mx.random.categorical(logits / temp)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GreedyDecoder(TokenDecoder):
|
|
|
|
|
def __init__(self, temperature: float, eot: int):
|
|
|
|
|
self.temperature = temperature
|
|
|
|
@@ -262,10 +263,8 @@ class GreedyDecoder(TokenDecoder):
|
|
|
|
|
if self.temperature == 0:
|
|
|
|
|
next_tokens = logits.argmax(axis=-1)
|
|
|
|
|
else:
|
|
|
|
|
next_tokens = mx.random.categorical(logits=logits / self.temperature)
|
|
|
|
|
next_tokens = categorical(logits, self.temperature)
|
|
|
|
|
|
|
|
|
|
next_tokens = mx.argmax(logits, axis=-1)
|
|
|
|
|
logits = logits.astype(mx.float32)
|
|
|
|
|
logprobs = logits - mx.logsumexp(logits, axis=-1)
|
|
|
|
|
|
|
|
|
|
current_logprobs = logprobs[mx.arange(logprobs.shape[0]), next_tokens]
|
|
|
|
@@ -281,7 +280,7 @@ class GreedyDecoder(TokenDecoder):
|
|
|
|
|
def finalize(self, tokens: mx.array, sum_logprobs: mx.array):
|
|
|
|
|
# make sure each sequence has at least one EOT token at the end
|
|
|
|
|
tokens = mx.pad(tokens, [(0, 0), (0, 0), (0, 1)], constant_values=self.eot)
|
|
|
|
|
return tokens, sum_logprobs.tolist()
|
|
|
|
|
return tokens, sum_logprobs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LogitFilter:
|
|
|
|
@@ -340,10 +339,10 @@ class ApplyTimestampRules(LogitFilter):
|
|
|
|
|
if self.tokenizer.no_timestamps is not None:
|
|
|
|
|
mask[:, self.tokenizer.no_timestamps] = -np.inf
|
|
|
|
|
|
|
|
|
|
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
|
|
|
|
|
for k in range(tokens.shape[0]):
|
|
|
|
|
sampled_tokens = tokens[k, self.sample_begin :]
|
|
|
|
|
seq = sampled_tokens.tolist()
|
|
|
|
|
## timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
|
|
|
|
|
tokens = tokens.tolist()
|
|
|
|
|
for k in range(len(tokens)):
|
|
|
|
|
seq = tokens[k][self.sample_begin :]
|
|
|
|
|
last_was_timestamp = (
|
|
|
|
|
len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
|
|
|
|
|
)
|
|
|
|
@@ -368,7 +367,7 @@ class ApplyTimestampRules(LogitFilter):
|
|
|
|
|
last_timestamp += 1
|
|
|
|
|
mask[k, self.tokenizer.timestamp_begin : last_timestamp] = -np.inf
|
|
|
|
|
|
|
|
|
|
if tokens.shape[1] == self.sample_begin:
|
|
|
|
|
if len(tokens[0]) == self.sample_begin:
|
|
|
|
|
# suppress generating non-timestamp tokens at the beginning
|
|
|
|
|
mask[:, : self.tokenizer.timestamp_begin] = -np.inf
|
|
|
|
|
|
|
|
|
@@ -380,16 +379,20 @@ class ApplyTimestampRules(LogitFilter):
|
|
|
|
|
mask[:, last_allowed + 1 :] = -np.inf
|
|
|
|
|
|
|
|
|
|
# if sum of probability over timestamps is above any other token, sample timestamp
|
|
|
|
|
mask = mx.array(mask)
|
|
|
|
|
logprobs = logits - mx.logsumexp(logits, axis=-1)
|
|
|
|
|
for k in range(tokens.shape[0]):
|
|
|
|
|
timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(
|
|
|
|
|
axis=-1
|
|
|
|
|
)
|
|
|
|
|
max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
|
|
|
|
|
if timestamp_logprob > max_text_token_logprob:
|
|
|
|
|
mask[k, : self.tokenizer.timestamp_begin] = -np.inf
|
|
|
|
|
|
|
|
|
|
return logits + mx.array(mask, logits.dtype)
|
|
|
|
|
timestamp_logprob = logprobs[:, self.tokenizer.timestamp_begin :].logsumexp(
|
|
|
|
|
axis=-1, keepdims=True
|
|
|
|
|
)
|
|
|
|
|
max_text_token_logprob = logprobs[:, : self.tokenizer.timestamp_begin].max(
|
|
|
|
|
axis=-1, keepdims=True
|
|
|
|
|
)
|
|
|
|
|
mask[:, : self.tokenizer.timestamp_begin] = mx.where(
|
|
|
|
|
timestamp_logprob > max_text_token_logprob,
|
|
|
|
|
-mx.inf,
|
|
|
|
|
mask[:, : self.tokenizer.timestamp_begin],
|
|
|
|
|
)
|
|
|
|
|
return logits + mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DecodingTask:
|
|
|
|
@@ -424,7 +427,7 @@ class DecodingTask:
|
|
|
|
|
self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
|
|
|
|
|
|
|
|
|
|
# inference: implements the forward pass through the decoder, including kv caching
|
|
|
|
|
self.inference = Inference(model, len(self.initial_tokens))
|
|
|
|
|
self.inference = Inference(model)
|
|
|
|
|
|
|
|
|
|
# sequence ranker: implements how to rank a group of sampled sequences
|
|
|
|
|
self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
|
|
|
|
@@ -432,9 +435,6 @@ class DecodingTask:
|
|
|
|
|
# decoder: implements how to select the next tokens, given the autoregressive distribution
|
|
|
|
|
if options.beam_size is not None:
|
|
|
|
|
raise NotImplementedError("Beam search decoder is not yet implemented")
|
|
|
|
|
# self.decoder = BeamSearchDecoder(
|
|
|
|
|
# options.beam_size, tokenizer.eot, self.inference, options.patience
|
|
|
|
|
# )
|
|
|
|
|
else:
|
|
|
|
|
self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
|
|
|
|
|
|
|
|
|
@@ -448,6 +448,7 @@ class DecodingTask:
|
|
|
|
|
self.logit_filters.append(
|
|
|
|
|
SuppressTokens(self._get_suppress_tokens(), model.dims.n_vocab)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if not options.without_timestamps:
|
|
|
|
|
precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
|
|
|
|
|
max_initial_timestamp_index = None
|
|
|
|
@@ -570,35 +571,47 @@ class DecodingTask:
|
|
|
|
|
|
|
|
|
|
def _main_loop(self, audio_features: mx.array, tokens: mx.array):
|
|
|
|
|
n_batch = tokens.shape[0]
|
|
|
|
|
sum_logprobs: mx.array = mx.zeros(n_batch)
|
|
|
|
|
no_speech_probs = [np.nan] * n_batch
|
|
|
|
|
sum_logprobs = mx.zeros(n_batch)
|
|
|
|
|
|
|
|
|
|
def _step(inputs, audio_features, tokens, sum_logprobs):
|
|
|
|
|
pre_logits = self.inference.logits(inputs, audio_features)
|
|
|
|
|
|
|
|
|
|
# consider the logits at the last token only
|
|
|
|
|
logits = pre_logits[:, -1]
|
|
|
|
|
|
|
|
|
|
# apply the logit filters, e.g. for suppressing or applying penalty to
|
|
|
|
|
for logit_filter in self.logit_filters:
|
|
|
|
|
logits = logit_filter.apply(logits, tokens)
|
|
|
|
|
|
|
|
|
|
# expand the tokens tensor with the selected next tokens
|
|
|
|
|
tokens, completed, sum_logprobs = self.decoder.update(
|
|
|
|
|
tokens, logits, sum_logprobs
|
|
|
|
|
)
|
|
|
|
|
return tokens, completed, sum_logprobs, pre_logits
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
for i in range(self.sample_len):
|
|
|
|
|
logits = self.inference.logits(tokens, audio_features)
|
|
|
|
|
tokens, completed, sum_logprobs, pre_logits = _step(
|
|
|
|
|
tokens, audio_features, tokens, sum_logprobs
|
|
|
|
|
)
|
|
|
|
|
if self.tokenizer.no_speech is not None: # compute no_speech_probs
|
|
|
|
|
probs_at_sot = mx.softmax(pre_logits[:, self.sot_index], axis=-1)
|
|
|
|
|
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech]
|
|
|
|
|
else:
|
|
|
|
|
no_speech_probs = mx.full(n_batch, mx.nan)
|
|
|
|
|
mx.async_eval(completed, tokens, sum_logprobs, no_speech_probs)
|
|
|
|
|
|
|
|
|
|
if (
|
|
|
|
|
i == 0 and self.tokenizer.no_speech is not None
|
|
|
|
|
): # save no_speech_probs
|
|
|
|
|
probs_at_sot = mx.softmax(
|
|
|
|
|
logits[:, self.sot_index].astype(mx.float32), axis=-1
|
|
|
|
|
)
|
|
|
|
|
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
|
|
|
|
|
|
|
|
|
# now we need to consider the logits at the last token only
|
|
|
|
|
logits = logits[:, -1]
|
|
|
|
|
|
|
|
|
|
# apply the logit filters, e.g. for suppressing or applying penalty to
|
|
|
|
|
for logit_filter in self.logit_filters:
|
|
|
|
|
logits = logit_filter.apply(logits, tokens)
|
|
|
|
|
|
|
|
|
|
# expand the tokens tensor with the selected next tokens
|
|
|
|
|
tokens, completed, sum_logprobs = self.decoder.update(
|
|
|
|
|
tokens, logits, sum_logprobs
|
|
|
|
|
for i in range(1, self.sample_len):
|
|
|
|
|
inputs = tokens[:, -1:]
|
|
|
|
|
next_tokens, next_completed, next_sum_logprobs, _ = _step(
|
|
|
|
|
inputs, audio_features, tokens, sum_logprobs
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
mx.async_eval(next_completed, next_tokens, next_sum_logprobs)
|
|
|
|
|
if completed or tokens.shape[-1] > self.n_ctx:
|
|
|
|
|
break
|
|
|
|
|
tokens = next_tokens
|
|
|
|
|
completed = next_completed
|
|
|
|
|
sum_logprobs = next_sum_logprobs
|
|
|
|
|
|
|
|
|
|
finally:
|
|
|
|
|
self.inference.reset()
|
|
|
|
|
|
|
|
|
@@ -610,8 +623,8 @@ class DecodingTask:
|
|
|
|
|
n_audio: int = mel.shape[0]
|
|
|
|
|
|
|
|
|
|
audio_features: mx.array = self._get_audio_features(mel) # encoder forward pass
|
|
|
|
|
tokens: np.array = np.array(self.initial_tokens)
|
|
|
|
|
tokens = np.broadcast_to(tokens, (n_audio, len(self.initial_tokens))).copy()
|
|
|
|
|
tokens: mx.array = mx.array(self.initial_tokens)
|
|
|
|
|
tokens = mx.broadcast_to(tokens, (n_audio, len(self.initial_tokens)))
|
|
|
|
|
|
|
|
|
|
# detect language if requested, overwriting the language token
|
|
|
|
|
languages, language_probs = self._detect_language(audio_features, tokens)
|
|
|
|
@@ -626,7 +639,6 @@ class DecodingTask:
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# repeat tokens by the group size, for beam search or best-of-n sampling
|
|
|
|
|
tokens = mx.array(tokens)
|
|
|
|
|
if self.n_group > 1:
|
|
|
|
|
tokens = tokens[:, None, :]
|
|
|
|
|
tokens = mx.broadcast_to(
|
|
|
|
@@ -649,7 +661,13 @@ class DecodingTask:
|
|
|
|
|
|
|
|
|
|
# get the final candidates for each group, and slice between the first sampled token and EOT
|
|
|
|
|
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
|
|
|
|
|
tokens = tokens[..., self.sample_begin :].tolist()
|
|
|
|
|
tokens = tokens[..., self.sample_begin :]
|
|
|
|
|
|
|
|
|
|
# eval and convert to list
|
|
|
|
|
mx.eval(tokens, sum_logprobs, no_speech_probs)
|
|
|
|
|
tokens = tokens.tolist()
|
|
|
|
|
sum_logprobs = sum_logprobs.tolist()
|
|
|
|
|
no_speech_probs = no_speech_probs.tolist()
|
|
|
|
|
tokens = [[t[: t.index(tokenizer.eot)] for t in s] for s in tokens]
|
|
|
|
|
|
|
|
|
|
# select the top-ranked sample in each group
|
|
|
|
|