Compare commits

...

7 Commits

Author SHA1 Message Date
John Mai
6c901ccbc9
Merge b3c1aaafd2 into b3d7b85376 2025-06-18 22:56:20 +08:00
Angelos Katharopoulos
b3d7b85376
Make ptx cache settable by environment variable (#2304) 2025-06-17 23:55:56 -07:00
John Mai
b3c1aaafd2 update: format code 2025-06-15 17:35:33 +08:00
John Mai
989e8bab66 feat: Add benchmarking for ReLUSquared activation function 2025-06-15 17:34:10 +08:00
John Mai
fe0672a9d2 docs: Update documentation to include ReLUSquared activation function 2025-06-15 17:33:58 +08:00
John Mai
cbd353bf73 test: Add unit test for ReLUSquared activation function 2025-06-15 17:07:33 +08:00
John Mai
940f64fe6a feat: Add ReLUSquared activation function 2025-06-15 17:07:22 +08:00
9 changed files with 102 additions and 28 deletions

View File

@ -224,6 +224,13 @@ def relu6(x):
mx.eval(y) mx.eval(y)
def relu_squared(x):
y = x
for i in range(100):
y = nn.relu_squared(y)
mx.eval(y)
def softplus(x): def softplus(x):
y = x y = x
for i in range(100): for i in range(100):
@ -458,6 +465,9 @@ if __name__ == "__main__":
elif args.benchmark == "relu6": elif args.benchmark == "relu6":
print(bench(relu6, x)) print(bench(relu6, x))
elif args.benchmark == "relu_squared":
print(bench(relu_squared, x))
elif args.benchmark == "celu": elif args.benchmark == "celu":
print(bench(celu, x)) print(bench(celu, x))

View File

@ -157,6 +157,15 @@ def relu6(x):
sync_if_needed(x) sync_if_needed(x)
@torch.no_grad()
def relu_squared(x):
y = x
for i in range(100):
y = torch.nn.functional.relu(y)
y = torch.square(y)
sync_if_needed(x)
@torch.no_grad() @torch.no_grad()
def softplus(x): def softplus(x):
y = x y = x
@ -407,6 +416,9 @@ if __name__ == "__main__":
elif args.benchmark == "relu6": elif args.benchmark == "relu6":
print(bench(relu6, x)) print(bench(relu6, x))
elif args.benchmark == "relu_squared":
print(bench(relu_squared, x))
elif args.benchmark == "softplus": elif args.benchmark == "softplus":
print(bench(softplus, x)) print(bench(softplus, x))

View File

@ -207,6 +207,8 @@ if __name__ == "__main__":
compare_filtered("elu --size 32x16x1024 --cpu") compare_filtered("elu --size 32x16x1024 --cpu")
compare_filtered("relu6 --size 32x16x1024") compare_filtered("relu6 --size 32x16x1024")
compare_filtered("relu6 --size 32x16x1024 --cpu") compare_filtered("relu6 --size 32x16x1024 --cpu")
compare_filtered("relu_squared --size 32x16x1024")
compare_filtered("relu_squared --size 32x16x1024 --cpu")
compare_filtered("softplus --size 32x16x1024") compare_filtered("softplus --size 32x16x1024")
compare_filtered("softplus --size 32x16x1024 --cpu") compare_filtered("softplus --size 32x16x1024 --cpu")
compare_filtered("celu --size 32x16x1024") compare_filtered("celu --size 32x16x1024")

View File

@ -28,6 +28,7 @@ simple functions.
prelu prelu
relu relu
relu6 relu6
relu_squared
selu selu
sigmoid sigmoid
silu silu

View File

@ -51,6 +51,7 @@ Layers
RMSNorm RMSNorm
ReLU ReLU
ReLU6 ReLU6
ReLUSquared
RNN RNN
RoPE RoPE
SELU SELU

View File

@ -37,7 +37,8 @@ void check_cu_error(const char* name, CUresult err) {
} }
// Return the location of the CUDA toolkit. // Return the location of the CUDA toolkit.
const char* cuda_home() { const std::string& cuda_home() {
static std::string home = []() -> std::string {
const char* home = std::getenv("CUDA_HOME"); const char* home = std::getenv("CUDA_HOME");
if (home) { if (home) {
return home; return home;
@ -54,19 +55,28 @@ const char* cuda_home() {
#endif #endif
throw std::runtime_error( throw std::runtime_error(
"Environment variable CUDA_HOME or CUDA_PATH is not set."); "Environment variable CUDA_HOME or CUDA_PATH is not set.");
}();
return home;
} }
// Get the cache directory for storing compiled results. // Get the cache directory for storing compiled results.
bool get_ptx_cache_dir(std::filesystem::path* result) { const std::filesystem::path& ptx_cache_dir() {
auto path = std::filesystem::temp_directory_path() / "mlx" / "ptx"; static std::filesystem::path cache = []() -> std::filesystem::path {
if (!std::filesystem::is_directory(path)) { std::filesystem::path cache;
if (auto c = std::getenv("MLX_PTX_CACHE"); c) {
cache = c;
} else {
cache = std::filesystem::temp_directory_path() / "mlx" / "ptx";
}
if (!std::filesystem::exists(cache)) {
std::error_code error; std::error_code error;
if (!std::filesystem::create_directories(path, error)) { if (!std::filesystem::create_directories(cache, error)) {
return false; return std::filesystem::path();
} }
} }
*result = path; return cache;
return true; }();
return cache;
} }
// Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|. // Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|.
@ -75,6 +85,10 @@ bool read_cached_ptx(
const std::string& module_name, const std::string& module_name,
std::vector<char>* ptx, std::vector<char>* ptx,
std::vector<std::pair<std::string, std::string>>* ptx_kernels) { std::vector<std::pair<std::string, std::string>>* ptx_kernels) {
if (cache_dir.empty()) {
return false;
}
auto ptx_path = cache_dir / (module_name + ".ptx"); auto ptx_path = cache_dir / (module_name + ".ptx");
std::error_code error; std::error_code error;
auto ptx_size = std::filesystem::file_size(ptx_path, error); auto ptx_size = std::filesystem::file_size(ptx_path, error);
@ -105,6 +119,10 @@ void write_cached_ptx(
const std::string& module_name, const std::string& module_name,
const std::vector<char>& ptx, const std::vector<char>& ptx,
const std::vector<std::pair<std::string, std::string>>& ptx_kernels) { const std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
if (cache_dir.empty()) {
return;
}
std::ofstream ptx_file(cache_dir / (module_name + ".ptx"), std::ios::binary); std::ofstream ptx_file(cache_dir / (module_name + ".ptx"), std::ios::binary);
if (!ptx.empty()) { if (!ptx.empty()) {
ptx_file.write(&ptx.front(), ptx.size()); ptx_file.write(&ptx.front(), ptx.size());
@ -184,11 +202,9 @@ JitModule::JitModule(
const std::string& module_name, const std::string& module_name,
const KernelBuilder& builder) { const KernelBuilder& builder) {
// Check cache. // Check cache.
std::filesystem::path cache_dir;
std::vector<char> ptx; std::vector<char> ptx;
std::vector<std::pair<std::string, std::string>> ptx_kernels; std::vector<std::pair<std::string, std::string>> ptx_kernels;
if (!get_ptx_cache_dir(&cache_dir) || if (!read_cached_ptx(ptx_cache_dir(), module_name, &ptx, &ptx_kernels)) {
!read_cached_ptx(cache_dir, module_name, &ptx, &ptx_kernels)) {
// Create program. // Create program.
auto [source_code, kernel_names] = builder(); auto [source_code, kernel_names] = builder();
nvrtcProgram prog; nvrtcProgram prog;
@ -246,7 +262,7 @@ JitModule::JitModule(
} else { } else {
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data())); CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
} }
write_cached_ptx(cache_dir, module_name, ptx, ptx_kernels); write_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels);
} }
// Load module. // Load module.

View File

@ -16,6 +16,7 @@ from mlx.nn.layers.activations import (
PReLU, PReLU,
ReLU, ReLU,
ReLU6, ReLU6,
ReLUSquared,
Sigmoid, Sigmoid,
SiLU, SiLU,
Softmax, Softmax,
@ -41,6 +42,7 @@ from mlx.nn.layers.activations import (
prelu, prelu,
relu, relu,
relu6, relu6,
relu_squared,
selu, selu,
sigmoid, sigmoid,
silu, silu,

View File

@ -71,6 +71,17 @@ def relu6(x):
return mx.minimum(mx.maximum(x, 0), 6.0) return mx.minimum(mx.maximum(x, 0), 6.0)
@partial(mx.compile, shapeless=True)
def relu_squared(x):
r"""Applies the Rectified Linear Unit squared.
Applies :math:`\max(x, 0)^2` element wise.
Reference: https://arxiv.org/abs/2109.08668v2
"""
return relu(x).square()
@partial(mx.compile, shapeless=True) @partial(mx.compile, shapeless=True)
def softmax(x, axis=-1): def softmax(x, axis=-1):
r"""Applies the Softmax function. r"""Applies the Softmax function.
@ -420,6 +431,18 @@ class ReLU6(Module):
""" """
@_make_activation_module(relu_squared)
class ReLUSquared(Module):
r"""Applies the Rectified Linear Unit squared.
Applies :math:`\max(x, 0)^2` element wise.
Reference: https://arxiv.org/abs/2109.08668v2
See :func:`relu_squared` for the functional equivalent.
"""
@_make_activation_module(softmax) @_make_activation_module(softmax)
class Softmax(Module): class Softmax(Module):
r"""Applies the Softmax function. r"""Applies the Softmax function.

View File

@ -855,6 +855,13 @@ class TestLayers(mlx_tests.MLXTestCase):
self.assertEqual(y.shape, (3,)) self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32) self.assertEqual(y.dtype, mx.float32)
def test_relu_squared(self):
x = mx.array([-1.0, 0.0, 1.0, 2.0, 3.0])
y = nn.relu_squared(x)
self.assertTrue(mx.array_equal(y, mx.array([0.0, 0.0, 1.0, 4.0, 9.0])))
self.assertEqual(y.shape, (5,))
self.assertEqual(y.dtype, mx.float32)
def test_leaky_relu(self): def test_leaky_relu(self):
x = mx.array([1.0, -1.0, 0.0]) x = mx.array([1.0, -1.0, 0.0])
y = nn.leaky_relu(x) y = nn.leaky_relu(x)