Compare commits

...

2 Commits

Author SHA1 Message Date
Angelos Katharopoulos
3336a35512 Fix the segments type in the test 2025-07-07 17:25:19 -07:00
Angelos Katharopoulos
1c589298ec Address comments 2025-07-07 17:03:28 -07:00
3 changed files with 6 additions and 2 deletions

View File

@@ -1,6 +1,5 @@
// Copyright © 2023-2024 Apple Inc.
#include <fmt/format.h>
#include <iostream>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/utils.h"

View File

@@ -4676,6 +4676,11 @@ array segmented_mm(
throw std::invalid_argument(msg.str());
}
if (!issubdtype(segments.dtype(), integer)) {
throw std::invalid_argument(
"[segmented_mm] Got segments with invalid dtype. Segments must be integral.");
}
a = astype(a, out_type, s);
b = astype(b, out_type, s);
segments = astype(segments, uint32, s);

View File

@@ -1247,7 +1247,7 @@ class TestBlas(mlx_tests.MLXTestCase):
a = mx.ones((10, 1000))
s = mx.random.randint(0, 16, shape=(1000,))
s = mx.zeros(16).at[s].add(1)
s = mx.zeros(16, dtype=s.dtype).at[s].add(1)
s = mx.sort(s)
s = mx.cumsum(s)
s = mx.concatenate([mx.array([0]), s])