mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 16:48:10 +08:00
* Implementation of mlx.random.multivariate_normal (#502) * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Updated typo in docstring * Restricted multivariate_normal to float32 * Generic mean and variance shapes * Review edits * Update mlx/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Update python/src/random.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Test for ndim of mean and cov * nits * smaller size for test * fix broadcasted sampling --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import math
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
@@ -101,6 +102,96 @@ class TestRandom(mlx_tests.MLXTestCase):
|
||||
a = abs(mx.random.normal(shape=(10000,), loc=0, scale=1, dtype=hp))
|
||||
self.assertTrue(mx.all(a < mx.inf))
|
||||
|
||||
def test_multivariate_normal(self):
|
||||
key = mx.random.key(0)
|
||||
mean = mx.array([0, 0])
|
||||
cov = mx.array([[1, 0], [0, 1]])
|
||||
|
||||
a = mx.random.multivariate_normal(mean, cov, key=key, stream=mx.cpu)
|
||||
self.assertEqual(a.shape, (2,))
|
||||
|
||||
## Check dtypes
|
||||
for t in [mx.float32]:
|
||||
a = mx.random.multivariate_normal(
|
||||
mean, cov, dtype=t, key=key, stream=mx.cpu
|
||||
)
|
||||
self.assertEqual(a.dtype, t)
|
||||
for t in [
|
||||
mx.int8,
|
||||
mx.int32,
|
||||
mx.int64,
|
||||
mx.uint8,
|
||||
mx.uint32,
|
||||
mx.uint64,
|
||||
mx.float16,
|
||||
mx.bfloat16,
|
||||
]:
|
||||
with self.assertRaises(ValueError):
|
||||
mx.random.multivariate_normal(
|
||||
mean, cov, dtype=t, key=key, stream=mx.cpu
|
||||
)
|
||||
|
||||
## Check incompatible shapes
|
||||
with self.assertRaises(ValueError):
|
||||
mean = mx.zeros((2, 2))
|
||||
cov = mx.zeros((2, 2))
|
||||
mx.random.multivariate_normal(mean, cov, shape=(3,), key=key, stream=mx.cpu)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mean = mx.zeros((2))
|
||||
cov = mx.zeros((2, 2, 2))
|
||||
mx.random.multivariate_normal(mean, cov, shape=(3,), key=key, stream=mx.cpu)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mean = mx.zeros((3,))
|
||||
cov = mx.zeros((2, 2))
|
||||
mx.random.multivariate_normal(mean, cov, key=key, stream=mx.cpu)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mean = mx.zeros((2,))
|
||||
cov = mx.zeros((2, 3))
|
||||
mx.random.multivariate_normal(mean, cov, key=key, stream=mx.cpu)
|
||||
|
||||
## Different shape of mean and cov
|
||||
mean = mx.array([[0, 7], [1, 2], [3, 4]])
|
||||
cov = mx.array([[1, 0.5], [0.5, 1]])
|
||||
a = mx.random.multivariate_normal(mean, cov, shape=(4, 3), stream=mx.cpu)
|
||||
self.assertEqual(a.shape, (4, 3, 2))
|
||||
|
||||
## Check correcteness of the mean and covariance
|
||||
n_test = int(1e5)
|
||||
|
||||
def check_jointly_gaussian(data, mean, cov):
|
||||
empirical_mean = mx.mean(data, axis=0)
|
||||
empirical_cov = (
|
||||
(data - empirical_mean).T @ (data - empirical_mean) / data.shape[0]
|
||||
)
|
||||
N = data.shape[1]
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
empirical_mean, mean, rtol=0.0, atol=10 * N**2 / math.sqrt(n_test)
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
empirical_cov, cov, rtol=0.0, atol=10 * N**2 / math.sqrt(n_test)
|
||||
)
|
||||
)
|
||||
|
||||
mean = mx.array([4.0, 7.0])
|
||||
cov = mx.array([[2, 0.5], [0.5, 1]])
|
||||
data = mx.random.multivariate_normal(
|
||||
mean, cov, shape=(n_test,), key=key, stream=mx.cpu
|
||||
)
|
||||
check_jointly_gaussian(data, mean, cov)
|
||||
|
||||
mean = mx.arange(3)
|
||||
cov = mx.array([[1, -1, 0.5], [-1, 1, -0.5], [0.5, -0.5, 1]])
|
||||
data = mx.random.multivariate_normal(
|
||||
mean, cov, shape=(n_test,), key=key, stream=mx.cpu
|
||||
)
|
||||
check_jointly_gaussian(data, mean, cov)
|
||||
|
||||
def test_randint(self):
|
||||
a = mx.random.randint(0, 1, [])
|
||||
self.assertEqual(a.shape, ())
|
||||
|
Reference in New Issue
Block a user