Merge branch 'ml-explore:main' into adding-orpo-training

This commit is contained in:
Gökdeniz Gülmez 2025-02-04 11:04:40 +01:00 committed by GitHub
commit c33c245c11
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 52 additions and 32 deletions

View File

@ -44,7 +44,8 @@ def shard_and_load(repo):
allow_patterns=["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"], allow_patterns=["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"],
) )
# Lazy load and shard model # Lazy load and shard model to figure out
# which weights we need
model, _ = load_model(model_path, lazy=True, strict=False) model, _ = load_model(model_path, lazy=True, strict=False)
group = mx.distributed.init(backend="mpi") group = mx.distributed.init(backend="mpi")
@ -62,8 +63,11 @@ def shard_and_load(repo):
# Download weights for local shard # Download weights for local shard
download(args.model, allow_patterns=local_files) download(args.model, allow_patterns=local_files)
# Load and shard the model, and load the weights
tokenizer = load_tokenizer(model_path) tokenizer = load_tokenizer(model_path)
model, _ = load_model(model_path) model, _ = load_model(model_path, lazy=True, strict=False)
model.model.pipeline(group)
mx.eval(model.parameters())
# Synchronize processes before generation to avoid timeout if downloading # Synchronize processes before generation to avoid timeout if downloading
# model for the first time. # model for the first time.

View File

@ -386,6 +386,7 @@ class DeepseekV2Model(nn.Module):
self.num_layers = layers_per_rank self.num_layers = layers_per_rank
self.layers = self.layers[: self.end_idx] self.layers = self.layers[: self.end_idx]
self.layers[: self.start_idx] = [None] * self.start_idx self.layers[: self.start_idx] = [None] * self.start_idx
self.num_layers = len(self.layers) - self.start_idx
def __call__( def __call__(
self, self,

View File

@ -397,12 +397,11 @@ class DeepseekV3Model(nn.Module):
layers_per_rank = ( layers_per_rank = (
len(self.layers) + self.pipeline_size - 1 len(self.layers) + self.pipeline_size - 1
) // self.pipeline_size ) // self.pipeline_size
start = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
self.end_idx = self.start_idx + layers_per_rank self.end_idx = self.start_idx + layers_per_rank
self.num_layers = layers_per_rank
self.layers = self.layers[: self.end_idx] self.layers = self.layers[: self.end_idx]
self.layers[: self.start_idx] = [None] * self.start_idx self.layers[: self.start_idx] = [None] * self.start_idx
self.num_layers = len(self.layers) - self.start_idx
def __call__( def __call__(
self, self,

View File

@ -1,3 +1,5 @@
# Copyright © 2025 Apple Inc.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional, Tuple from typing import Any, Optional, Tuple

View File

@ -1,4 +1,4 @@
# Copyright © 2024 Apple Inc. # Copyright © 2024-2025 Apple Inc.
import math import math
from dataclasses import dataclass from dataclasses import dataclass
@ -123,17 +123,16 @@ class MambaBlock(nn.Module):
self.intermediate_size, self.hidden_size, bias=args.use_bias self.intermediate_size, self.hidden_size, bias=args.use_bias
) )
def ssm_step(self, x, state=None): def ssm_step(self, x, A, state=None):
A = -mx.exp(self.A_log)
D = self.D D = self.D
deltaBC = self.x_proj(x) deltaBC = self.x_proj(x)
delta, B, C = mx.split( delta, B, C = map(
self.mixer_norm if self.use_bcdt_rms else lambda x: x,
mx.split(
deltaBC, deltaBC,
indices_or_sections=[ [self.time_step_rank, self.time_step_rank + self.ssm_state_size],
self.time_step_rank,
self.time_step_rank + self.ssm_state_size,
],
axis=-1, axis=-1,
),
) )
if self.use_bcdt_rms: if self.use_bcdt_rms:
delta, B, C = map(self.mixer_norm, (delta, B, C)) delta, B, C = map(self.mixer_norm, (delta, B, C))
@ -145,25 +144,40 @@ class MambaBlock(nn.Module):
y = y + D * x y = y + D * x
return y, new_state return y, new_state
def __call__(self, x, cache): def _process_sequence(self, x, conv_cache, state_cache):
B, T, D = x.shape B, T, D = x.shape
if cache is None: xz = self.in_proj(x)
cache = [None, None] x, z = xz.split(indices_or_sections=2, axis=-1)
conv_out, new_conv_cache = self.conv1d(x, conv_cache)
x = nn.silu(conv_out)
A = -mx.exp(self.A_log)
outputs = [] outputs = []
current_state = state_cache
y = []
for t in range(T): for t in range(T):
xt = x[:, t, :] y_t, current_state = self.ssm_step(x[:, t], A, current_state)
xz = self.in_proj(xt) y.append(y_t)
x_t, z_t = xz.split(indices_or_sections=2, axis=1) y = mx.stack(y, axis=1)
conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0]) z = self.out_proj(nn.silu(z) * y)
x_t = conv_out.squeeze(1) return z, (new_conv_cache, current_state)
x_t = nn.silu(x_t)
y_t, cache[1] = self.ssm_step(x_t, cache[1]) def __call__(self, x, cache):
z_t = nn.silu(z_t) if cache is None:
output_t = y_t * z_t conv_cache, state_cache = None, None
output_t = self.out_proj(output_t) else:
outputs.append(output_t) conv_cache, state_cache = cache[0], cache[1]
output = mx.stack(outputs, axis=1)
output, (new_conv_cache, new_state_cache) = self._process_sequence(
x, conv_cache, state_cache
)
if isinstance(cache, MambaCache):
cache[0] = new_conv_cache
cache[1] = new_state_cache
return output return output

View File

@ -1,4 +1,4 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2025 Apple Inc.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union

View File

@ -140,8 +140,8 @@ def evaluate(
loss: callable = default_loss, loss: callable = default_loss,
iterate_batches: callable = iterate_batches, iterate_batches: callable = iterate_batches,
): ):
all_losses = 0 all_losses = mx.array(0.0)
ntokens = 0 ntokens = mx.array(0)
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)