mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-07 17:44:38 +08:00
Fp64 on the CPU (#1843)
* add fp64 data type * clean build * update docs * fix bug
This commit is contained in:
178
python/tests/test_double.py
Normal file
178
python/tests/test_double.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import math
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TestDouble(mlx_tests.MLXTestCase):
|
||||
def test_unary_ops(self):
|
||||
shape = (3, 3)
|
||||
x = mx.random.normal(shape=shape)
|
||||
|
||||
if mx.default_device() == mx.gpu:
|
||||
with self.assertRaises(ValueError):
|
||||
x.astype(mx.float64)
|
||||
|
||||
x_double = x.astype(mx.float64, stream=mx.cpu)
|
||||
|
||||
ops = [
|
||||
mx.abs,
|
||||
mx.arccos,
|
||||
mx.arccosh,
|
||||
mx.arcsin,
|
||||
mx.arcsinh,
|
||||
mx.arctan,
|
||||
mx.arctanh,
|
||||
mx.ceil,
|
||||
mx.erf,
|
||||
mx.erfinv,
|
||||
mx.exp,
|
||||
mx.expm1,
|
||||
mx.floor,
|
||||
mx.log,
|
||||
mx.logical_not,
|
||||
mx.negative,
|
||||
mx.round,
|
||||
mx.sin,
|
||||
mx.sinh,
|
||||
mx.sqrt,
|
||||
mx.rsqrt,
|
||||
mx.tan,
|
||||
mx.tanh,
|
||||
]
|
||||
for op in ops:
|
||||
if mx.default_device() == mx.gpu:
|
||||
with self.assertRaises(ValueError):
|
||||
op(x_double)
|
||||
continue
|
||||
y = op(x)
|
||||
y_double = op(x_double)
|
||||
self.assertTrue(
|
||||
mx.allclose(y, y_double.astype(mx.float32, mx.cpu), equal_nan=True)
|
||||
)
|
||||
|
||||
def test_binary_ops(self):
|
||||
shape = (3, 3)
|
||||
a = mx.random.normal(shape=shape)
|
||||
b = mx.random.normal(shape=shape)
|
||||
|
||||
a_double = a.astype(mx.float64, stream=mx.cpu)
|
||||
b_double = b.astype(mx.float64, stream=mx.cpu)
|
||||
|
||||
ops = [
|
||||
mx.add,
|
||||
mx.arctan2,
|
||||
mx.divide,
|
||||
mx.multiply,
|
||||
mx.subtract,
|
||||
mx.logical_and,
|
||||
mx.logical_or,
|
||||
mx.remainder,
|
||||
mx.maximum,
|
||||
mx.minimum,
|
||||
mx.power,
|
||||
mx.equal,
|
||||
mx.greater,
|
||||
mx.greater_equal,
|
||||
mx.less,
|
||||
mx.less_equal,
|
||||
mx.not_equal,
|
||||
mx.logaddexp,
|
||||
]
|
||||
for op in ops:
|
||||
if mx.default_device() == mx.gpu:
|
||||
with self.assertRaises(ValueError):
|
||||
op(a_double, b_double)
|
||||
continue
|
||||
y = op(a, b)
|
||||
y_double = op(a_double, b_double)
|
||||
self.assertTrue(
|
||||
mx.allclose(y, y_double.astype(mx.float32, mx.cpu), equal_nan=True)
|
||||
)
|
||||
|
||||
def test_where(self):
|
||||
shape = (3, 3)
|
||||
cond = mx.random.uniform(shape=shape) > 0.5
|
||||
a = mx.random.normal(shape=shape)
|
||||
b = mx.random.normal(shape=shape)
|
||||
|
||||
a_double = a.astype(mx.float64, stream=mx.cpu)
|
||||
b_double = b.astype(mx.float64, stream=mx.cpu)
|
||||
|
||||
if mx.default_device() == mx.gpu:
|
||||
with self.assertRaises(ValueError):
|
||||
mx.where(cond, a_double, b_double)
|
||||
return
|
||||
y = mx.where(cond, a, b)
|
||||
y_double = mx.where(cond, a_double, b_double)
|
||||
self.assertTrue(mx.allclose(y, y_double.astype(mx.float32, mx.cpu)))
|
||||
|
||||
def test_reductions(self):
|
||||
shape = (32, 32)
|
||||
a = mx.random.normal(shape=shape)
|
||||
a_double = a.astype(mx.float64, stream=mx.cpu)
|
||||
|
||||
axes = [0, 1, (0, 1)]
|
||||
ops = [mx.sum, mx.prod, mx.min, mx.max, mx.any, mx.all]
|
||||
|
||||
for op in ops:
|
||||
for ax in axes:
|
||||
if mx.default_device() == mx.gpu:
|
||||
with self.assertRaises(ValueError):
|
||||
op(a_double, axis=ax)
|
||||
continue
|
||||
y = op(a)
|
||||
y_double = op(a_double)
|
||||
self.assertTrue(mx.allclose(y, y_double.astype(mx.float32, mx.cpu)))
|
||||
|
||||
def test_get_and_set_item(self):
|
||||
shape = (3, 3)
|
||||
a = mx.random.normal(shape=shape)
|
||||
b = mx.random.normal(shape=(2,))
|
||||
a_double = a.astype(mx.float64, stream=mx.cpu)
|
||||
b_double = b.astype(mx.float64, stream=mx.cpu)
|
||||
idx_i = mx.array([0, 2])
|
||||
idx_j = mx.array([0, 2])
|
||||
|
||||
if mx.default_device() == mx.gpu:
|
||||
with self.assertRaises(ValueError):
|
||||
a_double[idx_i, idx_j]
|
||||
else:
|
||||
y = a[idx_i, idx_j]
|
||||
y_double = a_double[idx_i, idx_j]
|
||||
self.assertTrue(mx.allclose(y, y_double.astype(mx.float32, mx.cpu)))
|
||||
|
||||
if mx.default_device() == mx.gpu:
|
||||
with self.assertRaises(ValueError):
|
||||
a_double[idx_i, idx_j] = b_double
|
||||
else:
|
||||
a[idx_i, idx_j] = b
|
||||
a_double[idx_i, idx_j] = b_double
|
||||
self.assertTrue(mx.allclose(a, a_double.astype(mx.float32, mx.cpu)))
|
||||
|
||||
def test_gemm(self):
|
||||
shape = (8, 8)
|
||||
a = mx.random.normal(shape=shape)
|
||||
b = mx.random.normal(shape=shape)
|
||||
|
||||
a_double = a.astype(mx.float64, stream=mx.cpu)
|
||||
b_double = b.astype(mx.float64, stream=mx.cpu)
|
||||
|
||||
if mx.default_device() == mx.gpu:
|
||||
with self.assertRaises(ValueError):
|
||||
a_double @ b_double
|
||||
return
|
||||
y = a @ b
|
||||
y_double = a_double @ b_double
|
||||
self.assertTrue(
|
||||
mx.allclose(y, y_double.astype(mx.float32, mx.cpu), equal_nan=True)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Reference in New Issue
Block a user