mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
10 Commits
e89e8b4272
...
v0.29.3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4bce5f9b2d | ||
|
|
e9eab527eb | ||
|
|
36ca62dba8 | ||
|
|
9cbb1b0148 | ||
|
|
9bfc476d72 | ||
|
|
25e2356316 | ||
|
|
226a1d24e0 | ||
|
|
630350ad3e | ||
|
|
380aeb58ae | ||
|
|
f37389d100 |
38
README.md
38
README.md
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
[**Quickstart**](#quickstart) | [**Installation**](#installation) |
|
[**Quickstart**](#quickstart) | [**Installation**](#installation) |
|
||||||
[**Documentation**](https://ml-explore.github.io/mlx/build/html/index.html) |
|
[**Documentation**](https://ml-explore.github.io/mlx/build/html/index.html) |
|
||||||
[**Examples**](#examples)
|
[**Examples**](#examples)
|
||||||
|
|
||||||
[](https://circleci.com/gh/ml-explore/mlx)
|
[](https://circleci.com/gh/ml-explore/mlx)
|
||||||
|
|
||||||
@@ -11,37 +11,37 @@ brought to you by Apple machine learning research.
|
|||||||
|
|
||||||
Some key features of MLX include:
|
Some key features of MLX include:
|
||||||
|
|
||||||
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
|
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
|
||||||
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
|
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
|
||||||
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
|
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
|
||||||
the Python API. MLX has higher-level packages like `mlx.nn` and
|
the Python API. MLX has higher-level packages like `mlx.nn` and
|
||||||
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
|
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
|
||||||
more complex models.
|
more complex models.
|
||||||
|
|
||||||
- **Composable function transformations**: MLX supports composable function
|
- **Composable function transformations**: MLX supports composable function
|
||||||
transformations for automatic differentiation, automatic vectorization,
|
transformations for automatic differentiation, automatic vectorization,
|
||||||
and computation graph optimization.
|
and computation graph optimization.
|
||||||
|
|
||||||
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
|
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
|
||||||
materialized when needed.
|
materialized when needed.
|
||||||
|
|
||||||
- **Dynamic graph construction**: Computation graphs in MLX are constructed
|
- **Dynamic graph construction**: Computation graphs in MLX are constructed
|
||||||
dynamically. Changing the shapes of function arguments does not trigger
|
dynamically. Changing the shapes of function arguments does not trigger
|
||||||
slow compilations, and debugging is simple and intuitive.
|
slow compilations, and debugging is simple and intuitive.
|
||||||
|
|
||||||
- **Multi-device**: Operations can run on any of the supported devices
|
- **Multi-device**: Operations can run on any of the supported devices
|
||||||
(currently the CPU and the GPU).
|
(currently the CPU and the GPU).
|
||||||
|
|
||||||
- **Unified memory**: A notable difference from MLX and other frameworks
|
- **Unified memory**: A notable difference from MLX and other frameworks
|
||||||
is the *unified memory model*. Arrays in MLX live in shared memory.
|
is the *unified memory model*. Arrays in MLX live in shared memory.
|
||||||
Operations on MLX arrays can be performed on any of the supported
|
Operations on MLX arrays can be performed on any of the supported
|
||||||
device types without transferring data.
|
device types without transferring data.
|
||||||
|
|
||||||
MLX is designed by machine learning researchers for machine learning
|
MLX is designed by machine learning researchers for machine learning
|
||||||
researchers. The framework is intended to be user-friendly, but still efficient
|
researchers. The framework is intended to be user-friendly, but still efficient
|
||||||
to train and deploy models. The design of the framework itself is also
|
to train and deploy models. The design of the framework itself is also
|
||||||
conceptually simple. We intend to make it easy for researchers to extend and
|
conceptually simple. We intend to make it easy for researchers to extend and
|
||||||
improve MLX with the goal of quickly exploring new ideas.
|
improve MLX with the goal of quickly exploring new ideas.
|
||||||
|
|
||||||
The design of MLX is inspired by frameworks like
|
The design of MLX is inspired by frameworks like
|
||||||
[NumPy](https://numpy.org/doc/stable/index.html),
|
[NumPy](https://numpy.org/doc/stable/index.html),
|
||||||
@@ -91,7 +91,7 @@ Checkout the
|
|||||||
[documentation](https://ml-explore.github.io/mlx/build/html/install.html#)
|
[documentation](https://ml-explore.github.io/mlx/build/html/install.html#)
|
||||||
for more information on building the C++ and Python APIs from source.
|
for more information on building the C++ and Python APIs from source.
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
Check out the [contribution guidelines](https://github.com/ml-explore/mlx/tree/main/CONTRIBUTING.md) for more information
|
Check out the [contribution guidelines](https://github.com/ml-explore/mlx/tree/main/CONTRIBUTING.md) for more information
|
||||||
on contributing to MLX. See the
|
on contributing to MLX. See the
|
||||||
@@ -110,7 +110,7 @@ Hannun, Jagrit Digani, Angelos Katharopoulos, and Ronan Collobert. If you find
|
|||||||
MLX useful in your research and wish to cite it, please use the following
|
MLX useful in your research and wish to cite it, please use the following
|
||||||
BibTex entry:
|
BibTex entry:
|
||||||
|
|
||||||
```
|
```text
|
||||||
@software{mlx2023,
|
@software{mlx2023,
|
||||||
author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
|
author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
|
||||||
title = {{MLX}: Efficient and flexible machine learning on Apple silicon},
|
title = {{MLX}: Efficient and flexible machine learning on Apple silicon},
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ inline std::tuple<Shape, Strides, Strides> collapse_batches(
|
|||||||
const array& a,
|
const array& a,
|
||||||
const array& b) {
|
const array& b) {
|
||||||
if (a.ndim() == 2) {
|
if (a.ndim() == 2) {
|
||||||
return {{1}, {0}, {0}};
|
return {Shape{1}, Strides{0}, Strides{0}};
|
||||||
}
|
}
|
||||||
|
|
||||||
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||||
@@ -38,7 +38,7 @@ inline std::tuple<Shape, Strides, Strides> collapse_batches(
|
|||||||
inline std::tuple<Shape, Strides, Strides, Strides>
|
inline std::tuple<Shape, Strides, Strides, Strides>
|
||||||
collapse_batches(const array& a, const array& b, const array& c) {
|
collapse_batches(const array& a, const array& b, const array& c) {
|
||||||
if (a.ndim() == 2) {
|
if (a.ndim() == 2) {
|
||||||
return {{1}, {0}, {0}, {0}};
|
return {Shape{1}, Strides{0}, Strides{0}, Strides{0}};
|
||||||
}
|
}
|
||||||
|
|
||||||
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||||
|
|||||||
@@ -131,10 +131,6 @@ void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
if (out.dtype() != float32) {
|
|
||||||
throw std::runtime_error(
|
|
||||||
"[AddMM::eval_cpu] Currently only supports float32.");
|
|
||||||
}
|
|
||||||
if (out.size() == 0) {
|
if (out.size() == 0) {
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
return;
|
return;
|
||||||
|
|||||||
@@ -15,6 +15,18 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
// NaN-aware comparator that places NaNs at the end
|
||||||
|
template <typename T>
|
||||||
|
bool nan_aware_less(T a, T b) {
|
||||||
|
if constexpr (std::is_floating_point_v<T> || std::is_same_v<T, complex64_t>) {
|
||||||
|
if (std::isnan(a))
|
||||||
|
return false;
|
||||||
|
if (std::isnan(b))
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return a < b;
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct StridedIterator {
|
struct StridedIterator {
|
||||||
using iterator_category = std::random_access_iterator_tag;
|
using iterator_category = std::random_access_iterator_tag;
|
||||||
@@ -130,7 +142,7 @@ void sort(array& out, int axis) {
|
|||||||
StridedIterator st(data_ptr, axis_stride, 0);
|
StridedIterator st(data_ptr, axis_stride, 0);
|
||||||
StridedIterator ed(data_ptr, axis_stride, axis_size);
|
StridedIterator ed(data_ptr, axis_stride, axis_size);
|
||||||
|
|
||||||
std::stable_sort(st, ed);
|
std::stable_sort(st, ed, nan_aware_less<T>);
|
||||||
src_it.step();
|
src_it.step();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -184,6 +196,15 @@ void argsort(const array& in, array& out, int axis) {
|
|||||||
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
|
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
|
||||||
auto v1 = data_ptr[a * in_stride];
|
auto v1 = data_ptr[a * in_stride];
|
||||||
auto v2 = data_ptr[b * in_stride];
|
auto v2 = data_ptr[b * in_stride];
|
||||||
|
|
||||||
|
// Handle NaNs (place them at the end)
|
||||||
|
if (std::is_floating_point<T>::value) {
|
||||||
|
if (std::isnan(v1))
|
||||||
|
return false;
|
||||||
|
if (std::isnan(v2))
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
return v1 < v2 || (v1 == v2 && a < b);
|
return v1 < v2 || (v1 == v2 && a < b);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -219,7 +240,7 @@ void partition(array& out, int axis, int kth) {
|
|||||||
StridedIterator md(data_ptr, axis_stride, kth);
|
StridedIterator md(data_ptr, axis_stride, kth);
|
||||||
StridedIterator ed(data_ptr, axis_stride, axis_size);
|
StridedIterator ed(data_ptr, axis_stride, axis_size);
|
||||||
|
|
||||||
std::nth_element(st, md, ed);
|
std::nth_element(st, md, ed, nan_aware_less<T>);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -276,6 +297,15 @@ void argpartition(const array& in, array& out, int axis, int kth) {
|
|||||||
std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
|
std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
|
||||||
auto v1 = data_ptr[a * in_stride];
|
auto v1 = data_ptr[a * in_stride];
|
||||||
auto v2 = data_ptr[b * in_stride];
|
auto v2 = data_ptr[b * in_stride];
|
||||||
|
|
||||||
|
// Handle NaNs (place them at the end)
|
||||||
|
if (std::is_floating_point<T>::value) {
|
||||||
|
if (std::isnan(v1))
|
||||||
|
return false;
|
||||||
|
if (std::isnan(v2))
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
return v1 < v2 || (v1 == v2 && a < b);
|
return v1 < v2 || (v1 == v2 && a < b);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -77,7 +77,8 @@ struct Real {
|
|||||||
struct Sigmoid {
|
struct Sigmoid {
|
||||||
template <int N, typename T>
|
template <int N, typename T>
|
||||||
Simd<T, N> operator()(Simd<T, N> x) {
|
Simd<T, N> operator()(Simd<T, N> x) {
|
||||||
return 1.0f / (1.0f + simd::exp(-x));
|
auto y = 1.0f / (1.0f + simd::exp(simd::abs(x)));
|
||||||
|
return simd::select(x < Simd<T, N>{0}, y, Simd<T, N>{1} - y);
|
||||||
}
|
}
|
||||||
SINGLE()
|
SINGLE()
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -170,6 +170,10 @@ target_link_libraries(mlx PRIVATE CUDNN::cudnn_all)
|
|||||||
# Suppress nvcc warnings on MLX headers.
|
# Suppress nvcc warnings on MLX headers.
|
||||||
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
||||||
--diag_suppress=997>)
|
--diag_suppress=997>)
|
||||||
|
# Supress warnings: note: parameter passing for argument of type
|
||||||
|
# ‘std::pair<float, float>’ when C++17 is enabled changed to match C++14 in GCC
|
||||||
|
# 10.1
|
||||||
|
target_compile_options(mlx PRIVATE -Wno-psabi)
|
||||||
|
|
||||||
# Install CCCL headers for JIT.
|
# Install CCCL headers for JIT.
|
||||||
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
|
install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
|
||||||
|
|||||||
@@ -30,15 +30,20 @@ SmallSizePool::SmallSizePool() {
|
|||||||
next_free_ = buffer_;
|
next_free_ = buffer_;
|
||||||
|
|
||||||
CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size));
|
CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size));
|
||||||
|
|
||||||
|
int device_count = 0;
|
||||||
|
CHECK_CUDA_ERROR(cudaGetDeviceCount(&device_count));
|
||||||
|
for (int i = 0; i < device_count; ++i) {
|
||||||
#if CUDART_VERSION >= 13000
|
#if CUDART_VERSION >= 13000
|
||||||
cudaMemLocation loc;
|
cudaMemLocation loc;
|
||||||
loc.type = cudaMemLocationTypeDevice;
|
loc.type = cudaMemLocationTypeDevice;
|
||||||
loc.id = 0;
|
loc.id = i;
|
||||||
#else
|
#else
|
||||||
int loc = 0;
|
int loc = i;
|
||||||
#endif // CUDART_VERSION >= 13000
|
#endif // CUDART_VERSION >= 13000
|
||||||
CHECK_CUDA_ERROR(
|
CHECK_CUDA_ERROR(
|
||||||
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetReadMostly, loc));
|
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetAccessedBy, loc));
|
||||||
|
}
|
||||||
|
|
||||||
auto curr = next_free_;
|
auto curr = next_free_;
|
||||||
for (size_t i = 1; i < num_blocks; ++i) {
|
for (size_t i = 1; i < num_blocks; ++i) {
|
||||||
|
|||||||
@@ -382,20 +382,19 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (op_graph) {
|
if (op_graph) {
|
||||||
// Setup inputs and outputs.
|
|
||||||
register_args(encoder, backend_type, in, wt, out, out_);
|
|
||||||
|
|
||||||
// Find a plan for the graph and execute it.
|
// Find a plan for the graph and execute it.
|
||||||
auto plan = find_cudnn_plan_from_op_graph(
|
auto plan = find_cudnn_plan_from_op_graph(
|
||||||
encoder.device().cudnn_handle(), backend_type, dtype, *op_graph);
|
encoder.device().cudnn_handle(), backend_type, dtype, *op_graph);
|
||||||
if (!plan) {
|
if (plan) {
|
||||||
throw std::runtime_error("[conv] Unable to find an execution plan.");
|
// Setup inputs and outputs.
|
||||||
}
|
register_args(encoder, backend_type, in, wt, out, out_);
|
||||||
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
|
||||||
if (encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
|
auto [x, w, y] = dispatch_args(backend_type, in, wt, out);
|
||||||
conv_cache().emplace(
|
if (encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) {
|
||||||
cache_key, std::make_pair(backend_type, std::move(*plan)));
|
conv_cache().emplace(
|
||||||
return;
|
cache_key, std::make_pair(backend_type, std::move(*plan)));
|
||||||
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -210,6 +210,9 @@ std::optional<cudnn_frontend::ExecutionPlan> find_cudnn_plan_from_op_graph(
|
|||||||
Dtype dtype,
|
Dtype dtype,
|
||||||
cudnn_frontend::OperationGraph& op_graph) {
|
cudnn_frontend::OperationGraph& op_graph) {
|
||||||
auto engine_configs = get_cudnn_engine_configs(backend_type, dtype, op_graph);
|
auto engine_configs = get_cudnn_engine_configs(backend_type, dtype, op_graph);
|
||||||
|
if (engine_configs.empty()) {
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
return find_cudnn_plan_from_engine_configs(handle, engine_configs, op_graph);
|
return find_cudnn_plan_from_engine_configs(handle, engine_configs, op_graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -257,8 +257,8 @@ struct Round {
|
|||||||
struct Sigmoid {
|
struct Sigmoid {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
T y = 1 / (1 + exp(-abs(x)));
|
T y = 1 / (1 + exp(abs(x)));
|
||||||
return (x < 0) ? 1 - y : y;
|
return (x < 0) ? y : 1 - y;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -1,284 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/common/unary.h"
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
|
||||||
#include "mlx/backend/cuda/device/unary_ops.cuh"
|
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
|
||||||
#include "mlx/dtype_utils.h"
|
|
||||||
#include "mlx/primitives.h"
|
|
||||||
|
|
||||||
#include <cooperative_groups.h>
|
|
||||||
#include <nvtx3/nvtx3.hpp>
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
namespace cu {
|
|
||||||
|
|
||||||
namespace cg = cooperative_groups;
|
|
||||||
|
|
||||||
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
|
||||||
__global__ void unary_v(const In* in, Out* out, IdxT size) {
|
|
||||||
IdxT index = cg::this_grid().thread_rank();
|
|
||||||
|
|
||||||
if ((index + 1) * N_READS > size) {
|
|
||||||
for (IdxT i = index * N_READS; i < size; ++i) {
|
|
||||||
out[i] = Op{}(in[i]);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
auto in_vec = load_vector<N_READS>(in, index);
|
|
||||||
|
|
||||||
AlignedVector<Out, N_READS> out_vec;
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < N_READS; ++i) {
|
|
||||||
out_vec[i] = Op{}(in_vec[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
store_vector<N_READS>(out, index, out_vec);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Op, typename In, typename Out, typename IdxT, int N_READS>
|
|
||||||
__global__ void unary_g(
|
|
||||||
const In* in,
|
|
||||||
Out* out,
|
|
||||||
IdxT size_rest,
|
|
||||||
const __grid_constant__ Shape shape,
|
|
||||||
const __grid_constant__ Strides strides,
|
|
||||||
int ndim) {
|
|
||||||
auto block = cg::this_thread_block();
|
|
||||||
auto grid = cg::this_grid();
|
|
||||||
IdxT index_rest =
|
|
||||||
grid.block_index().y * block.dim_threads().y + block.thread_index().y;
|
|
||||||
if (index_rest >= size_rest) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto shape_x = shape[ndim - 1];
|
|
||||||
auto stride_x = strides[ndim - 1];
|
|
||||||
IdxT index_x =
|
|
||||||
grid.block_index().x * block.dim_threads().x + block.thread_index().x;
|
|
||||||
auto idx =
|
|
||||||
elem_to_loc(index_rest * shape_x, shape.data(), strides.data(), ndim);
|
|
||||||
auto in_vec =
|
|
||||||
load_vector<N_READS>(in + idx, index_x, shape_x, stride_x, In(0));
|
|
||||||
AlignedVector<Out, N_READS> out_vec;
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < N_READS; ++i) {
|
|
||||||
out_vec[i] = Op{}(in_vec[i]);
|
|
||||||
}
|
|
||||||
store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Op, typename In, typename Out>
|
|
||||||
constexpr bool supports_unary_op() {
|
|
||||||
if (std::is_same_v<Op, Abs> || std::is_same_v<Op, Negative> ||
|
|
||||||
std::is_same_v<Op, Sign> || std::is_same_v<Op, Square>) {
|
|
||||||
return std::is_same_v<In, Out>;
|
|
||||||
}
|
|
||||||
if (std::is_same_v<Op, ArcCosh> || std::is_same_v<Op, ArcSinh> ||
|
|
||||||
std::is_same_v<Op, ArcTanh> || std::is_same_v<Op, Erf> ||
|
|
||||||
std::is_same_v<Op, ErfInv> || std::is_same_v<Op, Expm1> ||
|
|
||||||
std::is_same_v<Op, Sigmoid>) {
|
|
||||||
return std::is_same_v<In, Out> && is_floating_v<In>;
|
|
||||||
}
|
|
||||||
if (std::is_same_v<Op, BitwiseInvert>) {
|
|
||||||
return std::is_same_v<In, Out> && std::is_integral_v<In> &&
|
|
||||||
!std::is_same_v<In, bool>;
|
|
||||||
}
|
|
||||||
if (std::is_same_v<Op, Ceil> || std::is_same_v<Op, Floor>) {
|
|
||||||
return std::is_same_v<In, Out> && !mlx::core::is_complex_v<In>;
|
|
||||||
}
|
|
||||||
if (std::is_same_v<Op, Conjugate>) {
|
|
||||||
return std::is_same_v<In, Out> && mlx::core::is_complex_v<In>;
|
|
||||||
}
|
|
||||||
if (std::is_same_v<Op, ArcCos> || std::is_same_v<Op, ArcSin> ||
|
|
||||||
std::is_same_v<Op, ArcTan> || std::is_same_v<Op, Cos> ||
|
|
||||||
std::is_same_v<Op, Cosh> || std::is_same_v<Op, Exp> ||
|
|
||||||
std::is_same_v<Op, Log> || std::is_same_v<Op, Log2> ||
|
|
||||||
std::is_same_v<Op, Log10> || std::is_same_v<Op, Log1p> ||
|
|
||||||
std::is_same_v<Op, Round> || std::is_same_v<Op, Rsqrt> ||
|
|
||||||
std::is_same_v<Op, Sqrt> || std::is_same_v<Op, Sin> ||
|
|
||||||
std::is_same_v<Op, Sinh> || std::is_same_v<Op, Tan> ||
|
|
||||||
std::is_same_v<Op, Tanh>) {
|
|
||||||
return std::is_same_v<In, Out> && is_inexact_v<In>;
|
|
||||||
}
|
|
||||||
if (std::is_same_v<Op, Imag> || std::is_same_v<Op, Real>) {
|
|
||||||
return mlx::core::is_complex_v<In> && std::is_same_v<Out, float>;
|
|
||||||
}
|
|
||||||
if (std::is_same_v<Op, LogicalNot>) {
|
|
||||||
return std::is_same_v<In, Out> && std::is_same_v<In, bool>;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace cu
|
|
||||||
|
|
||||||
template <typename Op>
|
|
||||||
void unary_op_gpu_inplace(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
array& out,
|
|
||||||
const char* op,
|
|
||||||
const Stream& s) {
|
|
||||||
auto& in = inputs[0];
|
|
||||||
if (in.size() == 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
bool contig = in.flags().contiguous;
|
|
||||||
bool large;
|
|
||||||
if (!contig) {
|
|
||||||
large = in.data_size() > INT32_MAX || out.size() > INT32_MAX;
|
|
||||||
} else {
|
|
||||||
large = in.data_size() > UINT32_MAX;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
|
||||||
encoder.set_input_array(in);
|
|
||||||
encoder.set_output_array(out);
|
|
||||||
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
|
|
||||||
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
|
|
||||||
using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
|
|
||||||
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
|
|
||||||
if constexpr (cu::supports_unary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
|
||||||
dispatch_bool(large, [&](auto large) {
|
|
||||||
using InType = cuda_type_t<CTYPE_IN>;
|
|
||||||
using OutType = cuda_type_t<CTYPE_OUT>;
|
|
||||||
if (contig) {
|
|
||||||
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
|
||||||
constexpr int N_READS = 16 / sizeof(OutType);
|
|
||||||
auto [num_blocks, block_dims] = get_launch_args(
|
|
||||||
out.data_size(), out.shape(), out.strides(), large, N_READS);
|
|
||||||
encoder.add_kernel_node(
|
|
||||||
cu::unary_v<Op, InType, OutType, IdxT, N_READS>,
|
|
||||||
num_blocks,
|
|
||||||
block_dims,
|
|
||||||
0,
|
|
||||||
in.data<InType>(),
|
|
||||||
out.data<OutType>(),
|
|
||||||
out.data_size());
|
|
||||||
} else {
|
|
||||||
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
|
||||||
auto [shape, strides] = collapse_contiguous_dims(in);
|
|
||||||
auto ndim = shape.size();
|
|
||||||
int work_per_thread = 1;
|
|
||||||
auto kernel = cu::unary_g<Op, InType, OutType, IdxT, 1>;
|
|
||||||
auto dim0 = ndim > 0 ? shape.back() : 1;
|
|
||||||
auto rest = out.size() / dim0;
|
|
||||||
if (dim0 >= 4) {
|
|
||||||
kernel = cu::unary_g<Op, InType, OutType, IdxT, 4>;
|
|
||||||
work_per_thread = 4;
|
|
||||||
}
|
|
||||||
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
|
|
||||||
auto block_dims = get_block_dims(dim0, rest, 1);
|
|
||||||
uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x);
|
|
||||||
uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y);
|
|
||||||
encoder.add_kernel_node(
|
|
||||||
kernel,
|
|
||||||
{num_blocks_x, num_blocks_y},
|
|
||||||
block_dims,
|
|
||||||
0,
|
|
||||||
in.data<InType>(),
|
|
||||||
out.data<OutType>(),
|
|
||||||
rest,
|
|
||||||
const_param(shape),
|
|
||||||
const_param(strides),
|
|
||||||
ndim);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
throw std::runtime_error(fmt::format(
|
|
||||||
"Can not do unary op {} on input of {} with output of {}.",
|
|
||||||
op,
|
|
||||||
dtype_to_string(in.dtype()),
|
|
||||||
dtype_to_string(out.dtype())));
|
|
||||||
}
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Op>
|
|
||||||
void unary_op_gpu(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
array& out,
|
|
||||||
const char* op,
|
|
||||||
const Stream& s) {
|
|
||||||
set_unary_output_data(inputs[0], out);
|
|
||||||
unary_op_gpu_inplace<Op>(inputs, out, op, s);
|
|
||||||
}
|
|
||||||
|
|
||||||
#define UNARY_GPU(func) \
|
|
||||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
|
||||||
nvtx3::scoped_range r(#func "::eval_gpu"); \
|
|
||||||
auto& s = out.primitive().stream(); \
|
|
||||||
unary_op_gpu<cu::func>(inputs, out, name(), s); \
|
|
||||||
}
|
|
||||||
|
|
||||||
UNARY_GPU(Abs)
|
|
||||||
UNARY_GPU(ArcCos)
|
|
||||||
UNARY_GPU(ArcCosh)
|
|
||||||
UNARY_GPU(ArcSin)
|
|
||||||
UNARY_GPU(ArcSinh)
|
|
||||||
UNARY_GPU(ArcTan)
|
|
||||||
UNARY_GPU(ArcTanh)
|
|
||||||
UNARY_GPU(BitwiseInvert)
|
|
||||||
UNARY_GPU(Ceil)
|
|
||||||
UNARY_GPU(Conjugate)
|
|
||||||
UNARY_GPU(Cos)
|
|
||||||
UNARY_GPU(Cosh)
|
|
||||||
UNARY_GPU(Erf)
|
|
||||||
UNARY_GPU(ErfInv)
|
|
||||||
UNARY_GPU(Exp)
|
|
||||||
UNARY_GPU(Expm1)
|
|
||||||
UNARY_GPU(Floor)
|
|
||||||
UNARY_GPU(Imag)
|
|
||||||
UNARY_GPU(Log1p)
|
|
||||||
UNARY_GPU(LogicalNot)
|
|
||||||
UNARY_GPU(Negative)
|
|
||||||
UNARY_GPU(Real)
|
|
||||||
UNARY_GPU(Sigmoid)
|
|
||||||
UNARY_GPU(Sign)
|
|
||||||
UNARY_GPU(Sin)
|
|
||||||
UNARY_GPU(Sinh)
|
|
||||||
UNARY_GPU(Square)
|
|
||||||
UNARY_GPU(Tan)
|
|
||||||
UNARY_GPU(Tanh)
|
|
||||||
|
|
||||||
void Log::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
nvtx3::scoped_range r("Log::eval_gpu");
|
|
||||||
auto& s = out.primitive().stream();
|
|
||||||
switch (base_) {
|
|
||||||
case Base::e:
|
|
||||||
unary_op_gpu<cu::Log>(inputs, out, name(), s);
|
|
||||||
break;
|
|
||||||
case Base::two:
|
|
||||||
unary_op_gpu<cu::Log2>(inputs, out, name(), s);
|
|
||||||
break;
|
|
||||||
case Base::ten:
|
|
||||||
unary_op_gpu<cu::Log10>(inputs, out, name(), s);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
nvtx3::scoped_range r("Round::eval_gpu");
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
const auto& in = inputs[0];
|
|
||||||
auto& s = out.primitive().stream();
|
|
||||||
if (issubdtype(in.dtype(), inexact)) {
|
|
||||||
unary_op_gpu<cu::Round>(inputs, out, name(), s);
|
|
||||||
} else {
|
|
||||||
// No-op integer types
|
|
||||||
out.copy_shared_buffer(in);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Sqrt::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
nvtx3::scoped_range r("Sort::eval_gpu");
|
|
||||||
auto& s = out.primitive().stream();
|
|
||||||
if (recip_) {
|
|
||||||
unary_op_gpu<cu::Rsqrt>(inputs, out, "Rsqrt", s);
|
|
||||||
} else {
|
|
||||||
unary_op_gpu<cu::Sqrt>(inputs, out, "Sqrt", s);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -19,11 +19,28 @@ METAL_FUNC void thread_swap(thread T& a, thread T& b) {
|
|||||||
b = w;
|
b = w;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, typename = void>
|
||||||
|
struct Init {
|
||||||
|
static constexpr constant T v = Limits<T>::max;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct Init<T, metal::enable_if_t<metal::is_floating_point_v<T>>> {
|
||||||
|
static constexpr constant T v = metal::numeric_limits<T>::quiet_NaN();
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct LessThan {
|
struct LessThan {
|
||||||
static constexpr constant T init = Limits<T>::max;
|
static constexpr constant T init = Init<T>::v;
|
||||||
|
METAL_FUNC bool operator()(T a, T b) const {
|
||||||
METAL_FUNC bool operator()(T a, T b) {
|
if constexpr (
|
||||||
|
metal::is_floating_point_v<T> || metal::is_same_v<T, complex64_t>) {
|
||||||
|
bool an = isnan(a);
|
||||||
|
bool bn = isnan(b);
|
||||||
|
if (an | bn) {
|
||||||
|
return (!an) & bn;
|
||||||
|
}
|
||||||
|
}
|
||||||
return a < b;
|
return a < b;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -309,8 +309,8 @@ struct Round {
|
|||||||
struct Sigmoid {
|
struct Sigmoid {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T operator()(T x) {
|
T operator()(T x) {
|
||||||
auto y = 1 / (1 + metal::exp(-metal::abs(x)));
|
auto y = 1 / (1 + metal::exp(metal::abs(x)));
|
||||||
return (x < 0) ? 1 - y : y;
|
return (x < 0) ? y : 1 - y;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -21,6 +21,9 @@
|
|||||||
|
|
||||||
namespace mlx::core::distributed::nccl {
|
namespace mlx::core::distributed::nccl {
|
||||||
|
|
||||||
|
// Can be tuned with MLX_NCCL_TIMEOUT
|
||||||
|
constexpr int nccl_timeout = 300000; // miliseconds
|
||||||
|
|
||||||
#define CHECK_CUDA(cmd) \
|
#define CHECK_CUDA(cmd) \
|
||||||
do { \
|
do { \
|
||||||
cudaError_t e = cmd; \
|
cudaError_t e = cmd; \
|
||||||
@@ -181,8 +184,9 @@ inline void bootstrap_unique_id(
|
|||||||
close(sock);
|
close(sock);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
// Here just wanted to make show that rank 0 has enough time to bind
|
// Here we want to make sure that rank 0 has enough time to bind
|
||||||
// so we will retry to connect until max attempts
|
// so we will retry to connect until elapsed time exceeds nccl_timeout
|
||||||
|
// this is particularity important for multinode setup
|
||||||
|
|
||||||
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
int sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||||
if (sock < 0) {
|
if (sock < 0) {
|
||||||
@@ -200,32 +204,41 @@ inline void bootstrap_unique_id(
|
|||||||
memcpy(&serv.sin_addr, he->h_addr_list[0], he->h_length);
|
memcpy(&serv.sin_addr, he->h_addr_list[0], he->h_length);
|
||||||
serv.sin_port = htons(port);
|
serv.sin_port = htons(port);
|
||||||
|
|
||||||
const int max_retries = 30;
|
const int timeout_ms = env::nccl_timeout(nccl_timeout);
|
||||||
int attempt = 0;
|
|
||||||
bool connected = false;
|
bool connected = false;
|
||||||
|
|
||||||
bool do_log = std::getenv("NCCL_DEBUG") == "INFO";
|
const char* dbg = std::getenv("NCCL_DEBUG");
|
||||||
for (attempt = 0; attempt < max_retries; ++attempt) {
|
bool do_log = (dbg && std::string(dbg) == "INFO");
|
||||||
|
|
||||||
|
auto start = std::chrono::steady_clock::now();
|
||||||
|
int attempt = 0;
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
auto elapsed_ms = std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||||
|
std::chrono::steady_clock::now() - start)
|
||||||
|
.count();
|
||||||
|
if (elapsed_ms > timeout_ms)
|
||||||
|
break;
|
||||||
if (connect(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) ==
|
if (connect(sock, reinterpret_cast<sockaddr*>(&serv), sizeof(serv)) ==
|
||||||
0) {
|
0) {
|
||||||
connected = true;
|
connected = true;
|
||||||
if (do_log) {
|
if (do_log) {
|
||||||
std::cout << "[Rank " << rank
|
std::cout << "[Rank " << rank << "] Connected successfully after "
|
||||||
<< "] Connected successfully on attempt " << attempt + 1
|
<< elapsed_ms << " miliseconds" << std::endl;
|
||||||
<< std::endl;
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (errno != ECONNREFUSED) {
|
if (errno != ECONNREFUSED) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
++attempt;
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(500));
|
std::this_thread::sleep_for(std::chrono::milliseconds(500));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!connected) {
|
if (!connected) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[Rank " << rank << "] connect() failed after " << attempt
|
msg << "[Rank " << rank << "] connect() failed after " << timeout_ms
|
||||||
<< " retries: " << strerror(errno);
|
<< " milliseconds and " << attempt << " retries: " << strerror(errno);
|
||||||
close(sock);
|
close(sock);
|
||||||
throw std::runtime_error(msg.str());
|
throw std::runtime_error(msg.str());
|
||||||
}
|
}
|
||||||
@@ -256,7 +269,6 @@ class NCCLGroup : public GroupImpl {
|
|||||||
|
|
||||||
~NCCLGroup() {
|
~NCCLGroup() {
|
||||||
ncclCommDestroy(comm_);
|
ncclCommDestroy(comm_);
|
||||||
ncclGroupEnd();
|
|
||||||
initialized_ = false;
|
initialized_ = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -165,6 +165,11 @@ inline bool enable_tf32() {
|
|||||||
return enable_tf32_;
|
return enable_tf32_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline int nccl_timeout(int default_value) {
|
||||||
|
static int nccl_timeout = get_var("MLX_NCCL_TIMEOUT", default_value);
|
||||||
|
return nccl_timeout;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace env
|
} // namespace env
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
#define MLX_VERSION_MAJOR 0
|
#define MLX_VERSION_MAJOR 0
|
||||||
#define MLX_VERSION_MINOR 29
|
#define MLX_VERSION_MINOR 29
|
||||||
#define MLX_VERSION_PATCH 2
|
#define MLX_VERSION_PATCH 3
|
||||||
#define MLX_VERSION_NUMERIC \
|
#define MLX_VERSION_NUMERIC \
|
||||||
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
|
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
|
||||||
|
|
||||||
|
|||||||
@@ -712,6 +712,15 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
expected = beta * c + alpha * (a @ b)
|
expected = beta * c + alpha * (a @ b)
|
||||||
self.assertTrue(mx.allclose(expected, out))
|
self.assertTrue(mx.allclose(expected, out))
|
||||||
|
|
||||||
|
# Test half precision
|
||||||
|
for t, tol in [(mx.float16, 1e-3), (mx.bfloat16, 1e-2)]:
|
||||||
|
c = mx.ones((32, 32)).astype(t)
|
||||||
|
a = mx.random.uniform(shape=(32, 32)).astype(t)
|
||||||
|
b = mx.random.uniform(shape=(32, 32)).astype(t)
|
||||||
|
out = mx.addmm(c, a, b)
|
||||||
|
expected = a @ b + c
|
||||||
|
self.assertTrue(mx.allclose(out, expected, rtol=tol, atol=tol))
|
||||||
|
|
||||||
def test_addmm_grad(self):
|
def test_addmm_grad(self):
|
||||||
def make_ref_addmm(alpha, beta):
|
def make_ref_addmm(alpha, beta):
|
||||||
return lambda c, a, b: alpha * (a @ b) + beta * c
|
return lambda c, a, b: alpha * (a @ b) + beta * c
|
||||||
|
|||||||
@@ -1041,6 +1041,12 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
expected = 1 / (1 + np.exp(-a, dtype=np.float32))
|
expected = 1 / (1 + np.exp(-a, dtype=np.float32))
|
||||||
self.assertTrue(np.allclose(result, expected))
|
self.assertTrue(np.allclose(result, expected))
|
||||||
|
|
||||||
|
# Low precision
|
||||||
|
a = mx.array(-8.0).astype(mx.float16)
|
||||||
|
self.assertNotEqual(mx.sigmoid(a).item(), 0.0)
|
||||||
|
a = mx.array(8.0).astype(mx.float16)
|
||||||
|
self.assertNotEqual(mx.sigmoid(a).item(), 1.0)
|
||||||
|
|
||||||
def test_allclose(self):
|
def test_allclose(self):
|
||||||
a = mx.array(1.0)
|
a = mx.array(1.0)
|
||||||
b = mx.array(1.0)
|
b = mx.array(1.0)
|
||||||
@@ -3094,8 +3100,6 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
out = mx.depends(b, c)
|
out = mx.depends(b, c)
|
||||||
self.assertTrue(mx.array_equal(out, b))
|
self.assertTrue(mx.array_equal(out, b))
|
||||||
|
|
||||||
|
|
||||||
class TestBroadcast(mlx_tests.MLXTestCase):
|
|
||||||
def test_broadcast_shapes(self):
|
def test_broadcast_shapes(self):
|
||||||
# Basic broadcasting
|
# Basic broadcasting
|
||||||
self.assertEqual(mx.broadcast_shapes((1, 2, 3), (3,)), (1, 2, 3))
|
self.assertEqual(mx.broadcast_shapes((1, 2, 3), (3,)), (1, 2, 3))
|
||||||
@@ -3134,6 +3138,12 @@ class TestBroadcast(mlx_tests.MLXTestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
mx.broadcast_shapes()
|
mx.broadcast_shapes()
|
||||||
|
|
||||||
|
def test_sort_nan(self):
|
||||||
|
x = mx.array([3.0, mx.nan, 2.0, 0.0])
|
||||||
|
expected = mx.array([0.0, 2.0, 3.0, mx.nan])
|
||||||
|
self.assertTrue(mx.array_equal(mx.sort(x), expected, equal_nan=True))
|
||||||
|
x = mx.array([3.0, mx.nan, 2.0, 0.0]) + 1j * mx.array([1.0] * 4)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
mlx_tests.MLXTestRunner()
|
mlx_tests.MLXTestRunner()
|
||||||
|
|||||||
Reference in New Issue
Block a user