1 Commits

Author SHA1 Message Date
Angelos Katharopoulos
b9eff0d744 Improve printing for FLUX distributed training 2025-01-13 22:47:54 -08:00
11 changed files with 101 additions and 378 deletions

View File

@@ -261,19 +261,23 @@ if __name__ == "__main__":
generate_progress_images(0, flux, args) generate_progress_images(0, flux, args)
grads = None grads = None
losses = [] batch_cnt = 0
total_loss = 0
tic = time.time() tic = time.time()
for i, batch in zip(range(args.iterations), trainer.iterate(args.batch_size)): for i, batch in zip(range(args.iterations), trainer.iterate(args.batch_size)):
loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0) loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0)
mx.eval(loss, grads, state) total_loss = total_loss + loss
losses.append(loss.item()) batch_cnt += 1
mx.eval(total_loss, grads, state)
if (i + 1) % 10 == 0: if (i + 1) % 10 == 0 and mx.distributed.init().rank() == 0:
toc = time.time() toc = time.time()
peak_mem = mx.metal.get_peak_memory() / 1024**3 peak_mem = mx.metal.get_peak_memory() / 1024**3
total_loss = mx.distributed.all_sum(total_loss, stream=mx.cpu)
total_loss = total_loss.item() / batch_cnt
print( print(
f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} " f"Iter: {i + 1} Loss: {total_loss:.3f} "
f"It/s: {10 / (toc - tic):.3f} " f"It/s: {batch_cnt / (toc - tic):.3f} "
f"Peak mem: {peak_mem:.3f} GB", f"Peak mem: {peak_mem:.3f} GB",
flush=True, flush=True,
) )
@@ -285,7 +289,8 @@ if __name__ == "__main__":
save_adapters(f"{i + 1:07d}_adapters.safetensors", flux, args) save_adapters(f"{i + 1:07d}_adapters.safetensors", flux, args)
if (i + 1) % 10 == 0: if (i + 1) % 10 == 0:
losses = [] total_loss = 0
batch_cnt = 0
tic = time.time() tic = time.time()
save_adapters("final_adapters.safetensors", flux, args) save_adapters("final_adapters.safetensors", flux, args)

View File

