[MLX LM] Sampler refactor + a few improvements (#1094)

* starting

* refactor sampler/processor and a few improvements

* fix stream

* fix stream generate

* fix eos handling in stream generate
This commit is contained in:
Awni Hannun
2024-11-07 16:15:24 -08:00
committed by GitHub
parent ed9e81dd58
commit 657b4cc0aa
10 changed files with 259 additions and 239 deletions

View File

@@ -152,6 +152,7 @@ def main():
model(y[:step_size][None], cache=cache)
mx.eval([c.state for c in cache])
mx.metal.clear_cache()
processed += min(y.size, step_size)
y = y[step_size:]
current = time.time()
@@ -165,14 +166,13 @@ def main():
)
print()
print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")
print(f"Peak memory: {mx.metal.get_peak_memory() / 1e9:.3f} GB")
print("Saving...")
metadata = {}
metadata["model"] = args.model
metadata["chat_template"] = tokenizer.chat_template
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")
save_prompt_cache(args.prompt_cache_file, cache, metadata)