diff --git a/llms/mlx_lm/models/plamo2.py b/llms/mlx_lm/models/plamo2.py index 1d8215dd..657fa02e 100644 --- a/llms/mlx_lm/models/plamo2.py +++ b/llms/mlx_lm/models/plamo2.py @@ -2,7 +2,7 @@ import math from dataclasses import dataclass -from typing import Any, Optional, Union +from typing import Any, Optional import mlx.core as mx import mlx.nn as nn @@ -32,7 +32,6 @@ class ModelArgs(BaseModelArgs): mamba_enabled: bool = True intermediate_size: int = 13312 vocab_size: int = 32000 - max_position_embeddings: int = 10 * 1024 * 1024 class RMSNorm(nn.Module): @@ -53,6 +52,16 @@ class RMSNorm(nn.Module): ) +def _rms_norm(hidden_states: mx.array, eps: float) -> mx.array: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.astype(mx.float32) + variance = mx.power(hidden_states, 2).mean(-1, keepdims=True) + hidden_states = hidden_states * mx.rsqrt(variance + eps) + hidden_states = hidden_states.astype(input_dtype) + + return hidden_states + + def get_initial_dt_bias(num_heads: int) -> mx.array: dt_min = 0.001 dt_max = 0.1 @@ -220,8 +229,7 @@ def ssd_chunk_scan_combined( def causal_conv1d_update(conv_state, x, weight) -> tuple[mx.array, mx.array]: - batch, seqlen, dim = x.shape - width = weight.shape[1] + _, seqlen, dim = x.shape state_len = conv_state.shape[-2] x = mx.concatenate([conv_state, x], axis=-2) conv_state = x[:, -state_len:] @@ -392,8 +400,8 @@ class Attention(nn.Module): k = k.reshape(B, T, self.k_num_heads, self.qk_dim).transpose(0, 2, 1, 3) v = v.reshape(B, T, self.v_num_heads, self.v_dim).transpose(0, 2, 1, 3) - q = mx.fast.layer_norm(q, None, None, 1e-6) * self.q_weight[:, None] - k = mx.fast.layer_norm(k, None, None, 1e-6) * self.k_weight[:, None] + q = _rms_norm(q, 1e-6) * self.q_weight[:, None] + k = _rms_norm(k, 1e-6) * self.k_weight[:, None] if cache is not None: q = self.rope(q, offset=cache.offset) @@ -556,7 +564,6 @@ class PlamoModel(nn.Module): class Model(nn.Module): - def __init__(self, config: ModelArgs) -> None: super().__init__() self.config = config @@ -567,7 +574,7 @@ class Model(nn.Module): if not config.tie_word_embeddings: self.lm_head: nn.Module = nn.Linear( - config.hidden_size, vocab_size, bias=False + config.hidden_size, self.vocab_size, bias=False ) def sanitize(self, weights: dict[Any, Any]) -> dict[Any, Any]: diff --git a/lora/lora.py b/lora/lora.py index 723e783d..6f91ccca 100644 --- a/lora/lora.py +++ b/lora/lora.py @@ -3,6 +3,7 @@ import argparse import json import math +import sys import time from pathlib import Path @@ -14,6 +15,9 @@ import utils as lora_utils from mlx.utils import tree_flatten from models import LoRALinear +# Disable output buffering to see print statements in real-time +sys.stdout.reconfigure(line_buffering=True) + def build_parser(): parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")