mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 02:33:21 +08:00
contiguous op / prim (#1612)
This commit is contained in:
parent
0d5e7716ad
commit
dcca0d7477
@ -159,6 +159,17 @@ void Conjugate::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
if (in.flags().row_contiguous ||
|
||||||
|
(allow_col_major_ && in.flags().col_contiguous)) {
|
||||||
|
out.copy_shared_buffer(in);
|
||||||
|
} else {
|
||||||
|
copy(in, out, CopyType::General);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void Cos::eval(const std::vector<array>& inputs, array& out) {
|
void Cos::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
|
@ -170,6 +170,17 @@ void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
concatenate_gpu(inputs, out, axis_, stream());
|
concatenate_gpu(inputs, out, axis_, stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Contiguous::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
auto& in = inputs[0];
|
||||||
|
if (in.flags().row_contiguous ||
|
||||||
|
(allow_col_major_ && in.flags().col_contiguous)) {
|
||||||
|
move_or_copy(in, out);
|
||||||
|
} else {
|
||||||
|
copy_gpu(in, out, CopyType::General);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
eval(inputs, out);
|
eval(inputs, out);
|
||||||
}
|
}
|
||||||
|
@ -39,6 +39,7 @@ NO_CPU(Ceil)
|
|||||||
NO_CPU(Cholesky)
|
NO_CPU(Cholesky)
|
||||||
NO_CPU(Concatenate)
|
NO_CPU(Concatenate)
|
||||||
NO_CPU(Conjugate)
|
NO_CPU(Conjugate)
|
||||||
|
NO_CPU(Contiguous)
|
||||||
NO_CPU(Convolution)
|
NO_CPU(Convolution)
|
||||||
NO_CPU(Copy)
|
NO_CPU(Copy)
|
||||||
NO_CPU(Cos)
|
NO_CPU(Cos)
|
||||||
|
@ -40,6 +40,7 @@ NO_GPU(Ceil)
|
|||||||
NO_GPU_MULTI(Compiled)
|
NO_GPU_MULTI(Compiled)
|
||||||
NO_GPU(Concatenate)
|
NO_GPU(Concatenate)
|
||||||
NO_GPU(Conjugate)
|
NO_GPU(Conjugate)
|
||||||
|
NO_GPU(Contiguous)
|
||||||
NO_GPU(Convolution)
|
NO_GPU(Convolution)
|
||||||
NO_GPU(Copy)
|
NO_GPU(Copy)
|
||||||
NO_GPU(Cos)
|
NO_GPU(Cos)
|
||||||
|
@ -107,22 +107,13 @@ void save(std::shared_ptr<io::Writer> out_stream, array a) {
|
|||||||
////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////
|
||||||
// Check array
|
// Check array
|
||||||
|
|
||||||
|
a = contiguous(a, true);
|
||||||
a.eval();
|
a.eval();
|
||||||
|
|
||||||
if (a.nbytes() == 0) {
|
if (a.nbytes() == 0) {
|
||||||
throw std::invalid_argument("[save] cannot serialize an empty array");
|
throw std::invalid_argument("[save] cannot serialize an empty array");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!(a.flags().row_contiguous || a.flags().col_contiguous)) {
|
|
||||||
a = reshape(flatten(a), a.shape());
|
|
||||||
a.eval();
|
|
||||||
}
|
|
||||||
// Check once more in-case the above ops change
|
|
||||||
if (!(a.flags().row_contiguous || a.flags().col_contiguous)) {
|
|
||||||
throw std::invalid_argument(
|
|
||||||
"[save] can only serialize row or col contiguous arrays");
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////
|
||||||
// Check file
|
// Check file
|
||||||
if (!out_stream->good() || !out_stream->is_open()) {
|
if (!out_stream->good() || !out_stream->is_open()) {
|
||||||
|
@ -179,8 +179,9 @@ void save_safetensors(
|
|||||||
{
|
{
|
||||||
std::vector<array> to_eval;
|
std::vector<array> to_eval;
|
||||||
to_eval.reserve(a.size());
|
to_eval.reserve(a.size());
|
||||||
for (auto& [_, arr] : a) {
|
for (auto& p : a) {
|
||||||
to_eval.push_back(arr);
|
p.second = contiguous(p.second);
|
||||||
|
to_eval.push_back(p.second);
|
||||||
}
|
}
|
||||||
eval(std::move(to_eval));
|
eval(std::move(to_eval));
|
||||||
}
|
}
|
||||||
@ -192,19 +193,6 @@ void save_safetensors(
|
|||||||
"[save_safetensors] cannot serialize an empty array key: " + key);
|
"[save_safetensors] cannot serialize an empty array key: " + key);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to make it row contiguous
|
|
||||||
if (!arr.flags().row_contiguous) {
|
|
||||||
arr = reshape(flatten(arr), arr.shape());
|
|
||||||
arr.eval();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Has to be row-major now but, check one more time in case
|
|
||||||
// any of the above change in the future
|
|
||||||
if (!arr.flags().row_contiguous) {
|
|
||||||
throw std::invalid_argument(
|
|
||||||
"[save_safetensors] can only serialize row-major arrays");
|
|
||||||
}
|
|
||||||
|
|
||||||
json child;
|
json child;
|
||||||
child["dtype"] = dtype_to_safetensor_str(arr.dtype());
|
child["dtype"] = dtype_to_safetensor_str(arr.dtype());
|
||||||
child["shape"] = arr.shape();
|
child["shape"] = arr.shape();
|
||||||
|
11
mlx/ops.cpp
11
mlx/ops.cpp
@ -4700,4 +4700,15 @@ array imag(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
return array(a.shape(), float32, std::make_shared<Imag>(to_stream(s)), {a});
|
return array(a.shape(), float32, std::make_shared<Imag>(to_stream(s)), {a});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array contiguous(
|
||||||
|
const array& a,
|
||||||
|
bool allow_col_major /* = false */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
return array(
|
||||||
|
a.shape(),
|
||||||
|
a.dtype(),
|
||||||
|
std::make_shared<Contiguous>(to_stream(s), allow_col_major),
|
||||||
|
{a});
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -1486,6 +1486,12 @@ array real(const array& a, StreamOrDevice s = {});
|
|||||||
/* The imaginary part of a complex array. */
|
/* The imaginary part of a complex array. */
|
||||||
array imag(const array& a, StreamOrDevice s = {});
|
array imag(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/* Ensure the array's underlying memory is contiguous. */
|
||||||
|
array contiguous(
|
||||||
|
const array& a,
|
||||||
|
bool allow_col_major = false,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** @} */
|
/** @} */
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -889,6 +889,32 @@ std::pair<std::vector<array>, std::vector<int>> Conjugate::vmap(
|
|||||||
return {{conjugate(inputs[0], stream())}, axes};
|
return {{conjugate(inputs[0], stream())}, axes};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<array> Contiguous::vjp(
|
||||||
|
const std::vector<array>&,
|
||||||
|
const std::vector<array>& cotangents,
|
||||||
|
const std::vector<int>&,
|
||||||
|
const std::vector<array>&) {
|
||||||
|
return {cotangents};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> Contiguous::jvp(
|
||||||
|
const std::vector<array>&,
|
||||||
|
const std::vector<array>& tangents,
|
||||||
|
const std::vector<int>&) {
|
||||||
|
return {tangents};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<array>, std::vector<int>> Contiguous::vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
return {{contiguous(inputs[0], allow_col_major_, stream())}, axes};
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Contiguous::is_equivalent(const Primitive& other) const {
|
||||||
|
const Contiguous& c_other = static_cast<const Contiguous&>(other);
|
||||||
|
return allow_col_major_ == c_other.allow_col_major_;
|
||||||
|
}
|
||||||
|
|
||||||
array conv_weight_backward_patches(
|
array conv_weight_backward_patches(
|
||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
|
@ -639,6 +639,25 @@ class Conjugate : public UnaryPrimitive {
|
|||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class Contiguous : public UnaryPrimitive {
|
||||||
|
public:
|
||||||
|
explicit Contiguous(Stream stream, bool allow_col_major)
|
||||||
|
: UnaryPrimitive(stream), allow_col_major_(allow_col_major) {}
|
||||||
|
|
||||||
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
|
||||||
|
DEFINE_VMAP()
|
||||||
|
DEFINE_GRADS()
|
||||||
|
DEFINE_PRINT(Contiguous)
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool allow_col_major_;
|
||||||
|
};
|
||||||
|
|
||||||
class Convolution : public UnaryPrimitive {
|
class Convolution : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Convolution(
|
explicit Convolution(
|
||||||
|
@ -3747,3 +3747,17 @@ TEST_CASE("test roll") {
|
|||||||
CHECK(array_equal(y, array({8, 9, 5, 6, 7, 3, 4, 0, 1, 2}, {2, 5}))
|
CHECK(array_equal(y, array({8, 9, 5, 6, 7, 3, 4, 0, 1, 2}, {2, 5}))
|
||||||
.item<bool>());
|
.item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test contiguous") {
|
||||||
|
auto x = array({1, 2, 3});
|
||||||
|
x = contiguous(broadcast_to(x, {2, 2, 3}));
|
||||||
|
eval(x);
|
||||||
|
CHECK(x.flags().row_contiguous);
|
||||||
|
CHECK_EQ(x.strides(), decltype(x.strides()){6, 3, 1});
|
||||||
|
|
||||||
|
x = array({1, 2, 1, 2}, {2, 2});
|
||||||
|
x = contiguous(transpose(x), true);
|
||||||
|
eval(x);
|
||||||
|
CHECK(x.flags().col_contiguous);
|
||||||
|
CHECK_EQ(x.strides(), decltype(x.strides()){1, 2});
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user