Pre-commit formatting

This commit is contained in:
Joona Havukainen
2025-07-06 14:27:40 -07:00
parent af74818528
commit 5b089dc5da
4 changed files with 11 additions and 11 deletions

View File

@@ -50,12 +50,14 @@ def time_maximum():
mx.eval(a, b)
time_fn(mx.maximum, a, b)
def time_max():
a = mx.random.uniform(shape=(32, 1024, 1024))
a[1,1] = mx.nan
a[1, 1] = mx.nan
mx.eval(a)
time_fn(mx.max, a, 0)
def time_negative():
a = mx.random.uniform(shape=(10000, 1000))
mx.eval(a)

View File

@@ -187,10 +187,10 @@ struct Max {
template <typename T>
T simd_reduce_impl(T val) {
if(simd_any(val != val)) {
if (simd_any(val != val)) {
return static_cast<T>(NAN);
}
return simd_max(val);
return simd_max(val);
}
static constexpr constant U init = Limits<U>::min;
@@ -208,21 +208,19 @@ struct Max {
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T a, T b) {
if(metal::isnan(a) || metal::isnan(b)) {
if (metal::isnan(a) || metal::isnan(b)) {
return static_cast<T>(NAN);
} else {
return a > b ? a : b;
}
}
template <>
complex64_t operator()(complex64_t a, complex64_t b) {
if (metal::isnan(a.real) || metal::isnan(a.imag) || metal::isnan(b.real) || metal::isnan(b.imag)) {
if (metal::isnan(a.real) || metal::isnan(a.imag) || metal::isnan(b.real) ||
metal::isnan(b.imag)) {
return static_cast<complex64_t>(NAN);
}
return a > b ? a : b;
}
};

View File

@@ -168,7 +168,7 @@ class TestReduce(mlx_tests.MLXTestCase):
for dtype in dtypes:
with self.subTest(dtype=dtype):
x = (mx.random.normal((4, 4))).astype(getattr(mx, dtype))
indices = mx.random.randint(0, 4, shape=(6,)).reshape(3,2)
indices = mx.random.randint(0, 4, shape=(6,)).reshape(3, 2)
for idx in indices:
x[*idx] = mx.nan
x_np = np.array(x)
@@ -179,5 +179,6 @@ class TestReduce(mlx_tests.MLXTestCase):
ref = getattr(np, op)(x_np, axis=axis)
self.assertTrue(np.array_equal(out, ref, equal_nan=True))
if __name__ == "__main__":
mlx_tests.MLXTestRunner(failfast=True)

View File

@@ -1027,7 +1027,6 @@ TEST_CASE("test reduction ops") {
x = array({1.0f, NAN, 3.0f, 4.0f, 5.0f, 6.0f});
CHECK(isnan(max(x).item<float>()));
}
// Test logsumexp