fix tests

This commit is contained in:
Alex Barron 2024-10-31 12:37:15 -07:00
parent 8444ff0f6a
commit 79075b7a21
5 changed files with 22 additions and 7 deletions

View File

@ -258,9 +258,9 @@ def main():
top_p=args.top_p,
max_kv_size=args.max_kv_size,
prompt_cache=prompt_cache if using_cache else None,
quantized_kv_start=args.quantized_kv_start,
kv_group_size=args.kv_group_size,
kv_bits=args.kv_bits,
kv_group_size=args.kv_group_size,
quantized_kv_start=args.quantized_kv_start,
)
if not args.verbose:
print(response)

View File

@ -94,7 +94,12 @@ class PhiAttention(nn.Module):
scale = math.sqrt(1 / queries.shape[-1])
output = scaled_dot_product_attention(
queries.astype(mx.float32), keys, values, scale=scale, mask=mask
queries.astype(mx.float32),
keys,
values,
cache=cache,
scale=scale,
mask=mask,
).astype(values.dtype)
output = output.moveaxis(2, 1).reshape(B, L, -1)

View File

@ -72,7 +72,12 @@ class RoPEAttention(nn.Module):
scale = math.sqrt(1 / queries.shape[-1])
output = scaled_dot_product_attention(
queries.astype(mx.float32), keys, values, scale=scale, mask=mask
queries.astype(mx.float32),
keys,
values,
cache=cache,
scale=scale,
mask=mask,
).astype(values.dtype)
output = output.moveaxis(2, 1).reshape(B, L, -1)

View File

@ -96,6 +96,7 @@ class Attention(nn.Module):
queries,
keys,
values,
cache=cache,
scale=self.scale,
mask=attention_mask,
)

View File

@ -185,9 +185,9 @@ def generate_step(
prompt_cache: Optional[Any] = None,
logit_bias: Optional[Dict[int, float]] = None,
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
quantized_kv_start: Optional[int] = None,
kv_bits: Optional[int] = None,
kv_group_size: int = 64,
kv_bits: int = 8,
quantized_kv_start: int = 0,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
@ -216,6 +216,11 @@ def generate_step(
logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed
logits. Default: ``None``.
kv_bits (int, optional): Number of bits to use for KV cache quantization.
None implies no cache quantization. Default: ``None``.
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
quantized_kv_start (int): Step to begin using a quantized KV cache.
when ``kv_bits`` is non-None. Default: ``0``.
Yields:
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
@ -279,7 +284,6 @@ def generate_step(
def _step(y):
nonlocal prompt_cache
logits = model(y[None], cache=prompt_cache)
logits = logits[:, -1, :]