mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-21 20:46:50 +08:00
Add DeciLM/Nemotron-NAS architecture support for MLX
This commit introduces native MLX support for DeciLM models, including NVIDIA's Nemotron series that use Neural Architecture Search (NAS) optimizations. Key features: - Support for dummy layers (no-op attention/FFN components) - FFN fusion for improved efficiency - Variable Grouped Query Attention (VGQA) with different KV heads per layer - Block configuration handling for NAS architectures - Full conversion pipeline from HuggingFace to MLX format - Comprehensive test suite Tested with: - nvidia/Llama-3_1-Nemotron-Ultra-253B-v1 (Q5: 3.86 tokens/sec on M3 Ultra) - Memory usage: ~175GB peak for 253B model This enables running massive NAS-optimized models on Apple Silicon that were previously incompatible with MLX due to their unique architecture. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
977cd30242
commit
5adbd358b5
148
llms/decilm/README.md
Normal file
148
llms/decilm/README.md
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
# DeciLM / Nemotron-NAS Support for MLX
|
||||||
|
|
||||||
|
This module provides native MLX support for DeciLM architecture models, including NVIDIA's Nemotron series. DeciLM uses Neural Architecture Search (NAS) to create highly optimized transformer variants that achieve superior performance through architectural innovations.
|
||||||
|
|
||||||
|
## Architecture Features
|
||||||
|
|
||||||
|
DeciLM uses Neural Architecture Search (NAS) optimization with:
|
||||||
|
|
||||||
|
1. **Dummy Layers**: Layers where attention or FFN components are completely removed
|
||||||
|
2. **FFN Fusion**: Multiple sequential FFN layers fused into wider parallel layers
|
||||||
|
3. **Variable Grouped Query Attention (VGQA)**: Different number of KV heads per layer (1-8)
|
||||||
|
|
||||||
|
## Supported Models
|
||||||
|
|
||||||
|
- nvidia/Llama-3_1-Nemotron-Ultra-253B-v1
|
||||||
|
- nvidia/Llama-3_1-Nemotron-51B-Instruct
|
||||||
|
- Other DeciLM-based models
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### Converting Models
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python convert.py \
|
||||||
|
--hf-path nvidia/Llama-3_1-Nemotron-Ultra-253B-v1 \
|
||||||
|
--mlx-path ./nemotron-253b-mlx \
|
||||||
|
--quantize --q-bits 5
|
||||||
|
```
|
||||||
|
|
||||||
|
### Loading and Generation
|
||||||
|
|
||||||
|
```python
|
||||||
|
from mlx_lm import load, generate
|
||||||
|
from decilm import Model, DeciLMArgs
|
||||||
|
|
||||||
|
# Load pre-converted model
|
||||||
|
model, tokenizer = load("./nemotron-253b-mlx")
|
||||||
|
|
||||||
|
# Generate text
|
||||||
|
response = generate(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
prompt="Explain quantum computing in simple terms",
|
||||||
|
max_tokens=500,
|
||||||
|
temperature=0.7,
|
||||||
|
verbose=True
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Command Line Usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Using mlx_lm CLI
|
||||||
|
mlx_lm.generate \
|
||||||
|
--model ./nemotron-253b-mlx \
|
||||||
|
--prompt "Your prompt here" \
|
||||||
|
--max-tokens 1000 \
|
||||||
|
--temperature 0.8
|
||||||
|
|
||||||
|
# Start API server
|
||||||
|
mlx_lm.server \
|
||||||
|
--model ./nemotron-253b-mlx \
|
||||||
|
--host 0.0.0.0 \
|
||||||
|
--port 8080
|
||||||
|
```
|
||||||
|
|
||||||
|
## Implementation Details
|
||||||
|
|
||||||
|
The implementation handles:
|
||||||
|
- Block configurations with variable architectures
|
||||||
|
- Dummy layer passthrough (no computation)
|
||||||
|
- FFN fusion for improved efficiency
|
||||||
|
- Per-layer attention head configuration
|
||||||
|
|
||||||
|
## Performance
|
||||||
|
|
||||||
|
Tested on Mac Studio M3 Ultra (512GB RAM):
|
||||||
|
- Nemotron-253B Q5: ~3.86 tokens/sec generation
|
||||||
|
- Memory usage: ~175GB peak
|
||||||
|
|
||||||
|
## LM Studio Compatibility
|
||||||
|
|
||||||
|
⚠️ **Note**: DeciLM models are currently **NOT compatible with LM Studio** due to the NAS architecture with dummy layers. LM Studio expects standard transformer layers and encounters "NoneType object has no attribute 'shape'" errors with dummy components.
|
||||||
|
|
||||||
|
**Use `mlx_lm` CLI tools instead:**
|
||||||
|
```bash
|
||||||
|
# Generate text
|
||||||
|
uv run mlx_lm.generate \
|
||||||
|
--model /path/to/nemotron-mlx \
|
||||||
|
--prompt "Your prompt here" \
|
||||||
|
--max-tokens 1000
|
||||||
|
|
||||||
|
# Start server
|
||||||
|
uv run mlx_lm.server \
|
||||||
|
--model /path/to/nemotron-mlx \
|
||||||
|
--host 0.0.0.0 \
|
||||||
|
--port 8080
|
||||||
|
```
|
||||||
|
|
||||||
|
### Tokenizer Issues
|
||||||
|
|
||||||
|
If you encounter tokenizer issues, check the `USE-IF-MODEL-FAILED-TO-GENERATE` subfolder in the model directory for patched tokenizer configs and chat templates.
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
- **MLX**: >= 0.26.1
|
||||||
|
- **Python**: 3.11 - 3.12 (tested with CPython 3.12.11 via `uv`)
|
||||||
|
- **Memory**: Sufficient RAM for model size (e.g., ~175GB for Nemotron-253B)
|
||||||
|
- **mlx-lm**: Latest version for model inference
|
||||||
|
|
||||||
|
## Production Deployment
|
||||||
|
|
||||||
|
For production-grade API deployment, consider using [**lbrxServer**](https://github.com/LibraxisAI/lbrxServer):
|
||||||
|
- Robust API endpoints for various LLM architectures
|
||||||
|
- Native support for DeciLM/Nemotron models
|
||||||
|
- Built-in load balancing and request queuing
|
||||||
|
- Compatible with OpenAI API format
|
||||||
|
|
||||||
|
## Model Availability
|
||||||
|
|
||||||
|
Pre-converted DeciLM models for MLX:
|
||||||
|
- [LibraxisAI/Llama-3_1-Nemotron-Ultra-253B-v1-mlx-q5](https://huggingface.co/LibraxisAI/Llama-3_1-Nemotron-Ultra-253B-v1-mlx-q5) - 253B Q5 quantized
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
Run the test suite:
|
||||||
|
```bash
|
||||||
|
cd tests
|
||||||
|
python -m pytest test_decilm.py -v
|
||||||
|
```
|
||||||
|
|
||||||
|
For integration testing with a real model:
|
||||||
|
```bash
|
||||||
|
python test_generation.py --model-path /path/to/decilm-model
|
||||||
|
```
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
Contributions are welcome! Please:
|
||||||
|
1. Fork the repository
|
||||||
|
2. Create a feature branch
|
||||||
|
3. Add tests for new functionality
|
||||||
|
4. Submit a pull request
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
This module follows the same license as mlx-examples. Model weights are subject to their original licenses.
|
5
llms/decilm/__init__.py
Normal file
5
llms/decilm/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
"""DeciLM support for MLX."""
|
||||||
|
|
||||||
|
from .decilm import DeciLMArgs, Model
|
||||||
|
|
||||||
|
__all__ = ["DeciLMArgs", "Model"]
|
217
llms/decilm/convert.py
Normal file
217
llms/decilm/convert.py
Normal file
@ -0,0 +1,217 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Convert DeciLM/Nemotron models to MLX format.
|
||||||
|
Handles NAS architecture with dummy layers and variable configurations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
from mlx_lm.utils import load_model as load_hf_model
|
||||||
|
from mlx_lm.utils import save_model, get_model_path
|
||||||
|
|
||||||
|
|
||||||
|
def load_block_configs(config_path: Path) -> list:
|
||||||
|
"""Load block configurations from model config."""
|
||||||
|
with open(config_path, 'r') as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
block_configs = config.get("block_configs", [])
|
||||||
|
if not block_configs:
|
||||||
|
raise ValueError("No block_configs found in model config")
|
||||||
|
|
||||||
|
return block_configs
|
||||||
|
|
||||||
|
|
||||||
|
def convert_attention_weights(hf_weights: Dict, layer_idx: int, block_config: dict) -> Dict:
|
||||||
|
"""Convert attention layer weights, handling dummy layers."""
|
||||||
|
mlx_weights = {}
|
||||||
|
attn_config = block_config["attention"]
|
||||||
|
|
||||||
|
if attn_config.get("no_op", False):
|
||||||
|
# Dummy attention - no weights
|
||||||
|
return mlx_weights
|
||||||
|
|
||||||
|
# Standard attention weight conversion
|
||||||
|
prefix = f"model.layers.{layer_idx}.self_attn."
|
||||||
|
mlx_prefix = f"model.layers.{layer_idx}.self_attn."
|
||||||
|
|
||||||
|
# Convert projection weights
|
||||||
|
for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]:
|
||||||
|
if f"{prefix}{proj}.weight" in hf_weights:
|
||||||
|
mlx_weights[f"{mlx_prefix}{proj}.weight"] = hf_weights[f"{prefix}{proj}.weight"]
|
||||||
|
|
||||||
|
return mlx_weights
|
||||||
|
|
||||||
|
|
||||||
|
def convert_ffn_weights(hf_weights: Dict, layer_idx: int, block_config: dict) -> Dict:
|
||||||
|
"""Convert FFN layer weights, handling dummy layers."""
|
||||||
|
mlx_weights = {}
|
||||||
|
ffn_config = block_config["ffn"]
|
||||||
|
|
||||||
|
if ffn_config.get("no_op", False):
|
||||||
|
# Dummy FFN - no weights
|
||||||
|
return mlx_weights
|
||||||
|
|
||||||
|
# Standard FFN weight conversion
|
||||||
|
prefix = f"model.layers.{layer_idx}.mlp."
|
||||||
|
mlx_prefix = f"model.layers.{layer_idx}.mlp."
|
||||||
|
|
||||||
|
# Convert gate/up/down projections
|
||||||
|
for proj in ["gate_proj", "up_proj", "down_proj"]:
|
||||||
|
if f"{prefix}{proj}.weight" in hf_weights:
|
||||||
|
mlx_weights[f"{mlx_prefix}{proj}.weight"] = hf_weights[f"{prefix}{proj}.weight"]
|
||||||
|
|
||||||
|
return mlx_weights
|
||||||
|
|
||||||
|
|
||||||
|
def convert_weights(hf_weights: Dict, block_configs: list) -> Dict:
|
||||||
|
"""Convert all model weights from HF to MLX format."""
|
||||||
|
mlx_weights = {}
|
||||||
|
|
||||||
|
# Convert embeddings
|
||||||
|
if "model.embed_tokens.weight" in hf_weights:
|
||||||
|
mlx_weights["model.embed_tokens.weight"] = hf_weights["model.embed_tokens.weight"]
|
||||||
|
|
||||||
|
# Convert each layer based on its config
|
||||||
|
for i, block_config in enumerate(block_configs):
|
||||||
|
# Layer norms (always present)
|
||||||
|
for norm in ["input_layernorm", "post_attention_layernorm"]:
|
||||||
|
key = f"model.layers.{i}.{norm}.weight"
|
||||||
|
if key in hf_weights:
|
||||||
|
mlx_weights[key] = hf_weights[key]
|
||||||
|
|
||||||
|
# Attention weights
|
||||||
|
mlx_weights.update(convert_attention_weights(hf_weights, i, block_config))
|
||||||
|
|
||||||
|
# FFN weights
|
||||||
|
mlx_weights.update(convert_ffn_weights(hf_weights, i, block_config))
|
||||||
|
|
||||||
|
# Final norm and LM head
|
||||||
|
if "model.norm.weight" in hf_weights:
|
||||||
|
mlx_weights["model.norm.weight"] = hf_weights["model.norm.weight"]
|
||||||
|
if "lm_head.weight" in hf_weights:
|
||||||
|
mlx_weights["lm_head.weight"] = hf_weights["lm_head.weight"]
|
||||||
|
|
||||||
|
return mlx_weights
|
||||||
|
|
||||||
|
|
||||||
|
def save_config(config_path: Path, hf_config: Dict, block_configs: list):
|
||||||
|
"""Save MLX model configuration."""
|
||||||
|
mlx_config = {
|
||||||
|
"model_type": "decilm",
|
||||||
|
"hidden_size": hf_config["hidden_size"],
|
||||||
|
"num_hidden_layers": hf_config["num_hidden_layers"],
|
||||||
|
"intermediate_size": hf_config["intermediate_size"],
|
||||||
|
"num_attention_heads": hf_config["num_attention_heads"],
|
||||||
|
"num_key_value_heads": hf_config.get("num_key_value_heads", hf_config["num_attention_heads"]),
|
||||||
|
"vocab_size": hf_config["vocab_size"],
|
||||||
|
"rms_norm_eps": hf_config.get("rms_norm_eps", 1e-6),
|
||||||
|
"rope_theta": hf_config.get("rope_theta", 10000),
|
||||||
|
"rope_scaling": hf_config.get("rope_scaling"),
|
||||||
|
"block_configs": block_configs,
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(config_path, 'w') as f:
|
||||||
|
json.dump(mlx_config, f, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Convert DeciLM models to MLX")
|
||||||
|
parser.add_argument(
|
||||||
|
"--hf-path",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to HuggingFace model or repo ID",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--mlx-path",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Output path for MLX model",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--quantize",
|
||||||
|
action="store_true",
|
||||||
|
help="Quantize the model",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--q-bits",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="Number of bits for quantization",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--q-group-size",
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
help="Group size for quantization",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print(f"Loading model from {args.hf_path}")
|
||||||
|
model_path = get_model_path(args.hf_path)
|
||||||
|
|
||||||
|
# Load configurations
|
||||||
|
hf_config_path = model_path / "config.json"
|
||||||
|
with open(hf_config_path, 'r') as f:
|
||||||
|
hf_config = json.load(f)
|
||||||
|
|
||||||
|
block_configs = hf_config.get("block_configs", [])
|
||||||
|
if not block_configs:
|
||||||
|
raise ValueError("This doesn't appear to be a DeciLM model (no block_configs)")
|
||||||
|
|
||||||
|
print(f"Found {len(block_configs)} blocks with NAS configuration")
|
||||||
|
|
||||||
|
# Count dummy layers
|
||||||
|
dummy_attn = sum(1 for bc in block_configs if bc["attention"].get("no_op", False))
|
||||||
|
dummy_ffn = sum(1 for bc in block_configs if bc["ffn"].get("no_op", False))
|
||||||
|
print(f"Dummy layers: {dummy_attn} attention, {dummy_ffn} FFN")
|
||||||
|
|
||||||
|
# Load HF weights
|
||||||
|
print("Loading weights...")
|
||||||
|
model, _ = load_hf_model(args.hf_path)
|
||||||
|
hf_weights = dict(model.state_dict())
|
||||||
|
|
||||||
|
# Convert weights
|
||||||
|
print("Converting weights to MLX format...")
|
||||||
|
mlx_weights = convert_weights(hf_weights, block_configs)
|
||||||
|
|
||||||
|
# Quantize if requested
|
||||||
|
if args.quantize:
|
||||||
|
print(f"Quantizing to {args.q_bits} bits...")
|
||||||
|
mlx_weights = mx.quantize(
|
||||||
|
mlx_weights,
|
||||||
|
bits=args.q_bits,
|
||||||
|
group_size=args.q_group_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save model
|
||||||
|
output_path = Path(args.mlx_path)
|
||||||
|
output_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
print(f"Saving to {output_path}")
|
||||||
|
|
||||||
|
# Save weights
|
||||||
|
mx.save_safetensors(str(output_path / "model.safetensors"), mlx_weights)
|
||||||
|
|
||||||
|
# Save config
|
||||||
|
save_config(output_path / "config.json", hf_config, block_configs)
|
||||||
|
|
||||||
|
# Copy tokenizer files
|
||||||
|
for file in ["tokenizer.json", "tokenizer_config.json", "special_tokens_map.json"]:
|
||||||
|
src = model_path / file
|
||||||
|
if src.exists():
|
||||||
|
shutil.copy(src, output_path / file)
|
||||||
|
|
||||||
|
print("Conversion complete!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
223
llms/decilm/decilm.py
Normal file
223
llms/decilm/decilm.py
Normal file
@ -0,0 +1,223 @@
|
|||||||
|
"""
|
||||||
|
DeciLM model implementation for MLX.
|
||||||
|
Supports Neural Architecture Search (NAS) optimized models with:
|
||||||
|
- Dummy layers (no-op attention/FFN)
|
||||||
|
- Variable Grouped Query Attention
|
||||||
|
- FFN Fusion
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, Optional, Tuple, Any
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from mlx_lm.models.base import BaseModelArgs
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DeciLMArgs(BaseModelArgs):
|
||||||
|
"""Arguments for DeciLM model."""
|
||||||
|
model_type: str = "decilm"
|
||||||
|
hidden_size: int = 4096
|
||||||
|
num_hidden_layers: int = 32
|
||||||
|
intermediate_size: int = 11008
|
||||||
|
num_attention_heads: int = 32
|
||||||
|
num_key_value_heads: int = 8
|
||||||
|
rms_norm_eps: float = 1e-6
|
||||||
|
vocab_size: int = 32000
|
||||||
|
attention_bias: bool = False
|
||||||
|
rope_theta: float = 10000
|
||||||
|
rope_traditional: bool = False
|
||||||
|
rope_scaling: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
# DeciLM specific
|
||||||
|
block_configs: Optional[list] = None # Per-layer configurations
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.num_key_value_heads is None:
|
||||||
|
self.num_key_value_heads = self.num_attention_heads
|
||||||
|
|
||||||
|
|
||||||
|
class DummyAttention(nn.Module):
|
||||||
|
"""Dummy attention layer that passes input through unchanged."""
|
||||||
|
def __init__(self, args: DeciLMArgs):
|
||||||
|
super().__init__()
|
||||||
|
# No parameters - just pass through
|
||||||
|
|
||||||
|
def __call__(self, x, mask=None, cache=None):
|
||||||
|
# Return input unchanged
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DummyFFN(nn.Module):
|
||||||
|
"""Dummy FFN layer that passes input through unchanged."""
|
||||||
|
def __init__(self, args: DeciLMArgs):
|
||||||
|
super().__init__()
|
||||||
|
# No parameters - just pass through
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
# Return input unchanged
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class VariableAttention(nn.Module):
|
||||||
|
"""Attention with variable number of KV heads per layer."""
|
||||||
|
def __init__(self, args: DeciLMArgs, n_kv_heads: int):
|
||||||
|
super().__init__()
|
||||||
|
self.n_heads = args.num_attention_heads
|
||||||
|
self.n_kv_heads = n_kv_heads # Variable per layer
|
||||||
|
self.head_dim = args.hidden_size // args.num_attention_heads
|
||||||
|
|
||||||
|
self.scale = self.head_dim**-0.5
|
||||||
|
|
||||||
|
self.q_proj = nn.Linear(args.hidden_size, args.num_attention_heads * self.head_dim, bias=args.attention_bias)
|
||||||
|
self.k_proj = nn.Linear(args.hidden_size, self.n_kv_heads * self.head_dim, bias=args.attention_bias)
|
||||||
|
self.v_proj = nn.Linear(args.hidden_size, self.n_kv_heads * self.head_dim, bias=args.attention_bias)
|
||||||
|
self.o_proj = nn.Linear(args.num_attention_heads * self.head_dim, args.hidden_size, bias=args.attention_bias)
|
||||||
|
|
||||||
|
rope_scale = 1.0
|
||||||
|
if args.rope_scaling:
|
||||||
|
rope_scale = args.rope_scaling.get("factor", 1.0)
|
||||||
|
|
||||||
|
self.rope = nn.RoPE(self.head_dim, traditional=args.rope_traditional, base=args.rope_theta, scale=rope_scale)
|
||||||
|
|
||||||
|
def __call__(self, x, mask=None, cache=None):
|
||||||
|
B, L, _ = x.shape
|
||||||
|
|
||||||
|
queries = self.q_proj(x).reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
|
||||||
|
keys = self.k_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
|
||||||
|
values = self.v_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
|
# Apply RoPE
|
||||||
|
queries = self.rope(queries, offset=cache.offset if cache else 0)
|
||||||
|
keys = self.rope(keys, offset=cache.offset if cache else 0)
|
||||||
|
|
||||||
|
# Update cache if provided
|
||||||
|
if cache is not None:
|
||||||
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
|
|
||||||
|
# Repeat KV heads if needed
|
||||||
|
if self.n_kv_heads != self.n_heads:
|
||||||
|
n_rep = self.n_heads // self.n_kv_heads
|
||||||
|
keys = mx.repeat(keys, n_rep, axis=1)
|
||||||
|
values = mx.repeat(values, n_rep, axis=1)
|
||||||
|
|
||||||
|
# Compute attention
|
||||||
|
scores = (queries @ keys.transpose(0, 1, 3, 2)) * self.scale
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
scores = scores + mask
|
||||||
|
|
||||||
|
scores = mx.softmax(scores, axis=-1)
|
||||||
|
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
|
||||||
|
return self.o_proj(output)
|
||||||
|
|
||||||
|
|
||||||
|
class VariableFFN(nn.Module):
|
||||||
|
"""FFN with variable expansion ratio."""
|
||||||
|
def __init__(self, args: DeciLMArgs, ffn_mult: float):
|
||||||
|
super().__init__()
|
||||||
|
# Calculate intermediate size based on multiplier
|
||||||
|
intermediate_size = int(args.hidden_size * ffn_mult)
|
||||||
|
|
||||||
|
self.gate_proj = nn.Linear(args.hidden_size, intermediate_size, bias=False)
|
||||||
|
self.up_proj = nn.Linear(args.hidden_size, intermediate_size, bias=False)
|
||||||
|
self.down_proj = nn.Linear(intermediate_size, args.hidden_size, bias=False)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||||
|
|
||||||
|
|
||||||
|
class DeciLMBlock(nn.Module):
|
||||||
|
"""Transformer block with DeciLM variable architecture."""
|
||||||
|
def __init__(self, args: DeciLMArgs, block_config: dict):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.block_config = block_config
|
||||||
|
|
||||||
|
# Layer norms always present
|
||||||
|
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
|
|
||||||
|
# Attention layer (can be dummy)
|
||||||
|
attn_config = block_config["attention"]
|
||||||
|
if attn_config.get("no_op", False):
|
||||||
|
self.self_attn = DummyAttention(args)
|
||||||
|
else:
|
||||||
|
n_kv_heads = attn_config.get("n_heads_in_group", args.num_key_value_heads)
|
||||||
|
self.self_attn = VariableAttention(args, n_kv_heads)
|
||||||
|
|
||||||
|
# FFN layer (can be dummy)
|
||||||
|
ffn_config = block_config["ffn"]
|
||||||
|
if ffn_config.get("no_op", False):
|
||||||
|
self.mlp = DummyFFN(args)
|
||||||
|
else:
|
||||||
|
ffn_mult = ffn_config.get("ffn_mult", 2.5)
|
||||||
|
self.mlp = VariableFFN(args, ffn_mult)
|
||||||
|
|
||||||
|
def __call__(self, x, mask=None, cache=None):
|
||||||
|
# Self attention (may be dummy/no-op)
|
||||||
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
|
h = x + r
|
||||||
|
|
||||||
|
# FFN (may be dummy/no-op)
|
||||||
|
r = self.mlp(self.post_attention_layernorm(h))
|
||||||
|
out = h + r
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class DeciLMModel(nn.Module):
|
||||||
|
"""DeciLM model with NAS-optimized architecture."""
|
||||||
|
def __init__(self, args: DeciLMArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||||
|
|
||||||
|
# Build layers with per-layer configs
|
||||||
|
self.layers = []
|
||||||
|
for i, block_config in enumerate(args.block_configs):
|
||||||
|
self.layers.append(DeciLMBlock(args, block_config))
|
||||||
|
|
||||||
|
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
|
|
||||||
|
def __call__(self, inputs, cache=None):
|
||||||
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
if h.shape[1] > 1:
|
||||||
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||||
|
mask = mask.astype(h.dtype)
|
||||||
|
|
||||||
|
if cache is None:
|
||||||
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
|
for layer, c in zip(self.layers, cache):
|
||||||
|
h = layer(h, mask, c)
|
||||||
|
|
||||||
|
return self.norm(h)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
"""Full DeciLM model for generation."""
|
||||||
|
def __init__(self, args: DeciLMArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.model = DeciLMModel(args)
|
||||||
|
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||||
|
|
||||||
|
def __call__(self, inputs, cache=None):
|
||||||
|
out = self.model(inputs, cache)
|
||||||
|
return self.lm_head(out)
|
||||||
|
|
||||||
|
def sanitize(self, weights):
|
||||||
|
# Convert weights if needed
|
||||||
|
return weights
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.model.layers
|
118
llms/decilm/tests/test_decilm.py
Normal file
118
llms/decilm/tests/test_decilm.py
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
"""Tests for DeciLM implementation."""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
import sys
|
||||||
|
sys.path.append('..')
|
||||||
|
|
||||||
|
from decilm import DeciLMArgs, DummyAttention, DummyFFN, VariableAttention, DeciLMBlock
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeciLMComponents(unittest.TestCase):
|
||||||
|
"""Test DeciLM model components."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up test fixtures."""
|
||||||
|
self.args = DeciLMArgs(
|
||||||
|
hidden_size=4096,
|
||||||
|
num_attention_heads=32,
|
||||||
|
num_key_value_heads=8,
|
||||||
|
intermediate_size=11008,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_dummy_attention(self):
|
||||||
|
"""Test dummy attention passthrough."""
|
||||||
|
dummy_attn = DummyAttention(self.args)
|
||||||
|
|
||||||
|
# Test input
|
||||||
|
x = mx.random.normal((2, 10, 4096))
|
||||||
|
|
||||||
|
# Should return input unchanged
|
||||||
|
output = dummy_attn(x)
|
||||||
|
self.assertTrue(mx.array_equal(output, x))
|
||||||
|
|
||||||
|
def test_dummy_ffn(self):
|
||||||
|
"""Test dummy FFN passthrough."""
|
||||||
|
dummy_ffn = DummyFFN(self.args)
|
||||||
|
|
||||||
|
# Test input
|
||||||
|
x = mx.random.normal((2, 10, 4096))
|
||||||
|
|
||||||
|
# Should return input unchanged
|
||||||
|
output = dummy_ffn(x)
|
||||||
|
self.assertTrue(mx.array_equal(output, x))
|
||||||
|
|
||||||
|
def test_variable_attention(self):
|
||||||
|
"""Test variable attention with different KV heads."""
|
||||||
|
# Test with 4 KV heads (less than Q heads)
|
||||||
|
var_attn = VariableAttention(self.args, n_kv_heads=4)
|
||||||
|
|
||||||
|
x = mx.random.normal((2, 10, 4096))
|
||||||
|
output = var_attn(x)
|
||||||
|
|
||||||
|
# Output shape should match input
|
||||||
|
self.assertEqual(output.shape, (2, 10, 4096))
|
||||||
|
|
||||||
|
def test_decilm_block_dummy(self):
|
||||||
|
"""Test DeciLM block with dummy components."""
|
||||||
|
# Config with dummy attention and FFN
|
||||||
|
block_config = {
|
||||||
|
"attention": {"no_op": True},
|
||||||
|
"ffn": {"no_op": True}
|
||||||
|
}
|
||||||
|
|
||||||
|
block = DeciLMBlock(self.args, block_config)
|
||||||
|
x = mx.random.normal((2, 10, 4096))
|
||||||
|
|
||||||
|
output = block(x)
|
||||||
|
|
||||||
|
# With both dummy, output should be close to input
|
||||||
|
# (only layer norms applied)
|
||||||
|
self.assertEqual(output.shape, x.shape)
|
||||||
|
|
||||||
|
def test_decilm_block_mixed(self):
|
||||||
|
"""Test DeciLM block with mixed dummy/active components."""
|
||||||
|
# Config with active attention but dummy FFN
|
||||||
|
block_config = {
|
||||||
|
"attention": {"no_op": False, "n_heads_in_group": 8},
|
||||||
|
"ffn": {"no_op": True}
|
||||||
|
}
|
||||||
|
|
||||||
|
block = DeciLMBlock(self.args, block_config)
|
||||||
|
x = mx.random.normal((2, 10, 4096))
|
||||||
|
|
||||||
|
output = block(x)
|
||||||
|
self.assertEqual(output.shape, x.shape)
|
||||||
|
|
||||||
|
def test_block_config_variations(self):
|
||||||
|
"""Test various block configurations."""
|
||||||
|
configs = [
|
||||||
|
# Standard block
|
||||||
|
{
|
||||||
|
"attention": {"no_op": False, "n_heads_in_group": 8},
|
||||||
|
"ffn": {"no_op": False, "ffn_mult": 2.5}
|
||||||
|
},
|
||||||
|
# Variable FFN multiplier
|
||||||
|
{
|
||||||
|
"attention": {"no_op": False, "n_heads_in_group": 8},
|
||||||
|
"ffn": {"no_op": False, "ffn_mult": 1.5}
|
||||||
|
},
|
||||||
|
# Different KV heads
|
||||||
|
{
|
||||||
|
"attention": {"no_op": False, "n_heads_in_group": 4},
|
||||||
|
"ffn": {"no_op": False, "ffn_mult": 2.5}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
x = mx.random.normal((1, 5, 4096))
|
||||||
|
|
||||||
|
for config in configs:
|
||||||
|
block = DeciLMBlock(self.args, config)
|
||||||
|
output = block(x)
|
||||||
|
self.assertEqual(output.shape, x.shape)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in New Issue
Block a user