mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 18:11:15 +08:00
More primitives for compiling with shapeless (#1653)
* more shapeless and more Shape * more shape * fix * fix
This commit is contained in:
parent
95c4a2e3af
commit
d0b6cb0425
@ -1,5 +1,4 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
@ -8,6 +7,7 @@
|
|||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/compile.h"
|
#include "mlx/compile.h"
|
||||||
#include "mlx/compile_impl.h"
|
#include "mlx/compile_impl.h"
|
||||||
|
#include "mlx/fast_primitives.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/transforms.h"
|
#include "mlx/transforms.h"
|
||||||
#include "mlx/transforms_impl.h"
|
#include "mlx/transforms_impl.h"
|
||||||
@ -73,11 +73,18 @@ bool is_fusable(const Primitive& p) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool allows_shapeless(const Primitive& p) {
|
bool allows_shapeless(const Primitive& p) {
|
||||||
return typeid(p) == typeid(Compiled) || is_unary(p) || is_binary(p) ||
|
return typeid(p) == typeid(Arange) || typeid(p) == typeid(Compiled) ||
|
||||||
is_noop(p) || is_reduction(p) || typeid(p) == typeid(Softmax) ||
|
is_unary(p) || is_binary(p) || is_noop(p) || is_reduction(p) ||
|
||||||
typeid(p) == typeid(Sort) || typeid(p) == typeid(ArgSort) ||
|
typeid(p) == typeid(Softmax) || typeid(p) == typeid(Sort) ||
|
||||||
typeid(p) == typeid(ArgPartition) || typeid(p) == typeid(Partition) ||
|
typeid(p) == typeid(ArgSort) || typeid(p) == typeid(ArgPartition) ||
|
||||||
typeid(p) == typeid(Select) || typeid(p) == typeid(NumberOfElements);
|
typeid(p) == typeid(Partition) || typeid(p) == typeid(Select) ||
|
||||||
|
typeid(p) == typeid(NumberOfElements) || typeid(p) == typeid(Gather) ||
|
||||||
|
typeid(p) == typeid(Transpose) || typeid(p) == typeid(Concatenate) ||
|
||||||
|
typeid(p) == typeid(Matmul) || typeid(p) == typeid(QuantizedMatmul) ||
|
||||||
|
typeid(p) == typeid(fast::AffineQuantize) ||
|
||||||
|
typeid(p) == typeid(fast::LayerNorm) ||
|
||||||
|
typeid(p) == typeid(fast::RMSNorm) || typeid(p) == typeid(fast::RoPE) ||
|
||||||
|
typeid(p) == typeid(fast::ScaledDotProductAttention);
|
||||||
}
|
}
|
||||||
|
|
||||||
Compiled::Compiled(
|
Compiled::Compiled(
|
||||||
@ -93,23 +100,23 @@ Compiled::Compiled(
|
|||||||
constant_ids_(std::move(constant_ids)) {}
|
constant_ids_(std::move(constant_ids)) {}
|
||||||
|
|
||||||
std::vector<array> Compiled::vjp(
|
std::vector<array> Compiled::vjp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>&,
|
||||||
const std::vector<array>& cotangents,
|
const std::vector<array>&,
|
||||||
const std::vector<int>& argnums,
|
const std::vector<int>&,
|
||||||
const std::vector<array>& outputs) {
|
const std::vector<array>&) {
|
||||||
throw std::runtime_error("[Compiled] Cannot vjp primitive.");
|
throw std::runtime_error("[Compiled] Cannot vjp primitive.");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> Compiled::jvp(
|
std::vector<array> Compiled::jvp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>&,
|
||||||
const std::vector<array>& tangents,
|
const std::vector<array>&,
|
||||||
const std::vector<int>& argnums) {
|
const std::vector<int>&) {
|
||||||
throw std::runtime_error("[Compiled] Cannot jvp primitive.");
|
throw std::runtime_error("[Compiled] Cannot jvp primitive.");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<std::vector<array>, std::vector<int>> Compiled::vmap(
|
std::pair<std::vector<array>, std::vector<int>> Compiled::vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>&,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>&) {
|
||||||
throw std::runtime_error("[Compiled] Cannot vmap primitive.");
|
throw std::runtime_error("[Compiled] Cannot vmap primitive.");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -134,13 +141,12 @@ void Compiled::print(std::ostream& os) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::vector<int>> Compiled::output_shapes(
|
std::vector<Shape> Compiled::output_shapes(const std::vector<array>& inputs) {
|
||||||
const std::vector<array>& inputs) {
|
|
||||||
size_t nd = 0;
|
size_t nd = 0;
|
||||||
for (auto& in : inputs) {
|
for (auto& in : inputs) {
|
||||||
nd = std::max(nd, in.ndim());
|
nd = std::max(nd, in.ndim());
|
||||||
}
|
}
|
||||||
std::vector<int> out_shape(nd, 0);
|
Shape out_shape(nd, 0);
|
||||||
for (auto& in : inputs) {
|
for (auto& in : inputs) {
|
||||||
auto dd = nd - in.ndim();
|
auto dd = nd - in.ndim();
|
||||||
for (auto i = dd; i < nd; ++i) {
|
for (auto i = dd; i < nd; ++i) {
|
||||||
@ -148,7 +154,7 @@ std::vector<std::vector<int>> Compiled::output_shapes(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// All outputs have the same shape
|
// All outputs have the same shape
|
||||||
return std::vector<std::vector<int>>(outputs_.size(), out_shape);
|
return std::vector<Shape>(outputs_.size(), out_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace detail {
|
namespace detail {
|
||||||
@ -553,14 +559,12 @@ void compile_fuse(
|
|||||||
// - Collect inputs to the new compiled primitive
|
// - Collect inputs to the new compiled primitive
|
||||||
// - Add fusable primitives to a tape in the correct order
|
// - Add fusable primitives to a tape in the correct order
|
||||||
|
|
||||||
std::function<void(
|
std::function<void(const array&, int, const Stream&, const Shape&)> recurse;
|
||||||
const array&, int, const Stream&, const std::vector<int>&)>
|
|
||||||
recurse;
|
|
||||||
std::unordered_set<uintptr_t> cache;
|
std::unordered_set<uintptr_t> cache;
|
||||||
recurse = [&](const array& a,
|
recurse = [&](const array& a,
|
||||||
int depth,
|
int depth,
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
const std::vector<int>& shape) {
|
const Shape& shape) {
|
||||||
if (cache.find(a.id()) != cache.end()) {
|
if (cache.find(a.id()) != cache.end()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -667,7 +671,7 @@ void compile_fuse(
|
|||||||
}
|
}
|
||||||
old_outputs.push_back(arr);
|
old_outputs.push_back(arr);
|
||||||
|
|
||||||
std::vector<std::vector<int>> shapes;
|
std::vector<Shape> shapes;
|
||||||
std::vector<Dtype> types;
|
std::vector<Dtype> types;
|
||||||
for (auto& o : old_outputs) {
|
for (auto& o : old_outputs) {
|
||||||
if (o.shape() != old_outputs.back().shape()) {
|
if (o.shape() != old_outputs.back().shape()) {
|
||||||
@ -771,7 +775,7 @@ std::vector<array> compile_replace(
|
|||||||
for (auto& o : trace_out) {
|
for (auto& o : trace_out) {
|
||||||
types.push_back(o.dtype());
|
types.push_back(o.dtype());
|
||||||
}
|
}
|
||||||
std::vector<std::vector<int>> shapes;
|
std::vector<Shape> shapes;
|
||||||
if (shapeless) {
|
if (shapeless) {
|
||||||
shapes = a.primitive().output_shapes(real_inputs);
|
shapes = a.primitive().output_shapes(real_inputs);
|
||||||
} else {
|
} else {
|
||||||
|
25
mlx/fast.cpp
25
mlx/fast.cpp
@ -915,6 +915,31 @@ array affine_dequantize(
|
|||||||
return fallback({w, scales, biases})[0];
|
return fallback({w, scales, biases})[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool AffineQuantize::is_equivalent(const Primitive& other) const {
|
||||||
|
const AffineQuantize& p_other = static_cast<const AffineQuantize&>(other);
|
||||||
|
return (
|
||||||
|
p_other.group_size_ == group_size_ && p_other.bits_ == bits_ &&
|
||||||
|
p_other.dequantize_ == dequantize_);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Shape> AffineQuantize::output_shapes(
|
||||||
|
const std::vector<array>& inputs) {
|
||||||
|
auto& w = inputs[0];
|
||||||
|
if (dequantize_) {
|
||||||
|
auto out_size = w.shape(-1) * 32 / bits_;
|
||||||
|
auto out_shape = w.shape();
|
||||||
|
out_shape.back() = out_size;
|
||||||
|
return {std::move(out_shape)};
|
||||||
|
} else {
|
||||||
|
auto wq_shape = w.shape();
|
||||||
|
wq_shape.back() = w.shape(-1) * bits_ / 32;
|
||||||
|
auto sshape = w.shape();
|
||||||
|
sshape.back() = w.shape(-1) / group_size_;
|
||||||
|
auto bshape = sshape;
|
||||||
|
return {std::move(wq_shape), std::move(sshape), std::move(bshape)};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::string write_signature(
|
std::string write_signature(
|
||||||
std::string func_name,
|
std::string func_name,
|
||||||
const std::string& header,
|
const std::string& header,
|
||||||
|
@ -58,6 +58,7 @@ class RMSNorm : public Custom {
|
|||||||
|
|
||||||
DEFINE_PRINT(RMSNorm)
|
DEFINE_PRINT(RMSNorm)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||||
@ -110,6 +111,7 @@ class LayerNorm : public Custom {
|
|||||||
|
|
||||||
DEFINE_PRINT(LayerNorm)
|
DEFINE_PRINT(LayerNorm)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||||
@ -173,6 +175,7 @@ class RoPE : public Custom {
|
|||||||
|
|
||||||
DEFINE_PRINT(RoPE)
|
DEFINE_PRINT(RoPE)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||||
@ -207,6 +210,7 @@ class ScaledDotProductAttention : public Custom {
|
|||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
|
||||||
DEFINE_PRINT(ScaledDotProductAttention);
|
DEFINE_PRINT(ScaledDotProductAttention);
|
||||||
|
DEFINE_INPUT_OUTPUT_SHAPE()
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||||
@ -235,6 +239,9 @@ class AffineQuantize : public Custom {
|
|||||||
|
|
||||||
DEFINE_PRINT(AffineQuantize);
|
DEFINE_PRINT(AffineQuantize);
|
||||||
|
|
||||||
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||||
int group_size_;
|
int group_size_;
|
||||||
|
@ -267,6 +267,11 @@ bool Arange::is_equivalent(const Primitive& other) const {
|
|||||||
step_ == a_other.step_);
|
step_ == a_other.step_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<Shape> Arange::output_shapes(const std::vector<array>&) {
|
||||||
|
auto real_size = std::ceil((stop_ - start_) / step_);
|
||||||
|
return {{std::max(static_cast<int>(real_size), 0)}};
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<array> ArcCos::vjp(
|
std::vector<array> ArcCos::vjp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const std::vector<array>& cotangents,
|
const std::vector<array>& cotangents,
|
||||||
@ -534,11 +539,10 @@ std::pair<std::vector<array>, std::vector<int>> ArgSort::vmap(
|
|||||||
return {{argsort(inputs[0], axis_ + axis_left, stream())}, axes};
|
return {{argsort(inputs[0], axis_ + axis_left, stream())}, axes};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::vector<int>> ArgReduce::output_shapes(
|
std::vector<Shape> ArgReduce::output_shapes(const std::vector<array>& inputs) {
|
||||||
const std::vector<array>& inputs) {
|
|
||||||
auto out_shape = inputs[0].shape();
|
auto out_shape = inputs[0].shape();
|
||||||
out_shape[axis_] = 1;
|
out_shape[axis_] = 1;
|
||||||
return {out_shape};
|
return {std::move(out_shape)};
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ArgSort::is_equivalent(const Primitive& other) const {
|
bool ArgSort::is_equivalent(const Primitive& other) const {
|
||||||
@ -787,6 +791,23 @@ std::pair<std::vector<array>, std::vector<int>> Eigh::vmap(
|
|||||||
return {outputs, std::vector<int>(outputs.size(), ax)};
|
return {outputs, std::vector<int>(outputs.size(), ax)};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<Shape> Eigh::output_shapes(const std::vector<array>& inputs) {
|
||||||
|
auto shape = inputs[0].shape();
|
||||||
|
shape.pop_back(); // Remove last dimension for eigenvalues
|
||||||
|
if (compute_eigenvectors_) {
|
||||||
|
return {
|
||||||
|
std::move(shape), inputs[0].shape()}; // Eigenvalues and eigenvectors
|
||||||
|
} else {
|
||||||
|
return {std::move(shape)}; // Only eigenvalues
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Eigh::is_equivalent(const Primitive& other) const {
|
||||||
|
auto& e_other = static_cast<const Eigh&>(other);
|
||||||
|
return uplo_ == e_other.uplo_ &&
|
||||||
|
compute_eigenvectors_ == e_other.compute_eigenvectors_;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<array> Concatenate::vjp(
|
std::vector<array> Concatenate::vjp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const std::vector<array>& cotangents,
|
const std::vector<array>& cotangents,
|
||||||
@ -881,6 +902,15 @@ bool Concatenate::is_equivalent(const Primitive& other) const {
|
|||||||
return axis_ == c_other.axis_;
|
return axis_ == c_other.axis_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<Shape> Concatenate::output_shapes(
|
||||||
|
const std::vector<array>& inputs) {
|
||||||
|
auto shape = inputs[0].shape();
|
||||||
|
for (int i = 1; i < inputs.size(); ++i) {
|
||||||
|
shape[axis_] += inputs[i].shape(axis_);
|
||||||
|
}
|
||||||
|
return {std::move(shape)};
|
||||||
|
}
|
||||||
|
|
||||||
std::pair<std::vector<array>, std::vector<int>> Conjugate::vmap(
|
std::pair<std::vector<array>, std::vector<int>> Conjugate::vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
@ -1811,6 +1841,15 @@ bool Gather::is_equivalent(const Primitive& other) const {
|
|||||||
return axes_ == g_other.axes_ && slice_sizes_ == g_other.slice_sizes_;
|
return axes_ == g_other.axes_ && slice_sizes_ == g_other.slice_sizes_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<Shape> Gather::output_shapes(const std::vector<array>& inputs) {
|
||||||
|
Shape out_shape;
|
||||||
|
if (inputs.size() > 1) {
|
||||||
|
out_shape = inputs[0].shape();
|
||||||
|
}
|
||||||
|
out_shape.insert(out_shape.end(), slice_sizes_.begin(), slice_sizes_.end());
|
||||||
|
return {std::move(out_shape)};
|
||||||
|
}
|
||||||
|
|
||||||
std::pair<std::vector<array>, std::vector<int>> Greater::vmap(
|
std::pair<std::vector<array>, std::vector<int>> Greater::vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
@ -2184,6 +2223,12 @@ std::pair<std::vector<array>, std::vector<int>> Matmul::vmap(
|
|||||||
return {{matmul(a, b, stream())}, {0}};
|
return {{matmul(a, b, stream())}, {0}};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<Shape> Matmul::output_shapes(const std::vector<array>& inputs) {
|
||||||
|
auto out_shape = inputs[0].shape();
|
||||||
|
out_shape.back() = inputs[1].shape(-1);
|
||||||
|
return {std::move(out_shape)};
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<array> Maximum::vjp(
|
std::vector<array> Maximum::vjp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const std::vector<array>& cotangents,
|
const std::vector<array>& cotangents,
|
||||||
@ -2608,6 +2653,15 @@ bool QuantizedMatmul::is_equivalent(const Primitive& other) const {
|
|||||||
transpose_ == qm_other.transpose_;
|
transpose_ == qm_other.transpose_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<Shape> QuantizedMatmul::output_shapes(
|
||||||
|
const std::vector<array>& inputs) {
|
||||||
|
auto& w = inputs[1];
|
||||||
|
int w_outer_dims = (transpose_) ? w.shape(-2) : w.shape(-1) * 32 / bits_;
|
||||||
|
auto out_shape = inputs[0].shape();
|
||||||
|
out_shape.back() = w_outer_dims;
|
||||||
|
return {std::move(out_shape)};
|
||||||
|
}
|
||||||
|
|
||||||
std::pair<std::vector<array>, std::vector<int>> GatherQMM::vmap(
|
std::pair<std::vector<array>, std::vector<int>> GatherQMM::vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
@ -2937,13 +2991,12 @@ bool Reduce::is_equivalent(const Primitive& other) const {
|
|||||||
return reduce_type_ == r_other.reduce_type_ && axes_ == r_other.axes_;
|
return reduce_type_ == r_other.reduce_type_ && axes_ == r_other.axes_;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::vector<int>> Reduce::output_shapes(
|
std::vector<Shape> Reduce::output_shapes(const std::vector<array>& inputs) {
|
||||||
const std::vector<array>& inputs) {
|
auto out_shape = inputs[0].shape();
|
||||||
std::vector<int> out_shape = inputs[0].shape();
|
|
||||||
for (auto i : axes_) {
|
for (auto i : axes_) {
|
||||||
out_shape[i] = 1;
|
out_shape[i] = 1;
|
||||||
}
|
}
|
||||||
return {out_shape};
|
return {std::move(out_shape)};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> Round::vjp(
|
std::vector<array> Round::vjp(
|
||||||
@ -4209,6 +4262,15 @@ bool Transpose::is_equivalent(const Primitive& other) const {
|
|||||||
return axes_ == t_other.axes_;
|
return axes_ == t_other.axes_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<Shape> Transpose::output_shapes(const std::vector<array>& inputs) {
|
||||||
|
auto& in = inputs[0];
|
||||||
|
Shape shape(in.ndim(), 0);
|
||||||
|
for (int i = 0; i < axes_.size(); ++i) {
|
||||||
|
shape[i] = in.shape()[axes_[i]];
|
||||||
|
}
|
||||||
|
return {std::move(shape)};
|
||||||
|
}
|
||||||
|
|
||||||
std::pair<std::vector<array>, std::vector<int>> NumberOfElements::vmap(
|
std::pair<std::vector<array>, std::vector<int>> NumberOfElements::vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
|
@ -36,10 +36,10 @@
|
|||||||
return true; \
|
return true; \
|
||||||
}
|
}
|
||||||
|
|
||||||
#define DEFINE_INPUT_OUTPUT_SHAPE() \
|
#define DEFINE_INPUT_OUTPUT_SHAPE() \
|
||||||
std::vector<std::vector<int>> output_shapes( \
|
std::vector<Shape> output_shapes(const std::vector<array>& inputs) \
|
||||||
const std::vector<array>& inputs) override { \
|
override { \
|
||||||
return {inputs[0].shape()}; \
|
return {inputs[0].shape()}; \
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
@ -110,8 +110,7 @@ class Primitive {
|
|||||||
|
|
||||||
/** Get the output shapes of the primitive. This is not required to be
|
/** Get the output shapes of the primitive. This is not required to be
|
||||||
* implemented by derived classes, in which case it will throw. */
|
* implemented by derived classes, in which case it will throw. */
|
||||||
virtual std::vector<std::vector<int>> output_shapes(
|
virtual std::vector<Shape> output_shapes(const std::vector<array>& inputs);
|
||||||
const std::vector<array>& inputs);
|
|
||||||
|
|
||||||
virtual ~Primitive() = default;
|
virtual ~Primitive() = default;
|
||||||
Primitive(const Primitive& other) = delete;
|
Primitive(const Primitive& other) = delete;
|
||||||
@ -220,6 +219,7 @@ class Arange : public UnaryPrimitive {
|
|||||||
|
|
||||||
DEFINE_PRINT(Arange)
|
DEFINE_PRINT(Arange)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
double start_;
|
double start_;
|
||||||
@ -386,8 +386,7 @@ class ArgReduce : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(ArgReduce)
|
DEFINE_PRINT(ArgReduce)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
std::vector<std::vector<int>> output_shapes(
|
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||||
const std::vector<array>& inputs) override;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ReduceType reduce_type_;
|
ReduceType reduce_type_;
|
||||||
@ -437,11 +436,7 @@ class AsType : public UnaryPrimitive {
|
|||||||
|
|
||||||
class AsStrided : public UnaryPrimitive {
|
class AsStrided : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit AsStrided(
|
explicit AsStrided(Stream stream, Shape shape, Strides strides, size_t offset)
|
||||||
Stream stream,
|
|
||||||
std::vector<int> shape,
|
|
||||||
std::vector<size_t> strides,
|
|
||||||
size_t offset)
|
|
||||||
: UnaryPrimitive(stream),
|
: UnaryPrimitive(stream),
|
||||||
shape_(std::move(shape)),
|
shape_(std::move(shape)),
|
||||||
strides_(std::move(strides)),
|
strides_(std::move(strides)),
|
||||||
@ -455,8 +450,8 @@ class AsStrided : public UnaryPrimitive {
|
|||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<int> shape_;
|
Shape shape_;
|
||||||
std::vector<size_t> strides_;
|
Strides strides_;
|
||||||
size_t offset_;
|
size_t offset_;
|
||||||
|
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -527,7 +522,7 @@ class GatherMM : public UnaryPrimitive {
|
|||||||
|
|
||||||
class Broadcast : public UnaryPrimitive {
|
class Broadcast : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Broadcast(Stream stream, const std::vector<int>& shape)
|
explicit Broadcast(Stream stream, const Shape& shape)
|
||||||
: UnaryPrimitive(stream), shape_(shape) {}
|
: UnaryPrimitive(stream), shape_(shape) {}
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -539,7 +534,7 @@ class Broadcast : public UnaryPrimitive {
|
|||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<int> shape_;
|
Shape shape_;
|
||||||
|
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
};
|
};
|
||||||
@ -586,8 +581,7 @@ class Compiled : public Primitive {
|
|||||||
|
|
||||||
DEFINE_VMAP()
|
DEFINE_VMAP()
|
||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
std::vector<std::vector<int>> output_shapes(
|
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||||
const std::vector<array>& inputs) override;
|
|
||||||
void print(std::ostream& os) override;
|
void print(std::ostream& os) override;
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
|
||||||
@ -616,6 +610,7 @@ class Concatenate : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(Concatenate)
|
DEFINE_PRINT(Concatenate)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int axis_;
|
int axis_;
|
||||||
@ -853,8 +848,7 @@ class DivMod : public Primitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(DivMod)
|
DEFINE_PRINT(DivMod)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
std::vector<std::vector<int>> output_shapes(
|
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override {
|
||||||
const std::vector<array>& inputs) override {
|
|
||||||
return std::vector{inputs[0].shape(), inputs[0].shape()};
|
return std::vector{inputs[0].shape(), inputs[0].shape()};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1063,6 +1057,7 @@ class Gather : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(Gather)
|
DEFINE_PRINT(Gather)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -1339,6 +1334,7 @@ class Matmul : public UnaryPrimitive {
|
|||||||
DEFINE_VMAP()
|
DEFINE_VMAP()
|
||||||
DEFINE_PRINT(Matmul)
|
DEFINE_PRINT(Matmul)
|
||||||
DEFINE_DEFAULT_IS_EQUIVALENT()
|
DEFINE_DEFAULT_IS_EQUIVALENT()
|
||||||
|
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
class Maximum : public UnaryPrimitive {
|
class Maximum : public UnaryPrimitive {
|
||||||
@ -1444,8 +1440,7 @@ class NumberOfElements : public UnaryPrimitive {
|
|||||||
DEFINE_VMAP()
|
DEFINE_VMAP()
|
||||||
DEFINE_PRINT(NumberOfElements)
|
DEFINE_PRINT(NumberOfElements)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
std::vector<std::vector<int>> output_shapes(
|
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override {
|
||||||
const std::vector<array>& inputs) override {
|
|
||||||
return {{}};
|
return {{}};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1542,6 +1537,7 @@ class QuantizedMatmul : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(QuantizedMatmul)
|
DEFINE_PRINT(QuantizedMatmul)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int group_size_;
|
int group_size_;
|
||||||
@ -1577,7 +1573,7 @@ class GatherQMM : public UnaryPrimitive {
|
|||||||
|
|
||||||
class RandomBits : public UnaryPrimitive {
|
class RandomBits : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit RandomBits(Stream stream, const std::vector<int>& shape, int width)
|
explicit RandomBits(Stream stream, const Shape& shape, int width)
|
||||||
: UnaryPrimitive(stream), shape_(shape), width_(width) {}
|
: UnaryPrimitive(stream), shape_(shape), width_(width) {}
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1588,7 +1584,7 @@ class RandomBits : public UnaryPrimitive {
|
|||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<int> shape_;
|
Shape shape_;
|
||||||
int width_;
|
int width_;
|
||||||
|
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
@ -1610,7 +1606,7 @@ class Real : public UnaryPrimitive {
|
|||||||
|
|
||||||
class Reshape : public UnaryPrimitive {
|
class Reshape : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Reshape(Stream stream, const std::vector<int>& shape)
|
explicit Reshape(Stream stream, const Shape& shape)
|
||||||
: UnaryPrimitive(stream), shape_(shape) {}
|
: UnaryPrimitive(stream), shape_(shape) {}
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1622,16 +1618,16 @@ class Reshape : public UnaryPrimitive {
|
|||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<int> shape_;
|
Shape shape_;
|
||||||
|
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
|
|
||||||
std::pair<bool, std::vector<size_t>> prepare_reshape(
|
static std::pair<bool, Strides> prepare_reshape(
|
||||||
const array& in,
|
const array& in,
|
||||||
const array& out);
|
const array& out);
|
||||||
void shared_buffer_reshape(
|
static void shared_buffer_reshape(
|
||||||
const array& in,
|
const array& in,
|
||||||
const std::vector<size_t>& out_strides,
|
const Strides& out_strides,
|
||||||
array& out);
|
array& out);
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1656,8 +1652,7 @@ class Reduce : public UnaryPrimitive {
|
|||||||
const std::vector<int>& argnums,
|
const std::vector<int>& argnums,
|
||||||
const std::vector<array>& outputs) override;
|
const std::vector<array>& outputs) override;
|
||||||
|
|
||||||
std::vector<std::vector<int>> output_shapes(
|
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||||
const std::vector<array>& inputs) override;
|
|
||||||
|
|
||||||
void print(std::ostream& os) override {
|
void print(std::ostream& os) override {
|
||||||
switch (reduce_type_) {
|
switch (reduce_type_) {
|
||||||
@ -2141,6 +2136,7 @@ class Transpose : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(Transpose)
|
DEFINE_PRINT(Transpose)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<int> axes_;
|
std::vector<int> axes_;
|
||||||
@ -2230,24 +2226,9 @@ class Eigh : public Primitive {
|
|||||||
DEFINE_VMAP()
|
DEFINE_VMAP()
|
||||||
DEFINE_PRINT(Eigh)
|
DEFINE_PRINT(Eigh)
|
||||||
|
|
||||||
std::vector<std::vector<int>> output_shapes(
|
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||||
const std::vector<array>& inputs) override {
|
|
||||||
auto shape = inputs[0].shape();
|
|
||||||
shape.pop_back(); // Remove last dimension for eigenvalues
|
|
||||||
if (compute_eigenvectors_) {
|
|
||||||
return {shape, inputs[0].shape()}; // Eigenvalues and eigenvectors
|
|
||||||
} else {
|
|
||||||
return {shape}; // Only eigenvalues
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool is_equivalent(const Primitive& other) const override {
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
if (auto* p = dynamic_cast<const Eigh*>(&other)) {
|
|
||||||
return uplo_ == p->uplo_ &&
|
|
||||||
compute_eigenvectors_ == p->compute_eigenvectors_;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
|
||||||
|
Loading…
Reference in New Issue
Block a user