mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
contiguous op / prim (#1612)
This commit is contained in:
parent
0d5e7716ad
commit
dcca0d7477
@ -159,6 +159,17 @@ void Conjugate::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (in.flags().row_contiguous ||
|
||||
(allow_col_major_ && in.flags().col_contiguous)) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
copy(in, out, CopyType::General);
|
||||
}
|
||||
}
|
||||
|
||||
void Cos::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
|
@ -170,6 +170,17 @@ void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
concatenate_gpu(inputs, out, axis_, stream());
|
||||
}
|
||||
|
||||
void Contiguous::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (in.flags().row_contiguous ||
|
||||
(allow_col_major_ && in.flags().col_contiguous)) {
|
||||
move_or_copy(in, out);
|
||||
} else {
|
||||
copy_gpu(in, out, CopyType::General);
|
||||
}
|
||||
}
|
||||
|
||||
void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
@ -39,6 +39,7 @@ NO_CPU(Ceil)
|
||||
NO_CPU(Cholesky)
|
||||
NO_CPU(Concatenate)
|
||||
NO_CPU(Conjugate)
|
||||
NO_CPU(Contiguous)
|
||||
NO_CPU(Convolution)
|
||||
NO_CPU(Copy)
|
||||
NO_CPU(Cos)
|
||||
|
@ -40,6 +40,7 @@ NO_GPU(Ceil)
|
||||
NO_GPU_MULTI(Compiled)
|
||||
NO_GPU(Concatenate)
|
||||
NO_GPU(Conjugate)
|
||||
NO_GPU(Contiguous)
|
||||
NO_GPU(Convolution)
|
||||
NO_GPU(Copy)
|
||||
NO_GPU(Cos)
|
||||
|
@ -107,22 +107,13 @@ void save(std::shared_ptr<io::Writer> out_stream, array a) {
|
||||
////////////////////////////////////////////////////////
|
||||
// Check array
|
||||
|
||||
a = contiguous(a, true);
|
||||
a.eval();
|
||||
|
||||
if (a.nbytes() == 0) {
|
||||
throw std::invalid_argument("[save] cannot serialize an empty array");
|
||||
}
|
||||
|
||||
if (!(a.flags().row_contiguous || a.flags().col_contiguous)) {
|
||||
a = reshape(flatten(a), a.shape());
|
||||
a.eval();
|
||||
}
|
||||
// Check once more in-case the above ops change
|
||||
if (!(a.flags().row_contiguous || a.flags().col_contiguous)) {
|
||||
throw std::invalid_argument(
|
||||
"[save] can only serialize row or col contiguous arrays");
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////
|
||||
// Check file
|
||||
if (!out_stream->good() || !out_stream->is_open()) {
|
||||
|
@ -179,8 +179,9 @@ void save_safetensors(
|
||||
{
|
||||
std::vector<array> to_eval;
|
||||
to_eval.reserve(a.size());
|
||||
for (auto& [_, arr] : a) {
|
||||
to_eval.push_back(arr);
|
||||
for (auto& p : a) {
|
||||
p.second = contiguous(p.second);
|
||||
to_eval.push_back(p.second);
|
||||
}
|
||||
eval(std::move(to_eval));
|
||||
}
|
||||
@ -192,19 +193,6 @@ void save_safetensors(
|
||||
"[save_safetensors] cannot serialize an empty array key: " + key);
|
||||
}
|
||||
|
||||
// Try to make it row contiguous
|
||||
if (!arr.flags().row_contiguous) {
|
||||
arr = reshape(flatten(arr), arr.shape());
|
||||
arr.eval();
|
||||
}
|
||||
|
||||
// Has to be row-major now but, check one more time in case
|
||||
// any of the above change in the future
|
||||
if (!arr.flags().row_contiguous) {
|
||||
throw std::invalid_argument(
|
||||
"[save_safetensors] can only serialize row-major arrays");
|
||||
}
|
||||
|
||||
json child;
|
||||
child["dtype"] = dtype_to_safetensor_str(arr.dtype());
|
||||
child["shape"] = arr.shape();
|
||||
|
11
mlx/ops.cpp
11
mlx/ops.cpp
@ -4700,4 +4700,15 @@ array imag(const array& a, StreamOrDevice s /* = {} */) {
|
||||
return array(a.shape(), float32, std::make_shared<Imag>(to_stream(s)), {a});
|
||||
}
|
||||
|
||||
array contiguous(
|
||||
const array& a,
|
||||
bool allow_col_major /* = false */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return array(
|
||||
a.shape(),
|
||||
a.dtype(),
|
||||
std::make_shared<Contiguous>(to_stream(s), allow_col_major),
|
||||
{a});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -1486,6 +1486,12 @@ array real(const array& a, StreamOrDevice s = {});
|
||||
/* The imaginary part of a complex array. */
|
||||
array imag(const array& a, StreamOrDevice s = {});
|
||||
|
||||
/* Ensure the array's underlying memory is contiguous. */
|
||||
array contiguous(
|
||||
const array& a,
|
||||
bool allow_col_major = false,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** @} */
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -889,6 +889,32 @@ std::pair<std::vector<array>, std::vector<int>> Conjugate::vmap(
|
||||
return {{conjugate(inputs[0], stream())}, axes};
|
||||
}
|
||||
|
||||
std::vector<array> Contiguous::vjp(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>&,
|
||||
const std::vector<array>&) {
|
||||
return {cotangents};
|
||||
}
|
||||
|
||||
std::vector<array> Contiguous::jvp(
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>&) {
|
||||
return {tangents};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Contiguous::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
return {{contiguous(inputs[0], allow_col_major_, stream())}, axes};
|
||||
}
|
||||
|
||||
bool Contiguous::is_equivalent(const Primitive& other) const {
|
||||
const Contiguous& c_other = static_cast<const Contiguous&>(other);
|
||||
return allow_col_major_ == c_other.allow_col_major_;
|
||||
}
|
||||
|
||||
array conv_weight_backward_patches(
|
||||
const array& in,
|
||||
const array& wt,
|
||||
|
@ -639,6 +639,25 @@ class Conjugate : public UnaryPrimitive {
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class Contiguous : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Contiguous(Stream stream, bool allow_col_major)
|
||||
: UnaryPrimitive(stream), allow_col_major_(allow_col_major) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Contiguous)
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
bool allow_col_major_;
|
||||
};
|
||||
|
||||
class Convolution : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Convolution(
|
||||
|
@ -3747,3 +3747,17 @@ TEST_CASE("test roll") {
|
||||
CHECK(array_equal(y, array({8, 9, 5, 6, 7, 3, 4, 0, 1, 2}, {2, 5}))
|
||||
.item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test contiguous") {
|
||||
auto x = array({1, 2, 3});
|
||||
x = contiguous(broadcast_to(x, {2, 2, 3}));
|
||||
eval(x);
|
||||
CHECK(x.flags().row_contiguous);
|
||||
CHECK_EQ(x.strides(), decltype(x.strides()){6, 3, 1});
|
||||
|
||||
x = array({1, 2, 1, 2}, {2, 2});
|
||||
x = contiguous(transpose(x), true);
|
||||
eval(x);
|
||||
CHECK(x.flags().col_contiguous);
|
||||
CHECK_EQ(x.strides(), decltype(x.strides()){1, 2});
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user