Compare commits

...

7 Commits

Author SHA1 Message Date
John Mai
7170e5f40b
Merge b3c1aaafd2 into cad5c0241c 2025-06-17 21:14:13 +02:00
Awni Hannun
cad5c0241c
[CUDA] synch properly waits for all tasks to finish and clear (#2303)
* cuda synch properly waits for all tasks to finish and clear

* fix copy
2025-06-17 12:03:25 -07:00
John Mai
b3c1aaafd2 update: format code 2025-06-15 17:35:33 +08:00
John Mai
989e8bab66 feat: Add benchmarking for ReLUSquared activation function 2025-06-15 17:34:10 +08:00
John Mai
fe0672a9d2 docs: Update documentation to include ReLUSquared activation function 2025-06-15 17:33:58 +08:00
John Mai
cbd353bf73 test: Add unit test for ReLUSquared activation function 2025-06-15 17:07:33 +08:00
John Mai
940f64fe6a feat: Add ReLUSquared activation function 2025-06-15 17:07:22 +08:00
15 changed files with 85 additions and 8 deletions

View File

@ -224,6 +224,13 @@ def relu6(x):
mx.eval(y) mx.eval(y)
def relu_squared(x):
y = x
for i in range(100):
y = nn.relu_squared(y)
mx.eval(y)
def softplus(x): def softplus(x):
y = x y = x
for i in range(100): for i in range(100):
@ -458,6 +465,9 @@ if __name__ == "__main__":
elif args.benchmark == "relu6": elif args.benchmark == "relu6":
print(bench(relu6, x)) print(bench(relu6, x))
elif args.benchmark == "relu_squared":
print(bench(relu_squared, x))
elif args.benchmark == "celu": elif args.benchmark == "celu":
print(bench(celu, x)) print(bench(celu, x))

View File

@ -157,6 +157,15 @@ def relu6(x):
sync_if_needed(x) sync_if_needed(x)
@torch.no_grad()
def relu_squared(x):
y = x
for i in range(100):
y = torch.nn.functional.relu(y)
y = torch.square(y)
sync_if_needed(x)
@torch.no_grad() @torch.no_grad()
def softplus(x): def softplus(x):
y = x y = x
@ -407,6 +416,9 @@ if __name__ == "__main__":
elif args.benchmark == "relu6": elif args.benchmark == "relu6":
print(bench(relu6, x)) print(bench(relu6, x))
elif args.benchmark == "relu_squared":
print(bench(relu_squared, x))
elif args.benchmark == "softplus": elif args.benchmark == "softplus":
print(bench(softplus, x)) print(bench(softplus, x))

View File

@ -207,6 +207,8 @@ if __name__ == "__main__":
compare_filtered("elu --size 32x16x1024 --cpu") compare_filtered("elu --size 32x16x1024 --cpu")
compare_filtered("relu6 --size 32x16x1024") compare_filtered("relu6 --size 32x16x1024")
compare_filtered("relu6 --size 32x16x1024 --cpu") compare_filtered("relu6 --size 32x16x1024 --cpu")
compare_filtered("relu_squared --size 32x16x1024")
compare_filtered("relu_squared --size 32x16x1024 --cpu")
compare_filtered("softplus --size 32x16x1024") compare_filtered("softplus --size 32x16x1024")
compare_filtered("softplus --size 32x16x1024 --cpu") compare_filtered("softplus --size 32x16x1024 --cpu")
compare_filtered("celu --size 32x16x1024") compare_filtered("celu --size 32x16x1024")

View File

@ -28,6 +28,7 @@ simple functions.
prelu prelu
relu relu
relu6 relu6
relu_squared
selu selu
sigmoid sigmoid
silu silu

View File

@ -51,6 +51,7 @@ Layers
RMSNorm RMSNorm
ReLU ReLU
ReLU6 ReLU6
ReLUSquared
RNN RNN
RoPE RoPE
SELU SELU

View File

@ -106,7 +106,6 @@ void CudaAllocator::cuda_free(void* buf) {
return; return;
} }
} }
cudaFree(buf); cudaFree(buf);
} }

View File

