diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 811728307..1662dc788 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -49,6 +49,7 @@ Operations identity less less_equal + linspace load log log2 diff --git a/mlx/ops.cpp b/mlx/ops.cpp index c9bafaac1..c25ac28bf 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -129,6 +129,27 @@ array arange(int stop, StreamOrDevice s /* = {} */) { return arange(0.0, static_cast(stop), 1.0, int32, to_stream(s)); } +array linspace( + double start, + double stop, + int num /* = 50 */, + Dtype dtype /* = float32 */, + StreamOrDevice s /* = {} */) { + if (num < 0) { + std::ostringstream msg; + msg << "[linspace] number of samples, " << num << ", must be non-negative."; + throw std::invalid_argument(msg.str()); + } + array sequence = arange(0, num, float32, to_stream(s)); + float step = (stop - start) / (num - 1); + return astype( + add(multiply(sequence, array(step), to_stream(s)), + array(start), + to_stream(s)), + dtype, + to_stream(s)); +} + array astype(const array& a, Dtype dtype, StreamOrDevice s /* = {} */) { if (dtype == a.dtype()) { return a; diff --git a/mlx/ops.h b/mlx/ops.h index 0b99a282f..50a0dc1eb 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -20,7 +20,7 @@ Stream to_stream(StreamOrDevice s); /** * A 1D array of numbers starting at `start` (optional), - * stopping at stop, stepping by `step` (optional). **/ + * stopping at stop, stepping by `step` (optional). */ array arange( double start, double stop, @@ -37,6 +37,14 @@ array arange(int start, int stop, int step, StreamOrDevice s = {}); array arange(int start, int stop, StreamOrDevice s = {}); array arange(int stop, StreamOrDevice s = {}); +/** A 1D array of `num` evenly spaced numbers in the range `[start, stop]` */ +array linspace( + double start, + double stop, + int num = 50, + Dtype dtype = float32, + StreamOrDevice s = {}); + /** Convert an array to the given data type. */ array astype(const array& a, Dtype dtype, StreamOrDevice s = {}); diff --git a/python/src/ops.cpp b/python/src/ops.cpp index b16b3ef0f..14b281d82 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -1184,6 +1184,32 @@ void init_ops(py::module_& m) { This can lead to unexpected results for example if `start + step` is a fractional value and the `dtype` is integral. )pbdoc"); + m.def( + "linspace", + [](Scalar start, Scalar stop, int num, Dtype dtype, StreamOrDevice s) { + return linspace( + scalar_to_double(start), scalar_to_double(stop), num, dtype, s); + }, + "start"_a, + "stop"_a, + "num"_a = 50, + "dtype"_a = float32, + "stream"_a = none, + R"pbdoc( + linspace(start, stop, num: Optional[int] = 50, dtype: Optional[Dtype] = float32, stream: Union[None, Stream, Device] = None) -> array + + Generate ``num`` evenly spaced numbers over interval ``[start, stop]``. + + Args: + start (scalar): Starting value. + stop (scalar): Stopping value. + num (int, optional): Number of samples, defaults to ``50``. + dtype (Dtype, optional): Specifies the data type of the output, + default to ``float32``. + + Returns: + array: The range of values. + )pbdoc"); m.def( "take", [](const array& a, diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 5e0ac0062..3f4cef9ed 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1491,6 +1491,27 @@ class TestOps(mlx_tests.MLXTestCase): clipped = mx.clip(mx.array(a), mx.array(mins), mx.array(maxs)) self.assertTrue(np.array_equal(clipped, expected)) + def test_linspace(self): + # Test default num = 50 + a = mx.linspace(0, 1) + expected = mx.array(np.linspace(0, 1)) + self.assertEqualArray(a, expected) + + # Test int32 dtype + b = mx.linspace(0, 10, 5, mx.int64) + expected = mx.array(np.linspace(0, 10, 5, dtype=int)) + self.assertEqualArray(b, expected) + + # Test negative sequence with float start and stop + c = mx.linspace(-2.7, -0.7, 7) + expected = mx.array(np.linspace(-2.7, -0.7, 7)) + self.assertEqualArray(c, expected) + + # Test irrational step size of 1/9 + d = mx.linspace(0, 1, 10) + expected = mx.array(np.linspace(0, 1, 10)) + self.assertEqualArray(d, expected) + if __name__ == "__main__": unittest.main() diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 6dccf4bc1..4e87ffc33 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2201,3 +2201,17 @@ TEST_CASE("test clipping with only max") { auto clipped = clip(a, std::nullopt, array(4.0f)); CHECK(array_equal(clipped, expected).item()); } + +TEST_CASE("test linspace") { + auto x = linspace(0, 10, 5); + auto expected = array({0.0f, 2.5f, 5.0f, 7.5f, 10.0f}, {5}); + CHECK(array_equal(x, expected).item()); + + x = linspace(0, 10, 5, int32); + expected = array({0, 2, 5, 7, 10}, {5}); + CHECK(array_equal(x, expected).item()); + + x = linspace(0, 1, 0); + expected = array(std::initializer_list{}, {0}); + CHECK(array_equal(x, expected).item()); +}