mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
4 Commits
9acec364c2
...
v0.27.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4ad53414dd | ||
|
|
d1165b215e | ||
|
|
dcb8319f3d | ||
|
|
5597fa089c |
21
README.md
21
README.md
@@ -11,10 +11,10 @@ brought to you by Apple machine learning research.
|
|||||||
|
|
||||||
Some key features of MLX include:
|
Some key features of MLX include:
|
||||||
|
|
||||||
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
|
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
|
||||||
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
|
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
|
||||||
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
|
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
|
||||||
the Python API. MLX has higher-level packages like `mlx.nn` and
|
the Python API. MLX has higher-level packages like `mlx.nn` and
|
||||||
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
|
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
|
||||||
more complex models.
|
more complex models.
|
||||||
|
|
||||||
@@ -68,18 +68,23 @@ in the documentation.
|
|||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run:
|
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install MLX on
|
||||||
|
macOS, run:
|
||||||
|
|
||||||
**With `pip`**:
|
```bash
|
||||||
|
|
||||||
```
|
|
||||||
pip install mlx
|
pip install mlx
|
||||||
```
|
```
|
||||||
|
|
||||||
**With `conda`**:
|
To install the CUDA backend on Linux, run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install "mlx[cuda]"
|
||||||
```
|
```
|
||||||
conda install -c conda-forge mlx
|
|
||||||
|
To install a CPU-only Linux package, run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install "mlx[cpu]"
|
||||||
```
|
```
|
||||||
|
|
||||||
Checkout the
|
Checkout the
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ silicon computer is
|
|||||||
|
|
||||||
pip install mlx
|
pip install mlx
|
||||||
|
|
||||||
To install from PyPI you must meet the following requirements:
|
To install from PyPI your system must meet the following requirements:
|
||||||
|
|
||||||
- Using an M series chip (Apple silicon)
|
- Using an M series chip (Apple silicon)
|
||||||
- Using a native Python >= 3.9
|
- Using a native Python >= 3.9
|
||||||
@@ -26,13 +26,22 @@ To install from PyPI you must meet the following requirements:
|
|||||||
CUDA
|
CUDA
|
||||||
^^^^
|
^^^^
|
||||||
|
|
||||||
MLX has a CUDA backend which you can use on any Linux platform with CUDA 12
|
MLX has a CUDA backend which you can install with:
|
||||||
and SM 7.0 (Volta) and up. To install MLX with CUDA support, run:
|
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
pip install "mlx[cuda]"
|
pip install "mlx[cuda]"
|
||||||
|
|
||||||
|
To install the CUDA package from PyPi your system must meet the following
|
||||||
|
requirements:
|
||||||
|
|
||||||
|
- Nvidia architecture >= SM 7.0 (Volta)
|
||||||
|
- Nvidia driver >= 550.54.14
|
||||||
|
- CUDA toolkit >= 12.0
|
||||||
|
- Linux distribution with glibc >= 2.35
|
||||||
|
- Python >= 3.9
|
||||||
|
|
||||||
|
|
||||||
CPU-only (Linux)
|
CPU-only (Linux)
|
||||||
^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
@@ -42,6 +51,13 @@ For a CPU-only version of MLX that runs on Linux use:
|
|||||||
|
|
||||||
pip install "mlx[cpu]"
|
pip install "mlx[cpu]"
|
||||||
|
|
||||||
|
To install the CPU-only package from PyPi your system must meet the following
|
||||||
|
requirements:
|
||||||
|
|
||||||
|
- Linux distribution with glibc >= 2.35
|
||||||
|
- Python >= 3.9
|
||||||
|
|
||||||
|
|
||||||
Troubleshooting
|
Troubleshooting
|
||||||
^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
|||||||
@@ -265,9 +265,15 @@ void qvm_split_k(
|
|||||||
MTL::Size group_dims = MTL::Size(bk, 2, 1);
|
MTL::Size group_dims = MTL::Size(bk, 2, 1);
|
||||||
MTL::Size grid_dims = MTL::Size(M, N / bn, B);
|
MTL::Size grid_dims = MTL::Size(M, N / bn, B);
|
||||||
|
|
||||||
int x_batch_ndims = x.ndim() - 2;
|
|
||||||
auto x_shape = x.shape();
|
auto x_shape = x.shape();
|
||||||
auto x_strides = x.strides();
|
auto x_strides = x.strides();
|
||||||
|
if (x_shape.size() == 1) {
|
||||||
|
x_shape.insert(x_shape.begin(), 1);
|
||||||
|
x_strides.insert(x_strides.begin(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
int x_ndim = x_shape.size();
|
||||||
|
int x_batch_ndims = x_ndim - 2;
|
||||||
int w_batch_ndims = w.ndim() - 2;
|
int w_batch_ndims = w.ndim() - 2;
|
||||||
auto w_shape = w.shape();
|
auto w_shape = w.shape();
|
||||||
auto w_strides = w.strides();
|
auto w_strides = w.strides();
|
||||||
@@ -278,7 +284,7 @@ void qvm_split_k(
|
|||||||
x_shape.insert(x_shape.end() - 2, split_k);
|
x_shape.insert(x_shape.end() - 2, split_k);
|
||||||
x_shape.back() /= split_k;
|
x_shape.back() /= split_k;
|
||||||
x_strides.insert(x_strides.end() - 2, split_D);
|
x_strides.insert(x_strides.end() - 2, split_D);
|
||||||
x_strides[x.ndim() - 1] = split_D;
|
x_strides[x_ndim - 1] = split_D;
|
||||||
x_batch_ndims += 1;
|
x_batch_ndims += 1;
|
||||||
|
|
||||||
w_shape.insert(w_shape.end() - 2, split_k);
|
w_shape.insert(w_shape.end() - 2, split_k);
|
||||||
@@ -291,6 +297,9 @@ void qvm_split_k(
|
|||||||
int final_block_size = K - (split_k - 1) * split_D;
|
int final_block_size = K - (split_k - 1) * split_D;
|
||||||
|
|
||||||
auto temp_shape = out.shape();
|
auto temp_shape = out.shape();
|
||||||
|
if (temp_shape.size() == 1) {
|
||||||
|
temp_shape.insert(temp_shape.begin(), 1);
|
||||||
|
}
|
||||||
temp_shape.insert(temp_shape.end() - 2, split_k);
|
temp_shape.insert(temp_shape.end() - 2, split_k);
|
||||||
array intermediate(temp_shape, x.dtype(), nullptr, {});
|
array intermediate(temp_shape, x.dtype(), nullptr, {});
|
||||||
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
|
||||||
|
|||||||
@@ -3,8 +3,8 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#define MLX_VERSION_MAJOR 0
|
#define MLX_VERSION_MAJOR 0
|
||||||
#define MLX_VERSION_MINOR 26
|
#define MLX_VERSION_MINOR 27
|
||||||
#define MLX_VERSION_PATCH 5
|
#define MLX_VERSION_PATCH 1
|
||||||
#define MLX_VERSION_NUMERIC \
|
#define MLX_VERSION_NUMERIC \
|
||||||
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
|
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ auditwheel repair dist/* \
|
|||||||
--exclude libcublas* \
|
--exclude libcublas* \
|
||||||
--exclude libnvrtc* \
|
--exclude libnvrtc* \
|
||||||
--exclude libcuda* \
|
--exclude libcuda* \
|
||||||
|
--exclude libcudnn* \
|
||||||
-w wheel_tmp
|
-w wheel_tmp
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
auditwheel repair dist/* \
|
auditwheel repair dist/* \
|
||||||
--plat manylinux_2_35_x86_64 \
|
--plat manylinux_2_35_x86_64 \
|
||||||
|
--only-plat \
|
||||||
--exclude libmlx* \
|
--exclude libmlx* \
|
||||||
-w wheel_tmp
|
-w wheel_tmp
|
||||||
|
|
||||||
|
|||||||
@@ -220,6 +220,19 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(y_q.shape, y_hat.shape)
|
self.assertEqual(y_q.shape, y_hat.shape)
|
||||||
self.assertLess((y_q - y_hat).abs().max(), 2e-3)
|
self.assertLess((y_q - y_hat).abs().max(), 2e-3)
|
||||||
|
|
||||||
|
# Test with 1D vector
|
||||||
|
group_size = 32
|
||||||
|
bits = 8
|
||||||
|
N = 2048
|
||||||
|
x = 1e-1 * mx.random.normal(shape=(N,), key=k1)
|
||||||
|
w = 1e-1 * mx.random.normal(shape=(N, N), key=k2)
|
||||||
|
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
||||||
|
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
|
||||||
|
y_q = mx.quantized_matmul(x, w_q, scales, biases, False, group_size, bits)
|
||||||
|
y_hat = x @ w_hat
|
||||||
|
self.assertEqual(y_q.shape, y_hat.shape)
|
||||||
|
self.assertLess((y_q - y_hat).abs().max(), 2e-3)
|
||||||
|
|
||||||
def test_throw(self):
|
def test_throw(self):
|
||||||
x = mx.random.normal(shape=(10, 512))
|
x = mx.random.normal(shape=(10, 512))
|
||||||
w = mx.random.normal(shape=(32, 512))
|
w = mx.random.normal(shape=(32, 512))
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -289,7 +289,7 @@ if __name__ == "__main__":
|
|||||||
install_requires += [
|
install_requires += [
|
||||||
"nvidia-cublas-cu12==12.9.*",
|
"nvidia-cublas-cu12==12.9.*",
|
||||||
"nvidia-cuda-nvrtc-cu12==12.9.*",
|
"nvidia-cuda-nvrtc-cu12==12.9.*",
|
||||||
"nvidia-cudnn-cu12==12.9.*",
|
"nvidia-cudnn-cu12==9.*",
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
name = "mlx-cpu"
|
name = "mlx-cpu"
|
||||||
|
|||||||
Reference in New Issue
Block a user