@ -63,25 +63,30 @@ void copy_general(
MLX_SWITCH_BOOL(large, LARGE, { MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>; using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
int ndim = shape.size(); int ndim = shape.size();
size_t data_size = 1;
for (auto& s : shape)
data_size *= s;
if (ndim <= 3) { if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, { MLX_SWITCH_1_2_3(ndim, NDIM, {
auto kernel = cu::copy_gg_nd<InType, OutType, IdxT, NDIM>; auto kernel = cu::copy_gg_nd<InType, OutType, IdxT, NDIM>;
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); auto [num_blocks, block_dims] =
get_launch_args(kernel, data_size, shape, out.strides(), large);
kernel<<<num_blocks, block_dims, 0, stream>>>( kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr, in_ptr,
out_ptr, out_ptr,
out.size(), data_size,
const_param<NDIM>(shape), const_param<NDIM>(shape),
const_param<NDIM>(strides_in), const_param<NDIM>(strides_in),
const_param<NDIM>(strides_out)); const_param<NDIM>(strides_out));
}); });
} else { // ndim >= 4 } else { // ndim >= 4
auto kernel = cu::copy_gg<InType, OutType, IdxT>; auto kernel = cu::copy_gg<InType, OutType, IdxT>;
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); auto [num_blocks, block_dims] =
get_launch_args(kernel, data_size, shape, out.strides(), large);
kernel<<<num_blocks, block_dims, 0, stream>>>( kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr, in_ptr,
out_ptr, out_ptr,
out.size(), data_size,
const_param(shape), const_param(shape),
const_param(strides_in), const_param(strides_in),
const_param(strides_out), const_param(strides_out),

View File

@ -6,6 +6,7 @@
#include <fmt/format.h> #include <fmt/format.h>
#include <nvtx3/nvtx3.hpp> #include <nvtx3/nvtx3.hpp>
#include <future>
namespace mlx::core { namespace mlx::core {
@ -107,6 +108,16 @@ void CommandEncoder::commit() {
worker_.commit(stream_.last_cuda_stream()); worker_.commit(stream_.last_cuda_stream());
} }
void CommandEncoder::synchronize() {
stream().synchronize();
auto p = std::make_shared<std::promise<void>>();
std::future<void> f = p->get_future();
add_completed_handler([p = std::move(p)]() { p->set_value(); });
worker_.end_batch();
worker_.commit();
f.wait();
}
Device& device(mlx::core::Device device) { Device& device(mlx::core::Device device) {
static std::unordered_map<int, Device> devices; static std::unordered_map<int, Device> devices;
auto it = devices.find(device.index); auto it = devices.find(device.index);

View File

@ -123,6 +123,9 @@ class CommandEncoder {
return has_gpu_work_; return has_gpu_work_;
} }
// Wait until kernels and completion handlers are finished
void synchronize();
private: private:
Device& device_; Device& device_;
DeviceStream& stream_; DeviceStream& stream_;

View File

@ -62,7 +62,7 @@ void finalize(Stream s) {
void synchronize(Stream s) { void synchronize(Stream s) {
nvtx3::scoped_range r("gpu::synchronize"); nvtx3::scoped_range r("gpu::synchronize");
cu::get_stream(s).synchronize(); cu::get_command_encoder(s).synchronize();
} }
} // namespace mlx::core::gpu } // namespace mlx::core::gpu

View File

@ -80,7 +80,9 @@ void Worker::thread_fn() {
} }
worker_tasks_.erase(worker_tasks_.begin(), end); worker_tasks_.erase(worker_tasks_.begin(), end);
} }
for (auto& task : tasks) { // Make sure tasks are cleared before the next wait
for (int i = 0; i < tasks.size(); ++i) {
auto task = std::move(tasks[i]);
task(); task();
} }
worker_event_.wait(batch + 1); worker_event_.wait(batch + 1);

View File

@ -16,6 +16,7 @@ from mlx.nn.layers.activations import (
PReLU, PReLU,
ReLU, ReLU,
ReLU6, ReLU6,
ReLUSquared,
Sigmoid, Sigmoid,
SiLU, SiLU,
Softmax, Softmax,
@ -41,6 +42,7 @@ from mlx.nn.layers.activations import (
prelu, prelu,
relu, relu,
relu6, relu6,
relu_squared,
selu, selu,
sigmoid, sigmoid,
silu, silu,

View File

@ -71,6 +71,17 @@ def relu6(x):
return mx.minimum(mx.maximum(x, 0), 6.0) return mx.minimum(mx.maximum(x, 0), 6.0)
@partial(mx.compile, shapeless=True)
def relu_squared(x):
r"""Applies the Rectified Linear Unit squared.
Applies :math:`\max(x, 0)^2` element wise.
Reference: https://arxiv.org/abs/2109.08668v2
"""
return relu(x).square()
@partial(mx.compile, shapeless=True) @partial(mx.compile, shapeless=True)
def softmax(x, axis=-1): def softmax(x, axis=-1):
r"""Applies the Softmax function. r"""Applies the Softmax function.
@ -420,6 +431,18 @@ class ReLU6(Module):
""" """
@_make_activation_module(relu_squared)
class ReLUSquared(Module):
r"""Applies the Rectified Linear Unit squared.
Applies :math:`\max(x, 0)^2` element wise.
Reference: https://arxiv.org/abs/2109.08668v2
See :func:`relu_squared` for the functional equivalent.
"""
@_make_activation_module(softmax) @_make_activation_module(softmax)
class Softmax(Module): class Softmax(Module):
r"""Applies the Softmax function. r"""Applies the Softmax function.

View File

@ -6,7 +6,6 @@ cuda_skip = {
"TestEinsum.test_ellipses", "TestEinsum.test_ellipses",
"TestEinsum.test_opt_einsum_test_cases", "TestEinsum.test_opt_einsum_test_cases",
"TestLoad.test_load_f8_e4m3", "TestLoad.test_load_f8_e4m3",
"TestMemory.test_memory_info",
"TestLayers.test_group_norm", "TestLayers.test_group_norm",
"TestLayers.test_pooling", "TestLayers.test_pooling",
"TestLayers.test_quantized_embedding", "TestLayers.test_quantized_embedding",

View File

@ -855,6 +855,13 @@ class TestLayers(mlx_tests.MLXTestCase):
self.assertEqual(y.shape, (3,)) self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32) self.assertEqual(y.dtype, mx.float32)
def test_relu_squared(self):
x = mx.array([-1.0, 0.0, 1.0, 2.0, 3.0])
y = nn.relu_squared(x)
self.assertTrue(mx.array_equal(y, mx.array([0.0, 0.0, 1.0, 4.0, 9.0])))
self.assertEqual(y.shape, (5,))
self.assertEqual(y.dtype, mx.float32)
def test_leaky_relu(self): def test_leaky_relu(self):
x = mx.array([1.0, -1.0, 0.0]) x = mx.array([1.0, -1.0, 0.0])
y = nn.leaky_relu(x) y = nn.leaky_relu(x)