mlx/python/tests/test_double.py
2025-02-25 06:00:53 -08:00

189 lines
5.4 KiB
Python

# Copyright © 2024 Apple Inc.
import math
import os
import unittest
import mlx.core as mx
import mlx_tests
import numpy as np
class TestDouble(mlx_tests.MLXTestCase):
def test_unary_ops(self):
shape = (3, 3)
x = mx.random.normal(shape=shape)
if mx.default_device() == mx.gpu:
with self.assertRaises(ValueError):
x.astype(mx.float64)
x_double = x.astype(mx.float64, stream=mx.cpu)
ops = [
mx.abs,
mx.arccos,
mx.arccosh,
mx.arcsin,
mx.arcsinh,
mx.arctan,
mx.arctanh,
mx.ceil,
mx.erf,
mx.erfinv,
mx.exp,
mx.expm1,
mx.floor,
mx.log,
mx.logical_not,
mx.negative,
mx.round,
mx.sin,
mx.sinh,
mx.sqrt,
mx.rsqrt,
mx.tan,
mx.tanh,
]
for op in ops:
if mx.default_device() == mx.gpu:
with self.assertRaises(ValueError):
op(x_double)
continue
y = op(x)
y_double = op(x_double)
self.assertTrue(
mx.allclose(y, y_double.astype(mx.float32, mx.cpu), equal_nan=True)
)
def test_binary_ops(self):
shape = (3, 3)
a = mx.random.normal(shape=shape)
b = mx.random.normal(shape=shape)
a_double = a.astype(mx.float64, stream=mx.cpu)
b_double = b.astype(mx.float64, stream=mx.cpu)
ops = [
mx.add,
mx.arctan2,
mx.divide,
mx.multiply,
mx.subtract,
mx.logical_and,
mx.logical_or,
mx.remainder,
mx.maximum,
mx.minimum,
mx.power,
mx.equal,
mx.greater,
mx.greater_equal,
mx.less,
mx.less_equal,
mx.not_equal,
mx.logaddexp,
]
for op in ops:
if mx.default_device() == mx.gpu:
with self.assertRaises(ValueError):
op(a_double, b_double)
continue
y = op(a, b)
y_double = op(a_double, b_double)
self.assertTrue(
mx.allclose(y, y_double.astype(mx.float32, mx.cpu), equal_nan=True)
)
def test_where(self):
shape = (3, 3)
cond = mx.random.uniform(shape=shape) > 0.5
a = mx.random.normal(shape=shape)
b = mx.random.normal(shape=shape)
a_double = a.astype(mx.float64, stream=mx.cpu)
b_double = b.astype(mx.float64, stream=mx.cpu)
if mx.default_device() == mx.gpu:
with self.assertRaises(ValueError):
mx.where(cond, a_double, b_double)
return
y = mx.where(cond, a, b)
y_double = mx.where(cond, a_double, b_double)
self.assertTrue(mx.allclose(y, y_double.astype(mx.float32, mx.cpu)))
def test_reductions(self):
shape = (32, 32)
a = mx.random.normal(shape=shape)
a_double = a.astype(mx.float64, stream=mx.cpu)
axes = [0, 1, (0, 1)]
ops = [mx.sum, mx.prod, mx.min, mx.max, mx.any, mx.all]
for op in ops:
for ax in axes:
if mx.default_device() == mx.gpu:
with self.assertRaises(ValueError):
op(a_double, axis=ax)
continue
y = op(a)
y_double = op(a_double)
self.assertTrue(mx.allclose(y, y_double.astype(mx.float32, mx.cpu)))
def test_get_and_set_item(self):
shape = (3, 3)
a = mx.random.normal(shape=shape)
b = mx.random.normal(shape=(2,))
a_double = a.astype(mx.float64, stream=mx.cpu)
b_double = b.astype(mx.float64, stream=mx.cpu)
idx_i = mx.array([0, 2])
idx_j = mx.array([0, 2])
if mx.default_device() == mx.gpu:
with self.assertRaises(ValueError):
a_double[idx_i, idx_j]
else:
y = a[idx_i, idx_j]
y_double = a_double[idx_i, idx_j]
self.assertTrue(mx.allclose(y, y_double.astype(mx.float32, mx.cpu)))
if mx.default_device() == mx.gpu:
with self.assertRaises(ValueError):
a_double[idx_i, idx_j] = b_double
else:
a[idx_i, idx_j] = b
a_double[idx_i, idx_j] = b_double
self.assertTrue(mx.allclose(a, a_double.astype(mx.float32, mx.cpu)))
def test_gemm(self):
shape = (8, 8)
a = mx.random.normal(shape=shape)
b = mx.random.normal(shape=shape)
a_double = a.astype(mx.float64, stream=mx.cpu)
b_double = b.astype(mx.float64, stream=mx.cpu)
if mx.default_device() == mx.gpu:
with self.assertRaises(ValueError):
a_double @ b_double
return
y = a @ b
y_double = a_double @ b_double
self.assertTrue(
mx.allclose(y, y_double.astype(mx.float32, mx.cpu), equal_nan=True)
)
def test_type_promotion(self):
import mlx.core as mx
a = mx.array([4, 8], mx.float64)
b = mx.array([4, 8], mx.int32)
with mx.stream(mx.cpu):
c = a + b
self.assertEqual(c.dtype, mx.float64)
if __name__ == "__main__":
unittest.main()