Formatted code

This commit is contained in:
Vincent Amato
2025-08-15 23:54:53 -04:00
parent eccabdd227
commit 98d800866e
10 changed files with 200 additions and 127 deletions

View File

@@ -58,7 +58,9 @@ class RotaryEmbedding(nn.Module):
self._cos_cached = None
self._sin_cached = None
def _update_cos_sin_tables(self, x: mx.array, seq_dimension: int = 1) -> Tuple[mx.array, mx.array]:
def _update_cos_sin_tables(
self, x: mx.array, seq_dimension: int = 1
) -> Tuple[mx.array, mx.array]:
"""
Compute and cache cos/sin tables for the given sequence length.
@@ -109,4 +111,4 @@ class RotaryEmbedding(nn.Module):
return (
apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
)
)