Compare commits

...

4 Commits

Author SHA1 Message Date
Awni Hannun
caecbe876a no copy batch rope (#2595) 2025-09-15 14:23:48 -07:00
Umberto Mignozzetti
8afb6d62f2 Fix typo in average_gradients function call (#2594) 2025-09-15 11:29:21 -07:00
Awni Hannun
6ccfa603cd fix metal scan (#2591) 2025-09-15 11:01:57 -07:00
Umberto Mignozzetti
36cad99a11 Refactor code examples to use 'gelu' (#2592)
Updated code examples to use 'gelu' directly instead of 'nn.gelu'.
2025-09-15 09:47:02 -07:00
6 changed files with 116 additions and 75 deletions

View File

@@ -130,8 +130,8 @@ Now make an array, and benchmark both functions:
.. code-block:: python
x = mx.random.uniform(shape=(32, 1000, 4096))
timeit(nn.gelu, x)
timeit(mx.compile(nn.gelu), x)
timeit(gelu, x)
timeit(mx.compile(gelu), x)
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
five times faster.

View File

@@ -184,7 +184,7 @@ almost identical to the example above:
def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y)
grads = mlx.nn.average_gradients(grads) # <---- This line was added
grads = mx.nn.average_gradients(grads) # <---- This line was added
optimizer.update(model, grads)
return loss

View File

@@ -3,7 +3,12 @@
#include <metal_math>
#include "mlx/backend/metal/kernels/utils.h"
template <typename T, bool traditional, bool forward>
constant bool forward [[function_constant(1)]];
constant bool traditional [[function_constant(2)]];
constant bool hs_transpose [[function_constant(3)]];
template <typename T>
void rope_single_impl(
const device T* in,
device T* out,
@@ -46,7 +51,7 @@ void rope_single_impl(
out[index_2] = static_cast<T>(rx2);
}
template <typename T, bool traditional, bool forward>
template <typename T>
[[kernel]] void rope_single(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
@@ -58,11 +63,10 @@ template <typename T, bool traditional, bool forward>
uint2 grid [[threads_per_grid]]) {
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
float inv_freq = metal::exp2(-d * base);
rope_single_impl<T, traditional, forward>(
in, out, offset, inv_freq, scale, stride, pos, grid);
rope_single_impl<T>(in, out, offset, inv_freq, scale, stride, pos, grid);
}
template <typename T, bool traditional, bool forward>
template <typename T>
[[kernel]] void rope_single_freqs(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
@@ -74,11 +78,10 @@ template <typename T, bool traditional, bool forward>
uint2 pos [[thread_position_in_grid]],
uint2 grid [[threads_per_grid]]) {
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
rope_single_impl<T, traditional, forward>(
in, out, offset, inv_freq, scale, stride, pos, grid);
rope_single_impl<T>(in, out, offset, inv_freq, scale, stride, pos, grid);
}
template <typename T, bool traditional, bool forward, int N = 4>
template <typename T, typename IdxT, int N = 4>
void rope_impl(
const device T* in,
device T* out,
@@ -102,23 +105,29 @@ void rope_impl(
float theta = L * inv_freq;
float costheta = metal::fast::cos(theta);
float sintheta = metal::fast::sin(theta);
// Compute the input and output indices
size_t in_index_1, in_index_2;
size_t out_index_1, out_index_2;
if (traditional) {
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
mat_idx * out_strides[0];
out_index_2 = out_index_1 + 1;
IdxT in_index_1;
if (hs_transpose) {
IdxT batch_stride = grid.y * IdxT(strides[1]);
in_index_1 =
2 * pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
in_index_2 = in_index_1 + strides[2];
batch_idx * batch_stride + pos.y * strides[1] + head_idx * strides[0];
} else {
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
mat_idx * out_strides[0];
out_index_2 = out_index_1 + grid.x * out_strides[2];
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
in_index_2 = in_index_1 + grid.x * strides[2];
in_index_1 = pos.y * IdxT(strides[1]) + mat_idx * IdxT(strides[0]);
}
IdxT in_index_2;
IdxT out_index_1 =
pos.y * IdxT(out_strides[1]) + mat_idx * IdxT(out_strides[0]);
IdxT out_index_2;
if (traditional) {
out_index_1 += 2 * pos.x * IdxT(out_strides[2]);
out_index_2 = out_index_1 + 1;
in_index_1 += 2 * pos.x * IdxT(strides[2]);
in_index_2 = in_index_1 + IdxT(strides[2]);
} else {
out_index_1 += pos.x * IdxT(out_strides[2]);
out_index_2 = out_index_1 + grid.x * IdxT(out_strides[2]);
in_index_1 += pos.x * IdxT(strides[2]);
in_index_2 = in_index_1 + grid.x * IdxT(strides[2]);
}
for (int i = 0; i < N && head_idx + i < n_head; ++i) {
// Read and write the output
@@ -135,14 +144,14 @@ void rope_impl(
}
out[out_index_1] = static_cast<T>(rx1);
out[out_index_2] = static_cast<T>(rx2);
in_index_1 += strides[0];
in_index_2 += strides[0];
out_index_1 += out_strides[0];
out_index_2 += out_strides[0];
in_index_1 += IdxT(strides[0]);
in_index_2 += IdxT(strides[0]);
out_index_1 += IdxT(out_strides[0]);
out_index_2 += IdxT(out_strides[0]);
}
}
template <typename T, bool traditional, bool forward, int N = 4>
template <typename T, typename IdxT, int N = 4>
[[kernel]] void rope(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
@@ -157,7 +166,7 @@ template <typename T, bool traditional, bool forward, int N = 4>
uint3 grid [[threads_per_grid]]) {
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
float inv_freq = metal::exp2(-d * base);
rope_impl<T, traditional, forward, N>(
rope_impl<T, IdxT, N>(
in,
out,
offset,
@@ -171,7 +180,7 @@ template <typename T, bool traditional, bool forward, int N = 4>
grid);
}
template <typename T, bool traditional, bool forward, int N = 4>
template <typename T, typename IdxT, int N = 4>
[[kernel]] void rope_freqs(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
@@ -186,7 +195,7 @@ template <typename T, bool traditional, bool forward, int N = 4>
uint3 pos [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
float inv_freq = 1.0 / (freqs[freq_stride * pos.x]);
rope_impl<T, traditional, forward, N>(
rope_impl<T, IdxT, N>(
in,
out,
offset,
@@ -201,27 +210,20 @@ template <typename T, bool traditional, bool forward, int N = 4>
}
// clang-format off
#define instantiate_rope_g(name, type, traditional, forward) \
instantiate_kernel("rope_" #name, rope, type, traditional, forward) \
instantiate_kernel("rope_freqs_" #name, rope_freqs, type, traditional, forward)
#define instantiate_rope_g(name, type) \
instantiate_kernel("rope_" #name, rope, type, int32_t) \
instantiate_kernel("rope_freqs_" #name, rope_freqs, type, int32_t) \
instantiate_kernel("rope_large_" #name, rope, type, int64_t) \
instantiate_kernel("rope_freqs_large_" #name, rope_freqs, type, int64_t)
#define instantiate_rope_s(name, type, traditional, forward) \
instantiate_kernel("rope_single_" #name, rope_single, type, traditional, forward) \
instantiate_kernel("rope_single_freqs_" #name, rope_single_freqs, type, traditional, forward)
#define instantiate_rope_s(name, type) \
instantiate_kernel("rope_single_" #name, rope_single, type) \
instantiate_kernel("rope_single_freqs_" #name, rope_single_freqs, type)
#define instantiate_rope(name, type, traditional, forward) \
instantiate_rope_s(name, type, traditional, forward) \
instantiate_rope_g(name, type, traditional, forward)
#define instantiate_rope(name, type) \
instantiate_rope_s(name, type) \
instantiate_rope_g(name, type)
instantiate_rope(traditional_float16, half, true, true)
instantiate_rope(traditional_bfloat16, bfloat16_t, true, true)
instantiate_rope(traditional_float32, float, true, true)
instantiate_rope(float16, half, false, true)
instantiate_rope(bfloat16, bfloat16_t, false, true)
instantiate_rope(float32, float, false, true)
instantiate_rope(vjp_traditional_float16, half, true, false)
instantiate_rope(vjp_traditional_bfloat16, bfloat16_t, true, false)
instantiate_rope(vjp_traditional_float32, float, true, false)
instantiate_rope(vjp_float16, half, false, false)
instantiate_rope(vjp_bfloat16, bfloat16_t, false, false)
instantiate_rope(vjp_float32, float, false, false) // clang-format on
instantiate_rope(float16, half)
instantiate_rope(bfloat16, bfloat16_t)
instantiate_rope(float32, float) // clang-format on

View File

@@ -306,6 +306,7 @@ template <
U prev_thread = op.simd_exclusive_scan(values[N_READS - 1]);
// Write simdgroup_sums to SM
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_lane_id == simd_size - 1) {
simdgroup_sums[simd_group_id] = op(prev_thread, values[N_READS - 1]);
}
@@ -440,6 +441,7 @@ template <
}
// Read in SM
threadgroup_barrier(mem_flags::mem_threadgroup);
if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) {
for (int i = 0; i < N_READS; i++) {
read_into[i] = in[index_y * stride + i];

View File

@@ -29,6 +29,7 @@ void RoPE::eval_gpu(
int T = in.shape(-2);
int D = in.shape(-1);
size_t mat_size = T * D;
bool large = in.data_size() > INT32_MAX || in.size() > INT32_MAX;
int dispatch_ndim = ndim;
while (in.shape(-dispatch_ndim) == 1 && dispatch_ndim > 3) {
@@ -40,6 +41,8 @@ void RoPE::eval_gpu(
N *= in.shape(i);
}
bool head_seq_transpose = false;
if (dims_ < D) {
donated = true;
auto ctype =
@@ -64,6 +67,17 @@ void RoPE::eval_gpu(
strides[0] = in.strides()[ndim - 3];
strides[1] = in.strides()[ndim - 2];
strides[2] = in.strides()[ndim - 1];
} else if (
ndim == 4 &&
// batch dim is regularly strided
in.strides()[0] == T * N * D &&
// sequence and head dimensions are transposed
in.strides()[1] == D && in.strides()[2] == N * D) {
head_seq_transpose = true;
out.set_data(allocator::malloc(out.nbytes()));
strides[0] = in.strides()[1];
strides[1] = in.strides()[2];
strides[2] = in.strides()[3];
} else {
// Copy non-contiguous > 3D inputs into the output and treat
// input as donated
@@ -77,15 +91,33 @@ void RoPE::eval_gpu(
out_strides[1] = out.strides()[ndim - 2];
out_strides[2] = out.strides()[ndim - 1];
// Special case for inference (single batch, single time step, and contiguous)
bool single = in.flags().row_contiguous && B == 1 && T == 1;
// Special case for inference (single time step, contiguous, one offset)
auto& offset = inputs[1];
bool single = in.flags().row_contiguous && T == 1 && offset.size() == 1;
bool with_freqs = inputs.size() == 3;
std::ostringstream kname;
kname << "rope_" << (single ? "single_" : "")
<< ((with_freqs) ? "freqs_" : "") << (forward_ ? "" : "vjp_")
<< (traditional_ ? "traditional_" : "") << type_to_name(in);
auto kernel = d.get_kernel(kname.str());
std::string kname;
concatenate(
kname,
"rope_",
single ? "single_" : "",
(with_freqs) ? "freqs_" : "",
large ? "large_" : "",
type_to_name(in));
std::string hash_name;
concatenate(
hash_name,
kname,
"_",
forward_ ? "" : "vjp_",
traditional_ ? "traditional_" : "",
head_seq_transpose ? "transpose" : "");
metal::MTLFCList func_consts = {
{&forward_, MTL::DataType::DataTypeBool, 1},
{&traditional_, MTL::DataType::DataTypeBool, 2},
{&head_seq_transpose, MTL::DataType::DataTypeBool, 3}};
auto kernel = d.get_kernel(kname, hash_name, func_consts);
auto& compute_encoder = d.get_command_encoder(s.index);
float base = std::log2(base_);
@@ -93,7 +125,7 @@ void RoPE::eval_gpu(
compute_encoder.set_input_array(donated ? out : in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder.set_input_array(inputs[1], 2);
compute_encoder.set_input_array(offset, 2);
compute_encoder.set_bytes(scale_, 3);
MTL::Size group_dims;
@@ -107,8 +139,8 @@ void RoPE::eval_gpu(
compute_encoder.set_bytes(strides, 3, 4);
compute_encoder.set_bytes(out_strides, 3, 5);
int64_t offset_stride = 0;
if (inputs[1].ndim() > 0) {
offset_stride = inputs[1].strides()[0];
if (offset.ndim() > 0) {
offset_stride = offset.strides()[0];
}
compute_encoder.set_bytes(offset_stride, 6);
compute_encoder.set_bytes(N, 7);

View File

@@ -36,14 +36,6 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
bool contiguous = in.strides()[axis_] == 1;
std::ostringstream kname;
kname << (contiguous ? "contig_" : "strided_");
kname << "scan_";
if (reverse_) {
kname << "reverse_";
}
kname << ((inclusive_) ? "inclusive_" : "exclusive_");
std::string reduce_type;
switch (reduce_type_) {
case Scan::Sum:
@@ -62,9 +54,22 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
reduce_type = "logaddexp";
break;
}
kname << reduce_type << "_" << type_to_name(in) << "_" << type_to_name(out);
auto kernel = get_scan_kernel(
d, kname.str(), reverse_, inclusive_, reduce_type, in, out);
std::string kname;
concatenate(
kname,
contiguous ? "contig_" : "strided_",
"scan_",
reverse_ ? "reverse_" : "",
(inclusive_) ? "inclusive_" : "exclusive_",
reduce_type,
"_",
type_to_name(in),
"_",
type_to_name(out));
auto kernel =
get_scan_kernel(d, kname, reverse_, inclusive_, reduce_type, in, out);
if (contiguous) {
auto& compute_encoder = d.get_command_encoder(s.index);