mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
More jitting (#1132)
* docs + circle min size build * jit scan, arange, softmax * add sort * jit reductions * remove print * fix deps * clean includes / nits
This commit is contained in:
@@ -1,15 +1,18 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
constexpr int SOFTMAX_LOOPED_LIMIT = 4096;
|
||||
|
||||
void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
if (!issubdtype(out.dtype(), floating)) {
|
||||
@@ -52,18 +55,17 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
const int simd_size = 32;
|
||||
const int n_reads = SOFTMAX_N_READS;
|
||||
const int looped_limit = SOFTMAX_LOOPED_LIMIT;
|
||||
std::string op_name = "softmax_";
|
||||
if (axis_size > looped_limit) {
|
||||
op_name += "looped_";
|
||||
}
|
||||
|
||||
std::string kernel_name = (axis_size > looped_limit) ? "looped_" : "block_";
|
||||
kernel_name += "softmax_";
|
||||
if (in.dtype() != float32 && precise_) {
|
||||
op_name += "precise_";
|
||||
kernel_name += "precise_";
|
||||
}
|
||||
op_name += type_to_name(out);
|
||||
kernel_name += type_to_name(out);
|
||||
|
||||
auto kernel = get_softmax_kernel(d, kernel_name, precise_, out);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
{
|
||||
auto kernel = d.get_kernel(op_name);
|
||||
|
||||
MTL::Size grid_dims, group_dims;
|
||||
if (axis_size <= looped_limit) {
|
||||
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
|
||||
|
||||
Reference in New Issue
Block a user