mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Move multiple ops from np to mlx, clean comments
This commit is contained in:
@@ -33,10 +33,7 @@ def median_filter(x: np.ndarray, filter_width: int):
|
||||
|
||||
x = np.pad(x, ((0, 0), (0, 0), (pad_width, pad_width)), mode="reflect")
|
||||
|
||||
# sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
|
||||
# result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
|
||||
|
||||
# todo: more efficient version ? mlx.unfold ?
|
||||
# todo: more efficient version in mlx
|
||||
result = signal.medfilt(x.astype(np.float32), kernel_size=(1, 1, filter_width))[..., pad_width: -pad_width]
|
||||
|
||||
if ndim <= 2:
|
||||
@@ -97,6 +94,7 @@ def dtw_cpu(x: np.ndarray):
|
||||
|
||||
|
||||
def dtw(x: np.ndarray) -> np.ndarray:
|
||||
# todo: more efficient version in mlx
|
||||
return dtw_cpu(x)
|
||||
|
||||
|
||||
@@ -131,28 +129,21 @@ def find_alignment(
|
||||
]
|
||||
)
|
||||
|
||||
logits, _, cross_qk = model(mel[None, :], tokens[None, :])
|
||||
sampled_logits = logits[0][len(tokenizer.sot_sequence) :, : tokenizer.eot]
|
||||
# token_probs = mx.softmax(sampled_logits.astype(mx.float32), axis=-1).astype(sampled_logits.dtype)
|
||||
# todo: np float32 ?
|
||||
token_probs = mx.softmax(sampled_logits.astype(mx.float32), axis=-1)
|
||||
# todo: mlx ?
|
||||
text_token_probs = np.array(token_probs)[np.arange(len(text_tokens)), text_tokens]
|
||||
logits, cross_qk = model.forward_with_cross_qk(mel[None, :], tokens[None, :])
|
||||
# consider only the logits associated with predicting text
|
||||
sampled_logits = logits[0][len(tokenizer.sot_sequence) : -2, : tokenizer.eot]
|
||||
token_probs = mx.softmax(sampled_logits.astype(mx.float32), axis=-1).astype(sampled_logits.dtype)
|
||||
text_token_probs = mx.take_along_axis(token_probs, mx.array(text_tokens)[:, None], axis=1).squeeze(1)
|
||||
text_token_probs = np.array(text_token_probs)
|
||||
|
||||
# heads * tokens * frames
|
||||
# todo: save alignment_heads as list ?
|
||||
weights = mx.stack([cross_qk[_l.tolist()][0, _h.tolist()] for _l, _h in model.alignment_heads])
|
||||
weights = mx.stack([cross_qk[_l.item()][0, _h.item()] for _l, _h in model.alignment_heads])
|
||||
weights = weights[:, :, : num_frames // 2]
|
||||
# weights = mx.softmax((weights * qk_scale).astype(mx.float32), axis=-1).astype(weights.dtype)
|
||||
weights = mx.softmax((weights * qk_scale).astype(mx.float32), axis=-1)
|
||||
# todo: mlx.std ?
|
||||
weights = np.array(weights)
|
||||
std = np.std(weights, axis=-2, ddof=0, keepdims=True)
|
||||
mean = np.mean(weights, axis=-2, keepdims=True)
|
||||
# std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
|
||||
weights = mx.softmax((weights * qk_scale).astype(mx.float32), axis=-1).astype(weights.dtype)
|
||||
mean = mx.mean(weights, axis=-2, keepdims=True)
|
||||
std = mx.var(weights, axis=-2, keepdims=True, ddof=0).sqrt()
|
||||
weights = (weights - mean) / std
|
||||
# todo: mlx ?
|
||||
weights = median_filter(weights, medfilt_width)
|
||||
weights = median_filter(np.array(weights), medfilt_width)
|
||||
|
||||
matrix = weights.mean(axis=0)
|
||||
matrix = matrix[len(tokenizer.sot_sequence) : -1]
|
||||
|
Reference in New Issue
Block a user