This commit is contained in:
Alex Barron 2024-10-31 12:22:36 -07:00
parent 1d53354b51
commit 8444ff0f6a
2 changed files with 6 additions and 6 deletions

View File

@ -160,7 +160,7 @@ def main():
max_msg_len = max(max_msg_len, len(msg)) max_msg_len = max(max_msg_len, len(msg))
print(msg + " " * (max_msg_len - len(msg)), end="", flush=True) print(msg + " " * (max_msg_len - len(msg)), end="", flush=True)
cache = maybe_quantize_kv_cache( maybe_quantize_kv_cache(
cache, args.quantized_kv_start, args.kv_group_size, args.kv_bits cache, args.quantized_kv_start, args.kv_group_size, args.kv_bits
) )

View File

@ -165,10 +165,10 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_
and not isinstance(prompt_cache[0], cache.QuantizedKVCache) and not isinstance(prompt_cache[0], cache.QuantizedKVCache)
and prompt_cache[0].offset > quantized_kv_start and prompt_cache[0].offset > quantized_kv_start
): ):
return [ for i in range(len(prompt_cache)):
c.to_quantized(group_size=kv_group_size, bits=kv_bits) for c in prompt_cache prompt_cache[i] = prompt_cache[i].to_quantized(
] group_size=kv_group_size, bits=kv_bits
return prompt_cache )
def generate_step( def generate_step(
@ -290,7 +290,7 @@ def generate_step(
for processor in logits_processor: for processor in logits_processor:
logits = processor(tokens, logits) logits = processor(tokens, logits)
prompt_cache = maybe_quantize_kv_cache( maybe_quantize_kv_cache(
prompt_cache, quantized_kv_start, kv_group_size, kv_bits prompt_cache, quantized_kv_start, kv_group_size, kv_bits
) )