@@ -10,7 +10,7 @@ import logging
import os import os
from importlib.metadata import version from importlib.metadata import version
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional, Union
import lm_eval import lm_eval
import mlx.core as mx import mlx.core as mx
@@ -20,10 +20,11 @@ from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from tqdm import tqdm from tqdm import tqdm
from .models.base import create_causal_mask
from .models.cache import make_prompt_cache from .models.cache import make_prompt_cache
from .utils import load, stream_generate from .utils import load, stream_generate
PAD = 0
def _len_longest_common_prefix(a, b): def _len_longest_common_prefix(a, b):
l = 0 l = 0
@@ -42,14 +43,31 @@ def _rstrip_until(s, untils):
return s[: min(f)] return s[: min(f)]
def _pad_inputs(inputs): def _pad_inputs(
lengths = np.array([len(x) for x in inputs]) inputs,
maxlen = lengths.max() maxlen,
padded = np.stack( genlen=0,
[np.pad(x, (0, maxlen - len(x))) for x in inputs], pad_left=False,
pad_multiple=32,
truncate=False,
):
# pad the prompts to the left with at least genlen tokens.
actual_maxlen = max(len(p) for p in inputs) + genlen
if actual_maxlen > maxlen:
if not truncate:
raise ValueError("Inputs are too long.")
else: # drop begining
actual_maxlen = maxlen
inputs = [p[max(0, len(p) - maxlen) :] for p in inputs]
if pad_multiple > 0:
maxlen = (actual_maxlen + pad_multiple - 1) // pad_multiple
maxlen *= pad_multiple
assert PAD == 0
lr = np.array((1, 0) if pad_left else (0, 1))
return np.stack(
[np.pad(np.array(x, np.int32), lr * (maxlen - len(x))) for x in inputs],
axis=0, axis=0,
) )
return mx.array(padded), mx.array(lengths)
@register_model("mlxlm") @register_model("mlxlm")
@@ -65,33 +83,32 @@ class MLXLM(LM):
self._batch_size = batch_size self._batch_size = batch_size
self._model, self.tokenizer = load(path_or_hf_repo) self._model, self.tokenizer = load(path_or_hf_repo)
self._max_tokens = max_tokens or self.tokenizer.model_max_length self._max_tokens = max_tokens or self.tokenizer.model_max_length
self.use_chat_template = use_chat_template and ( self.use_chat_template = use_chat_template or (
self.tokenizer.chat_template is not None self.tokenizer.chat_template is not None
) )
def _score_fn(self, inputs, step_size: int = 64): def _score_fn(self, inputs, tokenize=True, step_size=32):
inputs, lengths = _pad_inputs(inputs) if tokenize:
inputs = self._tokenize(inputs)
inputs = _pad_inputs(inputs, self._max_tokens, truncate=False)
inputs = mx.array(inputs)
inputs, targets = inputs[..., :-1], inputs[..., 1:] inputs, targets = inputs[..., :-1], inputs[..., 1:]
cache = make_prompt_cache(self._model) cache = make_prompt_cache(self._model)
mask = targets != PAD
scores, is_greedy = [], [] scores, is_greedy = [], []
for i in range(0, inputs.shape[1], step_size): for i in range(0, inputs.shape[1], step_size):
inp = inputs[:, i : i + step_size] logits = self._model(inputs[:, i : i + step_size], cache=cache)
T = inp.shape[1]
offset = cache[0].offset
mask = create_causal_mask(T, offset, lengths=lengths)
mask = mask == 0
logits = self._model(inp, cache=cache, mask=mask)
log_probs = nn.log_softmax(logits.astype(mx.float32)) log_probs = nn.log_softmax(logits.astype(mx.float32))
score = mx.take_along_axis( score = mx.take_along_axis(
log_probs, targets[:, i : i + step_size, mx.newaxis], axis=-1 log_probs, targets[:, i : i + step_size, mx.newaxis], axis=-1
)[..., 0] )[..., 0]
ig = targets[:, i : i + step_size] == mx.argmax(logits, axis=-1) ig = mask[:, i : i + step_size] * (
ig = mx.where(mx.arange(T) + offset < lengths[:, None], ig, False) targets[:, i : i + step_size] == mx.argmax(logits, axis=-1)
)
mx.eval(score, ig) mx.eval(score, ig)
mx.metal.clear_cache() mx.metal.clear_cache()
@@ -102,32 +119,38 @@ class MLXLM(LM):
scores = mx.concatenate(scores, axis=1) scores = mx.concatenate(scores, axis=1)
is_greedy = mx.concatenate(is_greedy, axis=1) is_greedy = mx.concatenate(is_greedy, axis=1)
return scores, lengths, is_greedy return scores, mask.sum(axis=-1), is_greedy
def _loglikelihood(self, texts, score_spans=None): def _loglikelihood(self, texts, score_spans=None, tokenize=True):
all_scores = mx.zeros(len(texts)) # sort by length to get batches with little padding.
all_is_greedy = mx.zeros(len(texts), dtype=mx.bool_) sorted_indices = sorted(range(len(texts)), key=lambda i: -len(texts[i]))
for i in tqdm(range(0, len(texts), self._batch_size)): sorted_inputs = [texts[sorted_indices[i]] for i in range(len(texts))]
batch = texts[i : i + self._batch_size] sorted_spans = None
scores, lengths, is_greedy = self._score_fn(batch)
ind = np.arange(scores.shape[-1])
if score_spans is not None: if score_spans is not None:
spans = score_spans[i : i + self._batch_size] sorted_spans = [score_spans[sorted_indices[i]] for i in range(len(texts))]
lengths = [end - start for start, end in spans]
masks = mx.array(
np.array([(ind >= start) & (ind < end) for start, end in spans])
)
else:
masks = ind[None] < lengths[:, None]
scores = (masks * scores).sum(axis=-1) results = []
is_greedy = (masks * is_greedy).sum(axis=-1) for i in tqdm(range(0, len(sorted_inputs), self._batch_size)):
batch = sorted_inputs[i : i + self._batch_size]
scores, length, is_greedy = self._score_fn(batch, tokenize=tokenize)
for j in range(len(batch)):
if sorted_spans is None: # full sequence score
mask = mx.arange(scores[j].shape[-1]) < length
score = (scores[j].astype(mx.float32) * mask).sum(axis=-1)
ig = (is_greedy[j].astype(mx.int32) * mask).sum(axis=-1)
else: # subsequence score
start, end = sorted_spans[i + j]
score = scores[j][start:end].astype(mx.float32).sum()
ig = is_greedy[j][start:end].astype(mx.int32).sum()
length = end - start
all_scores[i : i + self._batch_size] = scores results.append((score.item(), ig.item(), length))
all_is_greedy[i : i + self._batch_size] = is_greedy == lengths
return all_scores, all_is_greedy # reorder the outputs
inv_sort = np.argsort(sorted_indices)
results = [results[inv_sort[i]] for i in range(len(results))]
return results
def _tokenize(self, texts): def _tokenize(self, texts):
return [ return [
@@ -199,53 +222,16 @@ class MLXLM(LM):
+ "completion longer than context." + "completion longer than context."
) )
num_results = len(shortened)
# sort by length to get batches with little padding.
sorted_indices = sorted(range(len(shortened)), key=lambda i: -len(shortened[i]))
shortened = [shortened[i] for i in sorted_indices]
completion_spans = [completion_spans[i] for i in sorted_indices]
group = mx.distributed.init()
# split strided so we have approximately the same lengths on each node
shortened = shortened[group.rank() :: group.size()]
completion_spans = completion_spans[group.rank() :: group.size()]
# model scoring, returns num_requests x (logp, is_greedy, length). # model scoring, returns num_requests x (logp, is_greedy, length).
scores, is_greedy = self._loglikelihood( results = self._loglikelihood(
shortened, shortened,
score_spans=completion_spans, score_spans=completion_spans,
tokenize=False,
) )
return [(r[0], r[1] == r[2]) for r in results]
# all gather the results across groups
if group.size() > 1:
per_group = int(np.ceil(num_results / group.size()))
scores = mx.pad(scores, ((0, per_group - len(scores)),))
is_greedy = mx.pad(is_greedy, ((0, per_group - len(is_greedy))))
scores = mx.distributed.all_gather(scores[mx.newaxis], stream=mx.cpu)
is_greedy = mx.distributed.all_gather(is_greedy[mx.newaxis], stream=mx.cpu)
scores = scores.T.reshape(-1)
is_greedy = is_greedy.T.reshape(-1)
scores = np.array(scores[:num_results])
is_greedy = np.array(is_greedy[:num_results])
results = [(score, ig) for score, ig in zip(scores, is_greedy)]
inv_sort = np.argsort(sorted_indices)
results = [results[inv_sort[i]] for i in range(len(inv_sort))]
return results
tokenizer_name = lm_eval.models.huggingface.HFLM.tokenizer_name tokenizer_name = lm_eval.models.huggingface.HFLM.tokenizer_name
apply_chat_template = lm_eval.models.huggingface.HFLM.apply_chat_template
def apply_chat_template(
self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
) -> str:
if len(chat_history) == 0:
return ""
return lm_eval.models.huggingface.HFLM.apply_chat_template(
chat_history, add_generation_prompt
)
def loglikelihood_rolling(self, requests) -> list[float]: def loglikelihood_rolling(self, requests) -> list[float]:
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation """Compute full log-likelihood of a string, with no truncation, for perplexity computation
@@ -282,9 +268,8 @@ class MLXLM(LM):
logging.info( logging.info(
"Estimating loglikelihood rolling for %d sequences." % len(requests) "Estimating loglikelihood rolling for %d sequences." % len(requests)
) )
inputs = self._tokenize([req.args[0] for req in requests]) inputs = [req.args[0] for req in requests]
scores, _ = self._loglikelihood(inputs) return [t[0] for t in self._loglikelihood(inputs)]
return scores.tolist()
def generate_until(self, requests) -> list[str]: def generate_until(self, requests) -> list[str]:
"""Generate greedily until a stopping sequence """Generate greedily until a stopping sequence
@@ -347,7 +332,7 @@ def main():
) )
parser.add_argument( parser.add_argument(
"--limit", "--limit",
default=None, default=1.0,
help="Limit the number of examples per task.", help="Limit the number of examples per task.",
type=float, type=float,
) )
@@ -361,8 +346,11 @@ def main():
) )
parser.add_argument( parser.add_argument(
"--apply-chat-template", "--apply-chat-template",
action="store_true", action=argparse.BooleanOptionalAction,
help="Specifies whether to apply a chat template to the prompt.", help="Specifies whether to apply a chat template to the prompt. If "
"the model has a chat template, this defaults to `True`, "
"otherwise `False`.",
default=None,
) )
args = parser.parse_args() args = parser.parse_args()

View File

@@ -22,11 +22,6 @@ import mlx.core as mx
from mlx_lm import load, stream_generate from mlx_lm import load, stream_generate
parser = argparse.ArgumentParser(description="LLM pipelined inference example") parser = argparse.ArgumentParser(description="LLM pipelined inference example")
parser.add_argument(
"--model",
default="mlx-community/DeepSeek-R1-3bit",
help="HF repo or path to local model.",
)
parser.add_argument( parser.add_argument(
"--prompt", "--prompt",
"-p", "-p",
@@ -42,7 +37,9 @@ parser.add_argument(
) )
args = parser.parse_args() args = parser.parse_args()
model, tokenizer = load(args.model, lazy=True) model_repo = "mlx-community/DeepSeek-V3-3bit"
model, tokenizer = load(model_repo, lazy=True)
messages = [{"role": "user", "content": args.prompt}] messages = [{"role": "user", "content": args.prompt}]
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)

View File

@@ -78,7 +78,6 @@ def build_parser():
"--train", "--train",
action="store_true", action="store_true",
help="Do training", help="Do training",
default=None,
) )
parser.add_argument( parser.add_argument(
"--data", "--data",
@@ -136,7 +135,6 @@ def build_parser():
"--test", "--test",
action="store_true", action="store_true",
help="Evaluate on the test set after training", help="Evaluate on the test set after training",
default=None,
) )
parser.add_argument( parser.add_argument(
"--test-batches", "--test-batches",
@@ -158,7 +156,6 @@ def build_parser():
"--grad-checkpoint", "--grad-checkpoint",
action="store_true", action="store_true",
help="Use gradient checkpointing to reduce memory use.", help="Use gradient checkpointing to reduce memory use.",
default=None,
) )
parser.add_argument("--seed", type=int, help="The PRNG seed") parser.add_argument("--seed", type=int, help="The PRNG seed")
return parser return parser

View File

@@ -400,8 +400,6 @@ class DeepseekV3Model(nn.Module):
pipeline_rank = self.pipeline_rank pipeline_rank = self.pipeline_rank
pipeline_size = self.pipeline_size pipeline_size = self.pipeline_size
# Hack to avoid time-outs during prompt-processing
dist_stream = mx.cpu if h.shape[1] > 1 else mx.gpu
if mask is None: if mask is None:
mask = create_attention_mask(h, cache) mask = create_attention_mask(h, cache)
@@ -409,21 +407,18 @@ class DeepseekV3Model(nn.Module):
cache = [None] * len(self.layers) cache = [None] * len(self.layers)
# Receive from the previous process in the pipeline # Receive from the previous process in the pipeline
if pipeline_rank < pipeline_size - 1: if pipeline_rank < pipeline_size - 1:
h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream) h = mx.distributed.recv_like(h, (pipeline_rank + 1))
for layer, c in zip(self.layers, cache): for layer, c in zip(self.layers, cache):
h = layer(h, mask, c) h = layer(h, mask, c)
# Send to the next process in the pipeline # Send to the next process in the pipeline
if pipeline_rank != 0: if pipeline_rank != 0:
h = mx.distributed.send( h = mx.distributed.send(h, (pipeline_rank - 1) % pipeline_size)
h, (pipeline_rank - 1) % pipeline_size, stream=dist_stream
)
# Broadcast h while keeping it in the graph # Broadcast h while keeping it in the graph
h = mx.distributed.all_gather(h, stream=dist_stream)[: h.shape[0]] h = mx.distributed.all_gather(h)[: h.shape[0]]
return self.norm(h) return self.norm(h)

View File

@@ -1,241 +0,0 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
bias: bool = False
qkv_bias: bool = False
max_position_embeddings: int = 32768
num_key_value_heads: int = None
rope_theta: float = 10000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = False
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
if self.rope_scaling:
required_keys = {"factor", "rope_type"}
if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(f"rope_scaling must contain keys {required_keys}")
if self.rope_scaling["rope_type"] not in ["linear", "dynamic"]:
raise ValueError(
"rope_scaling 'rope_type' currently only supports 'linear' or 'dynamic"
)
class DynamicNTKScalingRoPE(nn.Module):
"""Implements the rotary positional encoding with Dynamic NTK scaling."""
def __init__(
self,
dims: int,
max_position_embeddings: int = 2048,
traditional: bool = False,
base: float = 10000,
scale: float = 1.0,
):
super().__init__()
self.max_position_embeddings = max_position_embeddings
self.original_base = base
self.dims = dims
self.traditional = traditional
self.scale = scale
def extra_repr(self):
return f"{self.dims}, traditional={self.traditional}, max_position_embeddings={self.max_position_embeddings}, scaling_factor={self.scaling_factor}"
def __call__(self, x, offset: int = 0):
seq_len = x.shape[1] + offset
if seq_len > self.max_position_embeddings:
base = self.original_base * (
(self.scale * seq_len / self.max_position_embeddings) - (self.scale - 1)
) ** (self.dims / (self.dims - 2))
else:
base = self.original_base
return mx.fast.rope(
x,
self.dims,
traditional=self.traditional,
base=base,
scale=self.scale,
offset=offset,
)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
qkv_bias = args.qkv_bias
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.n_kv_groups = n_heads // args.num_key_value_heads
self.head_dim = head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=qkv_bias)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=qkv_bias)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=qkv_bias)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=qkv_bias)
rope_scale = (
1 / args.rope_scaling["factor"]
if args.rope_scaling is not None
and args.rope_scaling["rope_type"] == "linear"
else 2.0
)
self.rope = DynamicNTKScalingRoPE(
head_dim,
max_position_embeddings=args.max_position_embeddings,
traditional=args.rope_traditional,
base=args.rope_theta,
scale=rope_scale,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim, bias):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=bias)
self.down_proj = nn.Linear(hidden_dim, dim, bias=bias)
self.up_proj = nn.Linear(dim, hidden_dim, bias=bias)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size, args.bias)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out
class InternLM2Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
assert args.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, cache=c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = InternLM2Model(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out
def sanitize(self, weights):
# Remove unused precomputed rotary freqs
return {k: v for k, v in weights.items() if "attention.rope.inv_freq" not in k}
@property
def layers(self):
return self.model.layers

View File

@@ -170,7 +170,7 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
if prompt_feature and completion_feature: if prompt_feature and completion_feature:
return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature) return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
elif text_feature: elif text_feature:
return Dataset(ds, tokenizer, text_key=text_feature) return Dataset(train_ds, tokenizer, text_key=text_feature)
else: else:
raise ValueError( raise ValueError(
"Specify either a prompt and completion feature or a text " "Specify either a prompt and completion feature or a text "

View File

@@ -159,8 +159,8 @@ def evaluate(
ntokens += toks ntokens += toks
mx.eval(all_losses, ntokens) mx.eval(all_losses, ntokens)
all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu) all_losses = mx.distributed.all_sum(all_losses)
ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu) ntokens = mx.distributed.all_sum(ntokens)
return (all_losses / ntokens).item() return (all_losses / ntokens).item()
@@ -272,9 +272,9 @@ def train(
if it % args.steps_per_report == 0 or it == args.iters: if it % args.steps_per_report == 0 or it == args.iters:
stop = time.perf_counter() stop = time.perf_counter()
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item() train_loss = mx.distributed.all_sum(losses).item()
train_loss /= steps * mx.distributed.init().size() train_loss /= steps * mx.distributed.init().size()
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item() n_tokens = mx.distributed.all_sum(n_tokens).item()
learning_rate = optimizer.learning_rate.item() learning_rate = optimizer.learning_rate.item()
it_sec = args.steps_per_report / (stop - start) it_sec = args.steps_per_report / (stop - start)
tokens_sec = float(n_tokens) / (stop - start) tokens_sec = float(n_tokens) / (stop - start)

View File

@@ -100,7 +100,6 @@ def linear_to_lora_layers(
"minicpm", "minicpm",
"deepseek", "deepseek",
"olmo2", "olmo2",
"internlm3",
]: ]:
keys = set(["self_attn.q_proj", "self_attn.v_proj"]) keys = set(["self_attn.q_proj", "self_attn.v_proj"])
if model.model_type in ["mixtral", "phimoe"]: if model.model_type in ["mixtral", "phimoe"]:

View File

@@ -21,7 +21,7 @@ from mlx_lm.tuner.utils import build_schedule
@contextmanager @contextmanager
def swapped_with_identity(obj, func): def swapped_with_identity(obj, func):
old_func = getattr(obj, func) old_func = getattr(obj, func)
setattr(obj, func, lambda x, **kwargs: x) setattr(obj, func, lambda x: x)
yield yield
setattr(obj, func, old_func) setattr(obj, func, old_func)

View File

@@ -927,23 +927,6 @@ class TestModels(unittest.TestCase):
model, args.model_type, args.vocab_size, args.num_hidden_layers model, args.model_type, args.vocab_size, args.num_hidden_layers
) )
def test_internlm3(self):
from mlx_lm.models import internlm3
args = internlm3.ModelArgs(
model_type="internlm3",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
rms_norm_eps=1e-5,
vocab_size=10_000,
)
model = internlm3.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()