mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Compare commits
7 Commits
8f04b2f6ac
...
7170e5f40b
Author | SHA1 | Date | |
---|---|---|---|
![]() |
7170e5f40b | ||
![]() |
cad5c0241c | ||
![]() |
b3c1aaafd2 | ||
![]() |
989e8bab66 | ||
![]() |
fe0672a9d2 | ||
![]() |
cbd353bf73 | ||
![]() |
940f64fe6a |
@ -224,6 +224,13 @@ def relu6(x):
|
|||||||
mx.eval(y)
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
|
def relu_squared(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = nn.relu_squared(y)
|
||||||
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
def softplus(x):
|
def softplus(x):
|
||||||
y = x
|
y = x
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
@ -458,6 +465,9 @@ if __name__ == "__main__":
|
|||||||
elif args.benchmark == "relu6":
|
elif args.benchmark == "relu6":
|
||||||
print(bench(relu6, x))
|
print(bench(relu6, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "relu_squared":
|
||||||
|
print(bench(relu_squared, x))
|
||||||
|
|
||||||
elif args.benchmark == "celu":
|
elif args.benchmark == "celu":
|
||||||
print(bench(celu, x))
|
print(bench(celu, x))
|
||||||
|
|
||||||
|
@ -157,6 +157,15 @@ def relu6(x):
|
|||||||
sync_if_needed(x)
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def relu_squared(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = torch.nn.functional.relu(y)
|
||||||
|
y = torch.square(y)
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def softplus(x):
|
def softplus(x):
|
||||||
y = x
|
y = x
|
||||||
@ -407,6 +416,9 @@ if __name__ == "__main__":
|
|||||||
elif args.benchmark == "relu6":
|
elif args.benchmark == "relu6":
|
||||||
print(bench(relu6, x))
|
print(bench(relu6, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "relu_squared":
|
||||||
|
print(bench(relu_squared, x))
|
||||||
|
|
||||||
elif args.benchmark == "softplus":
|
elif args.benchmark == "softplus":
|
||||||
print(bench(softplus, x))
|
print(bench(softplus, x))
|
||||||
|
|
||||||
|
@ -207,6 +207,8 @@ if __name__ == "__main__":
|
|||||||
compare_filtered("elu --size 32x16x1024 --cpu")
|
compare_filtered("elu --size 32x16x1024 --cpu")
|
||||||
compare_filtered("relu6 --size 32x16x1024")
|
compare_filtered("relu6 --size 32x16x1024")
|
||||||
compare_filtered("relu6 --size 32x16x1024 --cpu")
|
compare_filtered("relu6 --size 32x16x1024 --cpu")
|
||||||
|
compare_filtered("relu_squared --size 32x16x1024")
|
||||||
|
compare_filtered("relu_squared --size 32x16x1024 --cpu")
|
||||||
compare_filtered("softplus --size 32x16x1024")
|
compare_filtered("softplus --size 32x16x1024")
|
||||||
compare_filtered("softplus --size 32x16x1024 --cpu")
|
compare_filtered("softplus --size 32x16x1024 --cpu")
|
||||||
compare_filtered("celu --size 32x16x1024")
|
compare_filtered("celu --size 32x16x1024")
|
||||||
|
@ -28,6 +28,7 @@ simple functions.
|
|||||||
prelu
|
prelu
|
||||||
relu
|
relu
|
||||||
relu6
|
relu6
|
||||||
|
relu_squared
|
||||||
selu
|
selu
|
||||||
sigmoid
|
sigmoid
|
||||||
silu
|
silu
|
||||||
|
@ -51,6 +51,7 @@ Layers
|
|||||||
RMSNorm
|
RMSNorm
|
||||||
ReLU
|
ReLU
|
||||||
ReLU6
|
ReLU6
|
||||||
|
ReLUSquared
|
||||||
RNN
|
RNN
|
||||||
RoPE
|
RoPE
|
||||||
SELU
|
SELU
|
||||||
|
@ -106,7 +106,6 @@ void CudaAllocator::cuda_free(void* buf) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cudaFree(buf);
|
cudaFree(buf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -63,25 +63,30 @@ void copy_general(
|
|||||||
MLX_SWITCH_BOOL(large, LARGE, {
|
MLX_SWITCH_BOOL(large, LARGE, {
|
||||||
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
|
size_t data_size = 1;
|
||||||
|
for (auto& s : shape)
|
||||||
|
data_size *= s;
|
||||||
if (ndim <= 3) {
|
if (ndim <= 3) {
|
||||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||||
auto kernel = cu::copy_gg_nd<InType, OutType, IdxT, NDIM>;
|
auto kernel = cu::copy_gg_nd<InType, OutType, IdxT, NDIM>;
|
||||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
auto [num_blocks, block_dims] =
|
||||||
|
get_launch_args(kernel, data_size, shape, out.strides(), large);
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
out.size(),
|
data_size,
|
||||||
const_param<NDIM>(shape),
|
const_param<NDIM>(shape),
|
||||||
const_param<NDIM>(strides_in),
|
const_param<NDIM>(strides_in),
|
||||||
const_param<NDIM>(strides_out));
|
const_param<NDIM>(strides_out));
|
||||||
});
|
});
|
||||||
} else { // ndim >= 4
|
} else { // ndim >= 4
|
||||||
auto kernel = cu::copy_gg<InType, OutType, IdxT>;
|
auto kernel = cu::copy_gg<InType, OutType, IdxT>;
|
||||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
auto [num_blocks, block_dims] =
|
||||||
|
get_launch_args(kernel, data_size, shape, out.strides(), large);
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
out.size(),
|
data_size,
|
||||||
const_param(shape),
|
const_param(shape),
|
||||||
const_param(strides_in),
|
const_param(strides_in),
|
||||||
const_param(strides_out),
|
const_param(strides_out),
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
|
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
#include <nvtx3/nvtx3.hpp>
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
#include <future>
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@ -107,6 +108,16 @@ void CommandEncoder::commit() {
|
|||||||
worker_.commit(stream_.last_cuda_stream());
|
worker_.commit(stream_.last_cuda_stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CommandEncoder::synchronize() {
|
||||||
|
stream().synchronize();
|
||||||
|
auto p = std::make_shared<std::promise<void>>();
|
||||||
|
std::future<void> f = p->get_future();
|
||||||
|
add_completed_handler([p = std::move(p)]() { p->set_value(); });
|
||||||
|
worker_.end_batch();
|
||||||
|
worker_.commit();
|
||||||
|
f.wait();
|
||||||
|
}
|
||||||
|
|
||||||
Device& device(mlx::core::Device device) {
|
Device& device(mlx::core::Device device) {
|
||||||
static std::unordered_map<int, Device> devices;
|
static std::unordered_map<int, Device> devices;
|
||||||
auto it = devices.find(device.index);
|
auto it = devices.find(device.index);
|
||||||
|
@ -123,6 +123,9 @@ class CommandEncoder {
|
|||||||
return has_gpu_work_;
|
return has_gpu_work_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Wait until kernels and completion handlers are finished
|
||||||
|
void synchronize();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Device& device_;
|
Device& device_;
|
||||||
DeviceStream& stream_;
|
DeviceStream& stream_;
|
||||||
|
@ -62,7 +62,7 @@ void finalize(Stream s) {
|
|||||||
|
|
||||||
void synchronize(Stream s) {
|
void synchronize(Stream s) {
|
||||||
nvtx3::scoped_range r("gpu::synchronize");
|
nvtx3::scoped_range r("gpu::synchronize");
|
||||||
cu::get_stream(s).synchronize();
|
cu::get_command_encoder(s).synchronize();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::gpu
|
} // namespace mlx::core::gpu
|
||||||
|
@ -80,7 +80,9 @@ void Worker::thread_fn() {
|
|||||||
}
|
}
|
||||||
worker_tasks_.erase(worker_tasks_.begin(), end);
|
worker_tasks_.erase(worker_tasks_.begin(), end);
|
||||||
}
|
}
|
||||||
for (auto& task : tasks) {
|
// Make sure tasks are cleared before the next wait
|
||||||
|
for (int i = 0; i < tasks.size(); ++i) {
|
||||||
|
auto task = std::move(tasks[i]);
|
||||||
task();
|
task();
|
||||||
}
|
}
|
||||||
worker_event_.wait(batch + 1);
|
worker_event_.wait(batch + 1);
|
||||||
|
@ -16,6 +16,7 @@ from mlx.nn.layers.activations import (
|
|||||||
PReLU,
|
PReLU,
|
||||||
ReLU,
|
ReLU,
|
||||||
ReLU6,
|
ReLU6,
|
||||||
|
ReLUSquared,
|
||||||
Sigmoid,
|
Sigmoid,
|
||||||
SiLU,
|
SiLU,
|
||||||
Softmax,
|
Softmax,
|
||||||
@ -41,6 +42,7 @@ from mlx.nn.layers.activations import (
|
|||||||
prelu,
|
prelu,
|
||||||
relu,
|
relu,
|
||||||
relu6,
|
relu6,
|
||||||
|
relu_squared,
|
||||||
selu,
|
selu,
|
||||||
sigmoid,
|
sigmoid,
|
||||||
silu,
|
silu,
|
||||||
|
@ -71,6 +71,17 @@ def relu6(x):
|
|||||||
return mx.minimum(mx.maximum(x, 0), 6.0)
|
return mx.minimum(mx.maximum(x, 0), 6.0)
|
||||||
|
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
|
def relu_squared(x):
|
||||||
|
r"""Applies the Rectified Linear Unit squared.
|
||||||
|
|
||||||
|
Applies :math:`\max(x, 0)^2` element wise.
|
||||||
|
|
||||||
|
Reference: https://arxiv.org/abs/2109.08668v2
|
||||||
|
"""
|
||||||
|
return relu(x).square()
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, shapeless=True)
|
@partial(mx.compile, shapeless=True)
|
||||||
def softmax(x, axis=-1):
|
def softmax(x, axis=-1):
|
||||||
r"""Applies the Softmax function.
|
r"""Applies the Softmax function.
|
||||||
@ -420,6 +431,18 @@ class ReLU6(Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@_make_activation_module(relu_squared)
|
||||||
|
class ReLUSquared(Module):
|
||||||
|
r"""Applies the Rectified Linear Unit squared.
|
||||||
|
|
||||||
|
Applies :math:`\max(x, 0)^2` element wise.
|
||||||
|
|
||||||
|
Reference: https://arxiv.org/abs/2109.08668v2
|
||||||
|
|
||||||
|
See :func:`relu_squared` for the functional equivalent.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
@_make_activation_module(softmax)
|
@_make_activation_module(softmax)
|
||||||
class Softmax(Module):
|
class Softmax(Module):
|
||||||
r"""Applies the Softmax function.
|
r"""Applies the Softmax function.
|
||||||
|
@ -6,7 +6,6 @@ cuda_skip = {
|
|||||||
"TestEinsum.test_ellipses",
|
"TestEinsum.test_ellipses",
|
||||||
"TestEinsum.test_opt_einsum_test_cases",
|
"TestEinsum.test_opt_einsum_test_cases",
|
||||||
"TestLoad.test_load_f8_e4m3",
|
"TestLoad.test_load_f8_e4m3",
|
||||||
"TestMemory.test_memory_info",
|
|
||||||
"TestLayers.test_group_norm",
|
"TestLayers.test_group_norm",
|
||||||
"TestLayers.test_pooling",
|
"TestLayers.test_pooling",
|
||||||
"TestLayers.test_quantized_embedding",
|
"TestLayers.test_quantized_embedding",
|
||||||
|
@ -855,6 +855,13 @@ class TestLayers(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(y.shape, (3,))
|
self.assertEqual(y.shape, (3,))
|
||||||
self.assertEqual(y.dtype, mx.float32)
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
|
def test_relu_squared(self):
|
||||||
|
x = mx.array([-1.0, 0.0, 1.0, 2.0, 3.0])
|
||||||
|
y = nn.relu_squared(x)
|
||||||
|
self.assertTrue(mx.array_equal(y, mx.array([0.0, 0.0, 1.0, 4.0, 9.0])))
|
||||||
|
self.assertEqual(y.shape, (5,))
|
||||||
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
def test_leaky_relu(self):
|
def test_leaky_relu(self):
|
||||||
x = mx.array([1.0, -1.0, 0.0])
|
x = mx.array([1.0, -1.0, 0.0])
|
||||||
y = nn.leaky_relu(x)
|
y = nn.leaky_relu(x)
|
||||||
|
Loading…
Reference in New Issue
Block a user