mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Make attention faster for a some models (#574)
* make attention faster for a couple models * remove unused generation flags * add comment on lora * include text files as well
This commit is contained in:
parent
3f3741d229
commit
e4b19bb9e1
@ -167,6 +167,12 @@ of memory. Here are some tips to reduce memory use should you need to do so:
|
|||||||
you can do is break your examples into smaller
|
you can do is break your examples into smaller
|
||||||
sequences when making the `{train, valid, test}.jsonl` files.
|
sequences when making the `{train, valid, test}.jsonl` files.
|
||||||
|
|
||||||
|
5. Gradient checkpointing lets you trade-off memory use (less) for computation
|
||||||
|
(more) by recomputing instead of storing intermediate values needed by the
|
||||||
|
backward pass. You can use gradient checkpointing by passing the
|
||||||
|
`--grad-checkpoint` flag. Gradient checkpointing will be more helpful for
|
||||||
|
larger batch sizes or sequence lengths with smaller or quantized models.
|
||||||
|
|
||||||
For example, for a machine with 32 GB the following should run reasonably fast:
|
For example, for a machine with 32 GB the following should run reasonably fast:
|
||||||
|
|
||||||
```
|
```
|
||||||
|
@ -61,19 +61,6 @@ def build_parser():
|
|||||||
"--model",
|
"--model",
|
||||||
help="The path to the local model directory or Hugging Face repo.",
|
help="The path to the local model directory or Hugging Face repo.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--max-tokens",
|
|
||||||
"-m",
|
|
||||||
type=int,
|
|
||||||
help="The maximum number of tokens to generate",
|
|
||||||
)
|
|
||||||
parser.add_argument("--temp", type=float, help="The sampling temperature")
|
|
||||||
parser.add_argument(
|
|
||||||
"--prompt",
|
|
||||||
"-p",
|
|
||||||
type=str,
|
|
||||||
help="The prompt for generation",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Training args
|
# Training args
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -39,8 +39,6 @@ class MixtralAttention(nn.Module):
|
|||||||
self.num_key_value_heads = args.num_key_value_heads
|
self.num_key_value_heads = args.num_key_value_heads
|
||||||
self.rope_theta = args.rope_theta
|
self.rope_theta = args.rope_theta
|
||||||
|
|
||||||
self.repeats = self.num_heads // self.num_key_value_heads
|
|
||||||
|
|
||||||
self.scale = self.head_dim**-0.5
|
self.scale = self.head_dim**-0.5
|
||||||
|
|
||||||
self.q_proj = nn.Linear(
|
self.q_proj = nn.Linear(
|
||||||
@ -79,10 +77,6 @@ class MixtralAttention(nn.Module):
|
|||||||
0, 2, 1, 3
|
0, 2, 1, 3
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.repeats > 1:
|
|
||||||
keys = mx.repeat(keys, self.repeats, axis=1)
|
|
||||||
values = mx.repeat(values, self.repeats, axis=1)
|
|
||||||
|
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
key_cache, value_cache = cache
|
key_cache, value_cache = cache
|
||||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||||
@ -93,11 +87,10 @@ class MixtralAttention(nn.Module):
|
|||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
|
|
||||||
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
|
output = mx.fast.scaled_dot_product_attention(
|
||||||
if mask is not None:
|
queries, keys, values, scale=self.scale, mask=mask
|
||||||
scores += mask
|
)
|
||||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
|
||||||
return self.o_proj(output), (keys, values)
|
return self.o_proj(output), (keys, values)
|
||||||
|
|
||||||
|
|
||||||
|
@ -70,46 +70,41 @@ class PhiAttention(nn.Module):
|
|||||||
|
|
||||||
# Extract some shapes
|
# Extract some shapes
|
||||||
B, L, D = queries.shape
|
B, L, D = queries.shape
|
||||||
|
n_heads, n_kv_heads = self.num_heads, self.num_key_value_heads
|
||||||
|
|
||||||
# Prepare the queries, keys and values for the attention computation
|
# Prepare the queries, keys and values for the attention computation
|
||||||
queries = queries.reshape(B, L, self.num_heads, self.head_dim).transpose(
|
queries = queries.reshape(
|
||||||
0, 2, 1, 3
|
B,
|
||||||
)
|
L,
|
||||||
keys = keys.reshape(B, L, self.num_key_value_heads, self.head_dim).transpose(
|
n_kv_heads,
|
||||||
0, 2, 1, 3
|
n_heads // n_kv_heads,
|
||||||
)
|
-1,
|
||||||
values = values.reshape(
|
).moveaxis(1, 3)
|
||||||
B, L, self.num_key_value_heads, self.head_dim
|
keys = keys.reshape(B, L, n_kv_heads, 1, -1).moveaxis(1, 3)
|
||||||
).transpose(0, 2, 1, 3)
|
values = values.reshape(B, L, n_kv_heads, 1, -1).moveaxis(1, 3)
|
||||||
|
|
||||||
if self.repeats > 1:
|
|
||||||
keys = mx.repeat(keys, self.repeats, axis=1)
|
|
||||||
values = mx.repeat(values, self.repeats, axis=1)
|
|
||||||
|
|
||||||
# Add RoPE to the queries and keys and combine them with the cache
|
# Add RoPE to the queries and keys and combine them with the cache
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
key_cache, value_cache = cache
|
key_cache, value_cache = cache
|
||||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
queries = self.rope(queries, offset=key_cache.shape[-2])
|
||||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
keys = self.rope(keys, offset=key_cache.shape[-2])
|
||||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
keys = mx.concatenate([key_cache, keys], axis=-2)
|
||||||
values = mx.concatenate([value_cache, values], axis=2)
|
values = mx.concatenate([value_cache, values], axis=-2)
|
||||||
else:
|
else:
|
||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
|
|
||||||
queries = queries.astype(mx.float32)
|
queries = queries.astype(mx.float32)
|
||||||
keys = keys.astype(mx.float32)
|
|
||||||
|
|
||||||
# Finally perform the attention computation
|
# Finally perform the attention computation
|
||||||
scale = math.sqrt(1 / queries.shape[-1])
|
scale = math.sqrt(1 / queries.shape[-1])
|
||||||
scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)
|
scores = (queries * scale) @ keys.swapaxes(-1, -2)
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
scores = scores + mask
|
scores = scores + mask
|
||||||
|
|
||||||
scores = mx.softmax(scores, axis=-1).astype(values.dtype)
|
scores = mx.softmax(scores, axis=-1).astype(values.dtype)
|
||||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = (scores @ values).moveaxis(3, 1).reshape(B, L, -1)
|
||||||
|
|
||||||
return self.dense(values_hat), (keys, values)
|
return self.dense(output), (keys, values)
|
||||||
|
|
||||||
|
|
||||||
class PhiMLP(nn.Module):
|
class PhiMLP(nn.Module):
|
||||||
@ -144,11 +139,16 @@ class PhiModel(nn.Module):
|
|||||||
self.layers = [PhiDecoderLayer(config) for i in range(config.num_hidden_layers)]
|
self.layers = [PhiDecoderLayer(config) for i in range(config.num_hidden_layers)]
|
||||||
self.final_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.final_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
def __call__(self, x, mask, cache):
|
def __call__(self, x, cache):
|
||||||
x = self.embed_tokens(x)
|
x = self.embed_tokens(x)
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
if x.shape[1] > 1:
|
||||||
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||||
|
mask = mask.astype(x.dtype)
|
||||||
|
|
||||||
for e, layer in enumerate(self.layers):
|
for e, layer in enumerate(self.layers):
|
||||||
x, cache[e] = layer(x, mask, cache[e])
|
x, cache[e] = layer(x, mask, cache[e])
|
||||||
return self.final_layernorm(x), cache
|
return self.final_layernorm(x), cache
|
||||||
@ -164,15 +164,9 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
mask: mx.array = None,
|
|
||||||
cache: mx.array = None,
|
cache: mx.array = None,
|
||||||
) -> Tuple[mx.array, mx.array]:
|
) -> Tuple[mx.array, mx.array]:
|
||||||
mask = None
|
y, cache = self.model(x, cache)
|
||||||
if x.shape[1] > 1:
|
|
||||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
|
||||||
mask = mask.astype(x.dtype)
|
|
||||||
|
|
||||||
y, cache = self.model(x, mask, cache)
|
|
||||||
return self.lm_head(y), cache
|
return self.lm_head(y), cache
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -68,7 +68,6 @@ class RoPEAttention(nn.Module):
|
|||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
|
|
||||||
queries = queries.astype(mx.float32)
|
queries = queries.astype(mx.float32)
|
||||||
keys = keys.astype(mx.float32)
|
|
||||||
|
|
||||||
# Finally perform the attention computation
|
# Finally perform the attention computation
|
||||||
scale = math.sqrt(1 / queries.shape[-1])
|
scale = math.sqrt(1 / queries.shape[-1])
|
||||||
|
@ -81,6 +81,7 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
|
|||||||
"*.py",
|
"*.py",
|
||||||
"tokenizer.model",
|
"tokenizer.model",
|
||||||
"*.tiktoken",
|
"*.tiktoken",
|
||||||
|
"*.txt",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -396,7 +397,6 @@ def fetch_from_hub(
|
|||||||
model_path: Path, lazy: bool = False
|
model_path: Path, lazy: bool = False
|
||||||
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
|
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
|
||||||
model = load_model(model_path, lazy)
|
model = load_model(model_path, lazy)
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(model_path)
|
config = AutoConfig.from_pretrained(model_path)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user