mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
More primitives for compiling with shapeless (#1653)
* more shapeless and more Shape * more shape * fix * fix
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
@@ -8,6 +7,7 @@
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/compile.h"
|
||||
#include "mlx/compile_impl.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/transforms.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
@@ -73,11 +73,18 @@ bool is_fusable(const Primitive& p) {
|
||||
}
|
||||
|
||||
bool allows_shapeless(const Primitive& p) {
|
||||
return typeid(p) == typeid(Compiled) || is_unary(p) || is_binary(p) ||
|
||||
is_noop(p) || is_reduction(p) || typeid(p) == typeid(Softmax) ||
|
||||
typeid(p) == typeid(Sort) || typeid(p) == typeid(ArgSort) ||
|
||||
typeid(p) == typeid(ArgPartition) || typeid(p) == typeid(Partition) ||
|
||||
typeid(p) == typeid(Select) || typeid(p) == typeid(NumberOfElements);
|
||||
return typeid(p) == typeid(Arange) || typeid(p) == typeid(Compiled) ||
|
||||
is_unary(p) || is_binary(p) || is_noop(p) || is_reduction(p) ||
|
||||
typeid(p) == typeid(Softmax) || typeid(p) == typeid(Sort) ||
|
||||
typeid(p) == typeid(ArgSort) || typeid(p) == typeid(ArgPartition) ||
|
||||
typeid(p) == typeid(Partition) || typeid(p) == typeid(Select) ||
|
||||
typeid(p) == typeid(NumberOfElements) || typeid(p) == typeid(Gather) ||
|
||||
typeid(p) == typeid(Transpose) || typeid(p) == typeid(Concatenate) ||
|
||||
typeid(p) == typeid(Matmul) || typeid(p) == typeid(QuantizedMatmul) ||
|
||||
typeid(p) == typeid(fast::AffineQuantize) ||
|
||||
typeid(p) == typeid(fast::LayerNorm) ||
|
||||
typeid(p) == typeid(fast::RMSNorm) || typeid(p) == typeid(fast::RoPE) ||
|
||||
typeid(p) == typeid(fast::ScaledDotProductAttention);
|
||||
}
|
||||
|
||||
Compiled::Compiled(
|
||||
@@ -93,23 +100,23 @@ Compiled::Compiled(
|
||||
constant_ids_(std::move(constant_ids)) {}
|
||||
|
||||
std::vector<array> Compiled::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) {
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&,
|
||||
const std::vector<int>&,
|
||||
const std::vector<array>&) {
|
||||
throw std::runtime_error("[Compiled] Cannot vjp primitive.");
|
||||
}
|
||||
|
||||
std::vector<array> Compiled::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
const std::vector<array>&,
|
||||
const std::vector<array>&,
|
||||
const std::vector<int>&) {
|
||||
throw std::runtime_error("[Compiled] Cannot jvp primitive.");
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Compiled::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
const std::vector<array>&,
|
||||
const std::vector<int>&) {
|
||||
throw std::runtime_error("[Compiled] Cannot vmap primitive.");
|
||||
}
|
||||
|
||||
@@ -134,13 +141,12 @@ void Compiled::print(std::ostream& os) {
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<int>> Compiled::output_shapes(
|
||||
const std::vector<array>& inputs) {
|
||||
std::vector<Shape> Compiled::output_shapes(const std::vector<array>& inputs) {
|
||||
size_t nd = 0;
|
||||
for (auto& in : inputs) {
|
||||
nd = std::max(nd, in.ndim());
|
||||
}
|
||||
std::vector<int> out_shape(nd, 0);
|
||||
Shape out_shape(nd, 0);
|
||||
for (auto& in : inputs) {
|
||||
auto dd = nd - in.ndim();
|
||||
for (auto i = dd; i < nd; ++i) {
|
||||
@@ -148,7 +154,7 @@ std::vector<std::vector<int>> Compiled::output_shapes(
|
||||
}
|
||||
}
|
||||
// All outputs have the same shape
|
||||
return std::vector<std::vector<int>>(outputs_.size(), out_shape);
|
||||
return std::vector<Shape>(outputs_.size(), out_shape);
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
@@ -553,14 +559,12 @@ void compile_fuse(
|
||||
// - Collect inputs to the new compiled primitive
|
||||
// - Add fusable primitives to a tape in the correct order
|
||||
|
||||
std::function<void(
|
||||
const array&, int, const Stream&, const std::vector<int>&)>
|
||||
recurse;
|
||||
std::function<void(const array&, int, const Stream&, const Shape&)> recurse;
|
||||
std::unordered_set<uintptr_t> cache;
|
||||
recurse = [&](const array& a,
|
||||
int depth,
|
||||
const Stream& s,
|
||||
const std::vector<int>& shape) {
|
||||
const Shape& shape) {
|
||||
if (cache.find(a.id()) != cache.end()) {
|
||||
return;
|
||||
}
|
||||
@@ -667,7 +671,7 @@ void compile_fuse(
|
||||
}
|
||||
old_outputs.push_back(arr);
|
||||
|
||||
std::vector<std::vector<int>> shapes;
|
||||
std::vector<Shape> shapes;
|
||||
std::vector<Dtype> types;
|
||||
for (auto& o : old_outputs) {
|
||||
if (o.shape() != old_outputs.back().shape()) {
|
||||
@@ -771,7 +775,7 @@ std::vector<array> compile_replace(
|
||||
for (auto& o : trace_out) {
|
||||
types.push_back(o.dtype());
|
||||
}
|
||||
std::vector<std::vector<int>> shapes;
|
||||
std::vector<Shape> shapes;
|
||||
if (shapeless) {
|
||||
shapes = a.primitive().output_shapes(real_inputs);
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user