mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Added linspace (#181)
* linspace ops support --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
f4f6e17d45
commit
e6872a4149
@ -49,6 +49,7 @@ Operations
|
|||||||
identity
|
identity
|
||||||
less
|
less
|
||||||
less_equal
|
less_equal
|
||||||
|
linspace
|
||||||
load
|
load
|
||||||
log
|
log
|
||||||
log2
|
log2
|
||||||
|
21
mlx/ops.cpp
21
mlx/ops.cpp
@ -129,6 +129,27 @@ array arange(int stop, StreamOrDevice s /* = {} */) {
|
|||||||
return arange(0.0, static_cast<double>(stop), 1.0, int32, to_stream(s));
|
return arange(0.0, static_cast<double>(stop), 1.0, int32, to_stream(s));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array linspace(
|
||||||
|
double start,
|
||||||
|
double stop,
|
||||||
|
int num /* = 50 */,
|
||||||
|
Dtype dtype /* = float32 */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
if (num < 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[linspace] number of samples, " << num << ", must be non-negative.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
array sequence = arange(0, num, float32, to_stream(s));
|
||||||
|
float step = (stop - start) / (num - 1);
|
||||||
|
return astype(
|
||||||
|
add(multiply(sequence, array(step), to_stream(s)),
|
||||||
|
array(start),
|
||||||
|
to_stream(s)),
|
||||||
|
dtype,
|
||||||
|
to_stream(s));
|
||||||
|
}
|
||||||
|
|
||||||
array astype(const array& a, Dtype dtype, StreamOrDevice s /* = {} */) {
|
array astype(const array& a, Dtype dtype, StreamOrDevice s /* = {} */) {
|
||||||
if (dtype == a.dtype()) {
|
if (dtype == a.dtype()) {
|
||||||
return a;
|
return a;
|
||||||
|
10
mlx/ops.h
10
mlx/ops.h
@ -20,7 +20,7 @@ Stream to_stream(StreamOrDevice s);
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* A 1D array of numbers starting at `start` (optional),
|
* A 1D array of numbers starting at `start` (optional),
|
||||||
* stopping at stop, stepping by `step` (optional). **/
|
* stopping at stop, stepping by `step` (optional). */
|
||||||
array arange(
|
array arange(
|
||||||
double start,
|
double start,
|
||||||
double stop,
|
double stop,
|
||||||
@ -37,6 +37,14 @@ array arange(int start, int stop, int step, StreamOrDevice s = {});
|
|||||||
array arange(int start, int stop, StreamOrDevice s = {});
|
array arange(int start, int stop, StreamOrDevice s = {});
|
||||||
array arange(int stop, StreamOrDevice s = {});
|
array arange(int stop, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** A 1D array of `num` evenly spaced numbers in the range `[start, stop]` */
|
||||||
|
array linspace(
|
||||||
|
double start,
|
||||||
|
double stop,
|
||||||
|
int num = 50,
|
||||||
|
Dtype dtype = float32,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Convert an array to the given data type. */
|
/** Convert an array to the given data type. */
|
||||||
array astype(const array& a, Dtype dtype, StreamOrDevice s = {});
|
array astype(const array& a, Dtype dtype, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
@ -1184,6 +1184,32 @@ void init_ops(py::module_& m) {
|
|||||||
This can lead to unexpected results for example if `start + step`
|
This can lead to unexpected results for example if `start + step`
|
||||||
is a fractional value and the `dtype` is integral.
|
is a fractional value and the `dtype` is integral.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"linspace",
|
||||||
|
[](Scalar start, Scalar stop, int num, Dtype dtype, StreamOrDevice s) {
|
||||||
|
return linspace(
|
||||||
|
scalar_to_double(start), scalar_to_double(stop), num, dtype, s);
|
||||||
|
},
|
||||||
|
"start"_a,
|
||||||
|
"stop"_a,
|
||||||
|
"num"_a = 50,
|
||||||
|
"dtype"_a = float32,
|
||||||
|
"stream"_a = none,
|
||||||
|
R"pbdoc(
|
||||||
|
linspace(start, stop, num: Optional[int] = 50, dtype: Optional[Dtype] = float32, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
|
Generate ``num`` evenly spaced numbers over interval ``[start, stop]``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start (scalar): Starting value.
|
||||||
|
stop (scalar): Stopping value.
|
||||||
|
num (int, optional): Number of samples, defaults to ``50``.
|
||||||
|
dtype (Dtype, optional): Specifies the data type of the output,
|
||||||
|
default to ``float32``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The range of values.
|
||||||
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"take",
|
"take",
|
||||||
[](const array& a,
|
[](const array& a,
|
||||||
|
@ -1491,6 +1491,27 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
clipped = mx.clip(mx.array(a), mx.array(mins), mx.array(maxs))
|
clipped = mx.clip(mx.array(a), mx.array(mins), mx.array(maxs))
|
||||||
self.assertTrue(np.array_equal(clipped, expected))
|
self.assertTrue(np.array_equal(clipped, expected))
|
||||||
|
|
||||||
|
def test_linspace(self):
|
||||||
|
# Test default num = 50
|
||||||
|
a = mx.linspace(0, 1)
|
||||||
|
expected = mx.array(np.linspace(0, 1))
|
||||||
|
self.assertEqualArray(a, expected)
|
||||||
|
|
||||||
|
# Test int32 dtype
|
||||||
|
b = mx.linspace(0, 10, 5, mx.int64)
|
||||||
|
expected = mx.array(np.linspace(0, 10, 5, dtype=int))
|
||||||
|
self.assertEqualArray(b, expected)
|
||||||
|
|
||||||
|
# Test negative sequence with float start and stop
|
||||||
|
c = mx.linspace(-2.7, -0.7, 7)
|
||||||
|
expected = mx.array(np.linspace(-2.7, -0.7, 7))
|
||||||
|
self.assertEqualArray(c, expected)
|
||||||
|
|
||||||
|
# Test irrational step size of 1/9
|
||||||
|
d = mx.linspace(0, 1, 10)
|
||||||
|
expected = mx.array(np.linspace(0, 1, 10))
|
||||||
|
self.assertEqualArray(d, expected)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -2201,3 +2201,17 @@ TEST_CASE("test clipping with only max") {
|
|||||||
auto clipped = clip(a, std::nullopt, array(4.0f));
|
auto clipped = clip(a, std::nullopt, array(4.0f));
|
||||||
CHECK(array_equal(clipped, expected).item<bool>());
|
CHECK(array_equal(clipped, expected).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test linspace") {
|
||||||
|
auto x = linspace(0, 10, 5);
|
||||||
|
auto expected = array({0.0f, 2.5f, 5.0f, 7.5f, 10.0f}, {5});
|
||||||
|
CHECK(array_equal(x, expected).item<bool>());
|
||||||
|
|
||||||
|
x = linspace(0, 10, 5, int32);
|
||||||
|
expected = array({0, 2, 5, 7, 10}, {5});
|
||||||
|
CHECK(array_equal(x, expected).item<bool>());
|
||||||
|
|
||||||
|
x = linspace(0, 1, 0);
|
||||||
|
expected = array(std::initializer_list<float>{}, {0});
|
||||||
|
CHECK(array_equal(x, expected).item<bool>());
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user