Fix server for openai package (#877)

* fix

* fixes for 9b
This commit is contained in:
Awni Hannun 2024-07-08 12:34:31 -07:00 committed by GitHub
parent 20e221f7f7
commit 68e88d42fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 27 additions and 13 deletions

View File

@ -14,7 +14,6 @@ class ModelArgs(BaseModelArgs):
hidden_size: int
attention_bias: bool
conv1d_width: int
embeddings_scale_by_sqrt_dim: bool
hidden_size: int
intermediate_size: int
logits_soft_cap: float
@ -25,7 +24,14 @@ class ModelArgs(BaseModelArgs):
rope_theta: float
attention_window_size: int
vocab_size: int
_block_types: List[str]
embeddings_scale_by_sqrt_dim: bool = True
block_types: Optional[List[str]] = None
_block_types: Optional[List[str]] = None
def __post_init__(self):
# For some reason these have different names in 2B and 9B
if self.block_types is None:
self.block_types = self._block_types
def create_window_causal_mask(N: int, window_size: int):
@ -202,6 +208,8 @@ class RGLRU(nn.Module):
# Apply gamma normalization to the input.
multiplier = mx.sqrt(1 - a_square)
if cache is None:
multiplier[:, 0, :] = 1.0
normalized_x = gated_x * multiplier.astype(x.dtype)
y, last_h = rnn_scan(
@ -404,8 +412,8 @@ class ResidualBlock(nn.Module):
raw_x = x
inputs_normalized = self.temporal_pre_norm(raw_x)
x = self.temporal_block(inputs_normalized, cache=cache, mask=mask)
x = self.temporal_block(inputs_normalized, cache=cache, mask=mask)
residual = x + raw_x
x = self.channel_pre_norm(residual)
@ -427,7 +435,7 @@ class Griffin(nn.Module):
)
self.scale_by_sqrt_dim = config.embeddings_scale_by_sqrt_dim
block_types = config._block_types
block_types = config.block_types
self.layers = [
ResidualBlock(
@ -461,14 +469,7 @@ class Griffin(nn.Module):
for i, block in enumerate(self.layers):
x = block(x, mask=mask, cache=cache[i])
x = self.final_norm(x)
logits = self.embed_tokens.as_linear(x)
c = self.config.logits_soft_cap
if c:
logits = mx.tanh(logits / c) * c
return logits
return self.final_norm(x)
class Model(nn.Module):
@ -476,13 +477,23 @@ class Model(nn.Module):
def __init__(self, config):
self.args = config
self.model = Griffin(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def __call__(self, tokens: mx.array, cache=None) -> mx.array:
"""
Args:
tokens: Sequence of input tokens.
"""
return self.model(tokens, cache=cache)
logits = self.model(tokens, cache=cache)
if "lm_head" in self:
logits = self.lm_head(logits)
else:
logits = self.model.embed_tokens.as_linear(logits)
c = self.args.logits_soft_cap
if c:
logits = mx.tanh(logits / c) * c
return logits
@property
def layers(self):
@ -493,6 +504,8 @@ class Model(nn.Module):
for k, v in weights.items():
if "conv_1d.weight" in k and v.ndim == 3:
weights[k] = v.squeeze(1).T
if "lm_head.weight" not in weights:
self.pop("lm_head")
return weights
def make_cache(self):

View File

@ -173,6 +173,7 @@ class APIHandler(BaseHTTPRequestHandler):
endpoints = {
"/v1/completions": self.handle_text_completions,
"/v1/chat/completions": self.handle_chat_completions,
"/chat/completions": self.handle_chat_completions,
}
if self.path not in endpoints: