mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Pre-commit formatting
This commit is contained in:
@@ -50,12 +50,14 @@ def time_maximum():
|
||||
mx.eval(a, b)
|
||||
time_fn(mx.maximum, a, b)
|
||||
|
||||
|
||||
def time_max():
|
||||
a = mx.random.uniform(shape=(32, 1024, 1024))
|
||||
a[1,1] = mx.nan
|
||||
a[1, 1] = mx.nan
|
||||
mx.eval(a)
|
||||
time_fn(mx.max, a, 0)
|
||||
|
||||
|
||||
def time_negative():
|
||||
a = mx.random.uniform(shape=(10000, 1000))
|
||||
mx.eval(a)
|
||||
|
||||
@@ -187,10 +187,10 @@ struct Max {
|
||||
|
||||
template <typename T>
|
||||
T simd_reduce_impl(T val) {
|
||||
if(simd_any(val != val)) {
|
||||
if (simd_any(val != val)) {
|
||||
return static_cast<T>(NAN);
|
||||
}
|
||||
return simd_max(val);
|
||||
return simd_max(val);
|
||||
}
|
||||
|
||||
static constexpr constant U init = Limits<U>::min;
|
||||
@@ -208,21 +208,19 @@ struct Max {
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T a, T b) {
|
||||
if(metal::isnan(a) || metal::isnan(b)) {
|
||||
if (metal::isnan(a) || metal::isnan(b)) {
|
||||
return static_cast<T>(NAN);
|
||||
} else {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t a, complex64_t b) {
|
||||
if (metal::isnan(a.real) || metal::isnan(a.imag) || metal::isnan(b.real) || metal::isnan(b.imag)) {
|
||||
if (metal::isnan(a.real) || metal::isnan(a.imag) || metal::isnan(b.real) ||
|
||||
metal::isnan(b.imag)) {
|
||||
return static_cast<complex64_t>(NAN);
|
||||
}
|
||||
}
|
||||
return a > b ? a : b;
|
||||
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
@@ -168,7 +168,7 @@ class TestReduce(mlx_tests.MLXTestCase):
|
||||
for dtype in dtypes:
|
||||
with self.subTest(dtype=dtype):
|
||||
x = (mx.random.normal((4, 4))).astype(getattr(mx, dtype))
|
||||
indices = mx.random.randint(0, 4, shape=(6,)).reshape(3,2)
|
||||
indices = mx.random.randint(0, 4, shape=(6,)).reshape(3, 2)
|
||||
for idx in indices:
|
||||
x[*idx] = mx.nan
|
||||
x_np = np.array(x)
|
||||
@@ -179,5 +179,6 @@ class TestReduce(mlx_tests.MLXTestCase):
|
||||
ref = getattr(np, op)(x_np, axis=axis)
|
||||
self.assertTrue(np.array_equal(out, ref, equal_nan=True))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner(failfast=True)
|
||||
|
||||
@@ -1027,7 +1027,6 @@ TEST_CASE("test reduction ops") {
|
||||
|
||||
x = array({1.0f, NAN, 3.0f, 4.0f, 5.0f, 6.0f});
|
||||
CHECK(isnan(max(x).item<float>()));
|
||||
|
||||
}
|
||||
|
||||
// Test logsumexp
|
||||
|
||||
Reference in New Issue
Block a user