mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
fix tests
This commit is contained in:
parent
8444ff0f6a
commit
79075b7a21
@ -258,9 +258,9 @@ def main():
|
||||
top_p=args.top_p,
|
||||
max_kv_size=args.max_kv_size,
|
||||
prompt_cache=prompt_cache if using_cache else None,
|
||||
quantized_kv_start=args.quantized_kv_start,
|
||||
kv_group_size=args.kv_group_size,
|
||||
kv_bits=args.kv_bits,
|
||||
kv_group_size=args.kv_group_size,
|
||||
quantized_kv_start=args.quantized_kv_start,
|
||||
)
|
||||
if not args.verbose:
|
||||
print(response)
|
||||
|
@ -94,7 +94,12 @@ class PhiAttention(nn.Module):
|
||||
|
||||
scale = math.sqrt(1 / queries.shape[-1])
|
||||
output = scaled_dot_product_attention(
|
||||
queries.astype(mx.float32), keys, values, scale=scale, mask=mask
|
||||
queries.astype(mx.float32),
|
||||
keys,
|
||||
values,
|
||||
cache=cache,
|
||||
scale=scale,
|
||||
mask=mask,
|
||||
).astype(values.dtype)
|
||||
|
||||
output = output.moveaxis(2, 1).reshape(B, L, -1)
|
||||
|
@ -72,7 +72,12 @@ class RoPEAttention(nn.Module):
|
||||
scale = math.sqrt(1 / queries.shape[-1])
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries.astype(mx.float32), keys, values, scale=scale, mask=mask
|
||||
queries.astype(mx.float32),
|
||||
keys,
|
||||
values,
|
||||
cache=cache,
|
||||
scale=scale,
|
||||
mask=mask,
|
||||
).astype(values.dtype)
|
||||
output = output.moveaxis(2, 1).reshape(B, L, -1)
|
||||
|
||||
|
@ -96,6 +96,7 @@ class Attention(nn.Module):
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
cache=cache,
|
||||
scale=self.scale,
|
||||
mask=attention_mask,
|
||||
)
|
||||
|
@ -185,9 +185,9 @@ def generate_step(
|
||||
prompt_cache: Optional[Any] = None,
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||
quantized_kv_start: Optional[int] = None,
|
||||
kv_bits: Optional[int] = None,
|
||||
kv_group_size: int = 64,
|
||||
kv_bits: int = 8,
|
||||
quantized_kv_start: int = 0,
|
||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||
"""
|
||||
A generator producing token ids based on the given prompt from the model.
|
||||
@ -216,6 +216,11 @@ def generate_step(
|
||||
logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
||||
A list of functions that take tokens and logits and return the processed
|
||||
logits. Default: ``None``.
|
||||
kv_bits (int, optional): Number of bits to use for KV cache quantization.
|
||||
None implies no cache quantization. Default: ``None``.
|
||||
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
|
||||
quantized_kv_start (int): Step to begin using a quantized KV cache.
|
||||
when ``kv_bits`` is non-None. Default: ``0``.
|
||||
|
||||
Yields:
|
||||
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
|
||||
@ -279,7 +284,6 @@ def generate_step(
|
||||
|
||||
def _step(y):
|
||||
|
||||
nonlocal prompt_cache
|
||||
logits = model(y[None], cache=prompt_cache)
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user