mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Compare commits
8 Commits
106568caa1
...
9d86a4d5ba
Author | SHA1 | Date | |
---|---|---|---|
![]() |
9d86a4d5ba | ||
![]() |
5adf185f86 | ||
![]() |
c9a9180584 | ||
![]() |
b3c1aaafd2 | ||
![]() |
989e8bab66 | ||
![]() |
fe0672a9d2 | ||
![]() |
cbd353bf73 | ||
![]() |
940f64fe6a |
@ -224,6 +224,13 @@ def relu6(x):
|
|||||||
mx.eval(y)
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
|
def relu_squared(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = nn.relu_squared(y)
|
||||||
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
def softplus(x):
|
def softplus(x):
|
||||||
y = x
|
y = x
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
@ -458,6 +465,9 @@ if __name__ == "__main__":
|
|||||||
elif args.benchmark == "relu6":
|
elif args.benchmark == "relu6":
|
||||||
print(bench(relu6, x))
|
print(bench(relu6, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "relu_squared":
|
||||||
|
print(bench(relu_squared, x))
|
||||||
|
|
||||||
elif args.benchmark == "celu":
|
elif args.benchmark == "celu":
|
||||||
print(bench(celu, x))
|
print(bench(celu, x))
|
||||||
|
|
||||||
|
@ -157,6 +157,15 @@ def relu6(x):
|
|||||||
sync_if_needed(x)
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def relu_squared(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = torch.nn.functional.relu(y)
|
||||||
|
y = torch.square(y)
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def softplus(x):
|
def softplus(x):
|
||||||
y = x
|
y = x
|
||||||
@ -407,6 +416,9 @@ if __name__ == "__main__":
|
|||||||
elif args.benchmark == "relu6":
|
elif args.benchmark == "relu6":
|
||||||
print(bench(relu6, x))
|
print(bench(relu6, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "relu_squared":
|
||||||
|
print(bench(relu_squared, x))
|
||||||
|
|
||||||
elif args.benchmark == "softplus":
|
elif args.benchmark == "softplus":
|
||||||
print(bench(softplus, x))
|
print(bench(softplus, x))
|
||||||
|
|
||||||
|
@ -207,6 +207,8 @@ if __name__ == "__main__":
|
|||||||
compare_filtered("elu --size 32x16x1024 --cpu")
|
compare_filtered("elu --size 32x16x1024 --cpu")
|
||||||
compare_filtered("relu6 --size 32x16x1024")
|
compare_filtered("relu6 --size 32x16x1024")
|
||||||
compare_filtered("relu6 --size 32x16x1024 --cpu")
|
compare_filtered("relu6 --size 32x16x1024 --cpu")
|
||||||
|
compare_filtered("relu_squared --size 32x16x1024")
|
||||||
|
compare_filtered("relu_squared --size 32x16x1024 --cpu")
|
||||||
compare_filtered("softplus --size 32x16x1024")
|
compare_filtered("softplus --size 32x16x1024")
|
||||||
compare_filtered("softplus --size 32x16x1024 --cpu")
|
compare_filtered("softplus --size 32x16x1024 --cpu")
|
||||||
compare_filtered("celu --size 32x16x1024")
|
compare_filtered("celu --size 32x16x1024")
|
||||||
|
@ -28,6 +28,7 @@ simple functions.
|
|||||||
prelu
|
prelu
|
||||||
relu
|
relu
|
||||||
relu6
|
relu6
|
||||||
|
relu_squared
|
||||||
selu
|
selu
|
||||||
sigmoid
|
sigmoid
|
||||||
silu
|
silu
|
||||||
|
@ -51,6 +51,7 @@ Layers
|
|||||||
RMSNorm
|
RMSNorm
|
||||||
ReLU
|
ReLU
|
||||||
ReLU6
|
ReLU6
|
||||||
|
ReLUSquared
|
||||||
RNN
|
RNN
|
||||||
RoPE
|
RoPE
|
||||||
SELU
|
SELU
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
#include "mlx/backend/cuda/allocator.h"
|
#include "mlx/backend/cuda/allocator.h"
|
||||||
#include "mlx/backend/cuda/utils.h"
|
#include "mlx/backend/cuda/utils.h"
|
||||||
#include "mlx/backend/cuda/worker.h"
|
#include "mlx/backend/cuda/worker.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
@ -14,9 +15,11 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace cu {
|
namespace cu {
|
||||||
|
|
||||||
|
constexpr int page_size = 16384;
|
||||||
|
|
||||||
CudaAllocator::CudaAllocator()
|
CudaAllocator::CudaAllocator()
|
||||||
: buffer_cache_(
|
: buffer_cache_(
|
||||||
getpagesize(),
|
page_size,
|
||||||
[](CudaBuffer* buf) { return buf->size; },
|
[](CudaBuffer* buf) { return buf->size; },
|
||||||
[this](CudaBuffer* buf) {
|
[this](CudaBuffer* buf) {
|
||||||
cuda_free(buf->data);
|
cuda_free(buf->data);
|
||||||
@ -31,7 +34,14 @@ CudaAllocator::CudaAllocator()
|
|||||||
|
|
||||||
Buffer CudaAllocator::malloc(size_t size) {
|
Buffer CudaAllocator::malloc(size_t size) {
|
||||||
// Find available buffer from cache.
|
// Find available buffer from cache.
|
||||||
|
auto orig_size = size;
|
||||||
std::unique_lock lock(mutex_);
|
std::unique_lock lock(mutex_);
|
||||||
|
if (size < page_size) {
|
||||||
|
size = next_power_of_2(size);
|
||||||
|
} else {
|
||||||
|
size = page_size * ((size + page_size - 1) / page_size);
|
||||||
|
}
|
||||||
|
|
||||||
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
|
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
|
||||||
if (!buf) {
|
if (!buf) {
|
||||||
// If we have a lot of memory pressure or are over the maximum cache size,
|
// If we have a lot of memory pressure or are over the maximum cache size,
|
||||||
|
@ -24,7 +24,6 @@ void copy_gpu_inplace(
|
|||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
if (ctype == CopyType::Scalar || ctype == CopyType::Vector) {
|
if (ctype == CopyType::Scalar || ctype == CopyType::Vector) {
|
||||||
copy_contiguous(encoder, ctype, in, out, offset_in, offset_out);
|
copy_contiguous(encoder, ctype, in, out, offset_in, offset_out);
|
||||||
return;
|
return;
|
||||||
|
@ -155,8 +155,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_nd(
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = NDIM - 1; i >= 0; --i) {
|
for (int i = NDIM - 1; i >= 0; --i) {
|
||||||
int dim_idx = elem % shape[i];
|
int dim_idx = elem % shape[i];
|
||||||
a_loc += dim_idx * a_strides[i];
|
a_loc += dim_idx * IdxT(a_strides[i]);
|
||||||
b_loc += dim_idx * b_strides[i];
|
b_loc += dim_idx * IdxT(b_strides[i]);
|
||||||
elem /= shape[i];
|
elem /= shape[i];
|
||||||
}
|
}
|
||||||
return cuda::std::make_tuple(a_loc, b_loc);
|
return cuda::std::make_tuple(a_loc, b_loc);
|
||||||
@ -175,9 +175,9 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_nd(
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = NDIM - 1; i >= 0; --i) {
|
for (int i = NDIM - 1; i >= 0; --i) {
|
||||||
int dim_idx = elem % shape[i];
|
int dim_idx = elem % shape[i];
|
||||||
a_loc += dim_idx * a_strides[i];
|
a_loc += dim_idx * IdxT(a_strides[i]);
|
||||||
b_loc += dim_idx * b_strides[i];
|
b_loc += dim_idx * IdxT(b_strides[i]);
|
||||||
c_loc += dim_idx * c_strides[i];
|
c_loc += dim_idx * IdxT(c_strides[i]);
|
||||||
elem /= shape[i];
|
elem /= shape[i];
|
||||||
}
|
}
|
||||||
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
||||||
@ -206,8 +206,8 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT> elem_to_loc_4d(
|
|||||||
IdxT b_loc = 0;
|
IdxT b_loc = 0;
|
||||||
for (int i = ndim - 1; i >= 0; --i) {
|
for (int i = ndim - 1; i >= 0; --i) {
|
||||||
int dim_idx = elem % shape[i];
|
int dim_idx = elem % shape[i];
|
||||||
a_loc += dim_idx * a_strides[i];
|
a_loc += dim_idx * IdxT(a_strides[i]);
|
||||||
b_loc += dim_idx * b_strides[i];
|
b_loc += dim_idx * IdxT(b_strides[i]);
|
||||||
elem /= shape[i];
|
elem /= shape[i];
|
||||||
}
|
}
|
||||||
return cuda::std::make_tuple(a_loc, b_loc);
|
return cuda::std::make_tuple(a_loc, b_loc);
|
||||||
@ -226,9 +226,9 @@ inline __host__ __device__ cuda::std::tuple<IdxT, IdxT, IdxT> elem_to_loc_4d(
|
|||||||
IdxT c_loc = 0;
|
IdxT c_loc = 0;
|
||||||
for (int i = ndim - 1; i >= 0; --i) {
|
for (int i = ndim - 1; i >= 0; --i) {
|
||||||
int dim_idx = elem % shape[i];
|
int dim_idx = elem % shape[i];
|
||||||
a_loc += dim_idx * a_strides[i];
|
a_loc += dim_idx * IdxT(a_strides[i]);
|
||||||
b_loc += dim_idx * b_strides[i];
|
b_loc += dim_idx * IdxT(b_strides[i]);
|
||||||
c_loc += dim_idx * c_strides[i];
|
c_loc += dim_idx * IdxT(c_strides[i]);
|
||||||
elem /= shape[i];
|
elem /= shape[i];
|
||||||
}
|
}
|
||||||
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
return cuda::std::make_tuple(a_loc, b_loc, c_loc);
|
||||||
|
@ -162,11 +162,15 @@ class MatMul {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
array workspace(
|
void* workspace_ptr = nullptr;
|
||||||
allocator::malloc(heuristic_.workspaceSize),
|
if (heuristic_.workspaceSize > 0) {
|
||||||
{static_cast<int>(heuristic_.workspaceSize)},
|
array workspace(
|
||||||
int8);
|
allocator::malloc(heuristic_.workspaceSize),
|
||||||
encoder.add_temporary(workspace);
|
{static_cast<int>(heuristic_.workspaceSize)},
|
||||||
|
int8);
|
||||||
|
encoder.add_temporary(workspace);
|
||||||
|
workspace_ptr = workspace.data<void>();
|
||||||
|
}
|
||||||
|
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
CHECK_CUBLAS_ERROR(cublasLtMatmul(
|
CHECK_CUBLAS_ERROR(cublasLtMatmul(
|
||||||
@ -183,8 +187,8 @@ class MatMul {
|
|||||||
out,
|
out,
|
||||||
out_desc_,
|
out_desc_,
|
||||||
&heuristic_.algo,
|
&heuristic_.algo,
|
||||||
workspace.data<void>(),
|
workspace_ptr,
|
||||||
workspace.nbytes(),
|
heuristic_.workspaceSize,
|
||||||
stream));
|
stream));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -358,9 +362,18 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
a_batch_strides.back(),
|
a_batch_strides.back(),
|
||||||
b_batch_strides.back());
|
b_batch_strides.back());
|
||||||
|
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
auto nbatch = batch_count / batch_shape.back();
|
||||||
|
if (nbatch == 1) {
|
||||||
|
matmul.run(encoder, out.data<int8_t>(), a.data<int8_t>(), b.data<int8_t>());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
||||||
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
||||||
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) {
|
for (size_t i = 0; i < nbatch; ++i) {
|
||||||
matmul.run(
|
matmul.run(
|
||||||
encoder,
|
encoder,
|
||||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
||||||
@ -444,10 +457,28 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
b_batch_strides.back(),
|
b_batch_strides.back(),
|
||||||
c_batch_strides.back());
|
c_batch_strides.back());
|
||||||
|
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_input_array(c);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
|
auto nbatch = batch_count / batch_shape.back();
|
||||||
|
if (nbatch == 1) {
|
||||||
|
matmul.run(
|
||||||
|
encoder,
|
||||||
|
out.data<int8_t>(),
|
||||||
|
a.data<int8_t>(),
|
||||||
|
b.data<int8_t>(),
|
||||||
|
c.data<int8_t>(),
|
||||||
|
alpha_,
|
||||||
|
beta_);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
||||||
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
||||||
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
|
ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1);
|
||||||
for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) {
|
for (size_t i = 0; i < nbatch; ++i) {
|
||||||
matmul.run(
|
matmul.run(
|
||||||
encoder,
|
encoder,
|
||||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
||||||
|
@ -79,9 +79,6 @@ void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) {
|
|||||||
void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||||
array out = out_;
|
array out = out_;
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
encoder.set_input_array(in);
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
|
|
||||||
if (axis < 0) {
|
if (axis < 0) {
|
||||||
axis += in.ndim();
|
axis += in.ndim();
|
||||||
}
|
}
|
||||||
@ -106,6 +103,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
|||||||
in.flags());
|
in.flags());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_output_array(out);
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
||||||
if constexpr (!std::is_same_v<CTYPE, complex64_t>) {
|
if constexpr (!std::is_same_v<CTYPE, complex64_t>) {
|
||||||
|
@ -16,6 +16,7 @@ from mlx.nn.layers.activations import (
|
|||||||
PReLU,
|
PReLU,
|
||||||
ReLU,
|
ReLU,
|
||||||
ReLU6,
|
ReLU6,
|
||||||
|
ReLUSquared,
|
||||||
Sigmoid,
|
Sigmoid,
|
||||||
SiLU,
|
SiLU,
|
||||||
Softmax,
|
Softmax,
|
||||||
@ -41,6 +42,7 @@ from mlx.nn.layers.activations import (
|
|||||||
prelu,
|
prelu,
|
||||||
relu,
|
relu,
|
||||||
relu6,
|
relu6,
|
||||||
|
relu_squared,
|
||||||
selu,
|
selu,
|
||||||
sigmoid,
|
sigmoid,
|
||||||
silu,
|
silu,
|
||||||
|
@ -71,6 +71,17 @@ def relu6(x):
|
|||||||
return mx.minimum(mx.maximum(x, 0), 6.0)
|
return mx.minimum(mx.maximum(x, 0), 6.0)
|
||||||
|
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
|
def relu_squared(x):
|
||||||
|
r"""Applies the Rectified Linear Unit squared.
|
||||||
|
|
||||||
|
Applies :math:`\max(x, 0)^2` element wise.
|
||||||
|
|
||||||
|
Reference: https://arxiv.org/abs/2109.08668v2
|
||||||
|
"""
|
||||||
|
return relu(x).square()
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, shapeless=True)
|
@partial(mx.compile, shapeless=True)
|
||||||
def softmax(x, axis=-1):
|
def softmax(x, axis=-1):
|
||||||
r"""Applies the Softmax function.
|
r"""Applies the Softmax function.
|
||||||
@ -420,6 +431,18 @@ class ReLU6(Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@_make_activation_module(relu_squared)
|
||||||
|
class ReLUSquared(Module):
|
||||||
|
r"""Applies the Rectified Linear Unit squared.
|
||||||
|
|
||||||
|
Applies :math:`\max(x, 0)^2` element wise.
|
||||||
|
|
||||||
|
Reference: https://arxiv.org/abs/2109.08668v2
|
||||||
|
|
||||||
|
See :func:`relu_squared` for the functional equivalent.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
@_make_activation_module(softmax)
|
@_make_activation_module(softmax)
|
||||||
class Softmax(Module):
|
class Softmax(Module):
|
||||||
r"""Applies the Softmax function.
|
r"""Applies the Softmax function.
|
||||||
|
@ -413,7 +413,7 @@ class Module(dict):
|
|||||||
f'Module does not have sub-module named "{k}".'
|
f'Module does not have sub-module named "{k}".'
|
||||||
)
|
)
|
||||||
elif isinstance(modules, list):
|
elif isinstance(modules, list):
|
||||||
for i in range(len(dst)):
|
for i in range(len(modules)):
|
||||||
current_value = dst[i]
|
current_value = dst[i]
|
||||||
new_value = modules[i]
|
new_value = modules[i]
|
||||||
if self.is_module(current_value) and self.is_module(new_value):
|
if self.is_module(current_value) and self.is_module(new_value):
|
||||||
|
@ -259,6 +259,11 @@ class TestBase(mlx_tests.MLXTestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
m = m.update_modules({"list": ["hi"]})
|
m = m.update_modules({"list": ["hi"]})
|
||||||
|
|
||||||
|
# Allow updating a strict subset
|
||||||
|
m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
|
||||||
|
m.update_modules({"layers": [{}, nn.Linear(3, 4)]})
|
||||||
|
self.assertEqual(m.layers[1].weight.shape, (4, 3))
|
||||||
|
|
||||||
|
|
||||||
class TestLayers(mlx_tests.MLXTestCase):
|
class TestLayers(mlx_tests.MLXTestCase):
|
||||||
def test_identity(self):
|
def test_identity(self):
|
||||||
@ -855,6 +860,13 @@ class TestLayers(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(y.shape, (3,))
|
self.assertEqual(y.shape, (3,))
|
||||||
self.assertEqual(y.dtype, mx.float32)
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
|
def test_relu_squared(self):
|
||||||
|
x = mx.array([-1.0, 0.0, 1.0, 2.0, 3.0])
|
||||||
|
y = nn.relu_squared(x)
|
||||||
|
self.assertTrue(mx.array_equal(y, mx.array([0.0, 0.0, 1.0, 4.0, 9.0])))
|
||||||
|
self.assertEqual(y.shape, (5,))
|
||||||
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
def test_leaky_relu(self):
|
def test_leaky_relu(self):
|
||||||
x = mx.array([1.0, -1.0, 0.0])
|
x = mx.array([1.0, -1.0, 0.0])
|
||||||
y = nn.leaky_relu(x)
|
y = nn.leaky_relu(x)
|
||||||
|
Loading…
Reference in New Issue
Block a user