Generation refactor: part 2 (#1099)

* unify with stream_generate

* fixes

* nit

* some cleanup, warnings, tests

* fix test + faster min p + test

* version
This commit is contained in:
Awni Hannun
2024-11-23 11:47:06 -08:00
committed by GitHub
parent 004eb4cc9d
commit 0f135396ae
13 changed files with 184 additions and 197 deletions

View File

@@ -27,6 +27,7 @@ from huggingface_hub import scan_cache_dir
from ._version import __version__
from .models.cache import make_prompt_cache
from .sample_utils import make_logits_processors, make_sampler
from .utils import load, stream_generate
@@ -464,25 +465,24 @@ class APIHandler(BaseHTTPRequestHandler):
text = ""
tic = time.perf_counter()
for n, (segment, token, logprobs) in enumerate(
stream_generate(
model=self.model,
tokenizer=self.tokenizer,
prompt=prompt,
max_tokens=self.max_tokens,
temp=self.temperature,
repetition_penalty=self.repetition_penalty,
repetition_context_size=self.repetition_context_size,
logit_bias=self.logit_bias,
prompt_cache=self.prompt_cache.cache,
),
sampler = make_sampler(self.temperature)
logits_processors = make_logits_processors(
self.logit_bias, self.repetition_penalty, self.repetition_context_size
)
for gen_response in stream_generate(
model=self.model,
tokenizer=self.tokenizer,
prompt=prompt,
max_tokens=self.max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=self.prompt_cache.cache,
):
if n == 0:
prompt_time = time.perf_counter() - tic
tic = time.perf_counter()
segment = gen_response.text
text += segment
logging.debug(text)
token = gen_response.token
logprobs = gen_response.logprobs
tokens.append(token)
if self.logprobs > 0:
@@ -523,13 +523,9 @@ class APIHandler(BaseHTTPRequestHandler):
self.prompt_cache.tokens.extend(tokens)
gen_time = time.perf_counter() - tic
prompt_tps = len(prompt) / prompt_time
gen_tps = len(tokens) / gen_time
peak_mem = mx.metal.get_peak_memory() / 1e9
logging.debug(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
logging.debug(f"Generation: {gen_tps:.3f} tokens-per-sec")
logging.debug(f"Peak memory: {peak_mem:.3f} GB")
logging.debug(f"Prompt: {gen_response.prompt_tps:.3f} tokens-per-sec")
logging.debug(f"Generation: {gen_response.generation_tps:.3f} tokens-per-sec")
logging.debug(f"Peak memory: {gen_response.peak_memory:.3f} GB")
if self.stream:
response = self.generate_response(segment, finish_reason)