diff --git a/docs/src/python/random.rst b/docs/src/python/random.rst index d08d5a7df..5d98304bb 100644 --- a/docs/src/python/random.rst +++ b/docs/src/python/random.rst @@ -44,3 +44,4 @@ we use a splittable version of Threefry, which is a counter-based PRNG. split truncated_normal uniform + laplace diff --git a/mlx/random.cpp b/mlx/random.cpp index 45ce85763..590ca375e 100644 --- a/mlx/random.cpp +++ b/mlx/random.cpp @@ -102,6 +102,19 @@ T above_minus_one() { return f; } +// Get the next representable value above -1.0 for half precision +// use std::nextafter as default case. +array above_minus_one_with_default(Dtype dtype) { + switch (dtype) { + case float16: + return array(above_minus_one(), dtype); + case bfloat16: + return array(above_minus_one(), dtype); + default: + return array(std::nextafter(-1.0f, 0.0f), dtype); + } +} + array uniform( const array& low, const array& high, @@ -171,17 +184,7 @@ array normal( const std::optional& key /*= nullopt */, StreamOrDevice s /* = {} */) { auto stream = to_stream(s); - auto get_low = [&dtype]() { - switch (dtype) { - case float16: - return array(above_minus_one(), dtype); - case bfloat16: - return array(above_minus_one(), dtype); - default: - return array(std::nextafter(-1.0f, 0.0f), dtype); - } - }; - auto low = get_low(); + auto low = above_minus_one_with_default(dtype); auto high = array(1.0f, dtype); auto samples = uniform(low, high, shape, dtype, key, stream); samples = @@ -428,4 +431,30 @@ array categorical( return categorical_impl(logits, axis, shape, key, s); } +array laplace( + const std::vector& shape, + Dtype dtype, + const float loc /* = 0.0 */, + const float scale /* = 1.0 */, + const std::optional& key /*= nullopt */, + StreamOrDevice s /* = {} */) { + auto stream = to_stream(s); + auto low = above_minus_one_with_default(dtype); + auto high = array(1.0f, dtype); + auto samples = uniform(low, high, shape, dtype, key, stream); + // Use inverse CDF to generate Laplacian noise + samples = multiply( + sign(samples), + log1p(multiply(array(-1.0f, dtype), abs(samples))), + stream); + + if (scale != 1.0) { + samples = multiply(array(scale, dtype), samples, stream); + } + if (loc != 0.0) { + samples = add(array(loc, dtype), samples, stream); + } + return samples; +} + } // namespace mlx::core::random diff --git a/mlx/random.h b/mlx/random.h index fb1a76bc0..ad030c7e3 100644 --- a/mlx/random.h +++ b/mlx/random.h @@ -224,4 +224,34 @@ array categorical( const std::optional& key = std::nullopt, StreamOrDevice s = {}); +/** Generate samples from the laplace distribution. */ +array laplace( + const std::vector& shape, + Dtype dtype, + const float loc, + const float scale, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); +inline array laplace( + const std::vector& shape, + const float loc, + const float scale, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return laplace(shape, float32, loc, scale, key, s); +} +inline array laplace( + const std::vector& shape, + const Dtype dtype, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return laplace(shape, dtype, 0.0, 1.0, key, s); +} +inline array laplace( + const std::vector& shape, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return laplace(shape, float32, 0.0, 1.0, key, s); +} + } // namespace mlx::core::random diff --git a/python/src/random.cpp b/python/src/random.cpp index 3f082e15d..1dd4ef56b 100644 --- a/python/src/random.cpp +++ b/python/src/random.cpp @@ -419,6 +419,38 @@ void init_random(nb::module_& parent_module) { Returns: array: The ``shape``-sized output array with type ``uint32``. )pbdoc"); + m.def( + "laplace", + [](const std::vector& shape, + std::optional type, + float loc, + float scale, + const std::optional& key_, + StreamOrDevice s) { + auto key = key_ ? key_.value() : default_key().next(); + return laplace(shape, type.value_or(float32), loc, scale, key, s); + }, + "shape"_a = std::vector{}, + "dtype"_a.none() = float32, + "loc"_a = 0.0, + "scale"_a = 1.0, + "key"_a = nb::none(), + "stream"_a = nb::none(), + nb::sig( + "def laplace(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: float = 0.0, scale: float = 1.0, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Sample numbers from a Laplace distribution. + + Args: + shape (list(int), optional): Shape of the output. Default is ``()``. + dtype (Dtype, optional): Type of the output. Default is ``float32``. + loc (float, optional): Mean of the distribution. Default is ``0.0``. + scale (float, optional): The scale "b" of the Laplace distribution. Default is ``1.0``. + key (array, optional): A PRNG key. Default: None. + + Returns: + array: The output array of random values. + )pbdoc"); // Register static Python object cleanup before the interpreter exits auto atexit = nb::module_::import_("atexit"); atexit.attr("register")(nb::cpp_function([]() { default_key().release(); })); diff --git a/python/tests/test_random.py b/python/tests/test_random.py index 4ddef837b..b6f632491 100644 --- a/python/tests/test_random.py +++ b/python/tests/test_random.py @@ -64,43 +64,50 @@ class TestRandom(mlx_tests.MLXTestCase): self.assertEqual(mx.random.uniform().dtype, mx.random.uniform(dtype=None).dtype) - def test_normal(self): - key = mx.random.key(0) - a = mx.random.normal(key=key) - self.assertEqual(a.shape, ()) - self.assertEqual(a.dtype, mx.float32) + def test_normal_and_laplace(self): + # Same tests for normal and laplace. + for distribution_sampler in [mx.random.normal, mx.random.laplace]: + key = mx.random.key(0) + a = distribution_sampler(key=key) + self.assertEqual(a.shape, ()) + self.assertEqual(a.dtype, mx.float32) - b = mx.random.normal(key=key) - self.assertEqual(a.item(), b.item()) + b = distribution_sampler(key=key) + self.assertEqual(a.item(), b.item()) - a = mx.random.normal(shape=(2, 3)) - self.assertEqual(a.shape, (2, 3)) + a = distribution_sampler(shape=(2, 3)) + self.assertEqual(a.shape, (2, 3)) - ## Generate in float16 or bfloat16 - for t in [mx.float16, mx.bfloat16]: - a = mx.random.normal(dtype=t) - self.assertEqual(a.dtype, t) + ## Generate in float16 or bfloat16 + for t in [mx.float16, mx.bfloat16]: + a = distribution_sampler(dtype=t) + self.assertEqual(a.dtype, t) - # Generate with a given mean and standard deviation - loc = 1.0 - scale = 2.0 + # Generate with a given mean and standard deviation + loc = 1.0 + scale = 2.0 - a = mx.random.normal(shape=(3, 2), loc=loc, scale=scale, key=key) - b = scale * mx.random.normal(shape=(3, 2), key=key) + loc - self.assertTrue(mx.allclose(a, b)) + a = distribution_sampler(shape=(3, 2), loc=loc, scale=scale, key=key) + b = scale * distribution_sampler(shape=(3, 2), key=key) + loc + self.assertTrue(mx.allclose(a, b)) - a = mx.random.normal( - shape=(3, 2), loc=loc, scale=scale, dtype=mx.float16, key=key - ) - b = scale * mx.random.normal(shape=(3, 2), dtype=mx.float16, key=key) + loc - self.assertTrue(mx.allclose(a, b)) + a = distribution_sampler( + shape=(3, 2), loc=loc, scale=scale, dtype=mx.float16, key=key + ) + b = ( + scale * distribution_sampler(shape=(3, 2), dtype=mx.float16, key=key) + + loc + ) + self.assertTrue(mx.allclose(a, b)) - self.assertEqual(mx.random.normal().dtype, mx.random.normal(dtype=None).dtype) + self.assertEqual( + distribution_sampler().dtype, distribution_sampler(dtype=None).dtype + ) - # Test not getting -inf or inf with half precison - for hp in [mx.float16, mx.bfloat16]: - a = abs(mx.random.normal(shape=(10000,), loc=0, scale=1, dtype=hp)) - self.assertTrue(mx.all(a < mx.inf)) + # Test not getting -inf or inf with half precison + for hp in [mx.float16, mx.bfloat16]: + a = abs(distribution_sampler(shape=(10000,), loc=0, scale=1, dtype=hp)) + self.assertTrue(mx.all(a < mx.inf)) def test_multivariate_normal(self): key = mx.random.key(0) diff --git a/tests/random_tests.cpp b/tests/random_tests.cpp index 42259e065..a449fed83 100644 --- a/tests/random_tests.cpp +++ b/tests/random_tests.cpp @@ -640,3 +640,74 @@ TEST_CASE("test categorical") { CHECK_EQ(categorical(logits, -2, 7).shape(), std::vector{5, 3, 7}); CHECK_EQ(categorical(logits, -3, 7).shape(), std::vector{4, 3, 7}); } + +TEST_CASE("test laplace") { + // Test shapes, types, and sizes + { + auto x = random::laplace({}); + CHECK_EQ(x.size(), 1); + CHECK_EQ(x.dtype(), float32); + + // Non float type throws + CHECK_THROWS_AS(random::laplace({}, int32), std::invalid_argument); + + // Check wrong key type or shape + auto key = array({0, 0}); + CHECK_THROWS_AS(random::laplace({}, key), std::invalid_argument); + key = array({0, 0}, {1, 2}); + CHECK_THROWS_AS(random::laplace({}, key), std::invalid_argument); + key = array({0u, 0u, 0u}, {3, 1}); + CHECK_THROWS_AS(random::laplace({}, key), std::invalid_argument); + key = array({0u, 0u}, {2, 1}); + CHECK_THROWS_AS(random::laplace({}, key), std::invalid_argument); + } + + { + constexpr float inf = std::numeric_limits::infinity(); + auto key = random::key(128291); + auto out = random::laplace({1000000}, key); + float sample_mean = mean(out).item(); + float sample_variance = var(out).item(); + + CHECK(all(less(abs(out), array(inf))).item()); + CHECK(abs(sample_mean) < 0.1); + + // Chebyshev's inequality. + for (int k = 1; k <= 5; ++k) { + float prob_above = + mean(greater_equal(out, array(k * std::sqrt(sample_variance)))) + .item(); + float bound = 1 / std::pow(k, 2); + CHECK(prob_above < bound); + } + + // Expected variance for Laplace distribution is 2*scale^2. + float expected_variance = 2.0; + CHECK(std::abs(sample_variance - expected_variance) < 0.01); + + // Expected kurtosis of Laplace distribution is 3. + array fourth_pows = power(out - sample_mean, {4}); + float sample_kurtosis = + mean(fourth_pows).item() / std::pow(sample_variance, 2) - 3; + float expected_kurtosis = 3.0; + CHECK(std::abs(sample_kurtosis - expected_kurtosis) < 0.1); + } + + { + constexpr float inf = std::numeric_limits::infinity(); + auto key = random::key(128291); + auto out = random::laplace({10000}, float16, key); + CHECK_EQ(out.dtype(), float16); + CHECK(all(less(abs(out), array(inf))).item()); + CHECK(abs(float(mean(out).item())) < 0.1); + } + + { + constexpr float inf = std::numeric_limits::infinity(); + auto key = random::key(128291); + auto out = random::laplace({10000}, bfloat16, key); + CHECK_EQ(out.dtype(), bfloat16); + CHECK(all(less(abs(out), array(inf))).item()); + CHECK(abs(float(mean(out).item())) < 0.1); + } +}