[mlx-lm] Add precompiled normalizations (#451)

* add precompiled normalizations

* nits
This commit is contained in:
Awni Hannun 2024-02-22 12:40:55 -08:00 committed by GitHub
parent 97c09a863d
commit f24edfa9dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 74 additions and 105 deletions

View File

@ -1,4 +1,4 @@
from .convert import convert
from .utils import generate, load
__version__ = "0.0.13"
__version__ = "0.0.14"

View File

@ -1,4 +1,5 @@
from dataclasses import dataclass
from functools import partial
from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
@ -22,18 +23,21 @@ class ModelArgs(BaseModelArgs):
rope_traditional: bool = False
@partial(mx.compile, shapeless=True)
def rms_norm(x, weight, eps):
x = x.astype(mx.float32)
x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
return (1.0 + weight) * x.astype(weight.dtype)
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return (1 + self.weight) * output
return rms_norm(x, self.weight, self.eps)
class Attention(nn.Module):

View File

@ -0,0 +1,51 @@
from functools import partial
import mlx.core as mx
import mlx.nn as nn
@partial(mx.compile, shapeless=True)
def rms_norm(x, weight, eps):
x = x.astype(mx.float32)
x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
return weight * x.astype(weight.dtype)
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def __call__(self, x):
return rms_norm(x, self.weight, self.eps)
@partial(mx.compile, shapeless=True)
def ln_norm(x, eps, weight=None, bias=None):
t = x.dtype
x = x.astype(mx.float32)
means = mx.mean(x, axis=-1, keepdims=True)
var = mx.var(x, axis=-1, keepdims=True)
x = (x - means) * mx.rsqrt(var + eps)
x = x.astype(t)
return weight * x + bias if weight is not None else x
class LayerNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5, affine: bool = True):
super().__init__()
if affine:
self.bias = mx.zeros((dims,))
self.weight = mx.ones((dims,))
self.eps = eps
self.dims = dims
def _extra_repr(self):
return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
def __call__(self, x: mx.array) -> mx.array:
if "weight" in self:
return ln_norm(x, self.eps, self.weight, self.bias)
else:
return ln_norm(x, self.eps)

View File

@ -5,6 +5,7 @@ import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .layers import RMSNorm
@dataclass
@ -34,20 +35,6 @@ class ModelArgs(BaseModelArgs):
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()

View File

@ -6,6 +6,7 @@ import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs
from .layers import RMSNorm
@dataclass
@ -30,20 +31,6 @@ class ModelArgs(BaseModelArgs):
self.num_key_value_heads = self.num_attention_heads
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output
class MixtralAttention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()

View File

@ -6,6 +6,7 @@ import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .layers import LayerNorm
try:
import hf_olmo
@ -37,11 +38,6 @@ class ModelArgs(BaseModelArgs):
)
class LayerNorm(nn.LayerNorm):
def __call__(self, x: mx.array) -> mx.array:
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()

View File

@ -6,6 +6,7 @@ import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .layers import LayerNorm
@dataclass
@ -27,11 +28,6 @@ class ModelArgs(BaseModelArgs):
self.num_key_value_heads = self.num_attention_heads
class LayerNorm(nn.LayerNorm):
def __call__(self, x: mx.array) -> mx.array:
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
class PhiAttention(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()

View File

@ -1,17 +1,13 @@
import glob
import inspect
import json
import math
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Tuple
from dataclasses import dataclass
from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from huggingface_hub import snapshot_download
from mlx.utils import tree_unflatten
from transformers import AutoTokenizer
from .layers import LayerNorm
@dataclass
@ -37,11 +33,6 @@ class ModelArgs:
)
class LayerNorm(nn.LayerNorm):
def __call__(self, x: mx.array) -> mx.array:
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
class RoPEAttention(nn.Module):
def __init__(self, dims: int, num_heads: int, rotary_dim: int):
super().__init__()

View File

@ -6,6 +6,7 @@ import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs
from .layers import RMSNorm
@dataclass
@ -22,20 +23,6 @@ class ModelArgs(BaseModelArgs):
rope_traditional: bool = False
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.variance_epsilon = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.variance_epsilon)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output
class Attention(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()

View File

@ -5,6 +5,7 @@ import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .layers import RMSNorm
@dataclass
@ -26,20 +27,6 @@ class ModelArgs(BaseModelArgs):
self.num_key_value_heads = self.num_attention_heads
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()

View File

@ -5,6 +5,7 @@ import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .layers import RMSNorm
@dataclass
@ -34,20 +35,6 @@ class ModelArgs(BaseModelArgs):
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-6):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()

View File

@ -6,6 +6,7 @@ import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .layers import LayerNorm
@dataclass
@ -24,11 +25,6 @@ class ModelArgs(BaseModelArgs):
use_qkv_bias: bool
class LayerNorm(nn.LayerNorm):
def __call__(self, x: mx.array) -> mx.array:
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
class Attention(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()

View File

@ -1,4 +1,4 @@
mlx>=0.1
mlx>=0.4
numpy
transformers>=4.38.0
protobuf