fix tests

This commit is contained in:
Alex Barron 2024-10-31 12:37:15 -07:00
parent 8444ff0f6a
commit 79075b7a21
5 changed files with 22 additions and 7 deletions

View File

@ -258,9 +258,9 @@ def main():
top_p=args.top_p, top_p=args.top_p,
max_kv_size=args.max_kv_size, max_kv_size=args.max_kv_size,
prompt_cache=prompt_cache if using_cache else None, prompt_cache=prompt_cache if using_cache else None,
quantized_kv_start=args.quantized_kv_start,
kv_group_size=args.kv_group_size,
kv_bits=args.kv_bits, kv_bits=args.kv_bits,
kv_group_size=args.kv_group_size,
quantized_kv_start=args.quantized_kv_start,
) )
if not args.verbose: if not args.verbose:
print(response) print(response)

View File

@ -94,7 +94,12 @@ class PhiAttention(nn.Module):
scale = math.sqrt(1 / queries.shape[-1]) scale = math.sqrt(1 / queries.shape[-1])
output = scaled_dot_product_attention( output = scaled_dot_product_attention(
queries.astype(mx.float32), keys, values, scale=scale, mask=mask queries.astype(mx.float32),
keys,
values,
cache=cache,
scale=scale,
mask=mask,
).astype(values.dtype) ).astype(values.dtype)
output = output.moveaxis(2, 1).reshape(B, L, -1) output = output.moveaxis(2, 1).reshape(B, L, -1)

View File

@ -72,7 +72,12 @@ class RoPEAttention(nn.Module):
scale = math.sqrt(1 / queries.shape[-1]) scale = math.sqrt(1 / queries.shape[-1])
output = scaled_dot_product_attention( output = scaled_dot_product_attention(
queries.astype(mx.float32), keys, values, scale=scale, mask=mask queries.astype(mx.float32),
keys,
values,
cache=cache,
scale=scale,
mask=mask,
).astype(values.dtype) ).astype(values.dtype)
output = output.moveaxis(2, 1).reshape(B, L, -1) output = output.moveaxis(2, 1).reshape(B, L, -1)

View File

@ -96,6 +96,7 @@ class Attention(nn.Module):
queries, queries,
keys, keys,
values, values,
cache=cache,
scale=self.scale, scale=self.scale,
mask=attention_mask, mask=attention_mask,
) )

View File

@ -185,9 +185,9 @@ def generate_step(
prompt_cache: Optional[Any] = None, prompt_cache: Optional[Any] = None,
logit_bias: Optional[Dict[int, float]] = None, logit_bias: Optional[Dict[int, float]] = None,
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
quantized_kv_start: Optional[int] = None, kv_bits: Optional[int] = None,
kv_group_size: int = 64, kv_group_size: int = 64,
kv_bits: int = 8, quantized_kv_start: int = 0,
) -> Generator[Tuple[mx.array, mx.array], None, None]: ) -> Generator[Tuple[mx.array, mx.array], None, None]:
""" """
A generator producing token ids based on the given prompt from the model. A generator producing token ids based on the given prompt from the model.
@ -216,6 +216,11 @@ def generate_step(
logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional): logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed A list of functions that take tokens and logits and return the processed
logits. Default: ``None``. logits. Default: ``None``.
kv_bits (int, optional): Number of bits to use for KV cache quantization.
None implies no cache quantization. Default: ``None``.
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
quantized_kv_start (int): Step to begin using a quantized KV cache.
when ``kv_bits`` is non-None. Default: ``0``.
Yields: Yields:
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
@ -279,7 +284,6 @@ def generate_step(
def _step(y): def _step(y):
nonlocal prompt_cache
logits = model(y[None], cache=prompt_cache) logits = model(y[None], cache=prompt_cache)
logits = logits[:, -1, :] logits = logits[:, -1, :]