This commit is contained in:
Awni Hannun 2024-10-14 10:47:58 -07:00
parent 1b05b51dc5
commit 5c4e6ce279

View File

@ -462,12 +462,11 @@ class APIHandler(BaseHTTPRequestHandler):
top_tokens = []
prompt = self.get_prompt_cache(prompt)
prompt = mx.array(prompt)
for _, (token, logprobs) in zip(
range(self.max_tokens),
generate_step(
prompt=prompt,
prompt=mx.array(prompt),
model=self.model,
temp=self.temperature,
top_p=self.top_p,
@ -554,7 +553,6 @@ class APIHandler(BaseHTTPRequestHandler):
logging.debug(f"Starting stream:")
prompt = self.get_prompt_cache(prompt)
prompt = mx.array(prompt)
for _, (token, _) in zip(
range(self.max_tokens),
@ -636,7 +634,7 @@ class APIHandler(BaseHTTPRequestHandler):
}
return response
def handle_chat_completions(self) -> mx.array:
def handle_chat_completions(self) -> List[int]:
"""
Handle a chat completion request.
@ -667,7 +665,7 @@ class APIHandler(BaseHTTPRequestHandler):
return prompt
def handle_text_completions(self) -> mx.array:
def handle_text_completions(self) -> List[int]:
"""
Handle a text completion request.