Compare commits

..

2 Commits

Author SHA1 Message Date
Awni Hannun
dcb8319f3d update install docs and requirements (#2419) 2025-07-25 12:13:19 -07:00
Awni Hannun
5597fa089c Fix qvm splitk (#2415) 2025-07-25 11:50:24 -07:00
5 changed files with 57 additions and 13 deletions

View File

@@ -11,10 +11,10 @@ brought to you by Apple machine learning research.
Some key features of MLX include:
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
the Python API. MLX has higher-level packages like `mlx.nn` and
the Python API. MLX has higher-level packages like `mlx.nn` and
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
more complex models.
@@ -68,18 +68,23 @@ in the documentation.
## Installation
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run:
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install MLX on
macOS, run:
**With `pip`**:
```
```bash
pip install mlx
```
**With `conda`**:
To install the CUDA backend on Linux, run:
```bash
pip install "mlx[cuda]"
```
conda install -c conda-forge mlx
To install a CPU-only Linux package, run:
```bash
pip install "mlx[cpu]"
```
Checkout the

View File

@@ -13,7 +13,7 @@ silicon computer is
pip install mlx
To install from PyPI you must meet the following requirements:
To install from PyPI your system must meet the following requirements:
- Using an M series chip (Apple silicon)
- Using a native Python >= 3.9
@@ -26,13 +26,22 @@ To install from PyPI you must meet the following requirements:
CUDA
^^^^
MLX has a CUDA backend which you can use on any Linux platform with CUDA 12
and SM 7.0 (Volta) and up. To install MLX with CUDA support, run:
MLX has a CUDA backend which you can install with:
.. code-block:: shell
pip install "mlx[cuda]"
To install the CUDA package from PyPi your system must meet the following
requirements:
- Nvidia architecture >= SM 7.0 (Volta)
- Nvidia driver >= 550.54.14
- CUDA toolkit >= 12.0
- Linux distribution with glibc >= 2.35
- Python >= 3.9
CPU-only (Linux)
^^^^^^^^^^^^^^^^
@@ -42,6 +51,13 @@ For a CPU-only version of MLX that runs on Linux use:
pip install "mlx[cpu]"
To install the CPU-only package from PyPi your system must meet the following
requirements:
- Linux distribution with glibc >= 2.35
- Python >= 3.9
Troubleshooting
^^^^^^^^^^^^^^^

View File

@@ -265,9 +265,15 @@ void qvm_split_k(
MTL::Size group_dims = MTL::Size(bk, 2, 1);
MTL::Size grid_dims = MTL::Size(M, N / bn, B);
int x_batch_ndims = x.ndim() - 2;
auto x_shape = x.shape();
auto x_strides = x.strides();
if (x_shape.size() == 1) {
x_shape.insert(x_shape.begin(), 1);
x_strides.insert(x_strides.begin(), 0);
}
int x_ndim = x_shape.size();
int x_batch_ndims = x_ndim - 2;
int w_batch_ndims = w.ndim() - 2;
auto w_shape = w.shape();
auto w_strides = w.strides();
@@ -278,7 +284,7 @@ void qvm_split_k(
x_shape.insert(x_shape.end() - 2, split_k);
x_shape.back() /= split_k;
x_strides.insert(x_strides.end() - 2, split_D);
x_strides[x.ndim() - 1] = split_D;
x_strides[x_ndim - 1] = split_D;
x_batch_ndims += 1;
w_shape.insert(w_shape.end() - 2, split_k);
@@ -291,6 +297,9 @@ void qvm_split_k(
int final_block_size = K - (split_k - 1) * split_D;
auto temp_shape = out.shape();
if (temp_shape.size() == 1) {
temp_shape.insert(temp_shape.begin(), 1);
}
temp_shape.insert(temp_shape.end() - 2, split_k);
array intermediate(temp_shape, x.dtype(), nullptr, {});
intermediate.set_data(allocator::malloc(intermediate.nbytes()));

View File

@@ -2,6 +2,7 @@
auditwheel repair dist/* \
--plat manylinux_2_35_x86_64 \
--only-plat \
--exclude libmlx* \
-w wheel_tmp

View File

@@ -220,6 +220,19 @@ class TestQuantized(mlx_tests.MLXTestCase):
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 2e-3)
# Test with 1D vector
group_size = 32
bits = 8
N = 2048
x = 1e-1 * mx.random.normal(shape=(N,), key=k1)
w = 1e-1 * mx.random.normal(shape=(N, N), key=k2)
w_q, scales, biases = mx.quantize(w, group_size, bits)
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
y_q = mx.quantized_matmul(x, w_q, scales, biases, False, group_size, bits)
y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 2e-3)
def test_throw(self):
x = mx.random.normal(shape=(10, 512))
w = mx.random.normal(shape=(32, 512))