Compare commits

..

8 Commits

Author SHA1 Message Date
dependabot[bot]
c2764d1073 Bump actions/download-artifact from 6 to 7 (#2912)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-12-15 06:10:16 -08:00
dependabot[bot]
093a62d2ed Bump actions/upload-artifact from 5 to 6 (#2911)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-12-15 06:09:55 -08:00
Awni Hannun
1b591ec736 No VJP for mask or sinks in attention (#2909)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-13 19:48:39 -08:00
Awni Hannun
47d2505ea9 Fix attention for large sizes (#2903)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-13 06:54:30 -08:00
Cheng
bedefed784 Fix ccache getting disabled (#2905)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-13 13:00:51 +09:00
Melissa Kilby
ccaaa7d6df fix: possible heap-buffer-overflow in RandomBits::eval_cpu (#2877)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-12 02:11:18 -08:00
Awni Hannun
f3e5ca5414 [CUDA] Add host nodes to subgraph types for graph update (#2901)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
2025-12-11 19:13:44 -08:00
Awni Hannun
81dfe5f137 Fix grad in place updates (#2899)
Some checks failed
Build and Test / Check Lint (push) Has been cancelled
Build and Test / Linux (cpu, aarch64) (push) Has been cancelled
Build and Test / Linux (cpu, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, aarch64) (push) Has been cancelled
Build and Test / Linux (cuda-12.6, x86_64) (push) Has been cancelled
Build and Test / Linux (cuda-12.9, x86_64) (push) Has been cancelled
Build and Test / macOS (14.0) (push) Has been cancelled
Build and Test / macOS (15.0) (push) Has been cancelled
Build and Test / Build Documentation (push) Has been cancelled
Build and Test / Linux Fedora (aarch64) (push) Has been cancelled
Build and Test / Linux Fedora (x86_64) (push) Has been cancelled
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14, ubuntu-22.04-arm) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
2025-12-11 14:44:58 -08:00
12 changed files with 100 additions and 48 deletions

View File

@@ -10,23 +10,29 @@ inputs:
description: 'Version of python to set up'
required: false
default: '3.10'
use-ccache:
description: 'Whether to enable ccache'
required: false
default: 'true'
runs:
using: "composite"
steps:
- name: Use ccache
if: ${{ runner.arch == 'x86_64' }}
uses: hendrikmuhs/ccache-action@v1.2
with:
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}-py${{ inputs.python-version }}
max-size: 1GB
- name: Install common dependencies
shell: bash
run: |
sudo apt-get update
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev zip
- name: Use ccache
if: ${{ inputs.use-ccache == 'true' }}
uses: hendrikmuhs/ccache-action@v1.2
with:
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ inputs.toolkit }}
max-size: 1GB
# ccache-action bug: running "apt-get update" fails on large arm runner.
update-package-index: false
- uses: actions/setup-python@v6
with:
python-version: ${{ inputs.python-version }}

View File

@@ -23,14 +23,14 @@ jobs:
build-backend: ${{ matrix.python-version == '3.10' }}
arch: "x86_64"
- name: Upload mlx artifacts
uses: actions/upload-artifact@v5
uses: actions/upload-artifact@v6
with:
name: linux-wheels-${{ matrix.python_version }}
path: wheelhouse/mlx-*.whl
retention-days: 7
- name: Upload mlx-cpu artifacts
if: matrix.python_version == '3.10'
uses: actions/upload-artifact@v5
uses: actions/upload-artifact@v6
with:
name: mlx-cpu
path: wheelhouse/mlx_cpu-*.whl
@@ -89,7 +89,7 @@ jobs:
with:
toolkit: 'cuda-12.9'
- name: Upload artifacts
uses: actions/upload-artifact@v5
uses: actions/upload-artifact@v6
with:
name: mlx-cuda
path: wheelhouse/mlx_cuda-*.whl

View File

@@ -57,19 +57,20 @@ jobs:
- uses: ./.github/actions/setup-linux
with:
python-version: ${{ matrix.python_version }}
use-ccache: false
- uses: ./.github/actions/build-linux-release
with:
build-backend: ${{ matrix.python-version == '3.10' }}
arch: ${{ matrix.arch }}
- name: Upload MLX artifacts
uses: actions/upload-artifact@v5
uses: actions/upload-artifact@v6
with:
overwrite: true
name: linux-wheels-${{ matrix.python_version }}-${{ matrix.arch }}
path: wheelhouse/mlx-*.whl
- name: Upload CPU artifacts
if: matrix.python_version == '3.10'
uses: actions/upload-artifact@v5
uses: actions/upload-artifact@v6
with:
overwrite: true
name: mlx-cpu-${{ matrix.arch }}
@@ -113,14 +114,14 @@ jobs:
macos-target: 15.0
build-backend: ${{ matrix.python-version == '3.10' }}
- name: Upload MLX artifacts
uses: actions/upload-artifact@v5
uses: actions/upload-artifact@v6
with:
overwrite: true
name: mac-wheels-${{ matrix.python-version }}
path: dist/mlx-*.whl
- name: Upload Metal artifacts
if: matrix.python-version == '3.10'
uses: actions/upload-artifact@v5
uses: actions/upload-artifact@v6
with:
overwrite: true
name: mlx-metal
@@ -141,12 +142,13 @@ jobs:
- uses: ./.github/actions/setup-linux
with:
toolkit: ${{ matrix.toolkit }}
use-ccache: false
- name: Build Python package
uses: ./.github/actions/build-cuda-release
with:
arch: ${{ matrix.arch }}
- name: Upload artifacts
uses: actions/upload-artifact@v5
uses: actions/upload-artifact@v6
with:
overwrite: true
name: mlx-cuda
@@ -162,12 +164,12 @@ jobs:
name: pypi
url: https://pypi.org/p/mlx
steps:
- uses: actions/download-artifact@v6
- uses: actions/download-artifact@v7
with:
pattern: linux-wheels-*
merge-multiple: true
path: dist
- uses: actions/download-artifact@v6
- uses: actions/download-artifact@v7
with:
pattern: mac-wheels-*
merge-multiple: true
@@ -189,7 +191,7 @@ jobs:
name: pypi
url: https://pypi.org/p/mlx-cuda
steps:
- uses: actions/download-artifact@v6
- uses: actions/download-artifact@v7
with:
name: mlx-cuda
path: dist
@@ -210,7 +212,7 @@ jobs:
name: pypi
url: https://pypi.org/p/mlx-cpu
steps:
- uses: actions/download-artifact@v6
- uses: actions/download-artifact@v7
with:
pattern: mlx-cpu-*
merge-multiple: true
@@ -232,7 +234,7 @@ jobs:
name: pypi
url: https://pypi.org/p/mlx-metal
steps:
- uses: actions/download-artifact@v6
- uses: actions/download-artifact@v7
with:
name: mlx-metal
path: dist

View File

@@ -291,6 +291,17 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
num_keys,
kshape = keys.shape(),
kstrides = keys.strides()]() mutable {
auto copy_remaining = [&](char* cptr, size_t loc, uint32_t v) {
if (4 * loc + 4 <= bytes_per_key) {
reinterpret_cast<uint32_t*>(cptr)[loc] = v;
} else {
std::copy(
reinterpret_cast<char*>(&v),
reinterpret_cast<char*>(&v) + bytes_per_key - 4 * loc,
cptr + 4 * loc);
}
};
size_t out_skip = (bytes_per_key + 4 - 1) / 4;
auto half_size = out_skip / 2;
bool even = out_skip % 2 == 0;
@@ -310,18 +321,12 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
if (count.first < half_size) {
auto rb = random::threefry2x32_hash(key, count);
ptr[count.first++] = rb.first;
if (bytes_per_key % 4 > 0) {
std::copy(
reinterpret_cast<char*>(&rb.second),
reinterpret_cast<char*>(&rb.second) + bytes_per_key % 4,
cptr + 4 * count.second);
} else {
ptr[count.second] = rb.second;
}
copy_remaining(cptr, count.second, rb.second);
}
if (!even) {
count.second = 0;
ptr[half_size] = random::threefry2x32_hash(key, count).first;
copy_remaining(
cptr, half_size, random::threefry2x32_hash(key, count).first);
}
}
});

View File

@@ -348,6 +348,9 @@ std::pair<std::string, bool> subgraph_to_key(cudaGraph_t graph) {
key += subkey;
break;
}
case cudaGraphNodeTypeHost:
key += "H";
break;
case cudaGraphNodeTypeMemset:
key += "M";
break;

View File

@@ -347,7 +347,7 @@ template <
MMAFrag_mask_t::load_safe(
mfrag,
mask,
int(mask_params->M_strides[2]),
int64_t(mask_params->M_strides[2]),
Int<1>{},
params->qL,
params->kL,

View File

@@ -346,7 +346,7 @@ template <
MSubTile mfrag;
mfrag.load_safe(
mask,
int(mask_params->M_strides[2]),
int64_t(mask_params->M_strides[2]),
Int<1>{},
params->qL,
params->kL,

View File

@@ -105,17 +105,20 @@ struct BaseMMAFrag<T, 8, 8> {
LimY lim_y,
OffX off_x = Int<0>{},
OffY off_y = Int<0>{}) {
src += off_x * str_x + off_y * str_y;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
dst[i * kElemCols + j] =
static_cast<T>(src[(off_x + i) * str_x + (off_y + j) * str_y]);
dst[i * kElemCols + j] = static_cast<T>(src[0]);
} else {
dst[i * kElemCols + j] = T(0);
}
src += str_y;
}
src -= kElemCols * str_y;
src += str_x;
}
}

