From dfdb284e163c6580dc00a3c78ed58052e3389121 Mon Sep 17 00:00:00 2001 From: Daniel Strobusch <1847260+dastrobu@users.noreply.github.com> Date: Fri, 5 Jan 2024 18:37:46 +0100 Subject: [PATCH] make behaviour of dtype arguments consistent and compliant to numpy (#379) All functions that take an optional dtype should * have a default dtype visible in the generated docs (accomplished via `"dtype"_a = std::optional{float32}`) * behave identical when `dtype=None` or no dtype is passed This important when passing kw args down from a numpy function like: ``` def f(x, dtype=None): mx.random.uniform(dtype=dtype) # ... ``` NumPy functions behave like this. It also fixes a minor bug in `tri`: #378 Closes #378 --- python/src/ops.cpp | 42 +++++++++++++++++++++++------------- python/src/random.cpp | 43 +++++++++++++++++++++++++------------ python/tests/test_ops.py | 2 ++ python/tests/test_random.py | 17 +++++++++++++++ 4 files changed, 75 insertions(+), 29 deletions(-) diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 49a17a5c8..15c8bf69d 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -1212,14 +1212,22 @@ void init_ops(py::module_& m) { )pbdoc"); m.def( "linspace", - [](Scalar start, Scalar stop, int num, Dtype dtype, StreamOrDevice s) { + [](Scalar start, + Scalar stop, + int num, + std::optional dtype, + StreamOrDevice s) { return linspace( - scalar_to_double(start), scalar_to_double(stop), num, dtype, s); + scalar_to_double(start), + scalar_to_double(stop), + num, + dtype.value_or(float32), + s); }, "start"_a, "stop"_a, "num"_a = 50, - "dtype"_a = float32, + "dtype"_a = std::optional{float32}, "stream"_a = none, R"pbdoc( linspace(start, stop, num: Optional[int] = 50, dtype: Optional[Dtype] = float32, stream: Union[None, Stream, Device] = None) -> array @@ -1356,11 +1364,11 @@ void init_ops(py::module_& m) { } }, "shape"_a, - "dtype"_a = std::nullopt, + "dtype"_a = std::optional{float32}, py::kw_only(), "stream"_a = none, R"pbdoc( - zeros(shape: Union[int, List[int]], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array + zeros(shape: Union[int, List[int]], dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array Construct an array of zeros. @@ -1403,11 +1411,11 @@ void init_ops(py::module_& m) { } }, "shape"_a, - "dtype"_a = std::nullopt, + "dtype"_a = std::optional{float32}, py::kw_only(), "stream"_a = none, R"pbdoc( - ones(shape: Union[int, List[int]], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array + ones(shape: Union[int, List[int]], dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array Construct an array of ones. @@ -1449,11 +1457,11 @@ void init_ops(py::module_& m) { "n"_a, "m"_a = py::none(), "k"_a = 0, - "dtype"_a = std::nullopt, + "dtype"_a = std::optional{float32}, py::kw_only(), "stream"_a = none, R"pbdoc( - eye(n: int, m: Optional[int] = None, k: int = 0, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array + eye(n: int, m: Optional[int] = None, k: int = 0, dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array Create an identity matrix or a general diagonal matrix. @@ -1473,11 +1481,11 @@ void init_ops(py::module_& m) { return identity(n, dtype.value_or(float32), s); }, "n"_a, - "dtype"_a = std::nullopt, + "dtype"_a = std::optional{float32}, py::kw_only(), "stream"_a = none, R"pbdoc( - identity(n: int, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array + identity(n: int, dtype: Optional[Dtype] = float32, *, stream: Union[None, Stream, Device] = None) -> array Create a square identity matrix. @@ -1491,13 +1499,17 @@ void init_ops(py::module_& m) { )pbdoc"); m.def( "tri", - [](int n, std::optional m, int k, Dtype dtype, StreamOrDevice s) { - return tri(n, m.value_or(n), k, float32, s); + [](int n, + std::optional m, + int k, + std::optional type, + StreamOrDevice s) { + return tri(n, m.value_or(n), k, type.value_or(float32), s); }, "n"_a, "m"_a = none, "k"_a = 0, - "dtype"_a = float32, + "dtype"_a = std::optional{float32}, py::kw_only(), "stream"_a = none, R"pbdoc( @@ -2432,7 +2444,7 @@ void init_ops(py::module_& m) { array (array): Input array. repeats (int): The number of repetitions for each element. axis (int, optional): The axis in which to repeat the array along. If - unspecified it uses the flattened array of the input and repeats + unspecified it uses the flattened array of the input and repeats along axis 0. stream (Stream, optional): Stream or device. Defaults to ``None``. diff --git a/python/src/random.cpp b/python/src/random.cpp index f648a2714..6e9f38d97 100644 --- a/python/src/random.cpp +++ b/python/src/random.cpp @@ -61,15 +61,21 @@ void init_random(py::module_& parent_module) { [](const ScalarOrArray& low, const ScalarOrArray& high, const std::vector& shape, - Dtype type, + std::optional type, const std::optional& key, StreamOrDevice s) { - return uniform(to_array(low), to_array(high), shape, type, key, s); + return uniform( + to_array(low), + to_array(high), + shape, + type.value_or(float32), + key, + s); }, "low"_a = 0, "high"_a = 1, "shape"_a = std::vector{}, - "dtype"_a = float32, + "dtype"_a = std::optional{float32}, "key"_a = none, "stream"_a = none, R"pbdoc( @@ -92,12 +98,14 @@ void init_random(py::module_& parent_module) { m.def( "normal", [](const std::vector& shape, - Dtype type, + std::optional type, const std::optional& key, - StreamOrDevice s) { return normal(shape, type, key, s); }, + StreamOrDevice s) { + return normal(shape, type.value_or(float32), key, s); + }, "shape"_a = std::vector{}, - "dtype"_a = float32, + "dtype"_a = std::optional{float32}, "key"_a = none, "stream"_a = none, R"pbdoc( @@ -116,10 +124,11 @@ void init_random(py::module_& parent_module) { [](const ScalarOrArray& low, const ScalarOrArray& high, const std::vector& shape, - Dtype type, + std::optional type, const std::optional& key, StreamOrDevice s) { - return randint(to_array(low), to_array(high), shape, type, key, s); + return randint( + to_array(low), to_array(high), shape, type.value_or(int32), key, s); }, "low"_a, "high"_a, @@ -183,21 +192,22 @@ void init_random(py::module_& parent_module) { [](const ScalarOrArray& lower_, const ScalarOrArray& upper_, const std::optional> shape_, - Dtype dtype, + std::optional type, const std::optional& key, StreamOrDevice s) { auto lower = to_array(lower_); auto upper = to_array(upper_); + auto t = type.value_or(float32); if (shape_.has_value()) { - return truncated_normal(lower, upper, shape_.value(), dtype, key, s); + return truncated_normal(lower, upper, shape_.value(), t, key, s); } else { - return truncated_normal(lower, upper, dtype, key, s); + return truncated_normal(lower, upper, t, key, s); } }, "lower"_a, "upper"_a, "shape"_a = none, - "dtype"_a = float32, + "dtype"_a = std::optional{float32}, "key"_a = none, "stream"_a = none, R"pbdoc( @@ -221,9 +231,14 @@ void init_random(py::module_& parent_module) { )pbdoc"); m.def( "gumbel", - &gumbel, + [](const std::vector& shape, + std::optional type, + const std::optional& key, + StreamOrDevice s) { + return gumbel(shape, type.value_or(float32), key, s); + }, "shape"_a = std::vector{}, - "dtype"_a = float32, + "dtype"_a = std::optional{float32}, "stream"_a = none, "key"_a = none, R"pbdoc( diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index d291ca31e..777a23cbe 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -325,6 +325,8 @@ class TestOps(mlx_tests.MLXTestCase): for shape in [[4], [4, 4], [2, 10]]: for diag in [-1, 0, 1, -2]: self.assertCmpNumpy(shape, mx.tri, np.tri, k=diag) + self.assertEqual(mx.tri(1, 1).dtype, mx.float32) + self.assertEqual(mx.tri(1, 1, dtype=mx.bfloat16).dtype, mx.bfloat16) def test_tril(self): for diag in [-1, 0, 1, -2]: diff --git a/python/tests/test_random.py b/python/tests/test_random.py index aa01339f4..c4ca7f62a 100644 --- a/python/tests/test_random.py +++ b/python/tests/test_random.py @@ -61,6 +61,8 @@ class TestRandom(mlx_tests.MLXTestCase): a = mx.random.uniform(low=-0.1, high=0.1, shape=(1,), dtype=mx.bfloat16) self.assertEqual(a.dtype, mx.bfloat16) + self.assertEqual(mx.random.uniform().dtype, mx.random.uniform(dtype=None).dtype) + def test_normal(self): key = mx.random.key(0) a = mx.random.normal(key=key) @@ -78,6 +80,8 @@ class TestRandom(mlx_tests.MLXTestCase): a = mx.random.normal(dtype=t) self.assertEqual(a.dtype, t) + self.assertEqual(mx.random.normal().dtype, mx.random.normal(dtype=None).dtype) + def test_randint(self): a = mx.random.randint(0, 1, []) self.assertEqual(a.shape, []) @@ -109,6 +113,10 @@ class TestRandom(mlx_tests.MLXTestCase): a = mx.random.randint(10, -10, [1000, 1000]) self.assertTrue(mx.all(a == 10).item()) + self.assertEqual( + mx.random.randint(0, 1).dtype, mx.random.randint(0, 1, dtype=None).dtype + ) + def test_bernoulli(self): a = mx.random.bernoulli() self.assertEqual(a.shape, []) @@ -159,6 +167,11 @@ class TestRandom(mlx_tests.MLXTestCase): with self.assertRaises(ValueError): mx.random.truncated_normal(lower, higher) # Bad shape + self.assertEqual( + mx.random.truncated_normal(0, 1).dtype, + mx.random.truncated_normal(0, 1, dtype=None).dtype, + ) + def test_gumbel(self): samples = mx.random.gumbel(shape=(100, 100)) self.assertEqual(samples.shape, [100, 100]) @@ -168,6 +181,10 @@ class TestRandom(mlx_tests.MLXTestCase): # so this test is pretty conservative self.assertTrue(mx.abs(mx.mean(samples) - mean) < 0.2) + self.assertEqual( + mx.random.gumbel((1, 1)).dtype, mx.random.gumbel((1, 1), dtype=None).dtype + ) + def test_categorical(self): logits = mx.zeros((10, 20)) self.assertEqual(mx.random.categorical(logits, -1).shape, [10])