contiguous op / prim (#1612)

This commit is contained in:
Awni Hannun 2024-11-21 19:51:49 -08:00 committed by GitHub
parent 0d5e7716ad
commit dcca0d7477
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 104 additions and 25 deletions

View File

@ -159,6 +159,17 @@ void Conjugate::eval(const std::vector<array>& inputs, array& out) {
} }
} }
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous)) {
out.copy_shared_buffer(in);
} else {
copy(in, out, CopyType::General);
}
}
void Cos::eval(const std::vector<array>& inputs, array& out) { void Cos::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];

View File

@ -170,6 +170,17 @@ void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
concatenate_gpu(inputs, out, axis_, stream()); concatenate_gpu(inputs, out, axis_, stream());
} }
void Contiguous::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous)) {
move_or_copy(in, out);
} else {
copy_gpu(in, out, CopyType::General);
}
}
void Copy::eval_gpu(const std::vector<array>& inputs, array& out) { void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out); eval(inputs, out);
} }

View File

@ -39,6 +39,7 @@ NO_CPU(Ceil)
NO_CPU(Cholesky) NO_CPU(Cholesky)
NO_CPU(Concatenate) NO_CPU(Concatenate)
NO_CPU(Conjugate) NO_CPU(Conjugate)
NO_CPU(Contiguous)
NO_CPU(Convolution) NO_CPU(Convolution)
NO_CPU(Copy) NO_CPU(Copy)
NO_CPU(Cos) NO_CPU(Cos)

View File

@ -40,6 +40,7 @@ NO_GPU(Ceil)
NO_GPU_MULTI(Compiled) NO_GPU_MULTI(Compiled)
NO_GPU(Concatenate) NO_GPU(Concatenate)
NO_GPU(Conjugate) NO_GPU(Conjugate)
NO_GPU(Contiguous)
NO_GPU(Convolution) NO_GPU(Convolution)
NO_GPU(Copy) NO_GPU(Copy)
NO_GPU(Cos) NO_GPU(Cos)

View File

@ -107,22 +107,13 @@ void save(std::shared_ptr<io::Writer> out_stream, array a) {
//////////////////////////////////////////////////////// ////////////////////////////////////////////////////////
// Check array // Check array
a = contiguous(a, true);
a.eval(); a.eval();
if (a.nbytes() == 0) { if (a.nbytes() == 0) {
throw std::invalid_argument("[save] cannot serialize an empty array"); throw std::invalid_argument("[save] cannot serialize an empty array");
} }
if (!(a.flags().row_contiguous || a.flags().col_contiguous)) {
a = reshape(flatten(a), a.shape());
a.eval();
}
// Check once more in-case the above ops change
if (!(a.flags().row_contiguous || a.flags().col_contiguous)) {
throw std::invalid_argument(
"[save] can only serialize row or col contiguous arrays");
}
//////////////////////////////////////////////////////// ////////////////////////////////////////////////////////
// Check file // Check file
if (!out_stream->good() || !out_stream->is_open()) { if (!out_stream->good() || !out_stream->is_open()) {

View File

@ -179,8 +179,9 @@ void save_safetensors(
{ {
std::vector<array> to_eval; std::vector<array> to_eval;
to_eval.reserve(a.size()); to_eval.reserve(a.size());
for (auto& [_, arr] : a) { for (auto& p : a) {
to_eval.push_back(arr); p.second = contiguous(p.second);
to_eval.push_back(p.second);
} }
eval(std::move(to_eval)); eval(std::move(to_eval));
} }
@ -192,19 +193,6 @@ void save_safetensors(
"[save_safetensors] cannot serialize an empty array key: " + key); "[save_safetensors] cannot serialize an empty array key: " + key);
} }
// Try to make it row contiguous
if (!arr.flags().row_contiguous) {
arr = reshape(flatten(arr), arr.shape());
arr.eval();
}
// Has to be row-major now but, check one more time in case
// any of the above change in the future
if (!arr.flags().row_contiguous) {
throw std::invalid_argument(
"[save_safetensors] can only serialize row-major arrays");
}
json child; json child;
child["dtype"] = dtype_to_safetensor_str(arr.dtype()); child["dtype"] = dtype_to_safetensor_str(arr.dtype());
child["shape"] = arr.shape(); child["shape"] = arr.shape();

View File

@ -4700,4 +4700,15 @@ array imag(const array& a, StreamOrDevice s /* = {} */) {
return array(a.shape(), float32, std::make_shared<Imag>(to_stream(s)), {a}); return array(a.shape(), float32, std::make_shared<Imag>(to_stream(s)), {a});
} }
array contiguous(
const array& a,
bool allow_col_major /* = false */,
StreamOrDevice s /* = {} */) {
return array(
a.shape(),
a.dtype(),
std::make_shared<Contiguous>(to_stream(s), allow_col_major),
{a});
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -1486,6 +1486,12 @@ array real(const array& a, StreamOrDevice s = {});
/* The imaginary part of a complex array. */ /* The imaginary part of a complex array. */
array imag(const array& a, StreamOrDevice s = {}); array imag(const array& a, StreamOrDevice s = {});
/* Ensure the array's underlying memory is contiguous. */
array contiguous(
const array& a,
bool allow_col_major = false,
StreamOrDevice s = {});
/** @} */ /** @} */
} // namespace mlx::core } // namespace mlx::core

