mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
179 lines
5.1 KiB
Python
179 lines
5.1 KiB
Python
# Copyright © 2024 Apple Inc.
|
|
|
|
import math
|
|
import os
|
|
import unittest
|
|
|
|
import mlx.core as mx
|
|
import mlx_tests
|
|
import numpy as np
|
|
|
|
|
|
class TestDouble(mlx_tests.MLXTestCase):
|
|
def test_unary_ops(self):
|
|
shape = (3, 3)
|
|
x = mx.random.normal(shape=shape)
|
|
|
|
if mx.default_device() == mx.gpu:
|
|
with self.assertRaises(ValueError):
|
|
x.astype(mx.float64)
|
|
|
|
x_double = x.astype(mx.float64, stream=mx.cpu)
|
|
|
|
ops = [
|
|
mx.abs,
|
|
mx.arccos,
|
|
mx.arccosh,
|
|
mx.arcsin,
|
|
mx.arcsinh,
|
|
mx.arctan,
|
|
mx.arctanh,
|
|
mx.ceil,
|
|
mx.erf,
|
|
mx.erfinv,
|
|
mx.exp,
|
|
mx.expm1,
|
|
mx.floor,
|
|
mx.log,
|
|
mx.logical_not,
|
|
mx.negative,
|
|
mx.round,
|
|
mx.sin,
|
|
mx.sinh,
|
|
mx.sqrt,
|
|
mx.rsqrt,
|
|
mx.tan,
|
|
mx.tanh,
|
|
]
|
|
for op in ops:
|
|
if mx.default_device() == mx.gpu:
|
|
with self.assertRaises(ValueError):
|
|
op(x_double)
|
|
continue
|
|
y = op(x)
|
|
y_double = op(x_double)
|
|
self.assertTrue(
|
|
mx.allclose(y, y_double.astype(mx.float32, mx.cpu), equal_nan=True)
|
|
)
|
|
|
|
def test_binary_ops(self):
|
|
shape = (3, 3)
|
|
a = mx.random.normal(shape=shape)
|
|
b = mx.random.normal(shape=shape)
|
|
|
|
a_double = a.astype(mx.float64, stream=mx.cpu)
|
|
b_double = b.astype(mx.float64, stream=mx.cpu)
|
|
|
|
ops = [
|
|
mx.add,
|
|
mx.arctan2,
|
|
mx.divide,
|
|
mx.multiply,
|
|
mx.subtract,
|
|
mx.logical_and,
|
|
mx.logical_or,
|
|
mx.remainder,
|
|
mx.maximum,
|
|
mx.minimum,
|
|
mx.power,
|
|
mx.equal,
|
|
mx.greater,
|
|
mx.greater_equal,
|
|
mx.less,
|
|
mx.less_equal,
|
|
mx.not_equal,
|
|
mx.logaddexp,
|
|
]
|
|
for op in ops:
|
|
if mx.default_device() == mx.gpu:
|
|
with self.assertRaises(ValueError):
|
|
op(a_double, b_double)
|
|
continue
|
|
y = op(a, b)
|
|
y_double = op(a_double, b_double)
|
|
self.assertTrue(
|
|
mx.allclose(y, y_double.astype(mx.float32, mx.cpu), equal_nan=True)
|
|
)
|
|
|
|
def test_where(self):
|
|
shape = (3, 3)
|
|
cond = mx.random.uniform(shape=shape) > 0.5
|
|
a = mx.random.normal(shape=shape)
|
|
b = mx.random.normal(shape=shape)
|
|
|
|
a_double = a.astype(mx.float64, stream=mx.cpu)
|
|
b_double = b.astype(mx.float64, stream=mx.cpu)
|
|
|
|
if mx.default_device() == mx.gpu:
|
|
with self.assertRaises(ValueError):
|
|
mx.where(cond, a_double, b_double)
|
|
return
|
|
y = mx.where(cond, a, b)
|
|
y_double = mx.where(cond, a_double, b_double)
|
|
self.assertTrue(mx.allclose(y, y_double.astype(mx.float32, mx.cpu)))
|
|
|
|
def test_reductions(self):
|
|
shape = (32, 32)
|
|
a = mx.random.normal(shape=shape)
|
|
a_double = a.astype(mx.float64, stream=mx.cpu)
|
|
|
|
axes = [0, 1, (0, 1)]
|
|
ops = [mx.sum, mx.prod, mx.min, mx.max, mx.any, mx.all]
|
|
|
|
for op in ops:
|
|
for ax in axes:
|
|
if mx.default_device() == mx.gpu:
|
|
with self.assertRaises(ValueError):
|
|
op(a_double, axis=ax)
|
|
continue
|
|
y = op(a)
|
|
y_double = op(a_double)
|
|
self.assertTrue(mx.allclose(y, y_double.astype(mx.float32, mx.cpu)))
|
|
|
|
def test_get_and_set_item(self):
|
|
shape = (3, 3)
|
|
a = mx.random.normal(shape=shape)
|
|
b = mx.random.normal(shape=(2,))
|
|
a_double = a.astype(mx.float64, stream=mx.cpu)
|
|
b_double = b.astype(mx.float64, stream=mx.cpu)
|
|
idx_i = mx.array([0, 2])
|
|
idx_j = mx.array([0, 2])
|
|
|
|
if mx.default_device() == mx.gpu:
|
|
with self.assertRaises(ValueError):
|
|
a_double[idx_i, idx_j]
|
|
else:
|
|
y = a[idx_i, idx_j]
|
|
y_double = a_double[idx_i, idx_j]
|
|
self.assertTrue(mx.allclose(y, y_double.astype(mx.float32, mx.cpu)))
|
|
|
|
if mx.default_device() == mx.gpu:
|
|
with self.assertRaises(ValueError):
|
|
a_double[idx_i, idx_j] = b_double
|
|
else:
|
|
a[idx_i, idx_j] = b
|
|
a_double[idx_i, idx_j] = b_double
|
|
self.assertTrue(mx.allclose(a, a_double.astype(mx.float32, mx.cpu)))
|
|
|
|
def test_gemm(self):
|
|
shape = (8, 8)
|
|
a = mx.random.normal(shape=shape)
|
|
b = mx.random.normal(shape=shape)
|
|
|
|
a_double = a.astype(mx.float64, stream=mx.cpu)
|
|
b_double = b.astype(mx.float64, stream=mx.cpu)
|
|
|
|
if mx.default_device() == mx.gpu:
|
|
with self.assertRaises(ValueError):
|
|
a_double @ b_double
|
|
return
|
|
y = a @ b
|
|
y_double = a_double @ b_double
|
|
self.assertTrue(
|
|
mx.allclose(y, y_double.astype(mx.float32, mx.cpu), equal_nan=True)
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|