Merge branch 'ml-explore:main' into adding-dpo-training

This commit is contained in:
Gökdeniz Gülmez 2025-03-03 22:13:31 +01:00 committed by GitHub
commit 2954398002
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 19 additions and 8 deletions

View File

@ -2,7 +2,7 @@
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional, Union from typing import Any, Optional
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -32,7 +32,6 @@ class ModelArgs(BaseModelArgs):
mamba_enabled: bool = True mamba_enabled: bool = True
intermediate_size: int = 13312 intermediate_size: int = 13312
vocab_size: int = 32000 vocab_size: int = 32000
max_position_embeddings: int = 10 * 1024 * 1024
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
@ -53,6 +52,16 @@ class RMSNorm(nn.Module):
) )
def _rms_norm(hidden_states: mx.array, eps: float) -> mx.array:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.astype(mx.float32)
variance = mx.power(hidden_states, 2).mean(-1, keepdims=True)
hidden_states = hidden_states * mx.rsqrt(variance + eps)
hidden_states = hidden_states.astype(input_dtype)
return hidden_states
def get_initial_dt_bias(num_heads: int) -> mx.array: def get_initial_dt_bias(num_heads: int) -> mx.array:
dt_min = 0.001 dt_min = 0.001
dt_max = 0.1 dt_max = 0.1
@ -220,8 +229,7 @@ def ssd_chunk_scan_combined(
def causal_conv1d_update(conv_state, x, weight) -> tuple[mx.array, mx.array]: def causal_conv1d_update(conv_state, x, weight) -> tuple[mx.array, mx.array]:
batch, seqlen, dim = x.shape _, seqlen, dim = x.shape
width = weight.shape[1]
state_len = conv_state.shape[-2] state_len = conv_state.shape[-2]
x = mx.concatenate([conv_state, x], axis=-2) x = mx.concatenate([conv_state, x], axis=-2)
conv_state = x[:, -state_len:] conv_state = x[:, -state_len:]
@ -392,8 +400,8 @@ class Attention(nn.Module):
k = k.reshape(B, T, self.k_num_heads, self.qk_dim).transpose(0, 2, 1, 3) k = k.reshape(B, T, self.k_num_heads, self.qk_dim).transpose(0, 2, 1, 3)
v = v.reshape(B, T, self.v_num_heads, self.v_dim).transpose(0, 2, 1, 3) v = v.reshape(B, T, self.v_num_heads, self.v_dim).transpose(0, 2, 1, 3)
q = mx.fast.layer_norm(q, None, None, 1e-6) * self.q_weight[:, None] q = _rms_norm(q, 1e-6) * self.q_weight[:, None]
k = mx.fast.layer_norm(k, None, None, 1e-6) * self.k_weight[:, None] k = _rms_norm(k, 1e-6) * self.k_weight[:, None]
if cache is not None: if cache is not None:
q = self.rope(q, offset=cache.offset) q = self.rope(q, offset=cache.offset)
@ -556,7 +564,6 @@ class PlamoModel(nn.Module):
class Model(nn.Module): class Model(nn.Module):
def __init__(self, config: ModelArgs) -> None: def __init__(self, config: ModelArgs) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
@ -567,7 +574,7 @@ class Model(nn.Module):
if not config.tie_word_embeddings: if not config.tie_word_embeddings:
self.lm_head: nn.Module = nn.Linear( self.lm_head: nn.Module = nn.Linear(
config.hidden_size, vocab_size, bias=False config.hidden_size, self.vocab_size, bias=False
) )
def sanitize(self, weights: dict[Any, Any]) -> dict[Any, Any]: def sanitize(self, weights: dict[Any, Any]) -> dict[Any, Any]:

View File

@ -3,6 +3,7 @@
import argparse import argparse
import json import json
import math import math
import sys
import time import time
from pathlib import Path from pathlib import Path
@ -14,6 +15,9 @@ import utils as lora_utils
from mlx.utils import tree_flatten from mlx.utils import tree_flatten
from models import LoRALinear from models import LoRALinear
# Disable output buffering to see print statements in real-time
sys.stdout.reconfigure(line_buffering=True)
def build_parser(): def build_parser():
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.") parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")