mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
2 Commits
8ea5729ee4
...
3336a35512
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3336a35512 | ||
|
|
1c589298ec |
@@ -1,6 +1,5 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
#include <iostream>
|
|
||||||
|
|
||||||
#include "mlx/backend/common/compiled.h"
|
#include "mlx/backend/common/compiled.h"
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
|
|||||||
@@ -4676,6 +4676,11 @@ array segmented_mm(
|
|||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!issubdtype(segments.dtype(), integer)) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[segmented_mm] Got segments with invalid dtype. Segments must be integral.");
|
||||||
|
}
|
||||||
|
|
||||||
a = astype(a, out_type, s);
|
a = astype(a, out_type, s);
|
||||||
b = astype(b, out_type, s);
|
b = astype(b, out_type, s);
|
||||||
segments = astype(segments, uint32, s);
|
segments = astype(segments, uint32, s);
|
||||||
|
|||||||
@@ -1247,7 +1247,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
a = mx.ones((10, 1000))
|
a = mx.ones((10, 1000))
|
||||||
s = mx.random.randint(0, 16, shape=(1000,))
|
s = mx.random.randint(0, 16, shape=(1000,))
|
||||||
s = mx.zeros(16).at[s].add(1)
|
s = mx.zeros(16, dtype=s.dtype).at[s].add(1)
|
||||||
s = mx.sort(s)
|
s = mx.sort(s)
|
||||||
s = mx.cumsum(s)
|
s = mx.cumsum(s)
|
||||||
s = mx.concatenate([mx.array([0]), s])
|
s = mx.concatenate([mx.array([0]), s])
|
||||||
|
|||||||
Reference in New Issue
Block a user