diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index 5464dd1a..ffbe7556 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -82,17 +82,21 @@ class APIHandler(BaseHTTPRequestHandler): self.created = int(time.time()) super().__init__(*args, **kwargs) - def _set_completion_headers(self, status_code: int = 200): - self.send_response(status_code) - self.send_header("Content-type", "application/json") + def _set_cors_headers(self): self.send_header("Access-Control-Allow-Origin", "*") self.send_header("Access-Control-Allow-Methods", "*") self.send_header("Access-Control-Allow-Headers", "*") + def _set_completion_headers(self, status_code: int = 200): + self.send_response(status_code) + self.send_header("Content-type", "application/json") + self._set_cors_headers() + def _set_stream_headers(self, status_code: int = 200): self.send_response(status_code) self.send_header("Content-type", "text/event-stream") self.send_header("Cache-Control", "no-cache") + self._set_cors_headers() def do_OPTIONS(self): self._set_completion_headers(204)