View File

@@ -880,6 +880,11 @@ std::vector<array> ScaledDotProductAttention::vjp(
std::vector<array> returned_vjps;
for (int arg : argnums) {
if (arg >= 3) {
throw std::invalid_argument(
"[scale_dot_product_attention] Does not support VJP with respect "
" to mask or attention sinks.");
}
returned_vjps.push_back(std::move(vjps[arg]));
}
return returned_vjps;

View File

@@ -124,37 +124,53 @@ auto py_value_and_grad(
// Collect the arrays
std::vector<mx::array> arrays;
std::vector<nb::object> array_objects;
auto flatten_with_objects = [&arrays, &array_objects](
auto tree, bool strict) {
tree_visit(tree, [&](nb::handle obj) {
if (nb::isinstance<mx::array>(obj)) {
arrays.push_back(nb::cast<mx::array>(obj));
array_objects.push_back(nb::borrow<nb::object>(obj));
} else if (strict) {
throw std::invalid_argument(
"[tree_flatten] The argument should contain only arrays");
}
});
};
std::vector<int> counts(1, 0);
std::vector<int> gradient_indices;
for (int i = 0, j = 0; i < args.size(); ++i) {
bool needs_grad = (j < argnums.size() && argnums[j] == i);
auto argsi = tree_flatten(args[i], /* strict = */ needs_grad);
auto pre_size = arrays.size();
flatten_with_objects(args[i], /* strict = */ needs_grad);
if (needs_grad) {
auto old_size = gradient_indices.size();
gradient_indices.resize(old_size + argsi.size());
auto delta_size = arrays.size() - pre_size;
gradient_indices.resize(old_size + delta_size);
std::iota(
gradient_indices.begin() + old_size,
gradient_indices.end(),
arrays.size());
pre_size);
j++;
counts.push_back(argsi.size());
counts.push_back(delta_size);
}
arrays.insert(arrays.end(), argsi.begin(), argsi.end());
}
for (auto item : kwargs) {
bool needs_grad =
(argnames.find(nb::cast<std::string>(item.first)) != argnames.end());
auto argsk = tree_flatten(item.second, /* strict = */ needs_grad);
auto pre_size = arrays.size();
flatten_with_objects(item.second, /* strict = */ needs_grad);
if (needs_grad) {
auto old_size = gradient_indices.size();
gradient_indices.resize(old_size + argsk.size());
auto delta_size = arrays.size() - pre_size;
gradient_indices.resize(old_size + delta_size);
std::iota(
gradient_indices.begin() + old_size,
gradient_indices.end(),
arrays.size());
counts.push_back(argsk.size());
pre_size);
counts.push_back(delta_size);
}
arrays.insert(arrays.end(), argsk.begin(), argsk.end());
}
std::partial_sum(counts.cbegin(), counts.cend(), counts.begin());
@@ -163,7 +179,7 @@ auto py_value_and_grad(
nb::object py_value_out;
auto value_and_grads = mx::value_and_grad(
[&fun,
&arrays,
&array_objects,
&args,
&kwargs,
&py_value_out,
@@ -183,8 +199,9 @@ auto py_value_and_grad(
tree_visit_update(tree, [&](nb::handle node) {
auto replace_arr = nb::cast<mx::array>(node);
if (replace_arr.id() == a[index].id()) {
return nb::cast(arrays[index++]);
return array_objects[index++];
} else {
index++;
return nb::cast(replace_arr);
}
});

View File

@@ -780,9 +780,21 @@ class TestAutograd(mlx_tests.MLXTestCase):
return arrs[0]
arrs = [mx.array(1.0)]
init_id = id(arrs[0])
arr = arrs[0]
mx.grad(fun)(arrs)
self.assertEqual(init_id, id(arrs[0]))
self.assertEqual(id(arr), id(arrs[0]))
def fun(arrs):
arrs[1] = sum(arrs)
return arrs[1]
arrs = [mx.array(1.0), mx.array(1.0), mx.array(1.0)]
a_0, a_1, a_2 = arrs
mx.grad(fun)(arrs)
self.assertEqual(id(a_0), id(arrs[0]))
self.assertNotEqual(id(a_1), id(arrs[1]))
self.assertEqual(id(a_2), id(arrs[2]))
def test_grad_with_inplace_update(self):
def loss_fn(model):

View File

@@ -744,7 +744,6 @@ class TestVmap(mlx_tests.MLXTestCase):
return Vector([t[0] + 10, t[1] * 10])
x = State(mx.array(1), mx.array(2))
print(f"{transform(x)=}")
vmap_transform = mx.vmap(transform)
vmap_transform_tuple = mx.vmap(transform_tuple)