mlx-examples/wwdc25/Get_started_with_MLX_for_Apple_silicon.ipynb

686 lines
103 KiB
Plaintext
Raw Permalink Normal View History

2025-06-11 01:23:25 +08:00
{
"cells": [
{
"cell_type": "markdown",
"id": "3f964709",
"metadata": {},
"source": [
"# Get started with MLX for Apple silicon"
]
},
{
"cell_type": "markdown",
"id": "b11ab931",
"metadata": {},
"source": [
"### Basics operations"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "0e068977",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Result c: array([5, 7, 9], dtype=int32)\n",
"Shape: (3,)\n",
"Data type: mlx.core.int32\n"
]
}
],
"source": [
"import mlx.core as mx\n",
"\n",
"# Make an array\n",
"a = mx.array([1, 2, 3])\n",
"\n",
"# Make another array\n",
"b = mx.array([4, 5, 6])\n",
"\n",
"# Do an operation\n",
"c = a + b\n",
"\n",
"# Access information about the array\n",
"shape = c.shape\n",
"dtype = c.dtype\n",
"\n",
"print(f\"Result c: {c}\")\n",
"print(f\"Shape: {shape}\")\n",
"print(f\"Data type: {dtype}\")"
]
},
{
"cell_type": "markdown",
"id": "852c80f9",
"metadata": {},
"source": [
"### Unified Memory"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "2c344ed4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"c computed on the GPU: array([5, 7, 9], dtype=int32)\n",
"d computed on the CPU: array([4, 10, 18], dtype=int32)\n"
]
}
],
"source": [
"import mlx.core as mx\n",
"\n",
"a = mx.array([1, 2, 3])\n",
"b = mx.array([4, 5, 6])\n",
"\n",
"c = mx.add(a, b, stream=mx.gpu)\n",
"d = mx.multiply(a, b, stream=mx.cpu)\n",
"\n",
"print(f\"c computed on the GPU: {c}\")\n",
"print(f\"d computed on the CPU: {d}\")"
]
},
{
"cell_type": "markdown",
"id": "b1c809aa",
"metadata": {},
"source": [
"### Lazy computation"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "83cb860d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"array([5, 7, 9], dtype=int32)\n",
"Evaluate c by converting to list: [5, 7, 9]\n",
"Evaluate c using print: array([5, 7, 9], dtype=int32)\n",
"Evaluate c using mx.eval(): array([5, 7, 9], dtype=int32)\n"
]
}
],
"source": [
"import mlx.core as mx\n",
"\n",
"# Make an array\n",
"a = mx.array([1, 2, 3])\n",
"\n",
"# Make another array\n",
"b = mx.array([4, 5, 6])\n",
"\n",
"# Do an operation\n",
"c = a + b\n",
"\n",
"# Evaluates c before printing it\n",
"print(c)\n",
"\n",
"# Also evaluates c\n",
"c_list = c.tolist()\n",
"\n",
"# Also evaluates c\n",
"mx.eval(c)\n",
"\n",
"print(f\"Evaluate c by converting to list: {c_list}\")\n",
"print(f\"Evaluate c using print: {c}\")\n",
"print(f\"Evaluate c using mx.eval(): {c}\")"
]
},
{
"cell_type": "markdown",
"id": "dc742c8a",
"metadata": {},
"source": [
"### Function transformation"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "f3e5fde1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array(-0.841471, dtype=float32)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import mlx.core as mx\n",
"\n",
"def sin(x):\n",
" return mx.sin(x)\n",
"\n",
"dfdx = mx.grad(sin)\n",
"\n",
"def sin(x):\n",
" return mx.sin(x)\n",
"\n",
"d2fdx2 = mx.grad(mx.grad(mx.sin))\n",
"\n",
"# Computes the second derivative of sine at 1.0\n",
"d2fdx2(mx.array(1.0))"
]
},
{
"cell_type": "markdown",
"id": "c4850f6d",
"metadata": {},
"source": [
"#### Visualizing `sin`, `grad(sin)`, `grad(grad(sin))`\n",
"Plot should show `sin(x)`, `cos(x)` and `-sin(x)`"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "6cb0c908",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA2IAAAINCAYAAABcesypAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjMsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvZiW1igAAAAlwSFlzAAAPYQAAD2EBqD+naQABAABJREFUeJzs3Qd8U9UXB/Bfk+5JW2jLKLQUKHuVvfdGhiJDQJEhCCJ7CaKCIjJlIwqIKEP4g8jee+9daMvuHnTPpP/PuaGlhRYSbPvykvP1E/Pykryc8jLeeffec03S09PTwRhjjDHGGGOswCgK7qUYY4wxxhhjjBFOxBhjjDHGGGOsgHEixhhjjDHGGGMFjBMxxhhjjDHGGCtgnIgxxhhjjDHGWAHjRIwxxhhjjDHGChgnYowxxhhjjDFWwDgRY4wxxhhjjLECZlrQL2iI1Go1AgMDYWdnBxMTE6nDYYwxxhhjjEkkPT0dsbGxKFasGBSK3Nu9OBHLA5SEubu7Sx0GY4wxxhhjTE88efIEJUqUyPV+TsTyALWEZfxj29vbSxpLamoq9u/fjzZt2sDMzEzSWJjueP/JG+8/eeP9J2+8/+SP96G88f57KSYmRjTSZOQIueFELA9kdEekJEwfEjFra2sRh7F/COSI95+88f6TN95/8sb7T/54H8ob77/XvW3IEhfrYIwxxhhjjLECxokYY4wxxhhjjBUwTsQYY4wxxhhjrIDxGDHGGGNal+NNS0uDSqWCoY5vMDU1RVJSksH+jVJTKpXi35inemGMMU7EGGOMaSElJQVBQUFISEiAISeabm5uogIuJwr5hwbzFy1aFObm5lKHwhhjkuJEjDHG2FsnrX/w4IFozaDJKekA2hATFfo74+LiYGtr+8YJONm7J7qU0IeFhYn3U9myZfnfmTFm1DgRY4wx9kZ08ExJCs2JQq0Zhor+RvpbLS0tOUHIJ1ZWVqKs9aNHjzL/rRljzFjxLw1jjDGtcHLC8gK/jxhjTIO/DRljjDHGGGOsgHEixhhjjDHGGGMFjBMxxhhjRueTTz5B165ddX7eoUOHUKFCBa3L29++fRslSpRAfHz8O0TJGGPMkHGxDsYYYwXiQXg8Nl98gqdRiSjhaIUPa7nDs7CNJLH8/PPPooqfriZMmICpU6eKCpLaqFixIurVq4f58+dj2rRp7xApY4wxQ8WJGGOMsXxHCdikrddF2XtKgOh65TF/zH6/KnrUci/weBwcHHR+zsmTJ+Hv74/3339fp+cNGDAAgwcPxuTJk8VkxowxxpjsuiYeP34cnTt3FvPY0I/49u3b3/qco0ePombNmrCwsECZMmWwdu3a1x6zdOlSeHh4iDK6devWxfnz5/PpL2CMMcNAyVRCSppWl9tB0SIJU6cDKnV6tuuJW6/jTlC01tvStRVry5YtqFKliiib7uzsjFatWolugq92TWzWrBm+/PJLfP311yhcuLCY2Pmbb77Jtq2NGzeidevWmSXXKRbaXtu2bTPjioyMFF0RaTsZ6Dm0/tixY//xX50xxpghkdWpOfrxrFatGj799FN07979rY+nCSM7duyIoUOH4s8//xR9+wcNGoSiRYuKH06yadMmjBkzBitWrBBJ2MKFC8V9vr6+cHFxKYC/ijGmb93W2NslpqpQ8et9/3k7lIy1//mk1o+//V1bWJtr99MVFBSE3r1746effkK3bt0QGxuLEydO5JrMrVu3Dp9//jnOnDmDc+fOiWStYcOGIpEi9Nw+ffpkPp5OCP7+++8i0Vu0aJFI5Oj3pnjx4tkSMZoAu3r16uL5LVu21PpvZYwxZthklYi1b99eXLRFyZWnpyfmzZsnbtMAa+pasmDBgsxEjPrtU5cR6jqS8Zxdu3Zh9erVmDRpEmTjyCxAoQQajH657thPgFoFNJ8sZWSMya7bGjMMlIilpaWJE3elSpUS6yhpyk3VqlUxceJE2Nvbw9vbG0uWLBEn8DISMZqEmHpkZEVJ18qVK9G/f38EBwdj9+7duHLlymtdEOl59HzG9N2ZwDP48fyPmFRnEuoXqy91OIwZNFklYrqis5rUbSQrSsBGjRolllNSUnDp0iXRbz/rRJP0HHqurFASduR7KFKT4RadBsU/O4GbmwGfAUDUI8CmCGBuLXWUjAmUbMWnqBAZl4KbgdGie5popMhoqXhxPWHrdbjZW6JGKUfYWhj015XsWJkpReuUNuYfuIc1Jx9ClUNLlNLEBAMaeWBM63Jav662qAcFtUBR8kXf/W3atMEHH3wAR0fHHB//apJGvSdCQ0MzbycmJmZ2S8yqR48e2LZtG3788UcsX74cZcuWfT1uKyskJCRoHTtj+SlFlYKIxAhEJUchKS1JXALjA8V385b7WxAQHYCFlxeKx7lau6KcUzkoTGQ1moUZMlUqEB8OxIcBKXGAkxdg5wo5MugjGzo76eqafcfQ7ZiYGPGDGhUVJUoQ5/SYu3fv5rrd5ORkcclA2yOpqaniIokGo6FQqaA8/iPqZl1/aY3mQse25jZQV/sI6jY/vLw/5CZQuBygNC/wkNnrMt4/kr2P8lhyqgp3Q+JwKzAGAeHxeBSRgMeRCXj2PAnJaeq3Pp+O2/ut1ozZtDZXooitBTycreFZ2Bqli9ignIstKha1h5W59gfn+cnQ9l8G+nvoAE2tVotLBktT7Q7Metd2x+qTD3K8Lx3p6FPHXettURzajhOjltV9+/bh9OnTOHDgABYvXoyvvvpKnGjL2E7Wv8fMzCzzNTLW029ExjKNHYuIiMj2HEIJFp3Uo0qK9+7de+1+Qs/z8vLK8T5jQ/8G9G9M7yttq08a8+fvXanUKjyMeQjfKF88in0klukSmhCK6JTotz7/dsRtjDg8QiyvarUKPi4+Yvl6+HVcDr0Mb0dvVHCqgEIWhfIsZt6H8pZf+8/k2WUoLv4CRPjDJPoxTBIist2f1mkR0qu97DauD7T9NzDoRCy/zJo1C99+++1r6/fv3w9raylbnSqis4kSinQV6DAlzqIoTNXJME+LhTI9FSYp8fB/HIQ7u3eLR1ukPke7myOhMjFDjJU7Im3KIMyuEiJsyyNNaSXh38HooFGO4lMBvxgT3I82QUCsCYISaQyQSa6PNzNJB0yAVHFsmtPj0kVFITVMkJCiwqPIBHE5dv/lIxRIRzEboKRtOsrap6OcQzpsNcfTkpHr/ssNdbOj4hVxcXGiJ4GunM2B6e3L4Ns9fmIv0/dTxjWtdzJTZZ7Qyg/U0kUXGsNF3Q+p6Ab9SFK3xYzXpeWMv43GkmWso8dlPKZy5cq4du3aa7HSOGPy999/48MPPxSFP5o0aZLtMTdu3BBjlvPz75QL+nemk6FUgIv+jfOaoX3+tJWWngYF/fei5Wp34m6cTj6d6+OVUMLaxBrmJuYwgxlS0lMQmR6Z42NvnLuBEGWIWD6UeAhHko9k3lfIpBDcTd1R2rQ0PE094axwFidB/gtj3YeG4l32n0m6CoUSHqBI7C04x/nCv0hbhDpUE/e5xFxHff8t2R6fDhMkm9qJ49U7N+8i8Jnm2FZfaNsDwqATMTpwCAnRfHFkoNvU/5+6idCZOLrk9Bh6bm6oK2PGDy+hH1Z3d3fR7YW2LRXFibkiCVOZmEKZngbruh9D3Xgc1HR2l5pu40PhaWYNT7ui4vEmTy8g3a8QlEnP4ZgQIC5eYfuRbqJEenEfqOsOR3r5jpL9PcaIDvroC4zGpGScnddndFbbNyQOB+6E4tDdUNwOis3sXZjBycYMlYvZo6yLLUo5W6OkkzXcHa1Q2NZcdDObd8APv76h29qgRp4Y1tQT4XEpCI5JwoPwBFHYwz88HneCYhEam4yn8cDTeBOcDqFWEKBSUXs0LuuMthVdUbGo3X8+KDDU/aetpKQkPHnyBLa2tjl2zdNGv0b2aFyhGDZffJqlIEsJeDjnX0EWKrhx+PBhsT+o+BLdDg8PF4UzqOWKEsyM72xapqIaxM5O856hdbQfMx5DiRQV9Mj6PU9jiqkY1KlTp0SF3nHjxmH48OG4evVqZhfIhw8fivFqnTp1kvQ3Qp/eT/QbTMnqu76fjOnzp40
"text/plain": [
"<Figure size 1000x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"def plot_waves(x, fn1, fn2, fn3, labels):\n",
"\n",
" # Generate y values for sin(x) and dfdx(sin(x))\n",
" y_1 = fn1(x)\n",
" y_2 = fn2(x)\n",
" y_3 = fn3(x)\n",
"\n",
" # Create the plot\n",
" plt.figure(figsize=(10, 6))\n",
" # Note: x is already in degrees here for plotting\n",
" plt.plot(x, y_1, label=labels[0], marker='o', linestyle='-', markersize=5, markevery=20)\n",
" plt.plot(x, y_2, label=labels[1], marker='x', linestyle='--', markersize=5, markevery=20)\n",
" plt.plot(x, y_3, label=labels[2], marker='*', linestyle='-.', markersize=5, markevery=20)\n",
"\n",
" # Add labels\n",
" plt.xlabel('X-axis (Degrees)') # Changed label here\n",
" plt.ylabel('Y-axis')\n",
"\n",
" plt.legend()\n",
" plt.grid(True)\n",
" plt.show()\n",
"\n",
"\n",
"x = mx.linspace(0, 2 * mx.pi, 400)\n",
"\n",
"cos = mx.vmap(dfdx)\n",
"negative_sin = mx.vmap(d2fdx2)\n",
"\n",
"plot_waves(x, sin, cos, negative_sin, [\"sin(x)\",\"dfdx(sin(x))\", \"d2fdx2(sin(x))\"])"
]
},
{
"cell_type": "markdown",
"id": "bbc810b7",
"metadata": {},
"source": [
"### Neural Networks in MLX and Pytorch"
]
},
{
"cell_type": "markdown",
"id": "be4d8dd4",
"metadata": {},
"source": [
"MLX Neural Network"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "fd53cc03",
"metadata": {},
"outputs": [],
"source": [
"import mlx.core as mx\n",
"import mlx.nn as nn\n",
"import mlx.optimizers as optim\n",
"\n",
"class MLP(nn.Module):\n",
" \"\"\"A simple MLP.\"\"\"\n",
"\n",
" def __init__(self, dim, h_dim):\n",
" super().__init__()\n",
" self.linear1 = nn.Linear(dim, h_dim)\n",
" self.linear2 = nn.Linear(h_dim, dim)\n",
"\n",
" def __call__(self, x):\n",
" x = self.linear1(x)\n",
" x = nn.relu(x)\n",
" x = self.linear2(x)\n",
" return x"
]
},
{
"cell_type": "markdown",
"id": "ace491fb",
"metadata": {},
"source": [
"MLX Training loop"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "80f3d568",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Final Loss after 5 steps: 1.0052\n"
]
}
],
"source": [
"n_epochs = 5\n",
"input_dim, hidden_dim, num_samples = 10, 50, 1000\n",
"\n",
"model = MLP(input_dim, hidden_dim)\n",
"\n",
"def loss_fn(model, X, y):\n",
" return nn.losses.mse_loss(model(X), y)\n",
"\n",
"loss_and_grad_fn = nn.value_and_grad(model, loss_fn)\n",
"optimizer = optim.Adam(learning_rate=0.01)\n",
"\n",
"X_train = mx.random.normal([num_samples, input_dim])\n",
"y_train = mx.random.normal([num_samples, input_dim])\n",
"\n",
"for epoch in range(n_epochs):\n",
" loss, grads = loss_and_grad_fn(model, X_train, y_train)\n",
" model.update(optimizer.apply_gradients(grads, model))\n",
" mx.eval(model.parameters(), optimizer.state) \n",
"\n",
"print(f\"Final Loss after 5 steps: {loss.item():.4f}\")"
]
},
{
"cell_type": "markdown",
"id": "5fa466a7",
"metadata": {},
"source": [
"PyTorch Neural Network"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "1c39f647",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"\n",
"class MLP(nn.Module):\n",
" \"\"\"A simple MLP.\"\"\"\n",
"\n",
" def __init__(self, dim, h_dim):\n",
" super().__init__()\n",
" self.linear1 = nn.Linear(dim, h_dim)\n",
" self.linear2 = nn.Linear(h_dim, dim)\n",
"\n",
" def forward(self, x):\n",
" x = self.linear1(x)\n",
" x = x.relu()\n",
" x = self.linear2(x)\n",
" return x"
]
},
{
"cell_type": "markdown",
"id": "1e568017",
"metadata": {},
"source": [
"PyTorch Training loop"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "0d1b3dc5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Final Loss after 5 steps: 1.0028\n"
]
}
],
"source": [
"n_epochs = 5\n",
"input_dim, hidden_dim, num_samples = 10, 50, 1000\n",
"model = MLP(input_dim, hidden_dim)\n",
"criterion = nn.MSELoss()\n",
"optimizer = optim.Adam(model.parameters(), lr=0.01)\n",
"\n",
"X_train = torch.randn([num_samples, input_dim])\n",
"y_train = torch.randn([num_samples, input_dim])\n",
"\n",
"for epoch in range(n_epochs):\n",
" outputs = model(X_train)\n",
" loss = criterion(outputs, y_train)\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
"print(f\"Final Loss after 5 steps: {loss.item():.4f}\")"
]
},
{
"cell_type": "markdown",
"id": "b9ace438",
"metadata": {},
"source": [
"### Compiling MLX functions"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "e1a6d2f6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"gelu: array([-0.169571, -0.0094711, 0.120888, -0.122945], dtype=float32)\n",
"compiled gelu: array([-0.169571, -0.0094711, 0.120888, -0.122945], dtype=float32)\n"
]
}
],
"source": [
"import mlx.core as mx\n",
"import math\n",
"\n",
"def gelu(x):\n",
" return x * (1 + mx.erf(x / math.sqrt(2))) / 2\n",
"\n",
"@mx.compile\n",
"def compiled_gelu(x):\n",
" return x * (1 + mx.erf(x / math.sqrt(2))) / 2\n",
"\n",
"x = mx.random.normal(shape=(4,))\n",
"\n",
"out = gelu(x)\n",
"compiled_out = compiled_gelu(x)\n",
"print(f\"gelu: {out}\")\n",
"print(f\"compiled gelu: {compiled_out}\")"
]
},
{
"cell_type": "markdown",
"id": "3ead2025",
"metadata": {},
"source": [
"### MLX Fast package"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "efe967cf",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"RMS norm: array([-1.47364, 0.545241, -0.767421, 0.140266], dtype=float32)\n",
"Fast RMS norm: array([-1.47364, 0.545241, -0.767421, 0.140266], dtype=float32)\n"
]
}
],
"source": [
"import mlx.core as mx\n",
"\n",
"def rms_norm(x, weight, eps=1e-5):\n",
" y = x.astype(mx.float32)\n",
" y = y * mx.rsqrt(mx.mean(\n",
" mx.square(y),\n",
" axis=-1,\n",
" keepdims=True,\n",
" ) + eps)\n",
" return (weight * y).astype(x.dtype)\n",
"\n",
"feature_dim = 4\n",
"\n",
"x = mx.random.normal((feature_dim,))\n",
"weight = mx.random.normal((feature_dim,))\n",
"\n",
"y = rms_norm(x, weight, eps=1e-5)\n",
"y_fast = mx.fast.rms_norm(x, weight, eps=1e-5)\n",
"\n",
"print(f\"RMS norm: {y}\")\n",
"print(f\"Fast RMS norm: {y_fast}\")"
]
},
{
"cell_type": "markdown",
"id": "5c8d147a",
"metadata": {},
"source": [
"### Custom Metal kernels"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "9529b127",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"array([2.71828, 7.38906, 20.0855], dtype=float32)\n"
]
}
],
"source": [
"import mlx.core as mx\n",
"\n",
"# Build the kernel\n",
"source = \"\"\"\n",
" uint elem = thread_position_in_grid.x;\n",
" out[elem] = metal::exp(inp[elem]);\n",
"\"\"\"\n",
"kernel = mx.fast.metal_kernel(\n",
" name=\"myexp\",\n",
" input_names=[\"inp\"],\n",
" output_names=[\"out\"],\n",
" source=source,\n",
")\n",
"\n",
"# Call the kernel on a sample input\n",
"x = mx.array([1.0, 2.0, 3.0])\n",
"out = kernel(\n",
" inputs=[x],\n",
" grid=(x.size, 1, 1),\n",
" threadgroup=(256, 1, 1),\n",
" output_shapes=[x.shape],\n",
" output_dtypes=[x.dtype],\n",
")[0]\n",
"print(out)"
]
},
{
"cell_type": "markdown",
"id": "93d536ac",
"metadata": {},
"source": [
"### Quantization"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "f04fe5fc",
"metadata": {},
"outputs": [],
"source": [
"import mlx.core as mx\n",
"\n",
"x = mx.random.normal([1024])\n",
"weight = mx.random.normal([1024, 1024])\n",
"\n",
"quantized_weight, scales, biases = mx.quantize(\n",
" weight, bits=4, group_size=32,\n",
")\n",
"\n",
"y = mx.quantized_matmul(\n",
" x,\n",
" quantized_weight,\n",
" scales=scales,\n",
" biases=biases,\n",
" bits=4,\n",
" group_size=32,\n",
")\n",
"\n",
"w_orig = mx.dequantize(\n",
" quantized_weight,\n",
" scales=scales,\n",
" biases=biases,\n",
" bits=4,\n",
" group_size=32,\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "096d593a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sequential(\n",
" (layers.0): Embedding(100, 32)\n",
" (layers.1): Linear(input_dims=32, output_dims=32, bias=True)\n",
" (layers.2): Linear(input_dims=32, output_dims=32, bias=True)\n",
" (layers.3): Linear(input_dims=32, output_dims=1, bias=True)\n",
")\n",
"Sequential(\n",
" (layers.0): QuantizedEmbedding(100, 32, group_size=32, bits=4)\n",
" (layers.1): QuantizedLinear(input_dims=32, output_dims=32, bias=True, group_size=32, bits=4)\n",
" (layers.2): QuantizedLinear(input_dims=32, output_dims=32, bias=True, group_size=32, bits=4)\n",
" (layers.3): QuantizedLinear(input_dims=32, output_dims=1, bias=True, group_size=32, bits=4)\n",
")\n"
]
}
],
"source": [
"import mlx.nn as nn\n",
"\n",
"model = nn.Sequential(\n",
" nn.Embedding(100, 32),\n",
" nn.Linear(32, 32),\n",
" nn.Linear(32, 32),\n",
" nn.Linear(32, 1),\n",
")\n",
"\n",
"print(model)\n",
"\n",
"nn.quantize(\n",
" model,\n",
" bits=4,\n",
" group_size=32,\n",
")\n",
"\n",
"print(model)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "3b92700c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"array([1], dtype=float32)\n"
]
}
],
"source": [
"import mlx.core as mx\n",
"\n",
"group = mx.distributed.init()\n",
"\n",
"world_size = group.size()\n",
"rank = group.rank()\n",
"\n",
"x = mx.array([1.0])\n",
"\n",
"x_sum = mx.distributed.all_sum(x)\n",
"\n",
"print(x_sum)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "mlx",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}