mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
parent
81def6ac76
commit
ea9090bbc4
@ -156,6 +156,7 @@ Operations
|
||||
tril
|
||||
triu
|
||||
var
|
||||
view
|
||||
where
|
||||
zeros
|
||||
zeros_like
|
||||
|
@ -5,7 +5,6 @@
|
||||
#else
|
||||
#include <cblas.h>
|
||||
#endif
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
@ -590,4 +590,36 @@ void Tanh::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void View::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
auto ibytes = size_of(in.dtype());
|
||||
auto obytes = size_of(out.dtype());
|
||||
// Conditions for buffer copying (disjunction):
|
||||
// - type size is the same
|
||||
// - type size is smaller and the last axis is contiguous
|
||||
// - the entire array is row contiguous
|
||||
if (ibytes == obytes || obytes < ibytes && in.strides().back() == 1 ||
|
||||
in.flags().row_contiguous) {
|
||||
auto strides = in.strides();
|
||||
for (int i = 0; i < strides.size() - 1; ++i) {
|
||||
strides[i] *= ibytes;
|
||||
strides[i] /= obytes;
|
||||
}
|
||||
out.copy_shared_buffer(
|
||||
in, strides, in.flags(), in.data_size() * obytes / ibytes);
|
||||
} else {
|
||||
auto tmp = array(in.shape(), in.dtype(), nullptr, {});
|
||||
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));
|
||||
copy_inplace(in, tmp, CopyType::General);
|
||||
|
||||
auto flags = out.flags();
|
||||
flags.contiguous = true;
|
||||
flags.row_contiguous = true;
|
||||
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
|
||||
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
|
||||
out.move_shared_buffer(tmp, out.strides(), flags, out.size());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -422,4 +422,35 @@ void Cholesky::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
"[Cholesky::eval_gpu] Metal Cholesky decomposition NYI.");
|
||||
}
|
||||
|
||||
void View::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& in = inputs[0];
|
||||
auto ibytes = size_of(in.dtype());
|
||||
auto obytes = size_of(out.dtype());
|
||||
// Conditions for buffer copying (disjunction):
|
||||
// - type size is the same
|
||||
// - type size is smaller and the last axis is contiguous
|
||||
// - the entire array is row contiguous
|
||||
if (ibytes == obytes || obytes < ibytes && in.strides().back() == 1 ||
|
||||
in.flags().row_contiguous) {
|
||||
auto strides = in.strides();
|
||||
for (int i = 0; i < strides.size() - 1; ++i) {
|
||||
strides[i] *= ibytes;
|
||||
strides[i] /= obytes;
|
||||
}
|
||||
out.copy_shared_buffer(
|
||||
in, strides, in.flags(), in.data_size() * ibytes / obytes);
|
||||
} else {
|
||||
auto tmp = array(in.shape(), in.dtype(), nullptr, {});
|
||||
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));
|
||||
copy_gpu_inplace(in, tmp, CopyType::General, stream());
|
||||
|
||||
auto flags = out.flags();
|
||||
flags.contiguous = true;
|
||||
flags.row_contiguous = true;
|
||||
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
|
||||
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
|
||||
out.move_shared_buffer(tmp, out.strides(), flags, out.size());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -106,5 +106,6 @@ NO_CPU(Tan)
|
||||
NO_CPU(Tanh)
|
||||
NO_CPU(Transpose)
|
||||
NO_CPU(Inverse)
|
||||
NO_CPU(View)
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -108,6 +108,7 @@ NO_GPU(Tanh)
|
||||
NO_GPU(Transpose)
|
||||
NO_GPU(Inverse)
|
||||
NO_GPU(Cholesky)
|
||||
NO_GPU(View)
|
||||
|
||||
namespace fast {
|
||||
NO_GPU_MULTI(LayerNorm)
|
||||
|
28
mlx/ops.cpp
28
mlx/ops.cpp
@ -4378,4 +4378,32 @@ array operator>>(const array& a, const array& b) {
|
||||
return right_shift(a, b);
|
||||
}
|
||||
|
||||
array view(const array& a, const Dtype& dtype, StreamOrDevice s /* = {} */) {
|
||||
if (a.dtype() == dtype) {
|
||||
return a;
|
||||
}
|
||||
auto out_shape = a.shape();
|
||||
auto ibytes = size_of(a.dtype());
|
||||
auto obytes = size_of(dtype);
|
||||
if (a.ndim() == 0 && ibytes != obytes) {
|
||||
throw std::invalid_argument(
|
||||
"[view] Changing the type of a scalar is only allowed"
|
||||
" for types with the same size.");
|
||||
} else {
|
||||
if (ibytes < obytes) {
|
||||
if (out_shape.back() % (obytes / ibytes) != 0) {
|
||||
throw std::invalid_argument(
|
||||
"[view] When viewing as a larger dtype, the size in bytes of the last"
|
||||
" axis must be a multiple of the requested type size.");
|
||||
}
|
||||
out_shape.back() /= (obytes / ibytes);
|
||||
} else {
|
||||
// Type size ratios are always integers
|
||||
out_shape.back() *= (ibytes / obytes);
|
||||
}
|
||||
}
|
||||
return array(
|
||||
out_shape, dtype, std::make_shared<View>(to_stream(s), dtype), {a});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -1300,6 +1300,7 @@ array operator<<(const array& a, const array& b);
|
||||
array right_shift(const array& a, const array& b, StreamOrDevice s = {});
|
||||
array operator>>(const array& a, const array& b);
|
||||
|
||||
array view(const array& a, const Dtype& dtype, StreamOrDevice s = {});
|
||||
/** @} */
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -3930,4 +3930,21 @@ std::pair<std::vector<array>, std::vector<int>> Inverse::vmap(
|
||||
return {{linalg::inv(a, stream())}, {ax}};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> View::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(axes.size() == 1);
|
||||
return {{view(inputs[0], dtype_, stream())}, axes};
|
||||
}
|
||||
|
||||
void View::print(std::ostream& os) {
|
||||
os << "View" << dtype_;
|
||||
}
|
||||
|
||||
bool View::is_equivalent(const Primitive& other) const {
|
||||
const View& a_other = static_cast<const View&>(other);
|
||||
return (dtype_ == a_other.dtype_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -2021,6 +2021,22 @@ class Uniform : public UnaryPrimitive {
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class View : public UnaryPrimitive {
|
||||
public:
|
||||
explicit View(Stream stream, Dtype dtype)
|
||||
: UnaryPrimitive(stream), dtype_(dtype) {};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
void print(std::ostream& os) override;
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
Dtype dtype_;
|
||||
};
|
||||
|
||||
class Transpose : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Transpose(Stream stream, const std::vector<int>& axes)
|
||||
|
@ -1276,5 +1276,14 @@ void init_array(nb::module_& m) {
|
||||
},
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
"See :func:`conj`.");
|
||||
"See :func:`conj`.")
|
||||
.def(
|
||||
"view",
|
||||
[](const ScalarOrArray& a, const Dtype& dtype, StreamOrDevice s) {
|
||||
return view(to_array(a), dtype, s);
|
||||
},
|
||||
"dtype"_a,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
"See :func:`view`.");
|
||||
}
|
||||
|
@ -3159,8 +3159,6 @@ void init_ops(nb::module_& m) {
|
||||
R"pbdoc(
|
||||
1D convolution over an input with several channels
|
||||
|
||||
Note: Only the default ``groups=1`` is currently supported.
|
||||
|
||||
Args:
|
||||
input (array): input array of shape (``N``, ``H``, ``C_in``)
|
||||
weight (array): weight array of shape (``C_out``, ``H``, ``C_in``)
|
||||
@ -3219,8 +3217,6 @@ void init_ops(nb::module_& m) {
|
||||
R"pbdoc(
|
||||
2D convolution over an input with several channels
|
||||
|
||||
Note: Only the default ``groups=1`` is currently supported.
|
||||
|
||||
Args:
|
||||
input (array): input array of shape ``(N, H, W, C_in)``
|
||||
weight (array): weight array of shape ``(C_out, H, W, C_in)``
|
||||
@ -3390,11 +3386,6 @@ void init_ops(nb::module_& m) {
|
||||
R"pbdoc(
|
||||
General convolution over an input with several channels
|
||||
|
||||
.. note::
|
||||
|
||||
* Only 1d and 2d convolutions are supported at the moment
|
||||
* the default ``groups=1`` is currently supported.
|
||||
|
||||
Args:
|
||||
input (array): Input array of shape ``(N, ..., C_in)``
|
||||
weight (array): Weight array of shape ``(C_out, ..., C_in)``
|
||||
@ -4356,4 +4347,32 @@ void init_ops(nb::module_& m) {
|
||||
Returns:
|
||||
array: The bitwise right shift ``a >> b``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"view",
|
||||
[](const ScalarOrArray& a, const Dtype& dtype, StreamOrDevice s) {
|
||||
return view(to_array(a), dtype, s);
|
||||
},
|
||||
nb::arg(),
|
||||
"dtype"_a,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def view(a: Union[scalar, array], dtype: Dtype, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
View the array as a different type.
|
||||
|
||||
The output shape changes along the last axis if the input array's
|
||||
type and the input ``dtype`` do not have the same size.
|
||||
|
||||
Note: the view op does not imply that the input and output arrays share
|
||||
their underlying data. The view only gaurantees that the binary
|
||||
representation of each element (or group of elements) is the same.
|
||||
|
||||
Args:
|
||||
a (array): Input array or scalar.
|
||||
dtype (Dtype): The data type to change to.
|
||||
|
||||
Returns:
|
||||
array: The array with the new type.
|
||||
)pbdoc");
|
||||
}
|
||||
|
@ -2333,6 +2333,30 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
out_np = a.conj()
|
||||
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
|
||||
|
||||
def test_view(self):
|
||||
a = mx.random.randint(shape=(4, 2, 4), low=-100, high=100)
|
||||
a_np = np.array(a)
|
||||
|
||||
for t in ["bool_", "int16", "float32", "int64"]:
|
||||
out = a.view(getattr(mx, t))
|
||||
expected = a_np.view(getattr(np, t))
|
||||
self.assertTrue(np.array_equal(out, expected, equal_nan=True))
|
||||
|
||||
# Irregular strides
|
||||
a = mx.random.randint(shape=(2, 4), low=-100, high=100)
|
||||
a = mx.broadcast_to(a, shape=(4, 2, 4))
|
||||
|
||||
for t in ["bool_", "int16", "float32", "int64"]:
|
||||
out = a.view(getattr(mx, t))
|
||||
a_out = out.view(mx.int32)
|
||||
self.assertTrue(mx.array_equal(a_out, a, equal_nan=True))
|
||||
|
||||
a = mx.random.randint(shape=(4, 4), low=-100, high=100).T
|
||||
for t in ["bool_", "int16", "float32", "int64"]:
|
||||
out = a.view(getattr(mx, t))
|
||||
a_out = out.view(mx.int32)
|
||||
self.assertTrue(mx.array_equal(a_out, a, equal_nan=True))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -3473,3 +3473,15 @@ TEST_CASE("test trace") {
|
||||
auto out4 = trace(in, 0, 1, 2, float32);
|
||||
CHECK(array_equal(out4, array({3, 11}, {2})).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test view") {
|
||||
auto in = array(3);
|
||||
CHECK_THROWS(view(in, int64));
|
||||
|
||||
in = array({1, 2, 3});
|
||||
CHECK_THROWS(view(in, int64));
|
||||
|
||||
in = array({1, 2, 3, 4}, int64);
|
||||
auto out = view(in, int32);
|
||||
CHECK(array_equal(out, array({1, 0, 2, 0, 3, 0, 4, 0})).item<bool>());
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user