diff --git a/CMakeLists.txt b/CMakeLists.txt index 8a0e9dc2b..6bf6d6697 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -15,6 +15,7 @@ option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON) option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF) option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF) option(MLX_BUILD_METAL "Build metal backend" ON) +option(MLX_BUILD_CPU "Build cpu backend" ON) option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF) option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF) option(MLX_BUILD_GGUF "Include support for GGUF format" ON) @@ -112,49 +113,53 @@ elseif (MLX_BUILD_METAL) ${QUARTZ_LIB}) endif() -find_library(ACCELERATE_LIBRARY Accelerate) -if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY) - message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}") - set(MLX_BUILD_ACCELERATE ON) - target_link_libraries(mlx ${ACCELERATE_LIBRARY}) - add_compile_definitions(ACCELERATE_NEW_LAPACK) +if (MLX_BUILD_CPU) + find_library(ACCELERATE_LIBRARY Accelerate) + if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY) + message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}") + set(MLX_BUILD_ACCELERATE ON) + target_link_libraries(mlx ${ACCELERATE_LIBRARY}) + add_compile_definitions(ACCELERATE_NEW_LAPACK) + else() + message(STATUS "Accelerate or arm neon not found, using default backend.") + set(MLX_BUILD_ACCELERATE OFF) + if(${CMAKE_HOST_APPLE}) + # The blas shipped in macOS SDK is not supported, search homebrew for + # openblas instead. + set(BLA_VENDOR OpenBLAS) + set(LAPACK_ROOT "${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas") + endif() + # Search and link with lapack. + find_package(LAPACK REQUIRED) + if (NOT LAPACK_FOUND) + message(FATAL_ERROR "Must have LAPACK installed") + endif() + find_path(LAPACK_INCLUDE_DIRS lapacke.h + /usr/include + /usr/local/include + /usr/local/opt/openblas/include) + message(STATUS "Lapack lib " ${LAPACK_LIBRARIES}) + message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS}) + target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS}) + target_link_libraries(mlx ${LAPACK_LIBRARIES}) + # List blas after lapack otherwise we may accidentally incldue an old version + # of lapack.h from the include dirs of blas. + find_package(BLAS REQUIRED) + if (NOT BLAS_FOUND) + message(FATAL_ERROR "Must have BLAS installed") + endif() + # TODO find a cleaner way to do this + find_path(BLAS_INCLUDE_DIRS cblas.h + /usr/include + /usr/local/include + $ENV{BLAS_HOME}/include) + message(STATUS "Blas lib " ${BLAS_LIBRARIES}) + message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS}) + target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS}) + target_link_libraries(mlx ${BLAS_LIBRARIES}) + endif() else() - message(STATUS "Accelerate or arm neon not found, using default backend.") set(MLX_BUILD_ACCELERATE OFF) - if(${CMAKE_HOST_APPLE}) - # The blas shipped in macOS SDK is not supported, search homebrew for - # openblas instead. - set(BLA_VENDOR OpenBLAS) - set(LAPACK_ROOT "${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas") - endif() - # Search and link with lapack. - find_package(LAPACK REQUIRED) - if (NOT LAPACK_FOUND) - message(FATAL_ERROR "Must have LAPACK installed") - endif() - find_path(LAPACK_INCLUDE_DIRS lapacke.h - /usr/include - /usr/local/include - /usr/local/opt/openblas/include) - message(STATUS "Lapack lib " ${LAPACK_LIBRARIES}) - message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS}) - target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS}) - target_link_libraries(mlx ${LAPACK_LIBRARIES}) - # List blas after lapack otherwise we may accidentally incldue an old version - # of lapack.h from the include dirs of blas. - find_package(BLAS REQUIRED) - if (NOT BLAS_FOUND) - message(FATAL_ERROR "Must have BLAS installed") - endif() - # TODO find a cleaner way to do this - find_path(BLAS_INCLUDE_DIRS cblas.h - /usr/include - /usr/local/include - $ENV{BLAS_HOME}/include) - message(STATUS "Blas lib " ${BLAS_LIBRARIES}) - message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS}) - target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS}) - target_link_libraries(mlx ${BLAS_LIBRARIES}) endif() add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx) diff --git a/docs/src/install.rst b/docs/src/install.rst index 252b234e6..213e04f64 100644 --- a/docs/src/install.rst +++ b/docs/src/install.rst @@ -153,6 +153,8 @@ should point to the path to the built metal library. - OFF * - MLX_BUILD_METAL - ON + * - MLX_BUILD_CPU + - ON * - MLX_BUILD_PYTHON_BINDINGS - OFF * - MLX_METAL_DEBUG @@ -179,10 +181,28 @@ should point to the path to the built metal library. xcrun -sdk macosx --show-sdk-version +Binary Size Minimization +~~~~~~~~~~~~~~~~~~~~~~~~ + +To produce a smaller binary use the CMake flags `CMAKE_BUILD_TYPE=MinSizeRel` +and `BUILD_SHARED_LIBS=ON`. + +The MLX CMake build has several additional options to make smaller binaries. +For example, if you don't need the CPU backend or support for safetensors and +GGUF, you can do: + +```shell +cmake .. \ + -DCMAKE_BUILD_TYPE=MinSizeRel \ + -DBUILD_SHARED_LIBS=ON \ + -DMLX_BUILD_CPU=ON \ + -DMLX_BUILD_SAFETENSORS=OFF \ + -DMLX_BUILD_GGUF=OFF +``` + Troubleshooting ^^^^^^^^^^^^^^^ - Metal not found ~~~~~~~~~~~~~~~ diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index d2f021af5..c53c3ec7d 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -19,11 +19,16 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h ) -add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common) +if (MLX_BUILD_CPU) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common) +else() + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu) +endif() + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io) if (MLX_BUILD_ACCELERATE) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate) -else() +elseif(MLX_BUILD_CPU) target_sources( mlx PRIVATE diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index ea0babf18..3e9f87dfa 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -37,6 +37,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/common.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp ${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp diff --git a/mlx/backend/common/common.cpp b/mlx/backend/common/common.cpp new file mode 100644 index 000000000..e89ce7d6c --- /dev/null +++ b/mlx/backend/common/common.cpp @@ -0,0 +1,347 @@ +// Copyright © 2024 Apple Inc. +#include + +#include "mlx/backend/common/utils.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +void AsStrided::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + + auto& in = inputs[0]; + + if (!in.flags().row_contiguous) { + // Just ensuring that inputs[0] came from the ops which would ensure the + // input is row contiguous. + throw std::runtime_error( + "AsStrided must be used with row contiguous arrays only."); + } + + // Compute the flags given the shape and strides + bool row_contiguous = true, col_contiguous = true; + size_t r = 1, c = 1; + for (int i = strides_.size() - 1, j = 0; i >= 0; i--, j++) { + row_contiguous &= (r == strides_[i]) || (shape_[i] == 1); + col_contiguous &= (c == strides_[j]) || (shape_[j] == 1); + r *= shape_[i]; + c *= shape_[j]; + } + auto flags = in.flags(); + // TODO: Compute the contiguous flag in a better way cause now we are + // unnecessarily strict. + flags.contiguous = row_contiguous || col_contiguous; + flags.row_contiguous = row_contiguous; + flags.col_contiguous = col_contiguous; + + // There is no easy way to compute the actual data size so we use out.size(). + // The contiguous flag will almost certainly not be set so no code should + // rely on data_size anyway. + size_t data_size = out.size(); + + return out.copy_shared_buffer(in, strides_, flags, data_size, offset_); +} + +void Broadcast::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (out.size() == 0) { + out.set_data(nullptr); + return; + } + std::vector strides(out.ndim(), 0); + int diff = out.ndim() - in.ndim(); + for (int i = in.ndim() - 1; i >= 0; --i) { + strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i]; + } + auto flags = in.flags(); + if (out.size() > in.size()) { + flags.row_contiguous = flags.col_contiguous = false; + } + out.copy_shared_buffer(in, strides, flags, in.data_size()); +} + +void Copy::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + out.copy_shared_buffer(inputs[0]); +} + +void CustomVJP::eval( + const std::vector& inputs, + std::vector& outputs) { + assert(inputs.size() > outputs.size()); + for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size(); + i++, j++) { + outputs[i].copy_shared_buffer(inputs[j]); + } +} + +void Depends::eval( + const std::vector& inputs, + std::vector& outputs) { + assert(inputs.size() > outputs.size()); + for (int i = 0; i < outputs.size(); i++) { + outputs[i].copy_shared_buffer(inputs[i]); + } +} + +void NumberOfElements::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + double numel = 1; + for (auto ax : axes_) { + numel *= inputs[0].shape(ax); + } + + if (inverted_) { + numel = 1.0 / numel; + } + + switch (out.dtype()) { + case bool_: + *out.data() = static_cast(numel); + break; + case uint8: + *out.data() = static_cast(numel); + break; + case uint16: + *out.data() = static_cast(numel); + break; + case uint32: + *out.data() = static_cast(numel); + break; + case uint64: + *out.data() = static_cast(numel); + break; + case int8: + *out.data() = static_cast(numel); + break; + case int16: + *out.data() = static_cast(numel); + break; + case int32: + *out.data() = static_cast(numel); + break; + case int64: + *out.data() = static_cast(numel); + break; + case float16: + *out.data() = static_cast(numel); + break; + case float32: + *out.data() = static_cast(numel); + break; + case bfloat16: + *out.data() = static_cast(numel); + break; + case complex64: + *out.data() = static_cast(numel); + break; + } +} + +std::pair> Reshape::prepare_reshape( + const array& in, + const array& out) { + // Special case for empty arrays or row contiguous arrays + if (in.size() == 0 || in.flags().row_contiguous) { + return {false, out.strides()}; + } + + // Special case for scalars + if (in.ndim() == 0) { + std::vector out_strides(out.ndim(), 0); + return {false, out_strides}; + } + + // Firstly let's collapse all the contiguous dimensions of the input + auto [shape, _strides] = collapse_contiguous_dims(in); + auto& strides = _strides[0]; + + // If shapes fit exactly in the contiguous dims then no copy is necessary so + // let's check. + std::vector out_strides; + bool copy_necessary = false; + int j = 0; + for (int i = 0; i < out.ndim(); i++) { + int N = out.shape(i); + if (j < shape.size() && shape[j] % N == 0) { + shape[j] /= N; + out_strides.push_back(shape[j] * strides[j]); + j += (shape[j] == 1); + } else if (N == 1) { + // i > 0 because otherwise j < shape.size() && shape[j] % 1 == 0 + out_strides.push_back(out_strides.back()); + } else { + copy_necessary = true; + break; + } + } + + return {copy_necessary, out_strides}; +} + +void Reshape::shared_buffer_reshape( + const array& in, + const std::vector& out_strides, + array& out) { + auto flags = in.flags(); + if (flags.row_contiguous) { + // For row contiguous reshapes: + // - Shallow copy the buffer + // - If reshaping into a vector (all singleton dimensions except one) it + // becomes col contiguous again. + auto max_dim = std::max_element(out.shape().begin(), out.shape().end()); + flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim; + } + out.copy_shared_buffer(in, out_strides, flags, in.data_size()); +} + +void Split::eval( + const std::vector& inputs, + std::vector& outputs) { + assert(inputs.size() == 1); + + auto& in = inputs[0]; + + auto compute_new_flags = [](const auto& shape, + const auto& strides, + size_t in_data_size, + auto flags) { + size_t data_size = 1; + size_t f_stride = 1; + size_t b_stride = 1; + flags.row_contiguous = true; + flags.col_contiguous = true; + for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) { + flags.col_contiguous &= strides[i] == f_stride || shape[i] == 1; + flags.row_contiguous &= strides[ri] == b_stride || shape[ri] == 1; + f_stride *= shape[i]; + b_stride *= shape[ri]; + if (strides[i] > 0) { + data_size *= shape[i]; + } + } + + if (data_size == 1) { + // Broadcasted scalar array is contiguous. + flags.contiguous = true; + } else if (data_size == in_data_size) { + // Means we sliced a broadcasted dimension so leave the "no holes" flag + // alone. + } else { + // We sliced something. So either we are row or col contiguous or we + // punched a hole. + flags.contiguous &= flags.row_contiguous || flags.col_contiguous; + } + + return std::pair{flags, data_size}; + }; + + std::vector indices(1, 0); + indices.insert(indices.end(), indices_.begin(), indices_.end()); + for (int i = 0; i < indices.size(); i++) { + size_t offset = indices[i] * in.strides()[axis_]; + auto [new_flags, data_size] = compute_new_flags( + outputs[i].shape(), in.strides(), in.data_size(), in.flags()); + outputs[i].copy_shared_buffer( + in, in.strides(), new_flags, data_size, offset); + } +} + +std::tuple> Slice::prepare_slice( + const array& in) { + int64_t data_offset = 0; + bool copy_needed = false; + std::vector inp_strides(in.ndim(), 0); + for (int i = 0; i < in.ndim(); ++i) { + data_offset += start_indices_[i] * in.strides()[i]; + inp_strides[i] = in.strides()[i] * strides_[i]; + + copy_needed |= strides_[i] < 0; + } + + return std::make_tuple(copy_needed, data_offset, inp_strides); +} + +void Slice::shared_buffer_slice( + const array& in, + const std::vector& out_strides, + size_t data_offset, + array& out) { + // Compute row/col contiguity + auto [data_size, is_row_contiguous, is_col_contiguous] = + check_contiguity(out.shape(), out_strides); + + auto flags = in.flags(); + flags.row_contiguous = is_row_contiguous; + flags.col_contiguous = is_col_contiguous; + + if (data_size == 1) { + // Broadcasted scalar array is contiguous. + flags.contiguous = true; + } else if (data_size == in.data_size()) { + // Means we sliced a broadcasted dimension so leave the "no holes" flag + // alone. + } else { + // We sliced something. So either we are row or col contiguous or we + // punched a hole. + flags.contiguous &= flags.row_contiguous || flags.col_contiguous; + } + + out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset); +} + +std::tuple> SliceUpdate::prepare_slice( + const array& in) { + int64_t data_offset = 0; + std::vector inp_strides(in.ndim(), 0); + for (int i = 0; i < in.ndim(); ++i) { + data_offset += start_indices_[i] * in.strides()[i]; + inp_strides[i] = in.strides()[i] * strides_[i]; + } + + return std::make_tuple(data_offset, inp_strides); +} + +void StopGradient::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + out.copy_shared_buffer(inputs[0]); +} + +void Transpose::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + std::vector out_strides(out.ndim()); + auto& in = inputs[0]; + for (int ax = 0; ax < axes_.size(); ++ax) { + out_strides[ax] = in.strides()[axes_[ax]]; + } + + // Conditions for {row/col}_contiguous + // - array must be contiguous (no gaps) + // - underlying buffer size should have the same size as the array + // - cumulative product of shapes is equal to the strides (we can ignore axes + // with size == 1) + // - in the forward direction (column contiguous) + // - in the reverse direction (row contiguous) + // - vectors are both row and col contiguous (hence if both row/col are + // true, they stay true) + auto flags = in.flags(); + if (flags.contiguous && in.data_size() == in.size()) { + size_t f_stride = 1; + size_t b_stride = 1; + flags.col_contiguous = true; + flags.row_contiguous = true; + for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) { + flags.col_contiguous &= (out_strides[i] == f_stride || out.shape(i) == 1); + f_stride *= out.shape(i); + flags.row_contiguous &= + (out_strides[ri] == b_stride || out.shape(ri) == 1); + b_stride *= out.shape(ri); + } + } + out.copy_shared_buffer(in, out_strides, flags, in.data_size()); +} + +} // namespace mlx::core diff --git a/mlx/backend/common/inverse.cpp b/mlx/backend/common/inverse.cpp index 7c442342e..2dfc78d21 100644 --- a/mlx/backend/common/inverse.cpp +++ b/mlx/backend/common/inverse.cpp @@ -2,7 +2,6 @@ #include "mlx/allocator.h" #include "mlx/backend/common/copy.h" -#include "mlx/linalg.h" #include "mlx/primitives.h" #ifdef ACCELERATE_NEW_LAPACK @@ -93,12 +92,4 @@ void Inverse::eval(const std::vector& inputs, array& output) { inverse_impl(inputs[0], output); } -std::pair, std::vector> Inverse::vmap( - const std::vector& inputs, - const std::vector& axes) { - auto ax = axes[0] >= 0 ? 0 : -1; - auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0]; - return {{linalg::inv(a, stream())}, {ax}}; -} - } // namespace mlx::core diff --git a/mlx/backend/common/primitives.cpp b/mlx/backend/common/primitives.cpp index 442b09af0..1d1f66ce9 100644 --- a/mlx/backend/common/primitives.cpp +++ b/mlx/backend/common/primitives.cpp @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include #include @@ -113,61 +113,6 @@ void AsType::eval(const std::vector& inputs, array& out) { copy(in, out, ctype); } -void AsStrided::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - - auto& in = inputs[0]; - - if (!in.flags().row_contiguous) { - // Just ensuring that inputs[0] came from the ops which would ensure the - // input is row contiguous. - throw std::runtime_error( - "AsStrided must be used with row contiguous arrays only."); - } - - // Compute the flags given the shape and strides - bool row_contiguous = true, col_contiguous = true; - size_t r = 1, c = 1; - for (int i = strides_.size() - 1, j = 0; i >= 0; i--, j++) { - row_contiguous &= (r == strides_[i]) || (shape_[i] == 1); - col_contiguous &= (c == strides_[j]) || (shape_[j] == 1); - r *= shape_[i]; - c *= shape_[j]; - } - auto flags = in.flags(); - // TODO: Compute the contiguous flag in a better way cause now we are - // unnecessarily strict. - flags.contiguous = row_contiguous || col_contiguous; - flags.row_contiguous = row_contiguous; - flags.col_contiguous = col_contiguous; - - // There is no easy way to compute the actual data size so we use out.size(). - // The contiguous flag will almost certainly not be set so no code should - // rely on data_size anyway. - size_t data_size = out.size(); - - return out.copy_shared_buffer(in, strides_, flags, data_size, offset_); -} - -void Broadcast::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; - if (out.size() == 0) { - out.set_data(nullptr); - return; - } - std::vector strides(out.ndim(), 0); - int diff = out.ndim() - in.ndim(); - for (int i = in.ndim() - 1; i >= 0; --i) { - strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i]; - } - auto flags = in.flags(); - if (out.size() > in.size()) { - flags.row_contiguous = flags.col_contiguous = false; - } - out.copy_shared_buffer(in, strides, flags, in.data_size()); -} - void Ceil::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; @@ -214,11 +159,6 @@ void Conjugate::eval(const std::vector& inputs, array& out) { } } -void Copy::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - out.copy_shared_buffer(inputs[0]); -} - void Cos::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; @@ -243,81 +183,6 @@ void Cosh::eval(const std::vector& inputs, array& out) { } } -void CustomVJP::eval( - const std::vector& inputs, - std::vector& outputs) { - assert(inputs.size() > outputs.size()); - for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size(); - i++, j++) { - outputs[i].copy_shared_buffer(inputs[j]); - } -} - -void Depends::eval( - const std::vector& inputs, - std::vector& outputs) { - assert(inputs.size() > outputs.size()); - for (int i = 0; i < outputs.size(); i++) { - outputs[i].copy_shared_buffer(inputs[i]); - } -} - -void NumberOfElements::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - out.set_data(allocator::malloc_or_wait(out.nbytes())); - - double numel = 1; - for (auto ax : axes_) { - numel *= inputs[0].shape(ax); - } - - if (inverted_) { - numel = 1.0 / numel; - } - - switch (out.dtype()) { - case bool_: - *out.data() = static_cast(numel); - break; - case uint8: - *out.data() = static_cast(numel); - break; - case uint16: - *out.data() = static_cast(numel); - break; - case uint32: - *out.data() = static_cast(numel); - break; - case uint64: - *out.data() = static_cast(numel); - break; - case int8: - *out.data() = static_cast(numel); - break; - case int16: - *out.data() = static_cast(numel); - break; - case int32: - *out.data() = static_cast(numel); - break; - case int64: - *out.data() = static_cast(numel); - break; - case float16: - *out.data() = static_cast(numel); - break; - case float32: - *out.data() = static_cast(numel); - break; - case bfloat16: - *out.data() = static_cast(numel); - break; - case complex64: - *out.data() = static_cast(numel); - break; - } -} - void Erf::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; @@ -547,63 +412,6 @@ void RandomBits::eval(const std::vector& inputs, array& out) { } } -std::pair> Reshape::prepare_reshape( - const array& in, - const array& out) { - // Special case for empty arrays or row contiguous arrays - if (in.size() == 0 || in.flags().row_contiguous) { - return {false, out.strides()}; - } - - // Special case for scalars - if (in.ndim() == 0) { - std::vector out_strides(out.ndim(), 0); - return {false, out_strides}; - } - - // Firstly let's collapse all the contiguous dimensions of the input - auto [shape, _strides] = collapse_contiguous_dims(in); - auto& strides = _strides[0]; - - // If shapes fit exactly in the contiguous dims then no copy is necessary so - // let's check. - std::vector out_strides; - bool copy_necessary = false; - int j = 0; - for (int i = 0; i < out.ndim(); i++) { - int N = out.shape(i); - if (j < shape.size() && shape[j] % N == 0) { - shape[j] /= N; - out_strides.push_back(shape[j] * strides[j]); - j += (shape[j] == 1); - } else if (N == 1) { - // i > 0 because otherwise j < shape.size() && shape[j] % 1 == 0 - out_strides.push_back(out_strides.back()); - } else { - copy_necessary = true; - break; - } - } - - return {copy_necessary, out_strides}; -} - -void Reshape::shared_buffer_reshape( - const array& in, - const std::vector& out_strides, - array& out) { - auto flags = in.flags(); - if (flags.row_contiguous) { - // For row contiguous reshapes: - // - Shallow copy the buffer - // - If reshaping into a vector (all singleton dimensions except one) it - // becomes col contiguous again. - auto max_dim = std::max_element(out.shape().begin(), out.shape().end()); - flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim; - } - out.copy_shared_buffer(in, out_strides, flags, in.data_size()); -} - void Reshape::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; @@ -674,49 +482,6 @@ void Sinh::eval(const std::vector& inputs, array& out) { } } -std::tuple> Slice::prepare_slice( - const array& in) { - int64_t data_offset = 0; - bool copy_needed = false; - std::vector inp_strides(in.ndim(), 0); - for (int i = 0; i < in.ndim(); ++i) { - data_offset += start_indices_[i] * in.strides()[i]; - inp_strides[i] = in.strides()[i] * strides_[i]; - - copy_needed |= strides_[i] < 0; - } - - return std::make_tuple(copy_needed, data_offset, inp_strides); -} - -void Slice::shared_buffer_slice( - const array& in, - const std::vector& out_strides, - size_t data_offset, - array& out) { - // Compute row/col contiguity - auto [data_size, is_row_contiguous, is_col_contiguous] = - check_contiguity(out.shape(), out_strides); - - auto flags = in.flags(); - flags.row_contiguous = is_row_contiguous; - flags.col_contiguous = is_col_contiguous; - - if (data_size == 1) { - // Broadcasted scalar array is contiguous. - flags.contiguous = true; - } else if (data_size == in.data_size()) { - // Means we sliced a broadcasted dimension so leave the "no holes" flag - // alone. - } else { - // We sliced something. So either we are row or col contiguous or we - // punched a hole. - flags.contiguous &= flags.row_contiguous || flags.col_contiguous; - } - - out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset); -} - void Slice::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); if (out.size() == 0) { @@ -748,18 +513,6 @@ void Slice::eval(const std::vector& inputs, array& out) { } } -std::tuple> SliceUpdate::prepare_slice( - const array& in) { - int64_t data_offset = 0; - std::vector inp_strides(in.ndim(), 0); - for (int i = 0; i < in.ndim(); ++i) { - data_offset += start_indices_[i] * in.strides()[i]; - inp_strides[i] = in.strides()[i] * strides_[i]; - } - - return std::make_tuple(data_offset, inp_strides); -} - void SliceUpdate::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); if (out.size() == 0) { @@ -797,58 +550,6 @@ void SliceUpdate::eval(const std::vector& inputs, array& out) { /* CopyType ctype = */ CopyType::GeneralGeneral); } -void Split::eval( - const std::vector& inputs, - std::vector& outputs) { - assert(inputs.size() == 1); - - auto& in = inputs[0]; - - auto compute_new_flags = [](const auto& shape, - const auto& strides, - size_t in_data_size, - auto flags) { - size_t data_size = 1; - size_t f_stride = 1; - size_t b_stride = 1; - flags.row_contiguous = true; - flags.col_contiguous = true; - for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) { - flags.col_contiguous &= strides[i] == f_stride || shape[i] == 1; - flags.row_contiguous &= strides[ri] == b_stride || shape[ri] == 1; - f_stride *= shape[i]; - b_stride *= shape[ri]; - if (strides[i] > 0) { - data_size *= shape[i]; - } - } - - if (data_size == 1) { - // Broadcasted scalar array is contiguous. - flags.contiguous = true; - } else if (data_size == in_data_size) { - // Means we sliced a broadcasted dimension so leave the "no holes" flag - // alone. - } else { - // We sliced something. So either we are row or col contiguous or we - // punched a hole. - flags.contiguous &= flags.row_contiguous || flags.col_contiguous; - } - - return std::pair{flags, data_size}; - }; - - std::vector indices(1, 0); - indices.insert(indices.end(), indices_.begin(), indices_.end()); - for (int i = 0; i < indices.size(); i++) { - size_t offset = indices[i] * in.strides()[axis_]; - auto [new_flags, data_size] = compute_new_flags( - outputs[i].shape(), in.strides(), in.data_size(), in.flags()); - outputs[i].copy_shared_buffer( - in, in.strides(), new_flags, data_size, offset); - } -} - void Square::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; @@ -865,11 +566,6 @@ void Sqrt::eval(const std::vector& inputs, array& out) { } } -void StopGradient::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - out.copy_shared_buffer(inputs[0]); -} - void Tan::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; @@ -894,38 +590,4 @@ void Tanh::eval(const std::vector& inputs, array& out) { } } -void Transpose::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - std::vector out_strides(out.ndim()); - auto& in = inputs[0]; - for (int ax = 0; ax < axes_.size(); ++ax) { - out_strides[ax] = in.strides()[axes_[ax]]; - } - - // Conditions for {row/col}_contiguous - // - array must be contiguous (no gaps) - // - underlying buffer size should have the same size as the array - // - cumulative product of shapes is equal to the strides (we can ignore axes - // with size == 1) - // - in the forward direction (column contiguous) - // - in the reverse direction (row contiguous) - // - vectors are both row and col contiguous (hence if both row/col are - // true, they stay true) - auto flags = in.flags(); - if (flags.contiguous && in.data_size() == in.size()) { - size_t f_stride = 1; - size_t b_stride = 1; - flags.col_contiguous = true; - flags.row_contiguous = true; - for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) { - flags.col_contiguous &= (out_strides[i] == f_stride || out.shape(i) == 1); - f_stride *= out.shape(i); - flags.row_contiguous &= - (out_strides[ri] == b_stride || out.shape(ri) == 1); - b_stride *= out.shape(ri); - } - } - out.copy_shared_buffer(in, out_strides, flags, in.data_size()); -} - } // namespace mlx::core diff --git a/mlx/backend/common/svd.cpp b/mlx/backend/common/svd.cpp index 0b56339aa..412f06297 100644 --- a/mlx/backend/common/svd.cpp +++ b/mlx/backend/common/svd.cpp @@ -3,7 +3,6 @@ #include "mlx/allocator.h" #include "mlx/backend/common/copy.h" #include "mlx/backend/common/lapack_helper.h" -#include "mlx/linalg.h" #include "mlx/primitives.h" namespace mlx::core { @@ -145,12 +144,4 @@ void SVD::eval(const std::vector& inputs, std::vector& outputs) { svd_impl(inputs[0], outputs[0], outputs[1], outputs[2]); } -std::pair, std::vector> SVD::vmap( - const std::vector& inputs, - const std::vector& axes) { - auto ax = axes[0] >= 0 ? 0 : -1; - auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0]; - return {{linalg::svd(a, stream())}, {ax, ax, ax}}; -} - } // namespace mlx::core diff --git a/mlx/backend/no_cpu/CMakeLists.txt b/mlx/backend/no_cpu/CMakeLists.txt new file mode 100644 index 000000000..f3f6d4250 --- /dev/null +++ b/mlx/backend/no_cpu/CMakeLists.txt @@ -0,0 +1,9 @@ +target_sources( + mlx + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../common/load.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../common/common.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled_nocpu.cpp +) diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp new file mode 100644 index 000000000..6c9430f18 --- /dev/null +++ b/mlx/backend/no_cpu/primitives.cpp @@ -0,0 +1,108 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/primitives.h" + +#define NO_CPU_MULTI(func) \ + void func::eval_cpu( \ + const std::vector& inputs, std::vector& outputs) { \ + throw std::runtime_error(#func " has no CPU implementation."); \ + } + +#define NO_CPU(func) \ + void func::eval_cpu(const std::vector& inputs, array& out) { \ + throw std::runtime_error(#func " has no CPU implementation."); \ + } + +namespace mlx::core { + +NO_CPU(Abs) +NO_CPU(Add) +NO_CPU(AddMM) +NO_CPU(Arange) +NO_CPU(ArcCos) +NO_CPU(ArcCosh) +NO_CPU(ArcSin) +NO_CPU(ArcSinh) +NO_CPU(ArcTan) +NO_CPU(ArcTan2) +NO_CPU(ArcTanh) +NO_CPU(ArgPartition) +NO_CPU(ArgReduce) +NO_CPU(ArgSort) +NO_CPU(AsType) +NO_CPU(AsStrided) +NO_CPU(BitwiseBinary) +NO_CPU(BlockMaskedMM) +NO_CPU(BlockSparseMM) +NO_CPU(Broadcast) +NO_CPU(Ceil) +NO_CPU(Concatenate) +NO_CPU(Conjugate) +NO_CPU(Convolution) +NO_CPU(Copy) +NO_CPU(Cos) +NO_CPU(Cosh) +NO_CPU_MULTI(CustomVJP) +NO_CPU_MULTI(Depends) +NO_CPU(Divide) +NO_CPU_MULTI(DivMod) +NO_CPU(NumberOfElements) +NO_CPU(Remainder) +NO_CPU(Equal) +NO_CPU(Erf) +NO_CPU(ErfInv) +NO_CPU(Exp) +NO_CPU(Expm1) +NO_CPU(FFT) +NO_CPU(Floor) +NO_CPU(Full) +NO_CPU(Gather) +NO_CPU(Greater) +NO_CPU(GreaterEqual) +NO_CPU(Less) +NO_CPU(LessEqual) +NO_CPU(Load) +NO_CPU(Log) +NO_CPU(Log1p) +NO_CPU(LogicalNot) +NO_CPU(LogicalAnd) +NO_CPU(LogicalOr) +NO_CPU(LogAddExp) +NO_CPU(Matmul) +NO_CPU(Maximum) +NO_CPU(Minimum) +NO_CPU(Multiply) +NO_CPU(Negative) +NO_CPU(NotEqual) +NO_CPU(Pad) +NO_CPU(Partition) +NO_CPU(Power) +NO_CPU_MULTI(QRF) +NO_CPU(QuantizedMatmul) +NO_CPU(RandomBits) +NO_CPU(Reduce) +NO_CPU(Reshape) +NO_CPU(Round) +NO_CPU(Scan) +NO_CPU(Scatter) +NO_CPU(Select) +NO_CPU(Sigmoid) +NO_CPU(Sign) +NO_CPU(Sin) +NO_CPU(Sinh) +NO_CPU(Slice) +NO_CPU(SliceUpdate) +NO_CPU(Softmax) +NO_CPU(Sort) +NO_CPU_MULTI(Split) +NO_CPU(Square) +NO_CPU(Sqrt) +NO_CPU(StopGradient) +NO_CPU(Subtract) +NO_CPU_MULTI(SVD) +NO_CPU(Tan) +NO_CPU(Tanh) +NO_CPU(Transpose) +NO_CPU(Inverse) + +} // namespace mlx::core diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index b938c6afb..b5f144384 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -8,6 +8,7 @@ #include "mlx/backend/common/utils.h" #include "mlx/fft.h" +#include "mlx/linalg.h" #include "mlx/ops.h" #include "mlx/primitives.h" #include "mlx/utils.h" @@ -3578,4 +3579,20 @@ bool NumberOfElements::is_equivalent(const Primitive& other) const { dtype_ == n_other.dtype_; } +std::pair, std::vector> SVD::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto ax = axes[0] >= 0 ? 0 : -1; + auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0]; + return {{linalg::svd(a, stream())}, {ax, ax, ax}}; +} + +std::pair, std::vector> Inverse::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto ax = axes[0] >= 0 ? 0 : -1; + auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0]; + return {{linalg::inv(a, stream())}, {ax}}; +} + } // namespace mlx::core