This commit is contained in:
Awni Hannun 2024-10-14 10:47:58 -07:00
parent 1b05b51dc5
commit 5c4e6ce279

View File

@ -462,12 +462,11 @@ class APIHandler(BaseHTTPRequestHandler):
top_tokens = [] top_tokens = []
prompt = self.get_prompt_cache(prompt) prompt = self.get_prompt_cache(prompt)
prompt = mx.array(prompt)
for _, (token, logprobs) in zip( for _, (token, logprobs) in zip(
range(self.max_tokens), range(self.max_tokens),
generate_step( generate_step(
prompt=prompt, prompt=mx.array(prompt),
model=self.model, model=self.model,
temp=self.temperature, temp=self.temperature,
top_p=self.top_p, top_p=self.top_p,
@ -554,7 +553,6 @@ class APIHandler(BaseHTTPRequestHandler):
logging.debug(f"Starting stream:") logging.debug(f"Starting stream:")
prompt = self.get_prompt_cache(prompt) prompt = self.get_prompt_cache(prompt)
prompt = mx.array(prompt)
for _, (token, _) in zip( for _, (token, _) in zip(
range(self.max_tokens), range(self.max_tokens),
@ -636,7 +634,7 @@ class APIHandler(BaseHTTPRequestHandler):
} }
return response return response
def handle_chat_completions(self) -> mx.array: def handle_chat_completions(self) -> List[int]:
""" """
Handle a chat completion request. Handle a chat completion request.
@ -667,7 +665,7 @@ class APIHandler(BaseHTTPRequestHandler):
return prompt return prompt
def handle_text_completions(self) -> mx.array: def handle_text_completions(self) -> List[int]:
""" """
Handle a text completion request. Handle a text completion request.