mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
More primitives for compiling with shapeless (#1653)
* more shapeless and more Shape * more shape * fix * fix
This commit is contained in:
@@ -267,6 +267,11 @@ bool Arange::is_equivalent(const Primitive& other) const {
|
||||
step_ == a_other.step_);
|
||||
}
|
||||
|
||||
std::vector<Shape> Arange::output_shapes(const std::vector<array>&) {
|
||||
auto real_size = std::ceil((stop_ - start_) / step_);
|
||||
return {{std::max(static_cast<int>(real_size), 0)}};
|
||||
}
|
||||
|
||||
std::vector<array> ArcCos::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
@@ -534,11 +539,10 @@ std::pair<std::vector<array>, std::vector<int>> ArgSort::vmap(
|
||||
return {{argsort(inputs[0], axis_ + axis_left, stream())}, axes};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int>> ArgReduce::output_shapes(
|
||||
const std::vector<array>& inputs) {
|
||||
std::vector<Shape> ArgReduce::output_shapes(const std::vector<array>& inputs) {
|
||||
auto out_shape = inputs[0].shape();
|
||||
out_shape[axis_] = 1;
|
||||
return {out_shape};
|
||||
return {std::move(out_shape)};
|
||||
}
|
||||
|
||||
bool ArgSort::is_equivalent(const Primitive& other) const {
|
||||
@@ -787,6 +791,23 @@ std::pair<std::vector<array>, std::vector<int>> Eigh::vmap(
|
||||
return {outputs, std::vector<int>(outputs.size(), ax)};
|
||||
}
|
||||
|
||||
std::vector<Shape> Eigh::output_shapes(const std::vector<array>& inputs) {
|
||||
auto shape = inputs[0].shape();
|
||||
shape.pop_back(); // Remove last dimension for eigenvalues
|
||||
if (compute_eigenvectors_) {
|
||||
return {
|
||||
std::move(shape), inputs[0].shape()}; // Eigenvalues and eigenvectors
|
||||
} else {
|
||||
return {std::move(shape)}; // Only eigenvalues
|
||||
}
|
||||
}
|
||||
|
||||
bool Eigh::is_equivalent(const Primitive& other) const {
|
||||
auto& e_other = static_cast<const Eigh&>(other);
|
||||
return uplo_ == e_other.uplo_ &&
|
||||
compute_eigenvectors_ == e_other.compute_eigenvectors_;
|
||||
}
|
||||
|
||||
std::vector<array> Concatenate::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
@@ -881,6 +902,15 @@ bool Concatenate::is_equivalent(const Primitive& other) const {
|
||||
return axis_ == c_other.axis_;
|
||||
}
|
||||
|
||||
std::vector<Shape> Concatenate::output_shapes(
|
||||
const std::vector<array>& inputs) {
|
||||
auto shape = inputs[0].shape();
|
||||
for (int i = 1; i < inputs.size(); ++i) {
|
||||
shape[axis_] += inputs[i].shape(axis_);
|
||||
}
|
||||
return {std::move(shape)};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Conjugate::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
@@ -1811,6 +1841,15 @@ bool Gather::is_equivalent(const Primitive& other) const {
|
||||
return axes_ == g_other.axes_ && slice_sizes_ == g_other.slice_sizes_;
|
||||
}
|
||||
|
||||
std::vector<Shape> Gather::output_shapes(const std::vector<array>& inputs) {
|
||||
Shape out_shape;
|
||||
if (inputs.size() > 1) {
|
||||
out_shape = inputs[0].shape();
|
||||
}
|
||||
out_shape.insert(out_shape.end(), slice_sizes_.begin(), slice_sizes_.end());
|
||||
return {std::move(out_shape)};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Greater::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
@@ -2184,6 +2223,12 @@ std::pair<std::vector<array>, std::vector<int>> Matmul::vmap(
|
||||
return {{matmul(a, b, stream())}, {0}};
|
||||
}
|
||||
|
||||
std::vector<Shape> Matmul::output_shapes(const std::vector<array>& inputs) {
|
||||
auto out_shape = inputs[0].shape();
|
||||
out_shape.back() = inputs[1].shape(-1);
|
||||
return {std::move(out_shape)};
|
||||
}
|
||||
|
||||
std::vector<array> Maximum::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
@@ -2608,6 +2653,15 @@ bool QuantizedMatmul::is_equivalent(const Primitive& other) const {
|
||||
transpose_ == qm_other.transpose_;
|
||||
}
|
||||
|
||||
std::vector<Shape> QuantizedMatmul::output_shapes(
|
||||
const std::vector<array>& inputs) {
|
||||
auto& w = inputs[1];
|
||||
int w_outer_dims = (transpose_) ? w.shape(-2) : w.shape(-1) * 32 / bits_;
|
||||
auto out_shape = inputs[0].shape();
|
||||
out_shape.back() = w_outer_dims;
|
||||
return {std::move(out_shape)};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> GatherQMM::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
@@ -2937,13 +2991,12 @@ bool Reduce::is_equivalent(const Primitive& other) const {
|
||||
return reduce_type_ == r_other.reduce_type_ && axes_ == r_other.axes_;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int>> Reduce::output_shapes(
|
||||
const std::vector<array>& inputs) {
|
||||
std::vector<int> out_shape = inputs[0].shape();
|
||||
std::vector<Shape> Reduce::output_shapes(const std::vector<array>& inputs) {
|
||||
auto out_shape = inputs[0].shape();
|
||||
for (auto i : axes_) {
|
||||
out_shape[i] = 1;
|
||||
}
|
||||
return {out_shape};
|
||||
return {std::move(out_shape)};
|
||||
}
|
||||
|
||||
std::vector<array> Round::vjp(
|
||||
@@ -4209,6 +4262,15 @@ bool Transpose::is_equivalent(const Primitive& other) const {
|
||||
return axes_ == t_other.axes_;
|
||||
}
|
||||
|
||||
std::vector<Shape> Transpose::output_shapes(const std::vector<array>& inputs) {
|
||||
auto& in = inputs[0];
|
||||
Shape shape(in.ndim(), 0);
|
||||
for (int i = 0; i < axes_.size(); ++i) {
|
||||
shape[i] = in.shape()[axes_[i]];
|
||||
}
|
||||
return {std::move(shape)};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> NumberOfElements::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
|
||||
Reference in New Issue
Block a user