mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 10:41:18 +08:00
Merge branch 'ml-explore:main' into adding-orpo-training
This commit is contained in:
commit
c33c245c11
@ -44,7 +44,8 @@ def shard_and_load(repo):
|
|||||||
allow_patterns=["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"],
|
allow_patterns=["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Lazy load and shard model
|
# Lazy load and shard model to figure out
|
||||||
|
# which weights we need
|
||||||
model, _ = load_model(model_path, lazy=True, strict=False)
|
model, _ = load_model(model_path, lazy=True, strict=False)
|
||||||
|
|
||||||
group = mx.distributed.init(backend="mpi")
|
group = mx.distributed.init(backend="mpi")
|
||||||
@ -62,8 +63,11 @@ def shard_and_load(repo):
|
|||||||
# Download weights for local shard
|
# Download weights for local shard
|
||||||
download(args.model, allow_patterns=local_files)
|
download(args.model, allow_patterns=local_files)
|
||||||
|
|
||||||
|
# Load and shard the model, and load the weights
|
||||||
tokenizer = load_tokenizer(model_path)
|
tokenizer = load_tokenizer(model_path)
|
||||||
model, _ = load_model(model_path)
|
model, _ = load_model(model_path, lazy=True, strict=False)
|
||||||
|
model.model.pipeline(group)
|
||||||
|
mx.eval(model.parameters())
|
||||||
|
|
||||||
# Synchronize processes before generation to avoid timeout if downloading
|
# Synchronize processes before generation to avoid timeout if downloading
|
||||||
# model for the first time.
|
# model for the first time.
|
||||||
|
@ -386,6 +386,7 @@ class DeepseekV2Model(nn.Module):
|
|||||||
self.num_layers = layers_per_rank
|
self.num_layers = layers_per_rank
|
||||||
self.layers = self.layers[: self.end_idx]
|
self.layers = self.layers[: self.end_idx]
|
||||||
self.layers[: self.start_idx] = [None] * self.start_idx
|
self.layers[: self.start_idx] = [None] * self.start_idx
|
||||||
|
self.num_layers = len(self.layers) - self.start_idx
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
@ -397,12 +397,11 @@ class DeepseekV3Model(nn.Module):
|
|||||||
layers_per_rank = (
|
layers_per_rank = (
|
||||||
len(self.layers) + self.pipeline_size - 1
|
len(self.layers) + self.pipeline_size - 1
|
||||||
) // self.pipeline_size
|
) // self.pipeline_size
|
||||||
start = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
|
||||||
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
||||||
self.end_idx = self.start_idx + layers_per_rank
|
self.end_idx = self.start_idx + layers_per_rank
|
||||||
self.num_layers = layers_per_rank
|
|
||||||
self.layers = self.layers[: self.end_idx]
|
self.layers = self.layers[: self.end_idx]
|
||||||
self.layers[: self.start_idx] = [None] * self.start_idx
|
self.layers[: self.start_idx] = [None] * self.start_idx
|
||||||
|
self.num_layers = len(self.layers) - self.start_idx
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
# Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional, Tuple
|
from typing import Any, Optional, Tuple
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
# Copyright © 2024-2025 Apple Inc.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -123,17 +123,16 @@ class MambaBlock(nn.Module):
|
|||||||
self.intermediate_size, self.hidden_size, bias=args.use_bias
|
self.intermediate_size, self.hidden_size, bias=args.use_bias
|
||||||
)
|
)
|
||||||
|
|
||||||
def ssm_step(self, x, state=None):
|
def ssm_step(self, x, A, state=None):
|
||||||
A = -mx.exp(self.A_log)
|
|
||||||
D = self.D
|
D = self.D
|
||||||
deltaBC = self.x_proj(x)
|
deltaBC = self.x_proj(x)
|
||||||
delta, B, C = mx.split(
|
delta, B, C = map(
|
||||||
deltaBC,
|
self.mixer_norm if self.use_bcdt_rms else lambda x: x,
|
||||||
indices_or_sections=[
|
mx.split(
|
||||||
self.time_step_rank,
|
deltaBC,
|
||||||
self.time_step_rank + self.ssm_state_size,
|
[self.time_step_rank, self.time_step_rank + self.ssm_state_size],
|
||||||
],
|
axis=-1,
|
||||||
axis=-1,
|
),
|
||||||
)
|
)
|
||||||
if self.use_bcdt_rms:
|
if self.use_bcdt_rms:
|
||||||
delta, B, C = map(self.mixer_norm, (delta, B, C))
|
delta, B, C = map(self.mixer_norm, (delta, B, C))
|
||||||
@ -145,25 +144,40 @@ class MambaBlock(nn.Module):
|
|||||||
y = y + D * x
|
y = y + D * x
|
||||||
return y, new_state
|
return y, new_state
|
||||||
|
|
||||||
def __call__(self, x, cache):
|
def _process_sequence(self, x, conv_cache, state_cache):
|
||||||
B, T, D = x.shape
|
B, T, D = x.shape
|
||||||
if cache is None:
|
xz = self.in_proj(x)
|
||||||
cache = [None, None]
|
x, z = xz.split(indices_or_sections=2, axis=-1)
|
||||||
|
|
||||||
|
conv_out, new_conv_cache = self.conv1d(x, conv_cache)
|
||||||
|
x = nn.silu(conv_out)
|
||||||
|
|
||||||
|
A = -mx.exp(self.A_log)
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
|
current_state = state_cache
|
||||||
|
y = []
|
||||||
for t in range(T):
|
for t in range(T):
|
||||||
xt = x[:, t, :]
|
y_t, current_state = self.ssm_step(x[:, t], A, current_state)
|
||||||
xz = self.in_proj(xt)
|
y.append(y_t)
|
||||||
x_t, z_t = xz.split(indices_or_sections=2, axis=1)
|
y = mx.stack(y, axis=1)
|
||||||
conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0])
|
z = self.out_proj(nn.silu(z) * y)
|
||||||
x_t = conv_out.squeeze(1)
|
return z, (new_conv_cache, current_state)
|
||||||
x_t = nn.silu(x_t)
|
|
||||||
y_t, cache[1] = self.ssm_step(x_t, cache[1])
|
def __call__(self, x, cache):
|
||||||
z_t = nn.silu(z_t)
|
if cache is None:
|
||||||
output_t = y_t * z_t
|
conv_cache, state_cache = None, None
|
||||||
output_t = self.out_proj(output_t)
|
else:
|
||||||
outputs.append(output_t)
|
conv_cache, state_cache = cache[0], cache[1]
|
||||||
output = mx.stack(outputs, axis=1)
|
|
||||||
|
output, (new_conv_cache, new_state_cache) = self._process_sequence(
|
||||||
|
x, conv_cache, state_cache
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(cache, MambaCache):
|
||||||
|
cache[0] = new_conv_cache
|
||||||
|
cache[1] = new_state_cache
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2025 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
@ -140,8 +140,8 @@ def evaluate(
|
|||||||
loss: callable = default_loss,
|
loss: callable = default_loss,
|
||||||
iterate_batches: callable = iterate_batches,
|
iterate_batches: callable = iterate_batches,
|
||||||
):
|
):
|
||||||
all_losses = 0
|
all_losses = mx.array(0.0)
|
||||||
ntokens = 0
|
ntokens = mx.array(0)
|
||||||
|
|
||||||
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user