Move multiple ops from np to mlx, clean comments

This commit is contained in:
bofenghuang
2024-01-06 17:27:06 +01:00
parent 5512f8e6f0
commit 441494b11a

View File

@@ -33,10 +33,7 @@ def median_filter(x: np.ndarray, filter_width: int):
x = np.pad(x, ((0, 0), (0, 0), (pad_width, pad_width)), mode="reflect")
# sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
# result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
# todo: more efficient version ? mlx.unfold ?
# todo: more efficient version in mlx
result = signal.medfilt(x.astype(np.float32), kernel_size=(1, 1, filter_width))[..., pad_width: -pad_width]
if ndim <= 2:
@@ -97,6 +94,7 @@ def dtw_cpu(x: np.ndarray):
def dtw(x: np.ndarray) -> np.ndarray:
# todo: more efficient version in mlx
return dtw_cpu(x)
@@ -131,28 +129,21 @@ def find_alignment(
]
)
logits, _, cross_qk = model(mel[None, :], tokens[None, :])
sampled_logits = logits[0][len(tokenizer.sot_sequence) :, : tokenizer.eot]
# token_probs = mx.softmax(sampled_logits.astype(mx.float32), axis=-1).astype(sampled_logits.dtype)
# todo: np float32 ?
token_probs = mx.softmax(sampled_logits.astype(mx.float32), axis=-1)
# todo: mlx ?
text_token_probs = np.array(token_probs)[np.arange(len(text_tokens)), text_tokens]
logits, cross_qk = model.forward_with_cross_qk(mel[None, :], tokens[None, :])
# consider only the logits associated with predicting text
sampled_logits = logits[0][len(tokenizer.sot_sequence) : -2, : tokenizer.eot]
token_probs = mx.softmax(sampled_logits.astype(mx.float32), axis=-1).astype(sampled_logits.dtype)
text_token_probs = mx.take_along_axis(token_probs, mx.array(text_tokens)[:, None], axis=1).squeeze(1)
text_token_probs = np.array(text_token_probs)
# heads * tokens * frames
# todo: save alignment_heads as list ?
weights = mx.stack([cross_qk[_l.tolist()][0, _h.tolist()] for _l, _h in model.alignment_heads])
weights = mx.stack([cross_qk[_l.item()][0, _h.item()] for _l, _h in model.alignment_heads])
weights = weights[:, :, : num_frames // 2]
# weights = mx.softmax((weights * qk_scale).astype(mx.float32), axis=-1).astype(weights.dtype)
weights = mx.softmax((weights * qk_scale).astype(mx.float32), axis=-1)
# todo: mlx.std ?
weights = np.array(weights)
std = np.std(weights, axis=-2, ddof=0, keepdims=True)
mean = np.mean(weights, axis=-2, keepdims=True)
# std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
weights = mx.softmax((weights * qk_scale).astype(mx.float32), axis=-1).astype(weights.dtype)
mean = mx.mean(weights, axis=-2, keepdims=True)
std = mx.var(weights, axis=-2, keepdims=True, ddof=0).sqrt()
weights = (weights - mean) / std
# todo: mlx ?
weights = median_filter(weights, medfilt_width)
weights = median_filter(np.array(weights), medfilt_width)
matrix = weights.mean(axis=0)
matrix = matrix[len(tokenizer.sot_sequence) : -1]