mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
fix tests
This commit is contained in:
parent
8444ff0f6a
commit
79075b7a21
@ -258,9 +258,9 @@ def main():
|
|||||||
top_p=args.top_p,
|
top_p=args.top_p,
|
||||||
max_kv_size=args.max_kv_size,
|
max_kv_size=args.max_kv_size,
|
||||||
prompt_cache=prompt_cache if using_cache else None,
|
prompt_cache=prompt_cache if using_cache else None,
|
||||||
quantized_kv_start=args.quantized_kv_start,
|
|
||||||
kv_group_size=args.kv_group_size,
|
|
||||||
kv_bits=args.kv_bits,
|
kv_bits=args.kv_bits,
|
||||||
|
kv_group_size=args.kv_group_size,
|
||||||
|
quantized_kv_start=args.quantized_kv_start,
|
||||||
)
|
)
|
||||||
if not args.verbose:
|
if not args.verbose:
|
||||||
print(response)
|
print(response)
|
||||||
|
@ -94,7 +94,12 @@ class PhiAttention(nn.Module):
|
|||||||
|
|
||||||
scale = math.sqrt(1 / queries.shape[-1])
|
scale = math.sqrt(1 / queries.shape[-1])
|
||||||
output = scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries.astype(mx.float32), keys, values, scale=scale, mask=mask
|
queries.astype(mx.float32),
|
||||||
|
keys,
|
||||||
|
values,
|
||||||
|
cache=cache,
|
||||||
|
scale=scale,
|
||||||
|
mask=mask,
|
||||||
).astype(values.dtype)
|
).astype(values.dtype)
|
||||||
|
|
||||||
output = output.moveaxis(2, 1).reshape(B, L, -1)
|
output = output.moveaxis(2, 1).reshape(B, L, -1)
|
||||||
|
@ -72,7 +72,12 @@ class RoPEAttention(nn.Module):
|
|||||||
scale = math.sqrt(1 / queries.shape[-1])
|
scale = math.sqrt(1 / queries.shape[-1])
|
||||||
|
|
||||||
output = scaled_dot_product_attention(
|
output = scaled_dot_product_attention(
|
||||||
queries.astype(mx.float32), keys, values, scale=scale, mask=mask
|
queries.astype(mx.float32),
|
||||||
|
keys,
|
||||||
|
values,
|
||||||
|
cache=cache,
|
||||||
|
scale=scale,
|
||||||
|
mask=mask,
|
||||||
).astype(values.dtype)
|
).astype(values.dtype)
|
||||||
output = output.moveaxis(2, 1).reshape(B, L, -1)
|
output = output.moveaxis(2, 1).reshape(B, L, -1)
|
||||||
|
|
||||||
|
@ -96,6 +96,7 @@ class Attention(nn.Module):
|
|||||||
queries,
|
queries,
|
||||||
keys,
|
keys,
|
||||||
values,
|
values,
|
||||||
|
cache=cache,
|
||||||
scale=self.scale,
|
scale=self.scale,
|
||||||
mask=attention_mask,
|
mask=attention_mask,
|
||||||
)
|
)
|
||||||
|
@ -185,9 +185,9 @@ def generate_step(
|
|||||||
prompt_cache: Optional[Any] = None,
|
prompt_cache: Optional[Any] = None,
|
||||||
logit_bias: Optional[Dict[int, float]] = None,
|
logit_bias: Optional[Dict[int, float]] = None,
|
||||||
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||||
quantized_kv_start: Optional[int] = None,
|
kv_bits: Optional[int] = None,
|
||||||
kv_group_size: int = 64,
|
kv_group_size: int = 64,
|
||||||
kv_bits: int = 8,
|
quantized_kv_start: int = 0,
|
||||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||||
"""
|
"""
|
||||||
A generator producing token ids based on the given prompt from the model.
|
A generator producing token ids based on the given prompt from the model.
|
||||||
@ -216,6 +216,11 @@ def generate_step(
|
|||||||
logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
||||||
A list of functions that take tokens and logits and return the processed
|
A list of functions that take tokens and logits and return the processed
|
||||||
logits. Default: ``None``.
|
logits. Default: ``None``.
|
||||||
|
kv_bits (int, optional): Number of bits to use for KV cache quantization.
|
||||||
|
None implies no cache quantization. Default: ``None``.
|
||||||
|
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
|
||||||
|
quantized_kv_start (int): Step to begin using a quantized KV cache.
|
||||||
|
when ``kv_bits`` is non-None. Default: ``0``.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
|
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
|
||||||
@ -279,7 +284,6 @@ def generate_step(
|
|||||||
|
|
||||||
def _step(y):
|
def _step(y):
|
||||||
|
|
||||||
nonlocal prompt_cache
|
|
||||||
logits = model(y[None], cache=prompt_cache)
|
logits = model(y[None], cache=prompt_cache)
|
||||||
logits = logits[:, -1, :]
|
logits = logits[:, -1, :]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user