View File

@ -889,6 +889,32 @@ std::pair<std::vector<array>, std::vector<int>> Conjugate::vmap(
return {{conjugate(inputs[0], stream())}, axes}; return {{conjugate(inputs[0], stream())}, axes};
} }
std::vector<array> Contiguous::vjp(
const std::vector<array>&,
const std::vector<array>& cotangents,
const std::vector<int>&,
const std::vector<array>&) {
return {cotangents};
}
std::vector<array> Contiguous::jvp(
const std::vector<array>&,
const std::vector<array>& tangents,
const std::vector<int>&) {
return {tangents};
}
std::pair<std::vector<array>, std::vector<int>> Contiguous::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
return {{contiguous(inputs[0], allow_col_major_, stream())}, axes};
}
bool Contiguous::is_equivalent(const Primitive& other) const {
const Contiguous& c_other = static_cast<const Contiguous&>(other);
return allow_col_major_ == c_other.allow_col_major_;
}
array conv_weight_backward_patches( array conv_weight_backward_patches(
const array& in, const array& in,
const array& wt, const array& wt,

View File

@ -639,6 +639,25 @@ class Conjugate : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
}; };
class Contiguous : public UnaryPrimitive {
public:
explicit Contiguous(Stream stream, bool allow_col_major)
: UnaryPrimitive(stream), allow_col_major_(allow_col_major) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(Contiguous)
DEFINE_INPUT_OUTPUT_SHAPE()
bool is_equivalent(const Primitive& other) const override;
private:
bool allow_col_major_;
};
class Convolution : public UnaryPrimitive { class Convolution : public UnaryPrimitive {
public: public:
explicit Convolution( explicit Convolution(

View File

@ -3747,3 +3747,17 @@ TEST_CASE("test roll") {
CHECK(array_equal(y, array({8, 9, 5, 6, 7, 3, 4, 0, 1, 2}, {2, 5})) CHECK(array_equal(y, array({8, 9, 5, 6, 7, 3, 4, 0, 1, 2}, {2, 5}))
.item<bool>()); .item<bool>());
} }
TEST_CASE("test contiguous") {
auto x = array({1, 2, 3});
x = contiguous(broadcast_to(x, {2, 2, 3}));
eval(x);
CHECK(x.flags().row_contiguous);
CHECK_EQ(x.strides(), decltype(x.strides()){6, 3, 1});
x = array({1, 2, 1, 2}, {2, 2});
x = contiguous(transpose(x), true);
eval(x);
CHECK(x.flags().col_contiguous);
CHECK_EQ(x.strides(), decltype(x.strides()){1, 2});
}