mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Shapeless support for zeros/ones_like (#2726)
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
Some checks failed
Nightly Build / build_linux_release (3.10) (push) Has been cancelled
Nightly Build / build_linux_release (3.14) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.10) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.11) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.12) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.13) (push) Has been cancelled
Nightly Build / build_linux_with_tests (3.14) (push) Has been cancelled
Nightly Build / build_mac_release (3.10) (push) Has been cancelled
Nightly Build / build_mac_release (3.13) (push) Has been cancelled
Nightly Build / build_cuda_with_tests (push) Has been cancelled
Nightly Build / build_cuda_release (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (aarch64) (push) Has been cancelled
Nightly Build / Linux Fedora CPP Build (x86_64) (push) Has been cancelled
* shapeless support for zeros/ones_like * Improvements * fix access after moved
This commit is contained in:
32
mlx/ops.cpp
32
mlx/ops.cpp
@@ -280,16 +280,19 @@ array copy(array a, StreamOrDevice s /* = {} */) {
|
||||
{std::move(a)});
|
||||
}
|
||||
|
||||
array full_impl(array vals, Dtype dtype, StreamOrDevice s /* = {} */) {
|
||||
return array(
|
||||
vals.shape(),
|
||||
dtype,
|
||||
std::make_shared<Full>(to_stream(s)),
|
||||
{astype(vals, dtype, s)});
|
||||
}
|
||||
|
||||
array full(Shape shape, array vals, Dtype dtype, StreamOrDevice s /* = {} */) {
|
||||
if (std::any_of(shape.begin(), shape.end(), [](auto i) { return i < 0; })) {
|
||||
throw std::invalid_argument("[full] Negative dimensions not allowed.");
|
||||
}
|
||||
auto copied_shape = shape; // |shape| will be moved
|
||||
return array(
|
||||
std::move(copied_shape),
|
||||
dtype,
|
||||
std::make_shared<Full>(to_stream(s)),
|
||||
{broadcast_to(astype(std::move(vals), dtype, s), std::move(shape), s)});
|
||||
return full_impl(broadcast_to(vals, std::move(shape), s), dtype, s);
|
||||
}
|
||||
|
||||
array full(Shape shape, array vals, StreamOrDevice s /* = {} */) {
|
||||
@@ -297,12 +300,25 @@ array full(Shape shape, array vals, StreamOrDevice s /* = {} */) {
|
||||
return full(std::move(shape), std::move(vals), dtype, to_stream(s));
|
||||
}
|
||||
|
||||
array full_like(
|
||||
const array& a,
|
||||
array vals,
|
||||
Dtype dtype,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto inputs = broadcast_arrays({a, std::move(vals)}, s);
|
||||
return full_impl(std::move(inputs[1]), dtype, s);
|
||||
}
|
||||
|
||||
array full_like(const array& a, array vals, StreamOrDevice s /* = {} */) {
|
||||
return full_like(a, std::move(vals), a.dtype(), to_stream(s));
|
||||
}
|
||||
|
||||
array zeros(const Shape& shape, Dtype dtype, StreamOrDevice s /* = {} */) {
|
||||
return full(shape, array(0, dtype), to_stream(s));
|
||||
}
|
||||
|
||||
array zeros_like(const array& a, StreamOrDevice s /* = {} */) {
|
||||
return zeros(a.shape(), a.dtype(), to_stream(s));
|
||||
return full_like(a, 0, a.dtype(), to_stream(s));
|
||||
}
|
||||
|
||||
array ones(const Shape& shape, Dtype dtype, StreamOrDevice s /* = {} */) {
|
||||
@@ -310,7 +326,7 @@ array ones(const Shape& shape, Dtype dtype, StreamOrDevice s /* = {} */) {
|
||||
}
|
||||
|
||||
array ones_like(const array& a, StreamOrDevice s /* = {} */) {
|
||||
return ones(a.shape(), a.dtype(), to_stream(s));
|
||||
return full_like(a, 1, a.dtype(), to_stream(s));
|
||||
}
|
||||
|
||||
array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s /* = {} */) {
|
||||
|
||||
11
mlx/ops.h
11
mlx/ops.h
@@ -69,6 +69,17 @@ array full(Shape shape, T val, StreamOrDevice s = {}) {
|
||||
return full(std::move(shape), array(val), to_stream(s));
|
||||
}
|
||||
|
||||
array full_like(const array& a, array vals, Dtype dtype, StreamOrDevice s = {});
|
||||
array full_like(const array& a, array vals, StreamOrDevice s = {});
|
||||
template <typename T>
|
||||
array full_like(const array& a, T val, Dtype dtype, StreamOrDevice s = {}) {
|
||||
return full_like(a, array(val, dtype), dtype, to_stream(s));
|
||||
}
|
||||
template <typename T>
|
||||
array full_like(const array& a, T val, StreamOrDevice s = {}) {
|
||||
return full_like(a, array(val, a.dtype()), to_stream(s));
|
||||
}
|
||||
|
||||
/** Fill an array of the given shape with zeros. */
|
||||
array zeros(const Shape& shape, Dtype dtype, StreamOrDevice s = {});
|
||||
inline array zeros(const Shape& shape, StreamOrDevice s = {}) {
|
||||
|
||||
@@ -1146,6 +1146,7 @@ class Full : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_NAME(Full)
|
||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
};
|
||||
|
||||
class Gather : public UnaryPrimitive {
|
||||
|
||||
Reference in New Issue
Block a user