mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
[mlx-lm] Add precompiled normalizations (#451)
* add precompiled normalizations * nits
This commit is contained in:
parent
97c09a863d
commit
f24edfa9dc
@ -1,4 +1,4 @@
|
|||||||
from .convert import convert
|
from .convert import convert
|
||||||
from .utils import generate, load
|
from .utils import generate, load
|
||||||
|
|
||||||
__version__ = "0.0.13"
|
__version__ = "0.0.14"
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from functools import partial
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
@ -22,18 +23,21 @@ class ModelArgs(BaseModelArgs):
|
|||||||
rope_traditional: bool = False
|
rope_traditional: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
|
def rms_norm(x, weight, eps):
|
||||||
|
x = x.astype(mx.float32)
|
||||||
|
x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
|
||||||
|
return (1.0 + weight) * x.astype(weight.dtype)
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
class RMSNorm(nn.Module):
|
||||||
def __init__(self, dims: int, eps: float = 1e-5):
|
def __init__(self, dims: int, eps: float = 1e-5):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight = mx.ones((dims,))
|
self.weight = mx.ones((dims,))
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
|
||||||
def _norm(self, x):
|
|
||||||
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
return rms_norm(x, self.weight, self.eps)
|
||||||
return (1 + self.weight) * output
|
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
|
51
llms/mlx_lm/models/layers.py
Normal file
51
llms/mlx_lm/models/layers.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
|
def rms_norm(x, weight, eps):
|
||||||
|
x = x.astype(mx.float32)
|
||||||
|
x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
|
||||||
|
return weight * x.astype(weight.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
def __init__(self, dims: int, eps: float = 1e-5):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = mx.ones((dims,))
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return rms_norm(x, self.weight, self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
|
def ln_norm(x, eps, weight=None, bias=None):
|
||||||
|
t = x.dtype
|
||||||
|
x = x.astype(mx.float32)
|
||||||
|
means = mx.mean(x, axis=-1, keepdims=True)
|
||||||
|
var = mx.var(x, axis=-1, keepdims=True)
|
||||||
|
x = (x - means) * mx.rsqrt(var + eps)
|
||||||
|
x = x.astype(t)
|
||||||
|
return weight * x + bias if weight is not None else x
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm(nn.Module):
|
||||||
|
def __init__(self, dims: int, eps: float = 1e-5, affine: bool = True):
|
||||||
|
super().__init__()
|
||||||
|
if affine:
|
||||||
|
self.bias = mx.zeros((dims,))
|
||||||
|
self.weight = mx.ones((dims,))
|
||||||
|
self.eps = eps
|
||||||
|
self.dims = dims
|
||||||
|
|
||||||
|
def _extra_repr(self):
|
||||||
|
return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
if "weight" in self:
|
||||||
|
return ln_norm(x, self.eps, self.weight, self.bias)
|
||||||
|
else:
|
||||||
|
return ln_norm(x, self.eps)
|
@ -5,6 +5,7 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs
|
from .base import BaseModelArgs
|
||||||
|
from .layers import RMSNorm
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -34,20 +35,6 @@ class ModelArgs(BaseModelArgs):
|
|||||||
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
|
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
|
||||||
def __init__(self, dims: int, eps: float = 1e-5):
|
|
||||||
super().__init__()
|
|
||||||
self.weight = mx.ones((dims,))
|
|
||||||
self.eps = eps
|
|
||||||
|
|
||||||
def _norm(self, x):
|
|
||||||
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
|
||||||
|
|
||||||
def __call__(self, x):
|
|
||||||
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
|
||||||
return self.weight * output
|
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -6,6 +6,7 @@ import mlx.nn as nn
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .base import BaseModelArgs
|
from .base import BaseModelArgs
|
||||||
|
from .layers import RMSNorm
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -30,20 +31,6 @@ class ModelArgs(BaseModelArgs):
|
|||||||
self.num_key_value_heads = self.num_attention_heads
|
self.num_key_value_heads = self.num_attention_heads
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
|
||||||
def __init__(self, dims: int, eps: float = 1e-5):
|
|
||||||
super().__init__()
|
|
||||||
self.weight = mx.ones((dims,))
|
|
||||||
self.eps = eps
|
|
||||||
|
|
||||||
def _norm(self, x):
|
|
||||||
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
|
||||||
|
|
||||||
def __call__(self, x):
|
|
||||||
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
|
||||||
return self.weight * output
|
|
||||||
|
|
||||||
|
|
||||||
class MixtralAttention(nn.Module):
|
class MixtralAttention(nn.Module):
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -6,6 +6,7 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs
|
from .base import BaseModelArgs
|
||||||
|
from .layers import LayerNorm
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import hf_olmo
|
import hf_olmo
|
||||||
@ -37,11 +38,6 @@ class ModelArgs(BaseModelArgs):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm(nn.LayerNorm):
|
|
||||||
def __call__(self, x: mx.array) -> mx.array:
|
|
||||||
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerBlock(nn.Module):
|
class TransformerBlock(nn.Module):
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -6,6 +6,7 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs
|
from .base import BaseModelArgs
|
||||||
|
from .layers import LayerNorm
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -27,11 +28,6 @@ class ModelArgs(BaseModelArgs):
|
|||||||
self.num_key_value_heads = self.num_attention_heads
|
self.num_key_value_heads = self.num_attention_heads
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm(nn.LayerNorm):
|
|
||||||
def __call__(self, x: mx.array) -> mx.array:
|
|
||||||
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class PhiAttention(nn.Module):
|
class PhiAttention(nn.Module):
|
||||||
def __init__(self, config: ModelArgs):
|
def __init__(self, config: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -1,17 +1,13 @@
|
|||||||
import glob
|
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from typing import Tuple
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
from mlx.utils import tree_unflatten
|
from .layers import LayerNorm
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -37,11 +33,6 @@ class ModelArgs:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm(nn.LayerNorm):
|
|
||||||
def __call__(self, x: mx.array) -> mx.array:
|
|
||||||
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class RoPEAttention(nn.Module):
|
class RoPEAttention(nn.Module):
|
||||||
def __init__(self, dims: int, num_heads: int, rotary_dim: int):
|
def __init__(self, dims: int, num_heads: int, rotary_dim: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -6,6 +6,7 @@ import mlx.nn as nn
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .base import BaseModelArgs
|
from .base import BaseModelArgs
|
||||||
|
from .layers import RMSNorm
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -22,20 +23,6 @@ class ModelArgs(BaseModelArgs):
|
|||||||
rope_traditional: bool = False
|
rope_traditional: bool = False
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
|
||||||
def __init__(self, dims: int, eps: float = 1e-5):
|
|
||||||
super().__init__()
|
|
||||||
self.weight = mx.ones((dims,))
|
|
||||||
self.variance_epsilon = eps
|
|
||||||
|
|
||||||
def _norm(self, x):
|
|
||||||
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.variance_epsilon)
|
|
||||||
|
|
||||||
def __call__(self, x):
|
|
||||||
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
|
||||||
return self.weight * output
|
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
def __init__(self, config: ModelArgs) -> None:
|
def __init__(self, config: ModelArgs) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -5,6 +5,7 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs
|
from .base import BaseModelArgs
|
||||||
|
from .layers import RMSNorm
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -26,20 +27,6 @@ class ModelArgs(BaseModelArgs):
|
|||||||
self.num_key_value_heads = self.num_attention_heads
|
self.num_key_value_heads = self.num_attention_heads
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
|
||||||
def __init__(self, dims: int, eps: float = 1e-5):
|
|
||||||
super().__init__()
|
|
||||||
self.weight = mx.ones((dims,))
|
|
||||||
self.eps = eps
|
|
||||||
|
|
||||||
def _norm(self, x):
|
|
||||||
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
|
||||||
|
|
||||||
def __call__(self, x):
|
|
||||||
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
|
||||||
return self.weight * output
|
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -5,6 +5,7 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs
|
from .base import BaseModelArgs
|
||||||
|
from .layers import RMSNorm
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -34,20 +35,6 @@ class ModelArgs(BaseModelArgs):
|
|||||||
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
|
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
|
||||||
def __init__(self, dims: int, eps: float = 1e-6):
|
|
||||||
super().__init__()
|
|
||||||
self.weight = mx.ones((dims,))
|
|
||||||
self.eps = eps
|
|
||||||
|
|
||||||
def _norm(self, x):
|
|
||||||
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
|
||||||
|
|
||||||
def __call__(self, x):
|
|
||||||
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
|
||||||
return self.weight * output
|
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -6,6 +6,7 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs
|
from .base import BaseModelArgs
|
||||||
|
from .layers import LayerNorm
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -24,11 +25,6 @@ class ModelArgs(BaseModelArgs):
|
|||||||
use_qkv_bias: bool
|
use_qkv_bias: bool
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm(nn.LayerNorm):
|
|
||||||
def __call__(self, x: mx.array) -> mx.array:
|
|
||||||
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
def __init__(self, config: ModelArgs):
|
def __init__(self, config: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
mlx>=0.1
|
mlx>=0.4
|
||||||
numpy
|
numpy
|
||||||
transformers>=4.38.0
|
transformers>=4.38.0
|
||||||
protobuf
|
protobuf
|
||||||
|
Loading…
Reference in New Issue
Block a user