mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 19:38:16 +08:00
Added linspace (#181)
* linspace ops support --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
21
mlx/ops.cpp
21
mlx/ops.cpp
@@ -129,6 +129,27 @@ array arange(int stop, StreamOrDevice s /* = {} */) {
|
||||
return arange(0.0, static_cast<double>(stop), 1.0, int32, to_stream(s));
|
||||
}
|
||||
|
||||
array linspace(
|
||||
double start,
|
||||
double stop,
|
||||
int num /* = 50 */,
|
||||
Dtype dtype /* = float32 */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (num < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[linspace] number of samples, " << num << ", must be non-negative.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
array sequence = arange(0, num, float32, to_stream(s));
|
||||
float step = (stop - start) / (num - 1);
|
||||
return astype(
|
||||
add(multiply(sequence, array(step), to_stream(s)),
|
||||
array(start),
|
||||
to_stream(s)),
|
||||
dtype,
|
||||
to_stream(s));
|
||||
}
|
||||
|
||||
array astype(const array& a, Dtype dtype, StreamOrDevice s /* = {} */) {
|
||||
if (dtype == a.dtype()) {
|
||||
return a;
|
||||
|
10
mlx/ops.h
10
mlx/ops.h
@@ -20,7 +20,7 @@ Stream to_stream(StreamOrDevice s);
|
||||
|
||||
/**
|
||||
* A 1D array of numbers starting at `start` (optional),
|
||||
* stopping at stop, stepping by `step` (optional). **/
|
||||
* stopping at stop, stepping by `step` (optional). */
|
||||
array arange(
|
||||
double start,
|
||||
double stop,
|
||||
@@ -37,6 +37,14 @@ array arange(int start, int stop, int step, StreamOrDevice s = {});
|
||||
array arange(int start, int stop, StreamOrDevice s = {});
|
||||
array arange(int stop, StreamOrDevice s = {});
|
||||
|
||||
/** A 1D array of `num` evenly spaced numbers in the range `[start, stop]` */
|
||||
array linspace(
|
||||
double start,
|
||||
double stop,
|
||||
int num = 50,
|
||||
Dtype dtype = float32,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Convert an array to the given data type. */
|
||||
array astype(const array& a, Dtype dtype, StreamOrDevice s = {});
|
||||
|
||||
|
Reference in New Issue
Block a user