Ensure linspace always contains start and stop (#1883)

This commit is contained in:
Angelos Katharopoulos
2025-02-19 13:53:20 -08:00
committed by GitHub
parent 344a29506e
commit 1a2cb72030
2 changed files with 15 additions and 7 deletions

View File

@@ -230,16 +230,16 @@ array linspace(
throw std::invalid_argument(msg.str());
}
if (num == 1) {
return astype(array({start}), dtype, to_stream(s));
return astype(array({start}), dtype, s);
}
array sequence = arange(0, num, float32, to_stream(s));
float step = (stop - start) / (num - 1);
array t = divide(arange(0, num, float32, s), array(num - 1, float32), s);
array t_bar = subtract(array(1, float32), t, s);
return astype(
add(multiply(sequence, array(step), to_stream(s)),
array(start),
to_stream(s)),
add(multiply(t_bar, array(start, float32), s),
multiply(t, array(stop, float32), s),
s),
dtype,
to_stream(s));
s);
}
array astype(array a, Dtype dtype, StreamOrDevice s /* = {} */) {