mlx-examples/wwdc25/Get_started_with_MLX_for_Apple_silicon.ipynb
2025-06-10 10:23:25 -07:00

686 lines
103 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "3f964709",
"metadata": {},
"source": [
"# Get started with MLX for Apple silicon"
]
},
{
"cell_type": "markdown",
"id": "b11ab931",
"metadata": {},
"source": [
"### Basics operations"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "0e068977",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Result c: array([5, 7, 9], dtype=int32)\n",
"Shape: (3,)\n",
"Data type: mlx.core.int32\n"
]
}
],
"source": [
"import mlx.core as mx\n",
"\n",
"# Make an array\n",
"a = mx.array([1, 2, 3])\n",
"\n",
"# Make another array\n",
"b = mx.array([4, 5, 6])\n",
"\n",
"# Do an operation\n",
"c = a + b\n",
"\n",
"# Access information about the array\n",
"shape = c.shape\n",
"dtype = c.dtype\n",
"\n",
"print(f\"Result c: {c}\")\n",
"print(f\"Shape: {shape}\")\n",
"print(f\"Data type: {dtype}\")"
]
},
{
"cell_type": "markdown",
"id": "852c80f9",
"metadata": {},
"source": [
"### Unified Memory"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "2c344ed4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"c computed on the GPU: array([5, 7, 9], dtype=int32)\n",
"d computed on the CPU: array([4, 10, 18], dtype=int32)\n"
]
}
],
"source": [
"import mlx.core as mx\n",
"\n",
"a = mx.array([1, 2, 3])\n",
"b = mx.array([4, 5, 6])\n",
"\n",
"c = mx.add(a, b, stream=mx.gpu)\n",
"d = mx.multiply(a, b, stream=mx.cpu)\n",
"\n",
"print(f\"c computed on the GPU: {c}\")\n",
"print(f\"d computed on the CPU: {d}\")"
]
},
{
"cell_type": "markdown",
"id": "b1c809aa",
"metadata": {},
"source": [
"### Lazy computation"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "83cb860d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"array([5, 7, 9], dtype=int32)\n",
"Evaluate c by converting to list: [5, 7, 9]\n",
"Evaluate c using print: array([5, 7, 9], dtype=int32)\n",
"Evaluate c using mx.eval(): array([5, 7, 9], dtype=int32)\n"
]
}
],
"source": [
"import mlx.core as mx\n",
"\n",
"# Make an array\n",
"a = mx.array([1, 2, 3])\n",
"\n",
"# Make another array\n",
"b = mx.array([4, 5, 6])\n",
"\n",
"# Do an operation\n",
"c = a + b\n",
"\n",
"# Evaluates c before printing it\n",
"print(c)\n",
"\n",
"# Also evaluates c\n",
"c_list = c.tolist()\n",
"\n",
"# Also evaluates c\n",
"mx.eval(c)\n",
"\n",
"print(f\"Evaluate c by converting to list: {c_list}\")\n",
"print(f\"Evaluate c using print: {c}\")\n",
"print(f\"Evaluate c using mx.eval(): {c}\")"
]
},
{
"cell_type": "markdown",
"id": "dc742c8a",
"metadata": {},
"source": [
"### Function transformation"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "f3e5fde1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array(-0.841471, dtype=float32)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import mlx.core as mx\n",
"\n",
"def sin(x):\n",
" return mx.sin(x)\n",
"\n",
"dfdx = mx.grad(sin)\n",
"\n",
"def sin(x):\n",
" return mx.sin(x)\n",
"\n",
"d2fdx2 = mx.grad(mx.grad(mx.sin))\n",
"\n",
"# Computes the second derivative of sine at 1.0\n",
"d2fdx2(mx.array(1.0))"
]
},
{
"cell_type": "markdown",
"id": "c4850f6d",
"metadata": {},
"source": [
"#### Visualizing `sin`, `grad(sin)`, `grad(grad(sin))`\n",
"Plot should show `sin(x)`, `cos(x)` and `-sin(x)`"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "6cb0c908",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 1000x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"def plot_waves(x, fn1, fn2, fn3, labels):\n",
"\n",
" # Generate y values for sin(x) and dfdx(sin(x))\n",
" y_1 = fn1(x)\n",
" y_2 = fn2(x)\n",
" y_3 = fn3(x)\n",
"\n",
" # Create the plot\n",
" plt.figure(figsize=(10, 6))\n",
" # Note: x is already in degrees here for plotting\n",
" plt.plot(x, y_1, label=labels[0], marker='o', linestyle='-', markersize=5, markevery=20)\n",
" plt.plot(x, y_2, label=labels[1], marker='x', linestyle='--', markersize=5, markevery=20)\n",
" plt.plot(x, y_3, label=labels[2], marker='*', linestyle='-.', markersize=5, markevery=20)\n",
"\n",
" # Add labels\n",
" plt.xlabel('X-axis (Degrees)') # Changed label here\n",
" plt.ylabel('Y-axis')\n",
"\n",
" plt.legend()\n",
" plt.grid(True)\n",
" plt.show()\n",
"\n",
"\n",
"x = mx.linspace(0, 2 * mx.pi, 400)\n",
"\n",
"cos = mx.vmap(dfdx)\n",
"negative_sin = mx.vmap(d2fdx2)\n",
"\n",
"plot_waves(x, sin, cos, negative_sin, [\"sin(x)\",\"dfdx(sin(x))\", \"d2fdx2(sin(x))\"])"
]
},
{
"cell_type": "markdown",
"id": "bbc810b7",
"metadata": {},
"source": [
"### Neural Networks in MLX and Pytorch"
]
},
{
"cell_type": "markdown",
"id": "be4d8dd4",
"metadata": {},
"source": [
"MLX Neural Network"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "fd53cc03",
"metadata": {},
"outputs": [],
"source": [
"import mlx.core as mx\n",
"import mlx.nn as nn\n",
"import mlx.optimizers as optim\n",
"\n",
"class MLP(nn.Module):\n",
" \"\"\"A simple MLP.\"\"\"\n",
"\n",
" def __init__(self, dim, h_dim):\n",
" super().__init__()\n",
" self.linear1 = nn.Linear(dim, h_dim)\n",
" self.linear2 = nn.Linear(h_dim, dim)\n",
"\n",
" def __call__(self, x):\n",
" x = self.linear1(x)\n",
" x = nn.relu(x)\n",
" x = self.linear2(x)\n",
" return x"
]
},
{
"cell_type": "markdown",
"id": "ace491fb",
"metadata": {},
"source": [
"MLX Training loop"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "80f3d568",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Final Loss after 5 steps: 1.0052\n"
]
}
],
"source": [
"n_epochs = 5\n",
"input_dim, hidden_dim, num_samples = 10, 50, 1000\n",
"\n",
"model = MLP(input_dim, hidden_dim)\n",
"\n",
"def loss_fn(model, X, y):\n",
" return nn.losses.mse_loss(model(X), y)\n",
"\n",
"loss_and_grad_fn = nn.value_and_grad(model, loss_fn)\n",
"optimizer = optim.Adam(learning_rate=0.01)\n",
"\n",
"X_train = mx.random.normal([num_samples, input_dim])\n",
"y_train = mx.random.normal([num_samples, input_dim])\n",
"\n",
"for epoch in range(n_epochs):\n",
" loss, grads = loss_and_grad_fn(model, X_train, y_train)\n",
" model.update(optimizer.apply_gradients(grads, model))\n",
" mx.eval(model.parameters(), optimizer.state) \n",
"\n",
"print(f\"Final Loss after 5 steps: {loss.item():.4f}\")"
]
},
{
"cell_type": "markdown",
"id": "5fa466a7",
"metadata": {},
"source": [
"PyTorch Neural Network"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "1c39f647",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"\n",
"class MLP(nn.Module):\n",
" \"\"\"A simple MLP.\"\"\"\n",
"\n",
" def __init__(self, dim, h_dim):\n",
" super().__init__()\n",
" self.linear1 = nn.Linear(dim, h_dim)\n",
" self.linear2 = nn.Linear(h_dim, dim)\n",
"\n",
" def forward(self, x):\n",
" x = self.linear1(x)\n",
" x = x.relu()\n",
" x = self.linear2(x)\n",
" return x"
]
},
{
"cell_type": "markdown",
"id": "1e568017",
"metadata": {},
"source": [
"PyTorch Training loop"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "0d1b3dc5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Final Loss after 5 steps: 1.0028\n"
]
}
],
"source": [
"n_epochs = 5\n",
"input_dim, hidden_dim, num_samples = 10, 50, 1000\n",
"model = MLP(input_dim, hidden_dim)\n",
"criterion = nn.MSELoss()\n",
"optimizer = optim.Adam(model.parameters(), lr=0.01)\n",
"\n",
"X_train = torch.randn([num_samples, input_dim])\n",
"y_train = torch.randn([num_samples, input_dim])\n",
"\n",
"for epoch in range(n_epochs):\n",
" outputs = model(X_train)\n",
" loss = criterion(outputs, y_train)\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
"print(f\"Final Loss after 5 steps: {loss.item():.4f}\")"
]
},
{
"cell_type": "markdown",
"id": "b9ace438",
"metadata": {},
"source": [
"### Compiling MLX functions"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "e1a6d2f6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"gelu: array([-0.169571, -0.0094711, 0.120888, -0.122945], dtype=float32)\n",
"compiled gelu: array([-0.169571, -0.0094711, 0.120888, -0.122945], dtype=float32)\n"
]
}
],
"source": [
"import mlx.core as mx\n",
"import math\n",
"\n",
"def gelu(x):\n",
" return x * (1 + mx.erf(x / math.sqrt(2))) / 2\n",
"\n",
"@mx.compile\n",
"def compiled_gelu(x):\n",
" return x * (1 + mx.erf(x / math.sqrt(2))) / 2\n",
"\n",
"x = mx.random.normal(shape=(4,))\n",
"\n",
"out = gelu(x)\n",
"compiled_out = compiled_gelu(x)\n",
"print(f\"gelu: {out}\")\n",
"print(f\"compiled gelu: {compiled_out}\")"
]
},
{
"cell_type": "markdown",
"id": "3ead2025",
"metadata": {},
"source": [
"### MLX Fast package"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "efe967cf",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"RMS norm: array([-1.47364, 0.545241, -0.767421, 0.140266], dtype=float32)\n",
"Fast RMS norm: array([-1.47364, 0.545241, -0.767421, 0.140266], dtype=float32)\n"
]
}
],
"source": [
"import mlx.core as mx\n",
"\n",
"def rms_norm(x, weight, eps=1e-5):\n",
" y = x.astype(mx.float32)\n",
" y = y * mx.rsqrt(mx.mean(\n",
" mx.square(y),\n",
" axis=-1,\n",
" keepdims=True,\n",
" ) + eps)\n",
" return (weight * y).astype(x.dtype)\n",
"\n",
"feature_dim = 4\n",
"\n",
"x = mx.random.normal((feature_dim,))\n",
"weight = mx.random.normal((feature_dim,))\n",
"\n",
"y = rms_norm(x, weight, eps=1e-5)\n",
"y_fast = mx.fast.rms_norm(x, weight, eps=1e-5)\n",
"\n",
"print(f\"RMS norm: {y}\")\n",
"print(f\"Fast RMS norm: {y_fast}\")"
]
},
{
"cell_type": "markdown",
"id": "5c8d147a",
"metadata": {},
"source": [
"### Custom Metal kernels"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "9529b127",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"array([2.71828, 7.38906, 20.0855], dtype=float32)\n"
]
}
],
"source": [
"import mlx.core as mx\n",
"\n",
"# Build the kernel\n",
"source = \"\"\"\n",
" uint elem = thread_position_in_grid.x;\n",
" out[elem] = metal::exp(inp[elem]);\n",
"\"\"\"\n",
"kernel = mx.fast.metal_kernel(\n",
" name=\"myexp\",\n",
" input_names=[\"inp\"],\n",
" output_names=[\"out\"],\n",
" source=source,\n",
")\n",
"\n",
"# Call the kernel on a sample input\n",
"x = mx.array([1.0, 2.0, 3.0])\n",
"out = kernel(\n",
" inputs=[x],\n",
" grid=(x.size, 1, 1),\n",
" threadgroup=(256, 1, 1),\n",
" output_shapes=[x.shape],\n",
" output_dtypes=[x.dtype],\n",
")[0]\n",
"print(out)"
]
},
{
"cell_type": "markdown",
"id": "93d536ac",
"metadata": {},
"source": [
"### Quantization"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "f04fe5fc",
"metadata": {},
"outputs": [],
"source": [
"import mlx.core as mx\n",
"\n",
"x = mx.random.normal([1024])\n",
"weight = mx.random.normal([1024, 1024])\n",
"\n",
"quantized_weight, scales, biases = mx.quantize(\n",
" weight, bits=4, group_size=32,\n",
")\n",
"\n",
"y = mx.quantized_matmul(\n",
" x,\n",
" quantized_weight,\n",
" scales=scales,\n",
" biases=biases,\n",
" bits=4,\n",
" group_size=32,\n",
")\n",
"\n",
"w_orig = mx.dequantize(\n",
" quantized_weight,\n",
" scales=scales,\n",
" biases=biases,\n",
" bits=4,\n",
" group_size=32,\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "096d593a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sequential(\n",
" (layers.0): Embedding(100, 32)\n",
" (layers.1): Linear(input_dims=32, output_dims=32, bias=True)\n",
" (layers.2): Linear(input_dims=32, output_dims=32, bias=True)\n",
" (layers.3): Linear(input_dims=32, output_dims=1, bias=True)\n",
")\n",
"Sequential(\n",
" (layers.0): QuantizedEmbedding(100, 32, group_size=32, bits=4)\n",
" (layers.1): QuantizedLinear(input_dims=32, output_dims=32, bias=True, group_size=32, bits=4)\n",
" (layers.2): QuantizedLinear(input_dims=32, output_dims=32, bias=True, group_size=32, bits=4)\n",
" (layers.3): QuantizedLinear(input_dims=32, output_dims=1, bias=True, group_size=32, bits=4)\n",
")\n"
]
}
],
"source": [
"import mlx.nn as nn\n",
"\n",
"model = nn.Sequential(\n",
" nn.Embedding(100, 32),\n",
" nn.Linear(32, 32),\n",
" nn.Linear(32, 32),\n",
" nn.Linear(32, 1),\n",
")\n",
"\n",
"print(model)\n",
"\n",
"nn.quantize(\n",
" model,\n",
" bits=4,\n",
" group_size=32,\n",
")\n",
"\n",
"print(model)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "3b92700c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"array([1], dtype=float32)\n"
]
}
],
"source": [
"import mlx.core as mx\n",
"\n",
"group = mx.distributed.init()\n",
"\n",
"world_size = group.size()\n",
"rank = group.rank()\n",
"\n",
"x = mx.array([1.0])\n",
"\n",
"x_sum = mx.distributed.all_sum(x)\n",
"\n",
"print(x_sum)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "mlx",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}