More jitting (#1132)

* docs + circle min size build

* jit scan, arange, softmax

* add sort

* jit reductions

* remove print

* fix deps

* clean includes / nits
This commit is contained in:
Awni Hannun
2024-05-23 16:23:44 -07:00
committed by GitHub
parent 9401507336
commit 0189ab6ab6
41 changed files with 2377 additions and 1846 deletions

View File

@@ -1,15 +1,18 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
constexpr int SOFTMAX_LOOPED_LIMIT = 4096;
void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
if (!issubdtype(out.dtype(), floating)) {
@@ -52,18 +55,17 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
const int simd_size = 32;
const int n_reads = SOFTMAX_N_READS;
const int looped_limit = SOFTMAX_LOOPED_LIMIT;
std::string op_name = "softmax_";
if (axis_size > looped_limit) {
op_name += "looped_";
}
std::string kernel_name = (axis_size > looped_limit) ? "looped_" : "block_";
kernel_name += "softmax_";
if (in.dtype() != float32 && precise_) {
op_name += "precise_";
kernel_name += "precise_";
}
op_name += type_to_name(out);
kernel_name += type_to_name(out);
auto kernel = get_softmax_kernel(d, kernel_name, precise_, out);
auto& compute_encoder = d.get_command_encoder(s.index);
{
auto kernel = d.get_kernel(op_name);
MTL::Size grid_dims, group_dims;
if (axis_size <= looped_limit) {
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;