contiguous op / prim (#1612)

This commit is contained in:
Awni Hannun 2024-11-21 19:51:49 -08:00 committed by GitHub
parent 0d5e7716ad
commit dcca0d7477
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 104 additions and 25 deletions

View File

@ -159,6 +159,17 @@ void Conjugate::eval(const std::vector<array>& inputs, array& out) {
}
}
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous)) {
out.copy_shared_buffer(in);
} else {
copy(in, out, CopyType::General);
}
}
void Cos::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];

View File

@ -170,6 +170,17 @@ void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
concatenate_gpu(inputs, out, axis_, stream());
}
void Contiguous::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous)) {
move_or_copy(in, out);
} else {
copy_gpu(in, out, CopyType::General);
}
}
void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}

View File

@ -39,6 +39,7 @@ NO_CPU(Ceil)
NO_CPU(Cholesky)
NO_CPU(Concatenate)
NO_CPU(Conjugate)
NO_CPU(Contiguous)
NO_CPU(Convolution)
NO_CPU(Copy)
NO_CPU(Cos)

View File

@ -40,6 +40,7 @@ NO_GPU(Ceil)
NO_GPU_MULTI(Compiled)
NO_GPU(Concatenate)
NO_GPU(Conjugate)
NO_GPU(Contiguous)
NO_GPU(Convolution)
NO_GPU(Copy)
NO_GPU(Cos)

View File

@ -107,22 +107,13 @@ void save(std::shared_ptr<io::Writer> out_stream, array a) {
////////////////////////////////////////////////////////
// Check array
a = contiguous(a, true);
a.eval();
if (a.nbytes() == 0) {
throw std::invalid_argument("[save] cannot serialize an empty array");
}
if (!(a.flags().row_contiguous || a.flags().col_contiguous)) {
a = reshape(flatten(a), a.shape());
a.eval();
}
// Check once more in-case the above ops change
if (!(a.flags().row_contiguous || a.flags().col_contiguous)) {
throw std::invalid_argument(
"[save] can only serialize row or col contiguous arrays");
}
////////////////////////////////////////////////////////
// Check file
if (!out_stream->good() || !out_stream->is_open()) {

View File

@ -179,8 +179,9 @@ void save_safetensors(
{
std::vector<array> to_eval;
to_eval.reserve(a.size());
for (auto& [_, arr] : a) {
to_eval.push_back(arr);
for (auto& p : a) {
p.second = contiguous(p.second);
to_eval.push_back(p.second);
}
eval(std::move(to_eval));
}
@ -192,19 +193,6 @@ void save_safetensors(
"[save_safetensors] cannot serialize an empty array key: " + key);
}
// Try to make it row contiguous
if (!arr.flags().row_contiguous) {
arr = reshape(flatten(arr), arr.shape());
arr.eval();
}
// Has to be row-major now but, check one more time in case
// any of the above change in the future
if (!arr.flags().row_contiguous) {
throw std::invalid_argument(
"[save_safetensors] can only serialize row-major arrays");
}
json child;
child["dtype"] = dtype_to_safetensor_str(arr.dtype());
child["shape"] = arr.shape();

View File

@ -4700,4 +4700,15 @@ array imag(const array& a, StreamOrDevice s /* = {} */) {
return array(a.shape(), float32, std::make_shared<Imag>(to_stream(s)), {a});
}
array contiguous(
const array& a,
bool allow_col_major /* = false */,
StreamOrDevice s /* = {} */) {
return array(
a.shape(),
a.dtype(),
std::make_shared<Contiguous>(to_stream(s), allow_col_major),
{a});
}
} // namespace mlx::core

View File

@ -1486,6 +1486,12 @@ array real(const array& a, StreamOrDevice s = {});
/* The imaginary part of a complex array. */
array imag(const array& a, StreamOrDevice s = {});
/* Ensure the array's underlying memory is contiguous. */
array contiguous(
const array& a,
bool allow_col_major = false,
StreamOrDevice s = {});
/** @} */
} // namespace mlx::core

View File

@ -889,6 +889,32 @@ std::pair<std::vector<array>, std::vector<int>> Conjugate::vmap(
return {{conjugate(inputs[0], stream())}, axes};
}
std::vector<array> Contiguous::vjp(
const std::vector<array>&,
const std::vector<array>& cotangents,
const std::vector<int>&,
const std::vector<array>&) {
return {cotangents};
}
std::vector<array> Contiguous::jvp(
const std::vector<array>&,
const std::vector<array>& tangents,
const std::vector<int>&) {
return {tangents};
}
std::pair<std::vector<array>, std::vector<int>> Contiguous::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
return {{contiguous(inputs[0], allow_col_major_, stream())}, axes};
}
bool Contiguous::is_equivalent(const Primitive& other) const {
const Contiguous& c_other = static_cast<const Contiguous&>(other);
return allow_col_major_ == c_other.allow_col_major_;
}
array conv_weight_backward_patches(
const array& in,
const array& wt,

View File

@ -639,6 +639,25 @@ class Conjugate : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};
class Contiguous : public UnaryPrimitive {
public:
explicit Contiguous(Stream stream, bool allow_col_major)
: UnaryPrimitive(stream), allow_col_major_(allow_col_major) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(Contiguous)
DEFINE_INPUT_OUTPUT_SHAPE()
bool is_equivalent(const Primitive& other) const override;
private:
bool allow_col_major_;
};
class Convolution : public UnaryPrimitive {
public:
explicit Convolution(

View File

@ -3747,3 +3747,17 @@ TEST_CASE("test roll") {
CHECK(array_equal(y, array({8, 9, 5, 6, 7, 3, 4, 0, 1, 2}, {2, 5}))
.item<bool>());
}
TEST_CASE("test contiguous") {
auto x = array({1, 2, 3});
x = contiguous(broadcast_to(x, {2, 2, 3}));
eval(x);
CHECK(x.flags().row_contiguous);
CHECK_EQ(x.strides(), decltype(x.strides()){6, 3, 1});
x = array({1, 2, 1, 2}, {2, 2});
x = contiguous(transpose(x), true);
eval(x);
CHECK(x.flags().col_contiguous);
CHECK_EQ(x.strides(), decltype(x.strides()){1, 2});
}