mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
4 Commits
v0.28.0
...
fce53b61d6
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fce53b61d6 | ||
|
|
8ae4a76308 | ||
|
|
7fde1b6a1e | ||
|
|
aa7b47481a |
@@ -491,19 +491,27 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
reduce_dispatch_sum_prod<uint8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint16:
|
||||
reduce_dispatch_sum_prod<uint16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint32:
|
||||
reduce_dispatch_sum_prod<uint32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint64:
|
||||
reduce_dispatch_sum_prod<uint64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int8:
|
||||
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
case uint16:
|
||||
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int32:
|
||||
case uint32:
|
||||
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int64:
|
||||
case uint64:
|
||||
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float16:
|
||||
|
||||
@@ -10,7 +10,34 @@ namespace mlx::core::cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
__global__ void set_mm_device_pointers(
|
||||
template <int NDIM>
|
||||
__global__ void set_mm_device_pointers_nd(
|
||||
int8_t** pointers,
|
||||
int8_t* a_start,
|
||||
int8_t* b_start,
|
||||
int8_t* out_start,
|
||||
int item_size,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> batch_shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_batch_strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_batch_strides,
|
||||
int64_t batch_stride,
|
||||
int batch_count) {
|
||||
auto index = cg::this_grid().thread_rank();
|
||||
if (index >= batch_count) {
|
||||
return;
|
||||
}
|
||||
auto [a_offset, b_offset] = elem_to_loc_nd<NDIM>(
|
||||
index,
|
||||
batch_shape.data(),
|
||||
a_batch_strides.data(),
|
||||
b_batch_strides.data());
|
||||
pointers[index] = a_start + item_size * a_offset;
|
||||
pointers[index + batch_count] = b_start + item_size * b_offset;
|
||||
pointers[index + 2 * batch_count] =
|
||||
out_start + item_size * index * batch_stride;
|
||||
}
|
||||
|
||||
__global__ void set_mm_device_pointers_g(
|
||||
int8_t** pointers,
|
||||
int8_t* a_start,
|
||||
int8_t* b_start,
|
||||
@@ -38,7 +65,38 @@ __global__ void set_mm_device_pointers(
|
||||
out_start + item_size * index * batch_stride;
|
||||
}
|
||||
|
||||
__global__ void set_addmm_device_pointers(
|
||||
template <int NDIM>
|
||||
__global__ void set_addmm_device_pointers_nd(
|
||||
int8_t** pointers,
|
||||
int8_t* a_start,
|
||||
int8_t* b_start,
|
||||
int8_t* c_start,
|
||||
int8_t* out_start,
|
||||
int item_size,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> batch_shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_batch_strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_batch_strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> c_batch_strides,
|
||||
int64_t batch_stride,
|
||||
int batch_count) {
|
||||
auto index = cg::this_grid().thread_rank();
|
||||
if (index >= batch_count) {
|
||||
return;
|
||||
}
|
||||
auto [a_offset, b_offset, c_offset] = elem_to_loc_nd<NDIM>(
|
||||
index,
|
||||
batch_shape.data(),
|
||||
a_batch_strides.data(),
|
||||
b_batch_strides.data(),
|
||||
c_batch_strides.data());
|
||||
pointers[index] = a_start + item_size * a_offset;
|
||||
pointers[index + batch_count] = b_start + item_size * b_offset;
|
||||
pointers[index + 2 * batch_count] = c_start + item_size * c_offset;
|
||||
pointers[index + 3 * batch_count] =
|
||||
out_start + item_size * index * batch_stride;
|
||||
}
|
||||
|
||||
__global__ void set_addmm_device_pointers_g(
|
||||
int8_t** pointers,
|
||||
int8_t* a_start,
|
||||
int8_t* b_start,
|
||||
@@ -89,37 +147,62 @@ void Matmul::run_batched(
|
||||
const mlx::core::Shape& batch_shape,
|
||||
const mlx::core::Strides& a_batch_strides,
|
||||
const mlx::core::Strides& b_batch_strides) {
|
||||
auto batch_count = out.size() / (M_ * N_);
|
||||
int batch_count = out.size() / (M_ * N_);
|
||||
set_pointer_mode(a_desc_, batch_count);
|
||||
set_pointer_mode(b_desc_, batch_count);
|
||||
set_pointer_mode(out_desc_, batch_count);
|
||||
|
||||
// Launch kernel to set device offsets
|
||||
auto pointers = array(
|
||||
allocator::malloc(batch_count * sizeof(uint64_t) * 3),
|
||||
{static_cast<int>(batch_count * 3)},
|
||||
allocator::malloc(batch_count * sizeof(void*) * 3),
|
||||
{batch_count * 3},
|
||||
uint64);
|
||||
|
||||
encoder.add_temporary(pointers);
|
||||
int block_size = 512;
|
||||
encoder.set_output_array(pointers);
|
||||
|
||||
encoder.add_kernel_node(
|
||||
cu::set_mm_device_pointers,
|
||||
cuda::ceil_div(pointers.size(), block_size),
|
||||
block_size,
|
||||
0,
|
||||
pointers.data<int8_t*>(),
|
||||
a.data<int8_t>(),
|
||||
b.data<int8_t>(),
|
||||
out.data<int8_t>(),
|
||||
static_cast<int>(out.dtype().size()),
|
||||
const_param(batch_shape),
|
||||
const_param(a_batch_strides),
|
||||
const_param(b_batch_strides),
|
||||
static_cast<int64_t>(M_) * N_,
|
||||
static_cast<int>(batch_shape.size()),
|
||||
batch_count);
|
||||
int block_dims = std::min(batch_count, 256);
|
||||
int num_blocks = cuda::ceil_div(batch_count, block_dims);
|
||||
int64_t batch_stride = M_ * N_;
|
||||
int item_size = out.itemsize();
|
||||
|
||||
int ndim = batch_shape.size();
|
||||
if (ndim <= 3) {
|
||||
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
|
||||
encoder.add_kernel_node(
|
||||
cu::set_mm_device_pointers_nd<ndim_constant()>,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
pointers.data<int8_t*>(),
|
||||
a.data<int8_t>(),
|
||||
b.data<int8_t>(),
|
||||
out.data<int8_t>(),
|
||||
item_size,
|
||||
const_param<ndim_constant()>(batch_shape),
|
||||
const_param<ndim_constant()>(a_batch_strides),
|
||||
const_param<ndim_constant()>(b_batch_strides),
|
||||
batch_stride,
|
||||
batch_count);
|
||||
});
|
||||
} else {
|
||||
encoder.add_kernel_node(
|
||||
cu::set_mm_device_pointers_g,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
pointers.data<int8_t*>(),
|
||||
a.data<int8_t>(),
|
||||
b.data<int8_t>(),
|
||||
out.data<int8_t>(),
|
||||
item_size,
|
||||
const_param(batch_shape),
|
||||
const_param(a_batch_strides),
|
||||
const_param(b_batch_strides),
|
||||
batch_stride,
|
||||
ndim,
|
||||
batch_count);
|
||||
}
|
||||
|
||||
// Run matmul
|
||||
encoder.set_input_array(pointers);
|
||||
@@ -150,7 +233,7 @@ void Matmul::run_batched(
|
||||
const mlx::core::Strides& c_batch_strides,
|
||||
float alpha,
|
||||
float beta) {
|
||||
auto batch_count = out.size() / (M_ * N_);
|
||||
int batch_count = out.size() / (M_ * N_);
|
||||
set_pointer_mode(a_desc_, batch_count);
|
||||
set_pointer_mode(b_desc_, batch_count);
|
||||
set_pointer_mode(c_desc_, batch_count);
|
||||
@@ -159,30 +242,58 @@ void Matmul::run_batched(
|
||||
// Launch kernel to set device offsets
|
||||
auto pointers = array(
|
||||
allocator::malloc(batch_count * sizeof(uint64_t) * 4),
|
||||
{static_cast<int>(batch_count * 4)},
|
||||
{batch_count * 4},
|
||||
uint64);
|
||||
|
||||
encoder.add_temporary(pointers);
|
||||
int block_size = 512;
|
||||
encoder.set_output_array(pointers);
|
||||
encoder.add_kernel_node(
|
||||
cu::set_addmm_device_pointers,
|
||||
cuda::ceil_div(pointers.size(), block_size),
|
||||
block_size,
|
||||
0,
|
||||
pointers.data<int8_t*>(),
|
||||
a.data<int8_t>(),
|
||||
b.data<int8_t>(),
|
||||
c.data<int8_t>(),
|
||||
out.data<int8_t>(),
|
||||
static_cast<int>(out.dtype().size()),
|
||||
const_param(batch_shape),
|
||||
const_param(a_batch_strides),
|
||||
const_param(b_batch_strides),
|
||||
const_param(c_batch_strides),
|
||||
static_cast<int64_t>(M_) * N_,
|
||||
static_cast<int>(batch_shape.size()),
|
||||
batch_count);
|
||||
|
||||
int block_dims = std::min(batch_count, 256);
|
||||
int num_blocks = cuda::ceil_div(batch_count, block_dims);
|
||||
int64_t batch_stride = M_ * N_;
|
||||
int item_size = out.itemsize();
|
||||
|
||||
int ndim = batch_shape.size();
|
||||
if (ndim <= 3) {
|
||||
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
|
||||
encoder.add_kernel_node(
|
||||
cu::set_addmm_device_pointers_nd<ndim_constant()>,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
pointers.data<int8_t*>(),
|
||||
a.data<int8_t>(),
|
||||
b.data<int8_t>(),
|
||||
c.data<int8_t>(),
|
||||
out.data<int8_t>(),
|
||||
item_size,
|
||||
const_param<ndim_constant()>(batch_shape),
|
||||
const_param<ndim_constant()>(a_batch_strides),
|
||||
const_param<ndim_constant()>(b_batch_strides),
|
||||
const_param<ndim_constant()>(c_batch_strides),
|
||||
batch_stride,
|
||||
batch_count);
|
||||
});
|
||||
} else {
|
||||
encoder.add_kernel_node(
|
||||
cu::set_addmm_device_pointers_g,
|
||||
num_blocks,
|
||||
block_dims,
|
||||
0,
|
||||
pointers.data<int8_t*>(),
|
||||
a.data<int8_t>(),
|
||||
b.data<int8_t>(),
|
||||
c.data<int8_t>(),
|
||||
out.data<int8_t>(),
|
||||
item_size,
|
||||
const_param(batch_shape),
|
||||
const_param(a_batch_strides),
|
||||
const_param(b_batch_strides),
|
||||
const_param(c_batch_strides),
|
||||
batch_stride,
|
||||
ndim,
|
||||
batch_count);
|
||||
}
|
||||
|
||||
// Run matmul
|
||||
encoder.set_input_array(pointers);
|
||||
|
||||
@@ -134,6 +134,10 @@ instantiate_and_or(and, And)
|
||||
instantiate_and_or(or, Or)
|
||||
|
||||
#define instantiate_sum_prod(name, op) \
|
||||
instantiate_reduce_functions(name, uint8, uint8_t, int32_t, op) \
|
||||
instantiate_reduce_functions(name, uint16, uint16_t, uint32_t, op) \
|
||||
instantiate_reduce_functions(name, uint32, uint32_t, uint32_t, op) \
|
||||
instantiate_reduce_functions(name, uint64, uint64_t, uint64_t, op) \
|
||||
instantiate_reduce_functions(name, int8, int8_t, int32_t, op) \
|
||||
instantiate_reduce_functions(name, int16, int16_t, int32_t, op) \
|
||||
instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \
|
||||
|
||||
@@ -247,15 +247,25 @@ std::pair<Dtype, Dtype> remap_reduce_types(
|
||||
const std::string& op_name) {
|
||||
if (op_name == "sum" || op_name == "prod") {
|
||||
if (issubdtype(in.dtype(), integer)) {
|
||||
switch (in.dtype().size()) {
|
||||
case 1:
|
||||
switch (in.dtype()) {
|
||||
case uint8:
|
||||
return {uint8, uint32};
|
||||
case uint16:
|
||||
return {uint16, uint32};
|
||||
case uint32:
|
||||
return {uint32, uint32};
|
||||
case uint64:
|
||||
return {uint64, uint64};
|
||||
case int8:
|
||||
return {int8, int32};
|
||||
case 2:
|
||||
case int16:
|
||||
return {int16, int32};
|
||||
case 4:
|
||||
case int32:
|
||||
return {int32, int32};
|
||||
case 8:
|
||||
case int64:
|
||||
return {int64, int64};
|
||||
default:
|
||||
throw std::runtime_error("Unsupported integer type");
|
||||
}
|
||||
}
|
||||
if (in.dtype() == bool_) {
|
||||
|
||||
31
mlx/ops.cpp
31
mlx/ops.cpp
@@ -2381,9 +2381,20 @@ array logsumexp(
|
||||
throw std::invalid_argument(
|
||||
"[logsumexp] Received non-empty axes for array with 0 dimensions.");
|
||||
}
|
||||
bool reduce_last_dim =
|
||||
!axes.empty() && (axes.back() == a.ndim() - 1 || axes.back() == -1);
|
||||
if (reduce_last_dim) {
|
||||
// For more than 2 axes check if axes is [0, 1, ..., NDIM - 1] and shape
|
||||
// is [1, 1, ..., N].
|
||||
for (int i = axes.size() - 2; i >= 0; --i) {
|
||||
if ((axes[i] + 1 != axes[i + 1]) || (a.shape(axes[i]) != 1)) {
|
||||
reduce_last_dim = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
bool is_complex = issubdtype(a.dtype(), complexfloating);
|
||||
if (!is_complex && axes.size() == 1 &&
|
||||
(a.ndim() == axes[0] + 1 || axes[0] == -1)) {
|
||||
if (!is_complex && reduce_last_dim) {
|
||||
auto dtype = at_least_float(a.dtype());
|
||||
auto out_shape = a.shape();
|
||||
out_shape.back() = 1;
|
||||
@@ -3403,10 +3414,20 @@ array softmax(
|
||||
throw std::invalid_argument(
|
||||
"[softmax] Received non-empty axes for array with 0 dimensions.");
|
||||
}
|
||||
|
||||
bool reduce_last_dim =
|
||||
!axes.empty() && (axes.back() == a.ndim() - 1 || axes.back() == -1);
|
||||
if (reduce_last_dim) {
|
||||
// For more than 2 axes check if axes is [0, 1, ..., NDIM - 1] and shape
|
||||
// is [1, 1, ..., N].
|
||||
for (int i = axes.size() - 2; i >= 0; --i) {
|
||||
if ((axes[i] + 1 != axes[i + 1]) || (a.shape(axes[i]) != 1)) {
|
||||
reduce_last_dim = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
bool is_complex = issubdtype(a.dtype(), complexfloating);
|
||||
if (!is_complex && axes.size() == 1 &&
|
||||
(a.ndim() == axes[0] + 1 || axes[0] == -1)) {
|
||||
if (!is_complex && reduce_last_dim) {
|
||||
auto dtype = at_least_float(a.dtype());
|
||||
return array(
|
||||
a.shape(),
|
||||
|
||||
@@ -2,6 +2,6 @@
|
||||
requires = [
|
||||
"setuptools>=80",
|
||||
"nanobind==2.4.0",
|
||||
"cmake>=3.25",
|
||||
"cmake>=3.25,<4.1",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
@@ -155,6 +155,19 @@ TEST_CASE("test gpu reduce") {
|
||||
CHECK_EQ(prod(a, Device::gpu).item<int32_t>(), 1);
|
||||
}
|
||||
|
||||
// sum and prod overflow
|
||||
{
|
||||
auto a = full({256, 2, 2}, 1u, uint8);
|
||||
CHECK_EQ(sum(a, Device::gpu).item<uint32_t>(), 256 * 4);
|
||||
CHECK_EQ(prod(a, Device::gpu).item<uint32_t>(), 1);
|
||||
|
||||
a = full({65535, 2, 2}, 1u, uint16);
|
||||
CHECK_EQ(sum(a, Device::gpu).item<uint32_t>(), 65535 * 4);
|
||||
CHECK_EQ(prod(a, Device::gpu).item<uint32_t>(), 1);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test gpu reduce with axes") {
|
||||
// reducing only some axes and irregular layouts
|
||||
{
|
||||
array a(1.0f);
|
||||
|
||||
@@ -915,6 +915,23 @@ TEST_CASE("test reduction ops") {
|
||||
CHECK(array_equal(sum(x, 1), array({3.0f, 6.0f}, {2})).item<bool>());
|
||||
}
|
||||
|
||||
// Test unsigned sum
|
||||
{
|
||||
const int num_elems = 1000;
|
||||
|
||||
auto x = astype(full({num_elems}, 255), uint8);
|
||||
CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 255 * num_elems);
|
||||
|
||||
x = astype(full({num_elems}, 65535), uint16);
|
||||
CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 65535 * num_elems);
|
||||
|
||||
x = full({3, 3, 3}, 10000, uint32);
|
||||
CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 270000);
|
||||
|
||||
x = full({3, 3, 3}, 10000, uint64);
|
||||
CHECK_EQ(sum(x, Device::cpu).item<uint64_t>(), 270000);
|
||||
}
|
||||
|
||||
// Test prod
|
||||
{
|
||||
auto x = array({});
|
||||
@@ -947,6 +964,21 @@ TEST_CASE("test reduction ops") {
|
||||
CHECK(array_equal(prod(x, 1), array({true, false})).item<bool>());
|
||||
}
|
||||
|
||||
// Test unsigned prod
|
||||
{
|
||||
auto x = array({255, 255}, {2}, uint8);
|
||||
CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 65025);
|
||||
|
||||
x = array({65535, 2}, {2}, uint16);
|
||||
CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 131070);
|
||||
|
||||
x = array({100000, 2}, {2}, uint32);
|
||||
CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 200000);
|
||||
|
||||
x = array({100000, 2}, {2}, uint64);
|
||||
CHECK_EQ(prod(x, Device::cpu).item<uint64_t>(), 200000);
|
||||
}
|
||||
|
||||
// Test all
|
||||
{
|
||||
auto x = array({});
|
||||
|
||||
Reference in New Issue
Block a user