diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md index 22ca3ad7..b20243ef 100644 --- a/llms/mlx_lm/LORA.md +++ b/llms/mlx_lm/LORA.md @@ -167,6 +167,12 @@ of memory. Here are some tips to reduce memory use should you need to do so: you can do is break your examples into smaller sequences when making the `{train, valid, test}.jsonl` files. +5. Gradient checkpointing lets you trade-off memory use (less) for computation + (more) by recomputing instead of storing intermediate values needed by the + backward pass. You can use gradient checkpointing by passing the + `--grad-checkpoint` flag. Gradient checkpointing will be more helpful for + larger batch sizes or sequence lengths with smaller or quantized models. + For example, for a machine with 32 GB the following should run reasonably fast: ``` diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 5c2d1f00..615fb417 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -61,19 +61,6 @@ def build_parser(): "--model", help="The path to the local model directory or Hugging Face repo.", ) - parser.add_argument( - "--max-tokens", - "-m", - type=int, - help="The maximum number of tokens to generate", - ) - parser.add_argument("--temp", type=float, help="The sampling temperature") - parser.add_argument( - "--prompt", - "-p", - type=str, - help="The prompt for generation", - ) # Training args parser.add_argument( diff --git a/llms/mlx_lm/models/mixtral.py b/llms/mlx_lm/models/mixtral.py index 027fcd7c..c2c9cb94 100644 --- a/llms/mlx_lm/models/mixtral.py +++ b/llms/mlx_lm/models/mixtral.py @@ -39,8 +39,6 @@ class MixtralAttention(nn.Module): self.num_key_value_heads = args.num_key_value_heads self.rope_theta = args.rope_theta - self.repeats = self.num_heads // self.num_key_value_heads - self.scale = self.head_dim**-0.5 self.q_proj = nn.Linear( @@ -79,10 +77,6 @@ class MixtralAttention(nn.Module): 0, 2, 1, 3 ) - if self.repeats > 1: - keys = mx.repeat(keys, self.repeats, axis=1) - values = mx.repeat(values, self.repeats, axis=1) - if cache is not None: key_cache, value_cache = cache queries = self.rope(queries, offset=key_cache.shape[2]) @@ -93,11 +87,10 @@ class MixtralAttention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) - if mask is not None: - scores += mask - scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) - output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + output = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output), (keys, values) diff --git a/llms/mlx_lm/models/phi.py b/llms/mlx_lm/models/phi.py index d8ef54f4..3d5a659e 100644 --- a/llms/mlx_lm/models/phi.py +++ b/llms/mlx_lm/models/phi.py @@ -70,46 +70,41 @@ class PhiAttention(nn.Module): # Extract some shapes B, L, D = queries.shape + n_heads, n_kv_heads = self.num_heads, self.num_key_value_heads # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(B, L, self.num_heads, self.head_dim).transpose( - 0, 2, 1, 3 - ) - keys = keys.reshape(B, L, self.num_key_value_heads, self.head_dim).transpose( - 0, 2, 1, 3 - ) - values = values.reshape( - B, L, self.num_key_value_heads, self.head_dim - ).transpose(0, 2, 1, 3) - - if self.repeats > 1: - keys = mx.repeat(keys, self.repeats, axis=1) - values = mx.repeat(values, self.repeats, axis=1) + queries = queries.reshape( + B, + L, + n_kv_heads, + n_heads // n_kv_heads, + -1, + ).moveaxis(1, 3) + keys = keys.reshape(B, L, n_kv_heads, 1, -1).moveaxis(1, 3) + values = values.reshape(B, L, n_kv_heads, 1, -1).moveaxis(1, 3) # Add RoPE to the queries and keys and combine them with the cache if cache is not None: key_cache, value_cache = cache - queries = self.rope(queries, offset=key_cache.shape[2]) - keys = self.rope(keys, offset=key_cache.shape[2]) - keys = mx.concatenate([key_cache, keys], axis=2) - values = mx.concatenate([value_cache, values], axis=2) + queries = self.rope(queries, offset=key_cache.shape[-2]) + keys = self.rope(keys, offset=key_cache.shape[-2]) + keys = mx.concatenate([key_cache, keys], axis=-2) + values = mx.concatenate([value_cache, values], axis=-2) else: queries = self.rope(queries) keys = self.rope(keys) queries = queries.astype(mx.float32) - keys = keys.astype(mx.float32) # Finally perform the attention computation scale = math.sqrt(1 / queries.shape[-1]) - scores = (queries * scale) @ keys.transpose(0, 1, 3, 2) + scores = (queries * scale) @ keys.swapaxes(-1, -2) if mask is not None: scores = scores + mask - scores = mx.softmax(scores, axis=-1).astype(values.dtype) - values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + output = (scores @ values).moveaxis(3, 1).reshape(B, L, -1) - return self.dense(values_hat), (keys, values) + return self.dense(output), (keys, values) class PhiMLP(nn.Module): @@ -144,11 +139,16 @@ class PhiModel(nn.Module): self.layers = [PhiDecoderLayer(config) for i in range(config.num_hidden_layers)] self.final_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - def __call__(self, x, mask, cache): + def __call__(self, x, cache): x = self.embed_tokens(x) if cache is None: cache = [None] * len(self.layers) + mask = None + if x.shape[1] > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) + mask = mask.astype(x.dtype) + for e, layer in enumerate(self.layers): x, cache[e] = layer(x, mask, cache[e]) return self.final_layernorm(x), cache @@ -164,15 +164,9 @@ class Model(nn.Module): def __call__( self, x: mx.array, - mask: mx.array = None, cache: mx.array = None, ) -> Tuple[mx.array, mx.array]: - mask = None - if x.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) - mask = mask.astype(x.dtype) - - y, cache = self.model(x, mask, cache) + y, cache = self.model(x, cache) return self.lm_head(y), cache @property diff --git a/llms/mlx_lm/models/phixtral.py b/llms/mlx_lm/models/phixtral.py index 8537645a..0f2c8369 100644 --- a/llms/mlx_lm/models/phixtral.py +++ b/llms/mlx_lm/models/phixtral.py @@ -68,7 +68,6 @@ class RoPEAttention(nn.Module): keys = self.rope(keys) queries = queries.astype(mx.float32) - keys = keys.astype(mx.float32) # Finally perform the attention computation scale = math.sqrt(1 / queries.shape[-1]) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 137c2ddd..739dafd9 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -81,6 +81,7 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path "*.py", "tokenizer.model", "*.tiktoken", + "*.txt", ], ) ) @@ -396,7 +397,6 @@ def fetch_from_hub( model_path: Path, lazy: bool = False ) -> Tuple[nn.Module, dict, PreTrainedTokenizer]: model = load_model(model_path, lazy) - config = AutoConfig.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path)