import inspect from dataclasses import dataclass import mlx.core as mx def create_additive_causal_mask(N: int, offset: int = 0): rinds = mx.arange(offset + N) linds = mx.arange(offset, offset + N) if offset else rinds mask = linds[:, None] < rinds[None] return mask * -1e9 class KVCache: def __init__(self, head_dim, n_kv_heads): self.n_kv_heads = n_kv_heads if isinstance(head_dim, int): self.k_head_dim = self.v_head_dim = head_dim elif isinstance(head_dim, tuple) and len(head_dim) == 2: self.k_head_dim, self.v_head_dim = head_dim else: raise ValueError("head_dim must be an int or a tuple of two ints") self.keys = None self.values = None self.offset = 0 self.step = 256 def update_and_fetch(self, keys, values): prev = self.offset if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]: n_steps = (self.step + keys.shape[2] - 1) // self.step k_shape = (1, self.n_kv_heads, n_steps * self.step, self.k_head_dim) v_shape = (1, self.n_kv_heads, n_steps * self.step, self.v_head_dim) new_k = mx.zeros(k_shape, keys.dtype) new_v = mx.zeros(v_shape, values.dtype) if self.keys is not None: if prev % self.step != 0: self.keys = self.keys[..., :prev, :] self.values = self.values[..., :prev, :] self.keys = mx.concatenate([self.keys, new_k], axis=2) self.values = mx.concatenate([self.values, new_v], axis=2) else: self.keys, self.values = new_k, new_v self.offset += keys.shape[2] self.keys[..., prev : self.offset, :] = keys self.values[..., prev : self.offset, :] = values return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] @dataclass class BaseModelArgs: @classmethod def from_dict(cls, params): return cls( **{ k: v for k, v in params.items() if k in inspect.signature(cls).parameters } )