Compare commits

...

8 Commits

Author SHA1 Message Date
John Mai
9d86a4d5ba
Merge b3c1aaafd2 into 5adf185f86 2025-06-22 14:39:55 +08:00
Angelos Katharopoulos
5adf185f86
Fix update_modules() when providing a subset (#2308) 2025-06-20 17:19:46 -07:00
Awni Hannun
c9a9180584
Cuda perf tuning (#2307)
* perf tuning

* fix adding inputs arrays in matmul / srot

* format

* fix
2025-06-20 14:50:57 -07:00
John Mai
b3c1aaafd2 update: format code 2025-06-15 17:35:33 +08:00
John Mai
989e8bab66 feat: Add benchmarking for ReLUSquared activation function 2025-06-15 17:34:10 +08:00
John Mai
fe0672a9d2 docs: Update documentation to include ReLUSquared activation function 2025-06-15 17:33:58 +08:00
John Mai
cbd353bf73 test: Add unit test for ReLUSquared activation function 2025-06-15 17:07:33 +08:00
John Mai
940f64fe6a feat: Add ReLUSquared activation function 2025-06-15 17:07:22 +08:00
14 changed files with 127 additions and 25 deletions

View File

@ -224,6 +224,13 @@ def relu6(x):
mx.eval(y)
def relu_squared(x):
y = x
for i in range(100):
y = nn.relu_squared(y)
mx.eval(y)
def softplus(x):
y = x
for i in range(100):
@ -458,6 +465,9 @@ if __name__ == "__main__":
elif args.benchmark == "relu6":
print(bench(relu6, x))
elif args.benchmark == "relu_squared":
print(bench(relu_squared, x))
elif args.benchmark == "celu":
print(bench(celu, x))

View File

@ -157,6 +157,15 @@ def relu6(x):
sync_if_needed(x)
@torch.no_grad()
def relu_squared(x):
y = x
for i in range(100):
y = torch.nn.functional.relu(y)
y = torch.square(y)
sync_if_needed(x)
@torch.no_grad()
def softplus(x):
y = x
@ -407,6 +416,9 @@ if __name__ == "__main__":
elif args.benchmark == "relu6":
print(bench(relu6, x))
elif args.benchmark == "relu_squared":
print(bench(relu_squared, x))
elif args.benchmark == "softplus":
print(bench(softplus, x))

View File

@ -207,6 +207,8 @@ if __name__ == "__main__":
compare_filtered("elu --size 32x16x1024 --cpu")
compare_filtered("relu6 --size 32x16x1024")
compare_filtered("relu6 --size 32x16x1024 --cpu")
compare_filtered("relu_squared --size 32x16x1024")
compare_filtered("relu_squared --size 32x16x1024 --cpu")
compare_filtered("softplus --size 32x16x1024")
compare_filtered("softplus --size 32x16x1024 --cpu")
compare_filtered("celu --size 32x16x1024")

View File

@ -28,6 +28,7 @@ simple functions.
prelu
relu
relu6
relu_squared
selu
sigmoid
silu

View File

@ -51,6 +51,7 @@ Layers
RMSNorm
ReLU
ReLU6
ReLUSquared
RNN
RoPE
SELU

View File

@ -3,6 +3,7 @@
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/backend/cuda/worker.h"
#include "mlx/utils.h"
#include <cuda_runtime.h>
#include <fmt/format.h>
@ -14,9 +15,11 @@ namespace mlx::core {
namespace cu {
constexpr int page_size = 16384;
CudaAllocator::CudaAllocator()
: buffer_cache_(
getpagesize(),
page_size,
[](CudaBuffer* buf) { return buf->size; },
[this](CudaBuffer* buf) {
cuda_free(buf->data);
@ -31,7 +34,14 @@ CudaAllocator::CudaAllocator()
Buffer CudaAllocator::malloc(size_t size) {
// Find available buffer from cache.
auto orig_size = size;
std::unique_lock lock(mutex_);
if (size < page_size) {
size = next_power_of_2(size);
} else {
size = page_size * ((size + page_size - 1) / page_size);
}
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
if (!buf) {
// If we have a lot of memory pressure or are over the maximum cache size,

View File

@ -24,7 +24,6 @@ void copy_gpu_inplace(
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
if (ctype == CopyType::Scalar || ctype == CopyType::Vector) {
copy_contiguous(encoder, ctype, in, out, offset_in, offset_out);
return;

View File

@ -155,8 +155,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_nd(
#pragma unroll
for (int i = NDIM - 1; i >= 0; --i) {
int dim_idx = elem % shape[i];
a_loc += dim_idx * a_strides[i];
b_loc += dim_idx * b_strides[i];
a_loc += dim_idx * IdxT(a_strides[i]);
b_loc += dim_idx * IdxT(b_strides[i]);
elem /= shape[i];
}
return cuda::std::make_tuple(a_loc, b_loc);
@ -175,9 +175,9 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_nd(
#pragma unroll
for (int i = NDIM - 1; i >= 0; --i) {
int dim_idx = elem % shape[i];
a_loc += dim_idx * a_strides[i];
b_loc += dim_idx * b_strides[i];
c_loc += dim_idx * c_strides[i];
a_loc += dim_idx * IdxT(a_strides[i]);
b_loc += dim_idx * IdxT(b_strides[i]);
c_loc += dim_idx * IdxT(c_strides[i]);
elem /= shape[i];
}
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
@ -206,8 +206,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
IdxT b_loc = 0;
for (int i = ndim - 1; i >= 0; --i) {
int dim_idx = elem % shape[i];
a_loc += dim_idx * a_strides[i];
b_loc += dim_idx * b_strides[i];
a_loc += dim_idx * IdxT(a_strides[i]);
b_loc += dim_idx * IdxT(b_strides[i]);
elem /= shape[i];
}
return cuda::std::make_tuple(a_loc, b_loc);
@ -226,9 +226,9 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_4d(
IdxT c_loc = 0;
for (int i = ndim - 1; i >= 0; --i) {
int dim_idx = elem % shape[i];
a_loc += dim_idx * a_strides[i];
b_loc += dim_idx * b_strides[i];
c_loc += dim_idx * c_strides[i];
a_loc += dim_idx * IdxT(a_strides[i]);
b_loc += dim_idx * IdxT(b_strides[i]);
c_loc += dim_idx * IdxT(c_strides[i]);
elem /= shape[i];
}
return cuda::std::make_tuple(a_loc, b_loc, c_loc);

View File

@ -162,11 +162,15 @@ class MatMul {
}
}
array workspace(
allocator::malloc(heuristic_.workspaceSize),
{static_cast<int>(heuristic_.workspaceSize)},
int8);
encoder.add_temporary(workspace);
void* workspace_ptr = nullptr;
if (heuristic_.workspaceSize > 0) {
array workspace(
allocator::malloc(heuristic_.workspaceSize),
{static_cast<int>(heuristic_.workspaceSize)},
int8);
encoder.add_temporary(workspace);
workspace_ptr = workspace.data<void>();
}
encoder.launch_kernel([&](cudaStream_t stream) {
CHECK_CUBLAS_ERROR(cublasLtMatmul(
@ -183,8 +187,8 @@ class MatMul {
out,
out_desc_,
&heuristic_.algo,
workspace.data<void>(),
workspace.nbytes(),
workspace_ptr,
heuristic_.workspaceSize,
stream));
});
}
@ -358,9 +362,18 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
a_batch_strides.back(),
b_batch_strides.back());
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
auto nbatch = batch_count / batch_shape.back();
if (nbatch == 1) {
matmul.run(encoder, out.data<int8_t>(), a.data<int8_t>(), b.data<int8_t>());
return;
}
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) {
for (size_t i = 0; i < nbatch; ++i) {
matmul.run(
encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
@ -444,10 +457,28 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
b_batch_strides.back(),
c_batch_strides.back());
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_input_array(c);
encoder.set_output_array(out);
auto nbatch = batch_count / batch_shape.back();
if (nbatch == 1) {
matmul.run(
encoder,
out.data<int8_t>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
alpha_,
beta_);
return;
}
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) {
for (size_t i = 0; i < nbatch; ++i) {
matmul.run(
encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,

View File

@ -79,9 +79,6 @@ void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) {
void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
array out = out_;
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);
if (axis < 0) {
axis += in.ndim();
}
@ -106,6 +103,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
in.flags());
}
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
if constexpr (!std::is_same_v<CTYPE, complex64_t>) {

View File

@ -16,6 +16,7 @@ from mlx.nn.layers.activations import (
PReLU,
ReLU,
ReLU6,
ReLUSquared,
Sigmoid,
SiLU,
Softmax,
@ -41,6 +42,7 @@ from mlx.nn.layers.activations import (
prelu,
relu,
relu6,
relu_squared,
selu,
sigmoid,
silu,

View File

@ -71,6 +71,17 @@ def relu6(x):
return mx.minimum(mx.maximum(x, 0), 6.0)
@partial(mx.compile, shapeless=True)
def relu_squared(x):
r"""Applies the Rectified Linear Unit squared.
Applies :math:`\max(x, 0)^2` element wise.
Reference: https://arxiv.org/abs/2109.08668v2
"""
return relu(x).square()
@partial(mx.compile, shapeless=True)
def softmax(x, axis=-1):
r"""Applies the Softmax function.
@ -420,6 +431,18 @@ class ReLU6(Module):
"""
@_make_activation_module(relu_squared)
class ReLUSquared(Module):
r"""Applies the Rectified Linear Unit squared.
Applies :math:`\max(x, 0)^2` element wise.
Reference: https://arxiv.org/abs/2109.08668v2
See :func:`relu_squared` for the functional equivalent.
"""
@_make_activation_module(softmax)
class Softmax(Module):
r"""Applies the Softmax function.

View File

@ -413,7 +413,7 @@ class Module(dict):
f'Module does not have sub-module named "{k}".'
)
elif isinstance(modules, list):
for i in range(len(dst)):
for i in range(len(modules)):
current_value = dst[i]
new_value = modules[i]
if self.is_module(current_value) and self.is_module(new_value):

View File

@ -259,6 +259,11 @@ class TestBase(mlx_tests.MLXTestCase):
with self.assertRaises(ValueError):
m = m.update_modules({"list": ["hi"]})
# Allow updating a strict subset
m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
m.update_modules({"layers": [{}, nn.Linear(3, 4)]})
self.assertEqual(m.layers[1].weight.shape, (4, 3))
class TestLayers(mlx_tests.MLXTestCase):
def test_identity(self):
@ -855,6 +860,13 @@ class TestLayers(mlx_tests.MLXTestCase):
self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32)
def test_relu_squared(self):
x = mx.array([-1.0, 0.0, 1.0, 2.0, 3.0])
y = nn.relu_squared(x)
self.assertTrue(mx.array_equal(y, mx.array([0.0, 0.0, 1.0, 4.0, 9.0])))
self.assertEqual(y.shape, (5,))
self.assertEqual(y.dtype, mx.float32)
def test_leaky_relu(self):
x = mx.array([1.0, -1.0, 0.0])
y = nn.leaky_relu(x)