mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-10 03:06:39 +08:00
Move some dims utils to common (#2223)
This commit is contained in:
parent
54a71f270a
commit
f76ee1ffd2
@ -1,9 +1,16 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
std::string get_primitive_string(Primitive* primitive) {
|
||||||
|
std::ostringstream op_t;
|
||||||
|
primitive->print(op_t);
|
||||||
|
return op_t.str();
|
||||||
|
}
|
||||||
|
|
||||||
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
||||||
const Shape& shape,
|
const Shape& shape,
|
||||||
const std::vector<Strides>& strides,
|
const std::vector<Strides>& strides,
|
||||||
@ -101,4 +108,105 @@ std::pair<Shape, Strides> collapse_contiguous_dims(
|
|||||||
return collapse_contiguous_dims(a.shape(), a.strides(), size_cap);
|
return collapse_contiguous_dims(a.shape(), a.strides(), size_cap);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 /* = 10 */) {
|
||||||
|
int pows[3] = {0, 0, 0};
|
||||||
|
int sum = 0;
|
||||||
|
while (true) {
|
||||||
|
int presum = sum;
|
||||||
|
// Check all the pows
|
||||||
|
if (dim0 >= (1 << (pows[0] + 1))) {
|
||||||
|
pows[0]++;
|
||||||
|
sum++;
|
||||||
|
}
|
||||||
|
if (sum == 10) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (dim1 >= (1 << (pows[1] + 1))) {
|
||||||
|
pows[1]++;
|
||||||
|
sum++;
|
||||||
|
}
|
||||||
|
if (sum == 10) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (dim2 >= (1 << (pows[2] + 1))) {
|
||||||
|
pows[2]++;
|
||||||
|
sum++;
|
||||||
|
}
|
||||||
|
if (sum == presum || sum == pow2) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::make_tuple(1ul << pows[0], 1ul << pows[1], 1ul << pows[2]);
|
||||||
|
}
|
||||||
|
|
||||||
|
Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides) {
|
||||||
|
// Dims with strides of 0 are ignored as they
|
||||||
|
// correspond to broadcasted dimensions
|
||||||
|
size_t grid_x = 1;
|
||||||
|
size_t grid_y = 1;
|
||||||
|
for (int i = 0; i < shape.size(); ++i) {
|
||||||
|
if (strides[i] == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (grid_x * shape[i] < UINT32_MAX) {
|
||||||
|
grid_x *= shape[i];
|
||||||
|
} else {
|
||||||
|
grid_y *= shape[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
|
||||||
|
throw std::runtime_error("Unable to safely factor shape.");
|
||||||
|
}
|
||||||
|
if (grid_y > grid_x) {
|
||||||
|
std::swap(grid_x, grid_y);
|
||||||
|
}
|
||||||
|
return std::make_tuple(
|
||||||
|
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
Dims get_2d_grid_dims_common(
|
||||||
|
const Shape& shape,
|
||||||
|
const Strides& strides,
|
||||||
|
size_t divisor) {
|
||||||
|
// Compute the 2d grid dimensions such that the total size of the grid is
|
||||||
|
// divided by divisor.
|
||||||
|
size_t grid_x = 1;
|
||||||
|
size_t grid_y = 1;
|
||||||
|
for (int i = 0; i < shape.size(); ++i) {
|
||||||
|
if (strides[i] == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// No need to add this shape we can just remove it from the divisor.
|
||||||
|
if (divisor % shape[i] == 0) {
|
||||||
|
divisor /= shape[i];
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (grid_x * shape[i] < UINT32_MAX) {
|
||||||
|
grid_x *= shape[i];
|
||||||
|
} else {
|
||||||
|
grid_y *= shape[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (divisor > 1) {
|
||||||
|
if (grid_x % divisor == 0) {
|
||||||
|
grid_x /= divisor;
|
||||||
|
divisor = 1;
|
||||||
|
} else if (grid_y % divisor == 0) {
|
||||||
|
grid_y /= divisor;
|
||||||
|
divisor = 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX || divisor > 1) {
|
||||||
|
throw std::runtime_error("Unable to safely factor shape.");
|
||||||
|
}
|
||||||
|
if (grid_y > grid_x) {
|
||||||
|
std::swap(grid_x, grid_y);
|
||||||
|
}
|
||||||
|
return std::make_tuple(
|
||||||
|
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -2,12 +2,15 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <tuple>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
std::string get_primitive_string(Primitive* primitive);
|
||||||
|
|
||||||
inline int64_t
|
inline int64_t
|
||||||
elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
|
elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
|
||||||
int64_t loc = 0;
|
int64_t loc = 0;
|
||||||
@ -70,6 +73,28 @@ std::pair<Shape, Strides> collapse_contiguous_dims(
|
|||||||
const array& a,
|
const array& a,
|
||||||
int64_t size_cap = std::numeric_limits<int32_t>::max());
|
int64_t size_cap = std::numeric_limits<int32_t>::max());
|
||||||
|
|
||||||
|
// Compute the thread block dimensions which fit the given
|
||||||
|
// input dimensions.
|
||||||
|
// - The thread block dimensions will be powers of two
|
||||||
|
// - The thread block size will be less than 2^pow2
|
||||||
|
using Dims = std::tuple<uint32_t, uint32_t, uint32_t>;
|
||||||
|
Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 = 10);
|
||||||
|
|
||||||
|
// Computes a 2D grid where each element is < UINT_MAX
|
||||||
|
// Assumes:
|
||||||
|
// - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2
|
||||||
|
// - shape and strides correspond to a contiguous (no holes) but
|
||||||
|
// possibly broadcasted array
|
||||||
|
Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides);
|
||||||
|
|
||||||
|
// Same as above but we do an implicit division with divisor.
|
||||||
|
// Basically, equivalent to factorizing
|
||||||
|
// Prod(s \forall s in shape if strides[s] > 0) / divisor.
|
||||||
|
Dims get_2d_grid_dims_common(
|
||||||
|
const Shape& shape,
|
||||||
|
const Strides& strides,
|
||||||
|
size_t divisor);
|
||||||
|
|
||||||
struct ContiguousIterator {
|
struct ContiguousIterator {
|
||||||
inline void step() {
|
inline void step() {
|
||||||
int dims = shape_.size();
|
int dims = shape_.size();
|
||||||
|
@ -11,6 +11,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/fence.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||||
|
26
mlx/backend/cuda/kernel_utils.cu
Normal file
26
mlx/backend/cuda/kernel_utils.cu
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
dim3 get_block_dims(int dim0, int dim1, int dim2, int pow2) {
|
||||||
|
Dims dims = get_block_dims_common(dim0, dim1, dim2, pow2);
|
||||||
|
return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims));
|
||||||
|
}
|
||||||
|
|
||||||
|
dim3 get_2d_grid_dims(const Shape& shape, const Strides& strides) {
|
||||||
|
Dims dims = get_2d_grid_dims_common(shape, strides);
|
||||||
|
return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims));
|
||||||
|
}
|
||||||
|
|
||||||
|
dim3 get_2d_grid_dims(
|
||||||
|
const Shape& shape,
|
||||||
|
const Strides& strides,
|
||||||
|
size_t divisor) {
|
||||||
|
Dims dims = get_2d_grid_dims_common(shape, strides, divisor);
|
||||||
|
return dim3(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -1,7 +1,13 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
// This file includes host-only utilies for writing CUDA kernels, the difference
|
||||||
|
// from backend/cuda/kernels/utils.cuh is that the latter file only include
|
||||||
|
// device-only code.
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/array.h"
|
||||||
|
|
||||||
#include <cuComplex.h>
|
#include <cuComplex.h>
|
||||||
#include <cuda_bf16.h>
|
#include <cuda_bf16.h>
|
||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
@ -32,4 +38,12 @@ struct CTypeToCudaType<complex64_t> {
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
using cuda_type_t = typename CTypeToCudaType<T>::type;
|
using cuda_type_t = typename CTypeToCudaType<T>::type;
|
||||||
|
|
||||||
|
// Compute the grid and block dimensions, check backend/common/utils.h for docs.
|
||||||
|
dim3 get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10);
|
||||||
|
dim3 get_2d_grid_dims(const Shape& shape, const Strides& strides);
|
||||||
|
dim3 get_2d_grid_dims(
|
||||||
|
const Shape& shape,
|
||||||
|
const Strides& strides,
|
||||||
|
size_t divisor);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
@ -1,7 +1,7 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/dtype_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/backend/cuda/kernels/arange.cuh"
|
#include "mlx/backend/cuda/kernels/arange.cuh"
|
||||||
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
|
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
|
||||||
#include "mlx/distributed/primitives.h"
|
#include "mlx/distributed/primitives.h"
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
// This file include utilies that are used by C++ code (i.e. .cpp files).
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
using namespace mlx;
|
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@ -59,109 +58,20 @@ std::string type_to_name(const array& a) {
|
|||||||
return type_to_name(a.dtype());
|
return type_to_name(a.dtype());
|
||||||
}
|
}
|
||||||
|
|
||||||
MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 /* = 10 */) {
|
MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2) {
|
||||||
int pows[3] = {0, 0, 0};
|
Dims dims = get_block_dims_common(dim0, dim1, dim2, pow2);
|
||||||
int sum = 0;
|
return MTL::Size(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims));
|
||||||
while (true) {
|
|
||||||
int presum = sum;
|
|
||||||
// Check all the pows
|
|
||||||
if (dim0 >= (1 << (pows[0] + 1))) {
|
|
||||||
pows[0]++;
|
|
||||||
sum++;
|
|
||||||
}
|
|
||||||
if (sum == 10) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if (dim1 >= (1 << (pows[1] + 1))) {
|
|
||||||
pows[1]++;
|
|
||||||
sum++;
|
|
||||||
}
|
|
||||||
if (sum == 10) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if (dim2 >= (1 << (pows[2] + 1))) {
|
|
||||||
pows[2]++;
|
|
||||||
sum++;
|
|
||||||
}
|
|
||||||
if (sum == presum || sum == pow2) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
MTL::Size get_2d_grid_dims(const Shape& shape, const Strides& strides) {
|
MTL::Size get_2d_grid_dims(const Shape& shape, const Strides& strides) {
|
||||||
// Dims with strides of 0 are ignored as they
|
Dims dims = get_2d_grid_dims_common(shape, strides);
|
||||||
// correspond to broadcasted dimensions
|
return MTL::Size(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims));
|
||||||
size_t grid_x = 1;
|
|
||||||
size_t grid_y = 1;
|
|
||||||
for (int i = 0; i < shape.size(); ++i) {
|
|
||||||
if (strides[i] == 0) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (grid_x * shape[i] < UINT32_MAX) {
|
|
||||||
grid_x *= shape[i];
|
|
||||||
} else {
|
|
||||||
grid_y *= shape[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
|
|
||||||
throw std::runtime_error("Unable to safely factor shape.");
|
|
||||||
}
|
|
||||||
if (grid_y > grid_x) {
|
|
||||||
std::swap(grid_x, grid_y);
|
|
||||||
}
|
|
||||||
return MTL::Size(
|
|
||||||
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
MTL::Size
|
MTL::Size
|
||||||
get_2d_grid_dims(const Shape& shape, const Strides& strides, size_t divisor) {
|
get_2d_grid_dims(const Shape& shape, const Strides& strides, size_t divisor) {
|
||||||
// Compute the 2d grid dimensions such that the total size of the grid is
|
Dims dims = get_2d_grid_dims_common(shape, strides, divisor);
|
||||||
// divided by divisor.
|
return MTL::Size(std::get<0>(dims), std::get<1>(dims), std::get<2>(dims));
|
||||||
size_t grid_x = 1;
|
|
||||||
size_t grid_y = 1;
|
|
||||||
for (int i = 0; i < shape.size(); ++i) {
|
|
||||||
if (strides[i] == 0) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// No need to add this shape we can just remove it from the divisor.
|
|
||||||
if (divisor % shape[i] == 0) {
|
|
||||||
divisor /= shape[i];
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (grid_x * shape[i] < UINT32_MAX) {
|
|
||||||
grid_x *= shape[i];
|
|
||||||
} else {
|
|
||||||
grid_y *= shape[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (divisor > 1) {
|
|
||||||
if (grid_x % divisor == 0) {
|
|
||||||
grid_x /= divisor;
|
|
||||||
divisor = 1;
|
|
||||||
} else if (grid_y % divisor == 0) {
|
|
||||||
grid_y /= divisor;
|
|
||||||
divisor = 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX || divisor > 1) {
|
|
||||||
throw std::runtime_error("Unable to safely factor shape.");
|
|
||||||
}
|
|
||||||
if (grid_y > grid_x) {
|
|
||||||
std::swap(grid_x, grid_y);
|
|
||||||
}
|
|
||||||
return MTL::Size(
|
|
||||||
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string get_primitive_string(Primitive* primitive) {
|
|
||||||
std::ostringstream op_t;
|
|
||||||
primitive->print(op_t);
|
|
||||||
return op_t.str();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -13,22 +13,9 @@ namespace mlx::core {
|
|||||||
std::string type_to_name(const Dtype& t);
|
std::string type_to_name(const Dtype& t);
|
||||||
std::string type_to_name(const array& a);
|
std::string type_to_name(const array& a);
|
||||||
|
|
||||||
// Compute the thread block dimensions which fit the given
|
// Compute the grid and block dimensions, check backend/common/utils.h for docs.
|
||||||
// input dimensions.
|
|
||||||
// - The thread block dimensions will be powers of two
|
|
||||||
// - The thread block size will be less than 2^pow2
|
|
||||||
MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10);
|
MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10);
|
||||||
|
|
||||||
// Computes a 2D grid where each element is < UINT_MAX
|
|
||||||
// Assumes:
|
|
||||||
// - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2
|
|
||||||
// - shape and strides correspond to a contiguous (no holes) but
|
|
||||||
// possibly broadcasted array
|
|
||||||
MTL::Size get_2d_grid_dims(const Shape& shape, const Strides& strides);
|
MTL::Size get_2d_grid_dims(const Shape& shape, const Strides& strides);
|
||||||
|
|
||||||
// Same as above but we do an implicit division with divisor.
|
|
||||||
// Basically, equivalent to factorizing
|
|
||||||
// Prod(s \forall s in shape if strides[s] > 0) / divisor.
|
|
||||||
MTL::Size
|
MTL::Size
|
||||||
get_2d_grid_dims(const Shape& shape, const Strides& strides, size_t divisor);
|
get_2d_grid_dims(const Shape& shape, const Strides& strides, size_t divisor);
|
||||||
|
|
||||||
@ -58,8 +45,6 @@ inline void debug_set_primitive_buffer_label(
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string get_primitive_string(Primitive* primitive);
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
constexpr bool is_numeric_except_char = std::is_arithmetic_v<T> &&
|
constexpr bool is_numeric_except_char = std::is_arithmetic_v<T> &&
|
||||||
!std::is_same_v<T, char> && !std::is_same_v<T, signed char> &&
|
!std::is_same_v<T, char> && !std::is_same_v<T, signed char> &&
|
||||||
|
Loading…
Reference in New Issue
Block a user