mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Compare commits
7 Commits
0ac9e9f691
...
f26c7a2e4f
Author | SHA1 | Date | |
---|---|---|---|
![]() |
f26c7a2e4f | ||
![]() |
c552ff2451 | ||
![]() |
b3c1aaafd2 | ||
![]() |
989e8bab66 | ||
![]() |
fe0672a9d2 | ||
![]() |
cbd353bf73 | ||
![]() |
940f64fe6a |
@ -224,6 +224,13 @@ def relu6(x):
|
|||||||
mx.eval(y)
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
|
def relu_squared(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = nn.relu_squared(y)
|
||||||
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
def softplus(x):
|
def softplus(x):
|
||||||
y = x
|
y = x
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
@ -458,6 +465,9 @@ if __name__ == "__main__":
|
|||||||
elif args.benchmark == "relu6":
|
elif args.benchmark == "relu6":
|
||||||
print(bench(relu6, x))
|
print(bench(relu6, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "relu_squared":
|
||||||
|
print(bench(relu_squared, x))
|
||||||
|
|
||||||
elif args.benchmark == "celu":
|
elif args.benchmark == "celu":
|
||||||
print(bench(celu, x))
|
print(bench(celu, x))
|
||||||
|
|
||||||
|
@ -157,6 +157,15 @@ def relu6(x):
|
|||||||
sync_if_needed(x)
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def relu_squared(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = torch.nn.functional.relu(y)
|
||||||
|
y = torch.square(y)
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def softplus(x):
|
def softplus(x):
|
||||||
y = x
|
y = x
|
||||||
@ -407,6 +416,9 @@ if __name__ == "__main__":
|
|||||||
elif args.benchmark == "relu6":
|
elif args.benchmark == "relu6":
|
||||||
print(bench(relu6, x))
|
print(bench(relu6, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "relu_squared":
|
||||||
|
print(bench(relu_squared, x))
|
||||||
|
|
||||||
elif args.benchmark == "softplus":
|
elif args.benchmark == "softplus":
|
||||||
print(bench(softplus, x))
|
print(bench(softplus, x))
|
||||||
|
|
||||||
|
@ -207,6 +207,8 @@ if __name__ == "__main__":
|
|||||||
compare_filtered("elu --size 32x16x1024 --cpu")
|
compare_filtered("elu --size 32x16x1024 --cpu")
|
||||||
compare_filtered("relu6 --size 32x16x1024")
|
compare_filtered("relu6 --size 32x16x1024")
|
||||||
compare_filtered("relu6 --size 32x16x1024 --cpu")
|
compare_filtered("relu6 --size 32x16x1024 --cpu")
|
||||||
|
compare_filtered("relu_squared --size 32x16x1024")
|
||||||
|
compare_filtered("relu_squared --size 32x16x1024 --cpu")
|
||||||
compare_filtered("softplus --size 32x16x1024")
|
compare_filtered("softplus --size 32x16x1024")
|
||||||
compare_filtered("softplus --size 32x16x1024 --cpu")
|
compare_filtered("softplus --size 32x16x1024 --cpu")
|
||||||
compare_filtered("celu --size 32x16x1024")
|
compare_filtered("celu --size 32x16x1024")
|
||||||
|
@ -28,6 +28,7 @@ simple functions.
|
|||||||
prelu
|
prelu
|
||||||
relu
|
relu
|
||||||
relu6
|
relu6
|
||||||
|
relu_squared
|
||||||
selu
|
selu
|
||||||
sigmoid
|
sigmoid
|
||||||
silu
|
silu
|
||||||
|
@ -51,6 +51,7 @@ Layers
|
|||||||
RMSNorm
|
RMSNorm
|
||||||
ReLU
|
ReLU
|
||||||
ReLU6
|
ReLU6
|
||||||
|
ReLUSquared
|
||||||
RNN
|
RNN
|
||||||
RoPE
|
RoPE
|
||||||
SELU
|
SELU
|
||||||
|
@ -107,6 +107,16 @@ same array:
|
|||||||
>>> a
|
>>> a
|
||||||
array([1, 2, 0], dtype=int32)
|
array([1, 2, 0], dtype=int32)
|
||||||
|
|
||||||
|
|
||||||
|
Note, unlike NumPy, updates to the same location are nondeterministic:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> a = mx.array([1, 2, 3])
|
||||||
|
>>> a[[0, 0]] = mx.array([4, 5])
|
||||||
|
|
||||||
|
The first element of ``a`` could be ``4`` or ``5``.
|
||||||
|
|
||||||
Transformations of functions which use in-place updates are allowed and work as
|
Transformations of functions which use in-place updates are allowed and work as
|
||||||
expected. For example:
|
expected. For example:
|
||||||
|
|
||||||
|
@ -165,7 +165,7 @@ void binary_op_gpu_inplace(
|
|||||||
a.data<InType>(),
|
a.data<InType>(),
|
||||||
b.data<InType>(),
|
b.data<InType>(),
|
||||||
out.data<OutType>(),
|
out.data<OutType>(),
|
||||||
out.data_size(),
|
out.size(),
|
||||||
const_param<NDIM>(shape),
|
const_param<NDIM>(shape),
|
||||||
const_param<NDIM>(a_strides),
|
const_param<NDIM>(a_strides),
|
||||||
const_param<NDIM>(b_strides));
|
const_param<NDIM>(b_strides));
|
||||||
@ -178,7 +178,7 @@ void binary_op_gpu_inplace(
|
|||||||
a.data<InType>(),
|
a.data<InType>(),
|
||||||
b.data<InType>(),
|
b.data<InType>(),
|
||||||
out.data<OutType>(),
|
out.data<OutType>(),
|
||||||
out.data_size(),
|
out.size(),
|
||||||
const_param(shape),
|
const_param(shape),
|
||||||
const_param(a_strides),
|
const_param(a_strides),
|
||||||
const_param(b_strides),
|
const_param(b_strides),
|
||||||
@ -196,8 +196,8 @@ void binary_op_gpu_inplace(
|
|||||||
} else if (bopt == BinaryOpType::VectorVector) {
|
} else if (bopt == BinaryOpType::VectorVector) {
|
||||||
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
|
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
|
||||||
}
|
}
|
||||||
auto [num_blocks, block_dims] =
|
auto [num_blocks, block_dims] = get_launch_args(
|
||||||
get_launch_args(kernel, out, LARGE);
|
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
a.data<InType>(),
|
a.data<InType>(),
|
||||||
b.data<InType>(),
|
b.data<InType>(),
|
||||||
@ -264,7 +264,6 @@ BINARY_GPU(Add)
|
|||||||
BINARY_GPU(ArcTan2)
|
BINARY_GPU(ArcTan2)
|
||||||
BINARY_GPU(Divide)
|
BINARY_GPU(Divide)
|
||||||
BINARY_GPU(Remainder)
|
BINARY_GPU(Remainder)
|
||||||
BINARY_GPU(Equal)
|
|
||||||
BINARY_GPU(Greater)
|
BINARY_GPU(Greater)
|
||||||
BINARY_GPU(GreaterEqual)
|
BINARY_GPU(GreaterEqual)
|
||||||
BINARY_GPU(Less)
|
BINARY_GPU(Less)
|
||||||
@ -279,6 +278,17 @@ BINARY_GPU(NotEqual)
|
|||||||
BINARY_GPU(Power)
|
BINARY_GPU(Power)
|
||||||
BINARY_GPU(Subtract)
|
BINARY_GPU(Subtract)
|
||||||
|
|
||||||
|
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("Equal::eval_gpu");
|
||||||
|
auto& s = out.primitive().stream();
|
||||||
|
auto op = get_primitive_string(this);
|
||||||
|
if (equal_nan_) {
|
||||||
|
binary_op_gpu<cu::NaNEqual>(inputs, out, op, s);
|
||||||
|
} else {
|
||||||
|
binary_op_gpu<cu::Equal>(inputs, out, op, s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
|
nvtx3::scoped_range r("BitwiseBinary::eval_gpu");
|
||||||
auto& s = out.primitive().stream();
|
auto& s = out.primitive().stream();
|
||||||
|
@ -6,7 +6,7 @@
|
|||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
void copy_gpu_inplace(
|
void copy_gpu_inplace(
|
||||||
const array& in_,
|
const array& in,
|
||||||
array& out,
|
array& out,
|
||||||
const Shape& shape,
|
const Shape& shape,
|
||||||
const Strides& strides_in,
|
const Strides& strides_in,
|
||||||
@ -20,7 +20,6 @@ void copy_gpu_inplace(
|
|||||||
if (out.size() == 0) {
|
if (out.size() == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const array& in = in_.data_shared_ptr() ? in_ : out;
|
|
||||||
|
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
|
@ -10,20 +10,13 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \
|
#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \
|
||||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \
|
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \
|
||||||
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \
|
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \
|
||||||
using InType = cuda_type_t<CTYPE_IN>; \
|
using InType = cuda_type_t<CTYPE_IN>; \
|
||||||
using OutType = cuda_type_t<CTYPE_OUT>; \
|
using OutType = cuda_type_t<CTYPE_OUT>; \
|
||||||
if constexpr (cu::CastOp<InType, OutType>::is_castable) { \
|
__VA_ARGS__; \
|
||||||
__VA_ARGS__; \
|
}); \
|
||||||
} else { \
|
|
||||||
throw std::runtime_error(fmt::format( \
|
|
||||||
"Can not copy data from dtype {} to {}.", \
|
|
||||||
dtype_to_string(out.dtype()), \
|
|
||||||
dtype_to_string(in.dtype()))); \
|
|
||||||
} \
|
|
||||||
}); \
|
|
||||||
})
|
})
|
||||||
|
|
||||||
void copy_contiguous(
|
void copy_contiguous(
|
||||||
|
@ -43,7 +43,8 @@ void copy_contiguous(
|
|||||||
if (ctype == CopyType::Vector) {
|
if (ctype == CopyType::Vector) {
|
||||||
kernel = cu::copy_v<InType, OutType, IdxT>;
|
kernel = cu::copy_v<InType, OutType, IdxT>;
|
||||||
}
|
}
|
||||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, LARGE);
|
auto [num_blocks, block_dims] = get_launch_args(
|
||||||
|
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
in.data<InType>() + in_offset,
|
in.data<InType>() + in_offset,
|
||||||
out.data<OutType>() + out_offset,
|
out.data<OutType>() + out_offset,
|
||||||
|
@ -59,9 +59,9 @@ void copy_general(
|
|||||||
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
||||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||||
bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
|
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
||||||
MLX_SWITCH_BOOL(large, LARGE, {
|
MLX_SWITCH_BOOL(large, LARGE, {
|
||||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
if (ndim <= 3) {
|
if (ndim <= 3) {
|
||||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||||
@ -70,7 +70,7 @@ void copy_general(
|
|||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
out.data_size(),
|
out.size(),
|
||||||
const_param<NDIM>(shape),
|
const_param<NDIM>(shape),
|
||||||
const_param<NDIM>(strides_in),
|
const_param<NDIM>(strides_in),
|
||||||
const_param<NDIM>(strides_out));
|
const_param<NDIM>(strides_out));
|
||||||
@ -81,7 +81,7 @@ void copy_general(
|
|||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
out.data_size(),
|
out.size(),
|
||||||
const_param(shape),
|
const_param(shape),
|
||||||
const_param(strides_in),
|
const_param(strides_in),
|
||||||
const_param(strides_out),
|
const_param(strides_out),
|
||||||
|
@ -65,9 +65,9 @@ void copy_general_dynamic(
|
|||||||
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
||||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||||
bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
|
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
||||||
MLX_SWITCH_BOOL(large, LARGE, {
|
MLX_SWITCH_BOOL(large, LARGE, {
|
||||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
if (ndim <= 3) {
|
if (ndim <= 3) {
|
||||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||||
@ -76,7 +76,7 @@ void copy_general_dynamic(
|
|||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
out.data_size(),
|
out.size(),
|
||||||
const_param<NDIM>(shape),
|
const_param<NDIM>(shape),
|
||||||
const_param<NDIM>(strides_in),
|
const_param<NDIM>(strides_in),
|
||||||
const_param<NDIM>(strides_out),
|
const_param<NDIM>(strides_out),
|
||||||
@ -89,7 +89,7 @@ void copy_general_dynamic(
|
|||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
out.data_size(),
|
out.size(),
|
||||||
const_param(shape),
|
const_param(shape),
|
||||||
const_param(strides_in),
|
const_param(strides_in),
|
||||||
const_param(strides_out),
|
const_param(strides_out),
|
||||||
|
@ -54,9 +54,9 @@ void copy_general_input(
|
|||||||
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
||||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||||
bool large = in.data_size() > UINT32_MAX || out.data_size() > UINT32_MAX;
|
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
||||||
MLX_SWITCH_BOOL(large, LARGE, {
|
MLX_SWITCH_BOOL(large, LARGE, {
|
||||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
if (ndim <= 3) {
|
if (ndim <= 3) {
|
||||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
||||||
@ -65,7 +65,7 @@ void copy_general_input(
|
|||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
out.data_size(),
|
out.size(),
|
||||||
const_param<NDIM>(shape),
|
const_param<NDIM>(shape),
|
||||||
const_param<NDIM>(strides_in));
|
const_param<NDIM>(strides_in));
|
||||||
});
|
});
|
||||||
@ -75,7 +75,7 @@ void copy_general_input(
|
|||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
out.data_size(),
|
out.size(),
|
||||||
const_param(shape),
|
const_param(shape),
|
||||||
const_param(strides_in),
|
const_param(strides_in),
|
||||||
ndim);
|
ndim);
|
||||||
|
@ -45,6 +45,18 @@ struct CastOp<
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename SrcT, typename DstT>
|
||||||
|
struct CastOp<
|
||||||
|
SrcT,
|
||||||
|
DstT,
|
||||||
|
cuda::std::enable_if_t<cuda::std::is_same_v<SrcT, DstT>>> {
|
||||||
|
static constexpr bool is_castable = true;
|
||||||
|
|
||||||
|
__device__ SrcT operator()(SrcT x) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Return an iterator that cast the value to DstT using CastOp.
|
// Return an iterator that cast the value to DstT using CastOp.
|
||||||
template <typename DstT, typename Iterator>
|
template <typename DstT, typename Iterator>
|
||||||
__host__ __device__ auto make_cast_iterator(Iterator it) {
|
__host__ __device__ auto make_cast_iterator(Iterator it) {
|
||||||
|
@ -5,6 +5,8 @@
|
|||||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||||
#include "mlx/backend/cuda/device/utils.cuh"
|
#include "mlx/backend/cuda/device/utils.cuh"
|
||||||
|
|
||||||
|
#include <math_constants.h>
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
struct Abs {
|
struct Abs {
|
||||||
@ -183,21 +185,38 @@ struct Imag {
|
|||||||
struct Log {
|
struct Log {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
return log(x);
|
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||||
|
auto r = log(cuCrealf(Abs{}(x)));
|
||||||
|
auto i = atan2f(cuCimagf(x), cuCrealf(x));
|
||||||
|
return {r, i};
|
||||||
|
} else {
|
||||||
|
return log(x);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Log2 {
|
struct Log2 {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
return log2(x);
|
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||||
|
auto y = Log{}(x);
|
||||||
|
return {cuCrealf(y) / CUDART_LN2_F, cuCimagf(y) / CUDART_LN2_F};
|
||||||
|
} else {
|
||||||
|
return log2(x);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Log10 {
|
struct Log10 {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
return log10(x);
|
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||||
|
auto y = Log{}(x);
|
||||||
|
return {cuCrealf(y) / CUDART_LNT_F, cuCimagf(y) / CUDART_LNT_F};
|
||||||
|
return y;
|
||||||
|
} else {
|
||||||
|
return log10(x);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -102,6 +102,11 @@ inline constexpr bool is_floating_v =
|
|||||||
cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double> ||
|
cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double> ||
|
||||||
cuda::std::is_same_v<T, float16_t> || cuda::std::is_same_v<T, bfloat16_t>;
|
cuda::std::is_same_v<T, float16_t> || cuda::std::is_same_v<T, bfloat16_t>;
|
||||||
|
|
||||||
|
// Type traits for detecting complex or real floating point numbers.
|
||||||
|
template <typename T>
|
||||||
|
inline constexpr bool is_inexact_v =
|
||||||
|
is_floating_v<T> || cuda::std::is_same_v<T, complex64_t>;
|
||||||
|
|
||||||
// Utility to copy data from vector to array in host.
|
// Utility to copy data from vector to array in host.
|
||||||
template <int NDIM = MAX_NDIM, typename T = int32_t>
|
template <int NDIM = MAX_NDIM, typename T = int32_t>
|
||||||
inline cuda::std::array<T, NDIM> const_param(const std::vector<T>& vec) {
|
inline cuda::std::array<T, NDIM> const_param(const std::vector<T>& vec) {
|
||||||
@ -136,17 +141,19 @@ inline uint max_occupancy_block_dim(T kernel) {
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
inline std::tuple<dim3, uint> get_launch_args(
|
inline std::tuple<dim3, uint> get_launch_args(
|
||||||
T kernel,
|
T kernel,
|
||||||
const array& arr,
|
size_t size,
|
||||||
|
const Shape& shape,
|
||||||
|
const Strides& strides,
|
||||||
bool large,
|
bool large,
|
||||||
int work_per_thread = 1) {
|
int work_per_thread = 1) {
|
||||||
size_t nthreads = cuda::ceil_div(arr.size(), work_per_thread);
|
size_t nthreads = cuda::ceil_div(size, work_per_thread);
|
||||||
uint block_dim = max_occupancy_block_dim(kernel);
|
uint block_dim = max_occupancy_block_dim(kernel);
|
||||||
if (block_dim > nthreads) {
|
if (block_dim > nthreads) {
|
||||||
block_dim = nthreads;
|
block_dim = nthreads;
|
||||||
}
|
}
|
||||||
dim3 num_blocks;
|
dim3 num_blocks;
|
||||||
if (large) {
|
if (large) {
|
||||||
num_blocks = get_2d_grid_dims(arr.shape(), arr.strides(), work_per_thread);
|
num_blocks = get_2d_grid_dims(shape, strides, work_per_thread);
|
||||||
num_blocks.x = cuda::ceil_div(num_blocks.x, block_dim);
|
num_blocks.x = cuda::ceil_div(num_blocks.x, block_dim);
|
||||||
} else {
|
} else {
|
||||||
num_blocks.x = cuda::ceil_div(nthreads, block_dim);
|
num_blocks.x = cuda::ceil_div(nthreads, block_dim);
|
||||||
@ -154,4 +161,14 @@ inline std::tuple<dim3, uint> get_launch_args(
|
|||||||
return std::make_tuple(num_blocks, block_dim);
|
return std::make_tuple(num_blocks, block_dim);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline std::tuple<dim3, uint> get_launch_args(
|
||||||
|
T kernel,
|
||||||
|
const array& arr,
|
||||||
|
bool large,
|
||||||
|
int work_per_thread = 1) {
|
||||||
|
return get_launch_args(
|
||||||
|
kernel, arr.size(), arr.shape(), arr.strides(), large, work_per_thread);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -116,7 +116,7 @@ void ternary_op_gpu_inplace(
|
|||||||
b.data<DType>(),
|
b.data<DType>(),
|
||||||
c.data<DType>(),
|
c.data<DType>(),
|
||||||
out.data<DType>(),
|
out.data<DType>(),
|
||||||
out.data_size(),
|
out.size(),
|
||||||
const_param<NDIM>(shape),
|
const_param<NDIM>(shape),
|
||||||
const_param<NDIM>(a_strides),
|
const_param<NDIM>(a_strides),
|
||||||
const_param<NDIM>(b_strides),
|
const_param<NDIM>(b_strides),
|
||||||
@ -142,7 +142,8 @@ void ternary_op_gpu_inplace(
|
|||||||
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
|
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
|
||||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
||||||
auto kernel = cu::ternary_v<Op, DType, IdxT>;
|
auto kernel = cu::ternary_v<Op, DType, IdxT>;
|
||||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, LARGE);
|
auto [num_blocks, block_dims] = get_launch_args(
|
||||||
|
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
a.data<bool>(),
|
a.data<bool>(),
|
||||||
b.data<DType>(),
|
b.data<DType>(),
|
||||||
|
@ -28,11 +28,14 @@ constexpr bool supports_unary_op() {
|
|||||||
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, ArcTanh> ||
|
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, ArcTanh> ||
|
||||||
std::is_same_v<Op, Erf> || std::is_same_v<Op, ErfInv> ||
|
std::is_same_v<Op, Erf> || std::is_same_v<Op, ErfInv> ||
|
||||||
std::is_same_v<Op, Expm1> || std::is_same_v<Op, Log1p> ||
|
std::is_same_v<Op, Expm1> || std::is_same_v<Op, Log1p> ||
|
||||||
std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
|
std::is_same_v<Op, Sigmoid> || std::is_same_v<Op, Sqrt> ||
|
||||||
std::is_same_v<Op, Log10> || std::is_same_v<Op, Sigmoid> ||
|
std::is_same_v<Op, Rsqrt>) {
|
||||||
std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Rsqrt>) {
|
|
||||||
return std::is_same_v<In, Out> && is_floating_v<In>;
|
return std::is_same_v<In, Out> && is_floating_v<In>;
|
||||||
}
|
}
|
||||||
|
if (std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
|
||||||
|
std::is_same_v<Op, Log10>) {
|
||||||
|
return std::is_same_v<In, Out> && is_inexact_v<In>;
|
||||||
|
}
|
||||||
if (std::is_same_v<Op, BitwiseInvert>) {
|
if (std::is_same_v<Op, BitwiseInvert>) {
|
||||||
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
|
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
|
||||||
!std::is_same_v<In, bool>;
|
!std::is_same_v<In, bool>;
|
||||||
@ -91,7 +94,7 @@ void unary_op_gpu_inplace(
|
|||||||
} else {
|
} else {
|
||||||
auto [shape, strides] = collapse_contiguous_dims(in);
|
auto [shape, strides] = collapse_contiguous_dims(in);
|
||||||
auto [in_begin, in_end] = cu::make_general_iterators<int64_t>(
|
auto [in_begin, in_end] = cu::make_general_iterators<int64_t>(
|
||||||
in_ptr, in.data_size(), shape, strides);
|
in_ptr, in.size(), shape, strides);
|
||||||
thrust::transform(policy, in_begin, in_end, out_ptr, Op());
|
thrust::transform(policy, in_begin, in_end, out_ptr, Op());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -16,6 +16,7 @@ from mlx.nn.layers.activations import (
|
|||||||
PReLU,
|
PReLU,
|
||||||
ReLU,
|
ReLU,
|
||||||
ReLU6,
|
ReLU6,
|
||||||
|
ReLUSquared,
|
||||||
Sigmoid,
|
Sigmoid,
|
||||||
SiLU,
|
SiLU,
|
||||||
Softmax,
|
Softmax,
|
||||||
@ -41,6 +42,7 @@ from mlx.nn.layers.activations import (
|
|||||||
prelu,
|
prelu,
|
||||||
relu,
|
relu,
|
||||||
relu6,
|
relu6,
|
||||||
|
relu_squared,
|
||||||
selu,
|
selu,
|
||||||
sigmoid,
|
sigmoid,
|
||||||
silu,
|
silu,
|
||||||
|
@ -71,6 +71,17 @@ def relu6(x):
|
|||||||
return mx.minimum(mx.maximum(x, 0), 6.0)
|
return mx.minimum(mx.maximum(x, 0), 6.0)
|
||||||
|
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
|
def relu_squared(x):
|
||||||
|
r"""Applies the Rectified Linear Unit squared.
|
||||||
|
|
||||||
|
Applies :math:`\max(x, 0)^2` element wise.
|
||||||
|
|
||||||
|
Reference: https://arxiv.org/abs/2109.08668v2
|
||||||
|
"""
|
||||||
|
return relu(x).square()
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, shapeless=True)
|
@partial(mx.compile, shapeless=True)
|
||||||
def softmax(x, axis=-1):
|
def softmax(x, axis=-1):
|
||||||
r"""Applies the Softmax function.
|
r"""Applies the Softmax function.
|
||||||
@ -420,6 +431,18 @@ class ReLU6(Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@_make_activation_module(relu_squared)
|
||||||
|
class ReLUSquared(Module):
|
||||||
|
r"""Applies the Rectified Linear Unit squared.
|
||||||
|
|
||||||
|
Applies :math:`\max(x, 0)^2` element wise.
|
||||||
|
|
||||||
|
Reference: https://arxiv.org/abs/2109.08668v2
|
||||||
|
|
||||||
|
See :func:`relu_squared` for the functional equivalent.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
@_make_activation_module(softmax)
|
@_make_activation_module(softmax)
|
||||||
class Softmax(Module):
|
class Softmax(Module):
|
||||||
r"""Applies the Softmax function.
|
r"""Applies the Softmax function.
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
cuda_skip = {
|
cuda_skip = {
|
||||||
"TestArray.test_api",
|
"TestArray.test_api",
|
||||||
"TestArray.test_setitem",
|
|
||||||
"TestAutograd.test_cumprod_grad",
|
"TestAutograd.test_cumprod_grad",
|
||||||
"TestAutograd.test_slice_grads",
|
"TestAutograd.test_slice_grads",
|
||||||
"TestAutograd.test_split_against_slice",
|
"TestAutograd.test_split_against_slice",
|
||||||
@ -51,7 +50,6 @@ cuda_skip = {
|
|||||||
"TestEinsum.test_opt_einsum_test_cases",
|
"TestEinsum.test_opt_einsum_test_cases",
|
||||||
"TestEval.test_multi_output_eval_during_transform",
|
"TestEval.test_multi_output_eval_during_transform",
|
||||||
"TestExportImport.test_export_conv",
|
"TestExportImport.test_export_conv",
|
||||||
"TestFast.test_rope_grad",
|
|
||||||
"TestFFT.test_fft",
|
"TestFFT.test_fft",
|
||||||
"TestFFT.test_fft_big_powers_of_two",
|
"TestFFT.test_fft_big_powers_of_two",
|
||||||
"TestFFT.test_fft_contiguity",
|
"TestFFT.test_fft_contiguity",
|
||||||
@ -89,9 +87,6 @@ cuda_skip = {
|
|||||||
"TestOps.test_argpartition",
|
"TestOps.test_argpartition",
|
||||||
"TestOps.test_array_equal",
|
"TestOps.test_array_equal",
|
||||||
"TestOps.test_as_strided",
|
"TestOps.test_as_strided",
|
||||||
"TestOps.test_atleast_1d",
|
|
||||||
"TestOps.test_atleast_2d",
|
|
||||||
"TestOps.test_atleast_3d",
|
|
||||||
"TestOps.test_binary_ops",
|
"TestOps.test_binary_ops",
|
||||||
"TestOps.test_bitwise_grad",
|
"TestOps.test_bitwise_grad",
|
||||||
"TestOps.test_complex_ops",
|
"TestOps.test_complex_ops",
|
||||||
@ -100,22 +95,16 @@ cuda_skip = {
|
|||||||
"TestOps.test_hadamard",
|
"TestOps.test_hadamard",
|
||||||
"TestOps.test_hadamard_grad_vmap",
|
"TestOps.test_hadamard_grad_vmap",
|
||||||
"TestOps.test_irregular_binary_ops",
|
"TestOps.test_irregular_binary_ops",
|
||||||
"TestOps.test_isfinite",
|
|
||||||
"TestOps.test_kron",
|
"TestOps.test_kron",
|
||||||
"TestOps.test_log",
|
|
||||||
"TestOps.test_log10",
|
|
||||||
"TestOps.test_log1p",
|
"TestOps.test_log1p",
|
||||||
"TestOps.test_log2",
|
|
||||||
"TestOps.test_logaddexp",
|
"TestOps.test_logaddexp",
|
||||||
"TestOps.test_logcumsumexp",
|
"TestOps.test_logcumsumexp",
|
||||||
"TestOps.test_partition",
|
"TestOps.test_partition",
|
||||||
"TestOps.test_scans",
|
"TestOps.test_scans",
|
||||||
"TestOps.test_slice_update_reversed",
|
|
||||||
"TestOps.test_softmax",
|
"TestOps.test_softmax",
|
||||||
"TestOps.test_sort",
|
"TestOps.test_sort",
|
||||||
"TestOps.test_tensordot",
|
"TestOps.test_tensordot",
|
||||||
"TestOps.test_tile",
|
"TestOps.test_tile",
|
||||||
"TestOps.test_view",
|
|
||||||
"TestQuantized.test_gather_matmul_grad",
|
"TestQuantized.test_gather_matmul_grad",
|
||||||
"TestQuantized.test_gather_qmm",
|
"TestQuantized.test_gather_qmm",
|
||||||
"TestQuantized.test_gather_qmm_sorted",
|
"TestQuantized.test_gather_qmm_sorted",
|
||||||
@ -136,7 +125,6 @@ cuda_skip = {
|
|||||||
"TestReduce.test_expand_sums",
|
"TestReduce.test_expand_sums",
|
||||||
"TestReduce.test_many_reduction_axes",
|
"TestReduce.test_many_reduction_axes",
|
||||||
"TestUpsample.test_torch_upsample",
|
"TestUpsample.test_torch_upsample",
|
||||||
"TestVmap.test_unary",
|
|
||||||
"TestVmap.test_vmap_conv",
|
"TestVmap.test_vmap_conv",
|
||||||
"TestVmap.test_vmap_inverse",
|
"TestVmap.test_vmap_inverse",
|
||||||
"TestVmap.test_vmap_svd",
|
"TestVmap.test_vmap_svd",
|
||||||
|
@ -1187,7 +1187,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1]))
|
check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1]))
|
||||||
check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1]))
|
check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1]))
|
||||||
check_slices(
|
check_slices(
|
||||||
np.zeros((3, 2)), np.array([[3, 3], [4, 4], [5, 5]]), np.array([0, 0, 1])
|
np.zeros((3, 2)), np.array([[3, 3], [4, 4], [5, 5]]), np.array([0, 2, 1])
|
||||||
)
|
)
|
||||||
|
|
||||||
# Multiple slices
|
# Multiple slices
|
||||||
|
@ -855,6 +855,13 @@ class TestLayers(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(y.shape, (3,))
|
self.assertEqual(y.shape, (3,))
|
||||||
self.assertEqual(y.dtype, mx.float32)
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
|
def test_relu_squared(self):
|
||||||
|
x = mx.array([-1.0, 0.0, 1.0, 2.0, 3.0])
|
||||||
|
y = nn.relu_squared(x)
|
||||||
|
self.assertTrue(mx.array_equal(y, mx.array([0.0, 0.0, 1.0, 4.0, 9.0])))
|
||||||
|
self.assertEqual(y.shape, (5,))
|
||||||
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
def test_leaky_relu(self):
|
def test_leaky_relu(self):
|
||||||
x = mx.array([1.0, -1.0, 0.0])
|
x = mx.array([1.0, -1.0, 0.0])
|
||||||
y = nn.leaky_relu(x)
|
y = nn.leaky_relu(x)
|
||||||
|
@ -2586,17 +2586,6 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqualArray(result, mx.array(expected))
|
self.assertEqualArray(result, mx.array(expected))
|
||||||
|
|
||||||
def test_atleast_1d(self):
|
def test_atleast_1d(self):
|
||||||
def compare_nested_lists(x, y):
|
|
||||||
if isinstance(x, list) and isinstance(y, list):
|
|
||||||
if len(x) != len(y):
|
|
||||||
return False
|
|
||||||
for i in range(len(x)):
|
|
||||||
if not compare_nested_lists(x[i], y[i]):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return x == y
|
|
||||||
|
|
||||||
# Test 1D input
|
# Test 1D input
|
||||||
arrays = [
|
arrays = [
|
||||||
[1],
|
[1],
|
||||||
@ -2614,23 +2603,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
for i, array in enumerate(arrays):
|
for i, array in enumerate(arrays):
|
||||||
mx_res = mx.atleast_1d(mx.array(array))
|
mx_res = mx.atleast_1d(mx.array(array))
|
||||||
np_res = np.atleast_1d(np.array(array))
|
np_res = np.atleast_1d(np.array(array))
|
||||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
|
||||||
self.assertEqual(mx_res.shape, np_res.shape)
|
self.assertEqual(mx_res.shape, np_res.shape)
|
||||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||||
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
|
self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i]))
|
||||||
|
|
||||||
def test_atleast_2d(self):
|
def test_atleast_2d(self):
|
||||||
def compare_nested_lists(x, y):
|
|
||||||
if isinstance(x, list) and isinstance(y, list):
|
|
||||||
if len(x) != len(y):
|
|
||||||
return False
|
|
||||||
for i in range(len(x)):
|
|
||||||
if not compare_nested_lists(x[i], y[i]):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return x == y
|
|
||||||
|
|
||||||
# Test 1D input
|
# Test 1D input
|
||||||
arrays = [
|
arrays = [
|
||||||
[1],
|
[1],
|
||||||
@ -2648,23 +2625,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
for i, array in enumerate(arrays):
|
for i, array in enumerate(arrays):
|
||||||
mx_res = mx.atleast_2d(mx.array(array))
|
mx_res = mx.atleast_2d(mx.array(array))
|
||||||
np_res = np.atleast_2d(np.array(array))
|
np_res = np.atleast_2d(np.array(array))
|
||||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
|
||||||
self.assertEqual(mx_res.shape, np_res.shape)
|
self.assertEqual(mx_res.shape, np_res.shape)
|
||||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||||
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
|
self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i]))
|
||||||
|
|
||||||
def test_atleast_3d(self):
|
def test_atleast_3d(self):
|
||||||
def compare_nested_lists(x, y):
|
|
||||||
if isinstance(x, list) and isinstance(y, list):
|
|
||||||
if len(x) != len(y):
|
|
||||||
return False
|
|
||||||
for i in range(len(x)):
|
|
||||||
if not compare_nested_lists(x[i], y[i]):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return x == y
|
|
||||||
|
|
||||||
# Test 1D input
|
# Test 1D input
|
||||||
arrays = [
|
arrays = [
|
||||||
[1],
|
[1],
|
||||||
@ -2682,10 +2647,9 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
for i, array in enumerate(arrays):
|
for i, array in enumerate(arrays):
|
||||||
mx_res = mx.atleast_3d(mx.array(array))
|
mx_res = mx.atleast_3d(mx.array(array))
|
||||||
np_res = np.atleast_3d(np.array(array))
|
np_res = np.atleast_3d(np.array(array))
|
||||||
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
|
|
||||||
self.assertEqual(mx_res.shape, np_res.shape)
|
self.assertEqual(mx_res.shape, np_res.shape)
|
||||||
self.assertEqual(mx_res.ndim, np_res.ndim)
|
self.assertEqual(mx_res.ndim, np_res.ndim)
|
||||||
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))
|
self.assertTrue(mx.array_equal(mx_res, atleast_arrays[i]))
|
||||||
|
|
||||||
def test_issubdtype(self):
|
def test_issubdtype(self):
|
||||||
self.assertTrue(mx.issubdtype(mx.bfloat16, mx.inexact))
|
self.assertTrue(mx.issubdtype(mx.bfloat16, mx.inexact))
|
||||||
|
Loading…
Reference in New Issue
Block a user