mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-17 17:28:10 +08:00
Ensure linspace always contains start and stop (#1883)
This commit is contained in:

committed by
GitHub

parent
344a29506e
commit
1a2cb72030
14
mlx/ops.cpp
14
mlx/ops.cpp
@@ -230,16 +230,16 @@ array linspace(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (num == 1) {
|
||||
return astype(array({start}), dtype, to_stream(s));
|
||||
return astype(array({start}), dtype, s);
|
||||
}
|
||||
array sequence = arange(0, num, float32, to_stream(s));
|
||||
float step = (stop - start) / (num - 1);
|
||||
array t = divide(arange(0, num, float32, s), array(num - 1, float32), s);
|
||||
array t_bar = subtract(array(1, float32), t, s);
|
||||
return astype(
|
||||
add(multiply(sequence, array(step), to_stream(s)),
|
||||
array(start),
|
||||
to_stream(s)),
|
||||
add(multiply(t_bar, array(start, float32), s),
|
||||
multiply(t, array(stop, float32), s),
|
||||
s),
|
||||
dtype,
|
||||
to_stream(s));
|
||||
s);
|
||||
}
|
||||
|
||||
array astype(array a, Dtype dtype, StreamOrDevice s /* = {} */) {
|
||||
|
Reference in New Issue
Block a user