mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix the test and cpu edge case
This commit is contained in:
@@ -81,6 +81,7 @@ inline void segmented_mm(
|
|||||||
uint32_t k_end =
|
uint32_t k_end =
|
||||||
segments[elem_to_loc(2 * i + 1, segments_shape, segments_strides)];
|
segments[elem_to_loc(2 * i + 1, segments_shape, segments_strides)];
|
||||||
if (k_end <= k_start) {
|
if (k_end <= k_start) {
|
||||||
|
std::fill_n(out + i * M * N, M * N, T(0));
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
a_copy[ndim - 1] = k_end - k_start;
|
a_copy[ndim - 1] = k_end - k_start;
|
||||||
|
|||||||
@@ -1207,38 +1207,38 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
(10, 10, 1000),
|
(10, 10, 1000),
|
||||||
(1000, 1000, 1000),
|
(1000, 1000, 1000),
|
||||||
]
|
]
|
||||||
segments = [[0, 0, 1.0], [0, 0.5, 1.0], [r / 9 for r in range(10)]]
|
all_segments = [[0, 0, 1.0], [0, 0.5, 1.0], [r / 9 for r in range(10)]]
|
||||||
|
|
||||||
for M, N, K in shapes:
|
for M, N, K in shapes:
|
||||||
for s in segments:
|
for s in all_segments:
|
||||||
segments = []
|
segments = []
|
||||||
for i in range(len(s) - 1):
|
for i in range(len(s) - 1):
|
||||||
segments.append([s[i], s[i + 1]])
|
segments.append([s[i], s[i + 1]])
|
||||||
segments = mx.array(segments)
|
segments = mx.array(segments)
|
||||||
segments = mx.maximum(K - 1, segments.astype(mx.uint32))
|
segments = mx.minimum(K - 1, (K * segments).astype(mx.uint32))
|
||||||
a = mx.random.normal((M, K))
|
a = mx.random.normal((M, K))
|
||||||
b = mx.random.normal((K, N))
|
b = mx.random.normal((K, N))
|
||||||
c1 = segmented_mm_ref(a, b, segments)
|
c1 = segmented_mm_ref(a, b, segments)
|
||||||
c2 = mx.segmented_mm(a, b, segments)
|
c2 = mx.segmented_mm(a, b, segments)
|
||||||
self.assertTrue(mx.allclose(c1, c2))
|
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
|
||||||
|
|
||||||
a = mx.random.normal((K, M))
|
a = mx.random.normal((K, M))
|
||||||
b = mx.random.normal((K, N))
|
b = mx.random.normal((K, N))
|
||||||
c1 = segmented_mm_ref(a.T, b, segments)
|
c1 = segmented_mm_ref(a.T, b, segments)
|
||||||
c2 = mx.segmented_mm(a.T, b, segments)
|
c2 = mx.segmented_mm(a.T, b, segments)
|
||||||
self.assertTrue(mx.allclose(c1, c2))
|
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
|
||||||
|
|
||||||
a = mx.random.normal((M, K))
|
a = mx.random.normal((M, K))
|
||||||
b = mx.random.normal((N, K))
|
b = mx.random.normal((N, K))
|
||||||
c1 = segmented_mm_ref(a, b.T, segments)
|
c1 = segmented_mm_ref(a, b.T, segments)
|
||||||
c2 = mx.segmented_mm(a, b.T, segments)
|
c2 = mx.segmented_mm(a, b.T, segments)
|
||||||
self.assertTrue(mx.allclose(c1, c2))
|
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
|
||||||
|
|
||||||
a = mx.random.normal((K, M))
|
a = mx.random.normal((K, M))
|
||||||
b = mx.random.normal((N, K))
|
b = mx.random.normal((N, K))
|
||||||
c1 = segmented_mm_ref(a.T, b.T, segments)
|
c1 = segmented_mm_ref(a.T, b.T, segments)
|
||||||
c2 = mx.segmented_mm(a.T, b.T, segments)
|
c2 = mx.segmented_mm(a.T, b.T, segments)
|
||||||
self.assertTrue(mx.allclose(c1, c2))
|
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
a = mx.ones((2, 10, 10))
|
a = mx.ones((2, 10, 10))
|
||||||
|
|||||||
Reference in New Issue
Block a user