mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-03 07:41:13 +08:00
shapeless slice update and broadcast when possible (#1727)
This commit is contained in:
parent
0308e9af71
commit
ebfe64b92d
@ -14,7 +14,7 @@ void slice_gpu(
|
||||
const Shape& start_indices,
|
||||
const Shape& strides,
|
||||
const Stream& s) {
|
||||
// Calculate out strides, initial offset and if copy needs to be made
|
||||
// Calculate out strides and initial offset
|
||||
auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides);
|
||||
|
||||
size_t data_end = 1;
|
||||
|
@ -72,23 +72,6 @@ bool is_fusable(const Primitive& p) {
|
||||
is_noop(p);
|
||||
}
|
||||
|
||||
bool allows_shapeless(const Primitive& p) {
|
||||
return typeid(p) == typeid(Arange) || typeid(p) == typeid(Compiled) ||
|
||||
is_unary(p) || is_binary(p) || is_noop(p) || is_reduction(p) ||
|
||||
typeid(p) == typeid(Softmax) || typeid(p) == typeid(Sort) ||
|
||||
typeid(p) == typeid(ArgSort) || typeid(p) == typeid(ArgPartition) ||
|
||||
typeid(p) == typeid(Partition) || typeid(p) == typeid(Select) ||
|
||||
typeid(p) == typeid(NumberOfElements) || typeid(p) == typeid(Gather) ||
|
||||
typeid(p) == typeid(Transpose) || typeid(p) == typeid(Concatenate) ||
|
||||
typeid(p) == typeid(Matmul) || typeid(p) == typeid(QuantizedMatmul) ||
|
||||
typeid(p) == typeid(Squeeze) || typeid(p) == typeid(ExpandDims) ||
|
||||
typeid(p) == typeid(Flatten) || typeid(p) == typeid(Unflatten) ||
|
||||
typeid(p) == typeid(fast::AffineQuantize) ||
|
||||
typeid(p) == typeid(fast::LayerNorm) ||
|
||||
typeid(p) == typeid(fast::RMSNorm) || typeid(p) == typeid(fast::RoPE) ||
|
||||
typeid(p) == typeid(fast::ScaledDotProductAttention);
|
||||
}
|
||||
|
||||
Compiled::Compiled(
|
||||
Stream stream,
|
||||
std::vector<array> inputs,
|
||||
@ -800,24 +783,6 @@ std::vector<array> compile_replace(
|
||||
return outputs;
|
||||
}
|
||||
|
||||
void compile_validate_shapeless(const std::vector<array>& tape) {
|
||||
for (auto& t : tape) {
|
||||
if (!t.has_primitive()) {
|
||||
continue;
|
||||
}
|
||||
auto& p = t.primitive();
|
||||
if (allows_shapeless(p)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::ostringstream msg;
|
||||
msg << "[compile] Cannot compile primitive ";
|
||||
p.print(msg);
|
||||
msg << " with shapeless enabled.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
bool skip_compile() {
|
||||
return compile_mode() == CompileMode::disabled ||
|
||||
!(compile_available_for_device(default_device()));
|
||||
@ -877,10 +842,6 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
if (compile_mode() != CompileMode::no_fuse) {
|
||||
compile_fuse(entry.tape, parents_map, entry.inputs, entry.outputs);
|
||||
}
|
||||
|
||||
if (shapeless) {
|
||||
compile_validate_shapeless(entry.tape);
|
||||
}
|
||||
}
|
||||
|
||||
// At this point we must have a tape, now replace the placeholders
|
||||
|
@ -740,6 +740,13 @@ bool Broadcast::is_equivalent(const Primitive& other) const {
|
||||
return shape_ == b_other.shape_;
|
||||
}
|
||||
|
||||
std::vector<Shape> Broadcast::output_shapes(const std::vector<array>& inputs) {
|
||||
if (broadcast_shapes(inputs[0].shape(), shape_) != shape_) {
|
||||
throw std::invalid_argument("[Broadcast] Unable to infer broadcast shape");
|
||||
}
|
||||
return {shape_};
|
||||
}
|
||||
|
||||
std::vector<array> Ceil::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
@ -3585,63 +3592,9 @@ std::vector<array> Slice::vjp(
|
||||
const std::vector<array>&) {
|
||||
// Check inputs
|
||||
assert(primals.size() == 1);
|
||||
|
||||
std::vector<array> inds;
|
||||
std::vector<int> ind_axes;
|
||||
std::vector<array> single_inds;
|
||||
std::vector<int> single_ind_axes;
|
||||
for (int i = 0; i < start_indices_.size(); ++i) {
|
||||
auto start = start_indices_[i];
|
||||
auto end = end_indices_[i];
|
||||
auto stride = strides_[i];
|
||||
if (start == 0 && stride == 1) {
|
||||
continue;
|
||||
}
|
||||
if (stride == 1) {
|
||||
single_inds.push_back(array(start));
|
||||
single_ind_axes.push_back(i);
|
||||
} else {
|
||||
inds.push_back(arange(start, end, stride, stream()));
|
||||
ind_axes.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
// Transpose and reshape cotangents
|
||||
auto cotan = cotangents[0];
|
||||
if (!ind_axes.empty()) {
|
||||
Shape cotan_shape;
|
||||
for (auto ax : ind_axes) {
|
||||
cotan_shape.push_back(cotan.shape(ax));
|
||||
}
|
||||
std::vector<int> cotan_axes(ind_axes);
|
||||
for (int j = 0, i = 0; i < cotan.ndim(); ++i) {
|
||||
if (j < ind_axes.size() && ind_axes[j] == i) {
|
||||
cotan_shape.push_back(1);
|
||||
j++;
|
||||
} else {
|
||||
cotan_shape.push_back(cotan.shape(i));
|
||||
cotan_axes.push_back(i);
|
||||
}
|
||||
}
|
||||
cotan =
|
||||
reshape(transpose(cotan, cotan_axes, stream()), cotan_shape, stream());
|
||||
}
|
||||
|
||||
// Make indices broadcastable
|
||||
Shape inds_shape(inds.size(), 1);
|
||||
for (int i = 0; i < inds.size(); ++i) {
|
||||
inds_shape[i] = inds[i].size();
|
||||
inds[i] = reshape(inds[i], inds_shape, stream());
|
||||
inds_shape[i] = 1;
|
||||
}
|
||||
|
||||
// Concatenate all the indices and axes
|
||||
inds.insert(inds.end(), single_inds.begin(), single_inds.end());
|
||||
ind_axes.insert(
|
||||
ind_axes.end(), single_ind_axes.begin(), single_ind_axes.end());
|
||||
|
||||
return {scatter_add(
|
||||
zeros_like(primals[0], stream()), inds, cotan, ind_axes, stream())};
|
||||
auto out = zeros_like(primals[0], stream());
|
||||
return {slice_update(
|
||||
out, cotangents[0], start_indices_, end_indices_, strides_, stream())};
|
||||
}
|
||||
|
||||
std::vector<array> Slice::jvp(
|
||||
|
@ -533,6 +533,8 @@ class Broadcast : public UnaryPrimitive {
|
||||
DEFINE_PRINT(Broadcast)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
|
||||
private:
|
||||
Shape shape_;
|
||||
|
||||
@ -1943,6 +1945,7 @@ class SliceUpdate : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(SliceUpdate)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||
|
||||
private:
|
||||
Shape start_indices_;
|
||||
|
@ -766,7 +766,8 @@ auto mlx_slice_update(
|
||||
const ScalarOrArray& v) {
|
||||
// Can't route to slice update if not slice or tuple
|
||||
if (src.ndim() == 0 ||
|
||||
(!nb::isinstance<nb::slice>(obj) && !nb::isinstance<nb::tuple>(obj))) {
|
||||
(!nb::isinstance<nb::slice>(obj) && !nb::isinstance<nb::tuple>(obj) &&
|
||||
!nb::isinstance<nb::int_>(obj))) {
|
||||
return std::make_pair(false, src);
|
||||
}
|
||||
if (nb::isinstance<nb::tuple>(obj)) {
|
||||
@ -777,7 +778,6 @@ auto mlx_slice_update(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Should be able to route to slice update
|
||||
|
||||
// Pre process tuple
|
||||
@ -797,6 +797,20 @@ auto mlx_slice_update(
|
||||
mx::Shape starts(src.ndim(), 0);
|
||||
mx::Shape stops = src.shape();
|
||||
mx::Shape strides(src.ndim(), 1);
|
||||
if (nb::isinstance<nb::int_>(obj)) {
|
||||
if (src.ndim() < 1) {
|
||||
std::ostringstream msg;
|
||||
msg << "Too many indices for array with " << src.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
auto idx = nb::cast<int>(obj);
|
||||
idx = idx < 0 ? idx + stops[0] : idx;
|
||||
starts[0] = idx;
|
||||
stops[0] = idx + 1;
|
||||
auto out = slice_update(
|
||||
src, up, std::move(starts), std::move(stops), std::move(strides));
|
||||
return std::make_pair(true, out);
|
||||
}
|
||||
|
||||
// If it's just a simple slice, just do a slice update and return
|
||||
if (nb::isinstance<nb::slice>(obj)) {
|
||||
|
@ -817,6 +817,19 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
fun = mx.compile(lambda a, b: a @ b, shapeless=True)
|
||||
self.assertTrue(mx.allclose(fun(a, b), a @ b))
|
||||
|
||||
def test_shapeless_compile_slice_update(self):
|
||||
def fun(x):
|
||||
x[2] = mx.array([3.0])
|
||||
return x
|
||||
|
||||
cfun = mx.compile(fun, shapeless=True)
|
||||
|
||||
a = mx.array([0.0, 1.0, 2.0, 3.0])
|
||||
self.assertTrue(mx.allclose(cfun(a), fun(a)))
|
||||
|
||||
a = mx.array([0.0, 1.0, 2.0, 3.0, 4.0])
|
||||
self.assertTrue(mx.allclose(cfun(a), fun(a)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user