mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-25 03:31:17 +08:00
Compare commits
7 Commits
7170e5f40b
...
6c901ccbc9
Author | SHA1 | Date | |
---|---|---|---|
![]() |
6c901ccbc9 | ||
![]() |
b3d7b85376 | ||
![]() |
b3c1aaafd2 | ||
![]() |
989e8bab66 | ||
![]() |
fe0672a9d2 | ||
![]() |
cbd353bf73 | ||
![]() |
940f64fe6a |
@ -224,6 +224,13 @@ def relu6(x):
|
|||||||
mx.eval(y)
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
|
def relu_squared(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = nn.relu_squared(y)
|
||||||
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
def softplus(x):
|
def softplus(x):
|
||||||
y = x
|
y = x
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
@ -458,6 +465,9 @@ if __name__ == "__main__":
|
|||||||
elif args.benchmark == "relu6":
|
elif args.benchmark == "relu6":
|
||||||
print(bench(relu6, x))
|
print(bench(relu6, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "relu_squared":
|
||||||
|
print(bench(relu_squared, x))
|
||||||
|
|
||||||
elif args.benchmark == "celu":
|
elif args.benchmark == "celu":
|
||||||
print(bench(celu, x))
|
print(bench(celu, x))
|
||||||
|
|
||||||
|
@ -157,6 +157,15 @@ def relu6(x):
|
|||||||
sync_if_needed(x)
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def relu_squared(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = torch.nn.functional.relu(y)
|
||||||
|
y = torch.square(y)
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def softplus(x):
|
def softplus(x):
|
||||||
y = x
|
y = x
|
||||||
@ -407,6 +416,9 @@ if __name__ == "__main__":
|
|||||||
elif args.benchmark == "relu6":
|
elif args.benchmark == "relu6":
|
||||||
print(bench(relu6, x))
|
print(bench(relu6, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "relu_squared":
|
||||||
|
print(bench(relu_squared, x))
|
||||||
|
|
||||||
elif args.benchmark == "softplus":
|
elif args.benchmark == "softplus":
|
||||||
print(bench(softplus, x))
|
print(bench(softplus, x))
|
||||||
|
|
||||||
|
@ -207,6 +207,8 @@ if __name__ == "__main__":
|
|||||||
compare_filtered("elu --size 32x16x1024 --cpu")
|
compare_filtered("elu --size 32x16x1024 --cpu")
|
||||||
compare_filtered("relu6 --size 32x16x1024")
|
compare_filtered("relu6 --size 32x16x1024")
|
||||||
compare_filtered("relu6 --size 32x16x1024 --cpu")
|
compare_filtered("relu6 --size 32x16x1024 --cpu")
|
||||||
|
compare_filtered("relu_squared --size 32x16x1024")
|
||||||
|
compare_filtered("relu_squared --size 32x16x1024 --cpu")
|
||||||
compare_filtered("softplus --size 32x16x1024")
|
compare_filtered("softplus --size 32x16x1024")
|
||||||
compare_filtered("softplus --size 32x16x1024 --cpu")
|
compare_filtered("softplus --size 32x16x1024 --cpu")
|
||||||
compare_filtered("celu --size 32x16x1024")
|
compare_filtered("celu --size 32x16x1024")
|
||||||
|
@ -28,6 +28,7 @@ simple functions.
|
|||||||
prelu
|
prelu
|
||||||
relu
|
relu
|
||||||
relu6
|
relu6
|
||||||
|
relu_squared
|
||||||
selu
|
selu
|
||||||
sigmoid
|
sigmoid
|
||||||
silu
|
silu
|
||||||
|
@ -51,6 +51,7 @@ Layers
|
|||||||
RMSNorm
|
RMSNorm
|
||||||
ReLU
|
ReLU
|
||||||
ReLU6
|
ReLU6
|
||||||
|
ReLUSquared
|
||||||
RNN
|
RNN
|
||||||
RoPE
|
RoPE
|
||||||
SELU
|
SELU
|
||||||
|
@ -37,7 +37,8 @@ void check_cu_error(const char* name, CUresult err) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Return the location of the CUDA toolkit.
|
// Return the location of the CUDA toolkit.
|
||||||
const char* cuda_home() {
|
const std::string& cuda_home() {
|
||||||
|
static std::string home = []() -> std::string {
|
||||||
const char* home = std::getenv("CUDA_HOME");
|
const char* home = std::getenv("CUDA_HOME");
|
||||||
if (home) {
|
if (home) {
|
||||||
return home;
|
return home;
|
||||||
@ -54,19 +55,28 @@ const char* cuda_home() {
|
|||||||
#endif
|
#endif
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"Environment variable CUDA_HOME or CUDA_PATH is not set.");
|
"Environment variable CUDA_HOME or CUDA_PATH is not set.");
|
||||||
|
}();
|
||||||
|
return home;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the cache directory for storing compiled results.
|
// Get the cache directory for storing compiled results.
|
||||||
bool get_ptx_cache_dir(std::filesystem::path* result) {
|
const std::filesystem::path& ptx_cache_dir() {
|
||||||
auto path = std::filesystem::temp_directory_path() / "mlx" / "ptx";
|
static std::filesystem::path cache = []() -> std::filesystem::path {
|
||||||
if (!std::filesystem::is_directory(path)) {
|
std::filesystem::path cache;
|
||||||
|
if (auto c = std::getenv("MLX_PTX_CACHE"); c) {
|
||||||
|
cache = c;
|
||||||
|
} else {
|
||||||
|
cache = std::filesystem::temp_directory_path() / "mlx" / "ptx";
|
||||||
|
}
|
||||||
|
if (!std::filesystem::exists(cache)) {
|
||||||
std::error_code error;
|
std::error_code error;
|
||||||
if (!std::filesystem::create_directories(path, error)) {
|
if (!std::filesystem::create_directories(cache, error)) {
|
||||||
return false;
|
return std::filesystem::path();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
*result = path;
|
return cache;
|
||||||
return true;
|
}();
|
||||||
|
return cache;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|.
|
// Try to read the cached |ptx| and |ptx_kernels| from |cache_dir|.
|
||||||
@ -75,6 +85,10 @@ bool read_cached_ptx(
|
|||||||
const std::string& module_name,
|
const std::string& module_name,
|
||||||
std::vector<char>* ptx,
|
std::vector<char>* ptx,
|
||||||
std::vector<std::pair<std::string, std::string>>* ptx_kernels) {
|
std::vector<std::pair<std::string, std::string>>* ptx_kernels) {
|
||||||
|
if (cache_dir.empty()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
auto ptx_path = cache_dir / (module_name + ".ptx");
|
auto ptx_path = cache_dir / (module_name + ".ptx");
|
||||||
std::error_code error;
|
std::error_code error;
|
||||||
auto ptx_size = std::filesystem::file_size(ptx_path, error);
|
auto ptx_size = std::filesystem::file_size(ptx_path, error);
|
||||||
@ -105,6 +119,10 @@ void write_cached_ptx(
|
|||||||
const std::string& module_name,
|
const std::string& module_name,
|
||||||
const std::vector<char>& ptx,
|
const std::vector<char>& ptx,
|
||||||
const std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
|
const std::vector<std::pair<std::string, std::string>>& ptx_kernels) {
|
||||||
|
if (cache_dir.empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
std::ofstream ptx_file(cache_dir / (module_name + ".ptx"), std::ios::binary);
|
std::ofstream ptx_file(cache_dir / (module_name + ".ptx"), std::ios::binary);
|
||||||
if (!ptx.empty()) {
|
if (!ptx.empty()) {
|
||||||
ptx_file.write(&ptx.front(), ptx.size());
|
ptx_file.write(&ptx.front(), ptx.size());
|
||||||
@ -184,11 +202,9 @@ JitModule::JitModule(
|
|||||||
const std::string& module_name,
|
const std::string& module_name,
|
||||||
const KernelBuilder& builder) {
|
const KernelBuilder& builder) {
|
||||||
// Check cache.
|
// Check cache.
|
||||||
std::filesystem::path cache_dir;
|
|
||||||
std::vector<char> ptx;
|
std::vector<char> ptx;
|
||||||
std::vector<std::pair<std::string, std::string>> ptx_kernels;
|
std::vector<std::pair<std::string, std::string>> ptx_kernels;
|
||||||
if (!get_ptx_cache_dir(&cache_dir) ||
|
if (!read_cached_ptx(ptx_cache_dir(), module_name, &ptx, &ptx_kernels)) {
|
||||||
!read_cached_ptx(cache_dir, module_name, &ptx, &ptx_kernels)) {
|
|
||||||
// Create program.
|
// Create program.
|
||||||
auto [source_code, kernel_names] = builder();
|
auto [source_code, kernel_names] = builder();
|
||||||
nvrtcProgram prog;
|
nvrtcProgram prog;
|
||||||
@ -246,7 +262,7 @@ JitModule::JitModule(
|
|||||||
} else {
|
} else {
|
||||||
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
|
CHECK_NVRTC_ERROR(nvrtcGetPTX(prog, ptx.data()));
|
||||||
}
|
}
|
||||||
write_cached_ptx(cache_dir, module_name, ptx, ptx_kernels);
|
write_cached_ptx(ptx_cache_dir(), module_name, ptx, ptx_kernels);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load module.
|
// Load module.
|
||||||
|
@ -16,6 +16,7 @@ from mlx.nn.layers.activations import (
|
|||||||
PReLU,
|
PReLU,
|
||||||
ReLU,
|
ReLU,
|
||||||
ReLU6,
|
ReLU6,
|
||||||
|
ReLUSquared,
|
||||||
Sigmoid,
|
Sigmoid,
|
||||||
SiLU,
|
SiLU,
|
||||||
Softmax,
|
Softmax,
|
||||||
@ -41,6 +42,7 @@ from mlx.nn.layers.activations import (
|
|||||||
prelu,
|
prelu,
|
||||||
relu,
|
relu,
|
||||||
relu6,
|
relu6,
|
||||||
|
relu_squared,
|
||||||
selu,
|
selu,
|
||||||
sigmoid,
|
sigmoid,
|
||||||
silu,
|
silu,
|
||||||
|
@ -71,6 +71,17 @@ def relu6(x):
|
|||||||
return mx.minimum(mx.maximum(x, 0), 6.0)
|
return mx.minimum(mx.maximum(x, 0), 6.0)
|
||||||
|
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
|
def relu_squared(x):
|
||||||
|
r"""Applies the Rectified Linear Unit squared.
|
||||||
|
|
||||||
|
Applies :math:`\max(x, 0)^2` element wise.
|
||||||
|
|
||||||
|
Reference: https://arxiv.org/abs/2109.08668v2
|
||||||
|
"""
|
||||||
|
return relu(x).square()
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, shapeless=True)
|
@partial(mx.compile, shapeless=True)
|
||||||
def softmax(x, axis=-1):
|
def softmax(x, axis=-1):
|
||||||
r"""Applies the Softmax function.
|
r"""Applies the Softmax function.
|
||||||
@ -420,6 +431,18 @@ class ReLU6(Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@_make_activation_module(relu_squared)
|
||||||
|
class ReLUSquared(Module):
|
||||||
|
r"""Applies the Rectified Linear Unit squared.
|
||||||
|
|
||||||
|
Applies :math:`\max(x, 0)^2` element wise.
|
||||||
|
|
||||||
|
Reference: https://arxiv.org/abs/2109.08668v2
|
||||||
|
|
||||||
|
See :func:`relu_squared` for the functional equivalent.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
@_make_activation_module(softmax)
|
@_make_activation_module(softmax)
|
||||||
class Softmax(Module):
|
class Softmax(Module):
|
||||||
r"""Applies the Softmax function.
|
r"""Applies the Softmax function.
|
||||||
|
@ -855,6 +855,13 @@ class TestLayers(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(y.shape, (3,))
|
self.assertEqual(y.shape, (3,))
|
||||||
self.assertEqual(y.dtype, mx.float32)
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
|
def test_relu_squared(self):
|
||||||
|
x = mx.array([-1.0, 0.0, 1.0, 2.0, 3.0])
|
||||||
|
y = nn.relu_squared(x)
|
||||||
|
self.assertTrue(mx.array_equal(y, mx.array([0.0, 0.0, 1.0, 4.0, 9.0])))
|
||||||
|
self.assertEqual(y.shape, (5,))
|
||||||
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
def test_leaky_relu(self):
|
def test_leaky_relu(self):
|
||||||
x = mx.array([1.0, -1.0, 0.0])
|
x = mx.array([1.0, -1.0, 0.0])
|
||||||
y = nn.leaky_relu(x)
|
y = nn.leaky_relu(x)
|
||||||
|
Loading…
Reference in New Issue
Block a user