mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 09:51:17 +08:00
No CPU option for binary minimization (#1105)
* no cpu build option * docs * fix
This commit is contained in:
parent
e7f9710499
commit
7178ac0111
@ -15,6 +15,7 @@ option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON)
|
|||||||
option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
|
option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
|
||||||
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
||||||
option(MLX_BUILD_METAL "Build metal backend" ON)
|
option(MLX_BUILD_METAL "Build metal backend" ON)
|
||||||
|
option(MLX_BUILD_CPU "Build cpu backend" ON)
|
||||||
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
||||||
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
||||||
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
||||||
@ -112,49 +113,53 @@ elseif (MLX_BUILD_METAL)
|
|||||||
${QUARTZ_LIB})
|
${QUARTZ_LIB})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
find_library(ACCELERATE_LIBRARY Accelerate)
|
if (MLX_BUILD_CPU)
|
||||||
if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
|
find_library(ACCELERATE_LIBRARY Accelerate)
|
||||||
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
|
||||||
set(MLX_BUILD_ACCELERATE ON)
|
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
||||||
target_link_libraries(mlx ${ACCELERATE_LIBRARY})
|
set(MLX_BUILD_ACCELERATE ON)
|
||||||
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
target_link_libraries(mlx ${ACCELERATE_LIBRARY})
|
||||||
|
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
||||||
|
else()
|
||||||
|
message(STATUS "Accelerate or arm neon not found, using default backend.")
|
||||||
|
set(MLX_BUILD_ACCELERATE OFF)
|
||||||
|
if(${CMAKE_HOST_APPLE})
|
||||||
|
# The blas shipped in macOS SDK is not supported, search homebrew for
|
||||||
|
# openblas instead.
|
||||||
|
set(BLA_VENDOR OpenBLAS)
|
||||||
|
set(LAPACK_ROOT "${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
|
||||||
|
endif()
|
||||||
|
# Search and link with lapack.
|
||||||
|
find_package(LAPACK REQUIRED)
|
||||||
|
if (NOT LAPACK_FOUND)
|
||||||
|
message(FATAL_ERROR "Must have LAPACK installed")
|
||||||
|
endif()
|
||||||
|
find_path(LAPACK_INCLUDE_DIRS lapacke.h
|
||||||
|
/usr/include
|
||||||
|
/usr/local/include
|
||||||
|
/usr/local/opt/openblas/include)
|
||||||
|
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
|
||||||
|
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
|
||||||
|
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
|
||||||
|
target_link_libraries(mlx ${LAPACK_LIBRARIES})
|
||||||
|
# List blas after lapack otherwise we may accidentally incldue an old version
|
||||||
|
# of lapack.h from the include dirs of blas.
|
||||||
|
find_package(BLAS REQUIRED)
|
||||||
|
if (NOT BLAS_FOUND)
|
||||||
|
message(FATAL_ERROR "Must have BLAS installed")
|
||||||
|
endif()
|
||||||
|
# TODO find a cleaner way to do this
|
||||||
|
find_path(BLAS_INCLUDE_DIRS cblas.h
|
||||||
|
/usr/include
|
||||||
|
/usr/local/include
|
||||||
|
$ENV{BLAS_HOME}/include)
|
||||||
|
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
|
||||||
|
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
|
||||||
|
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
|
||||||
|
target_link_libraries(mlx ${BLAS_LIBRARIES})
|
||||||
|
endif()
|
||||||
else()
|
else()
|
||||||
message(STATUS "Accelerate or arm neon not found, using default backend.")
|
|
||||||
set(MLX_BUILD_ACCELERATE OFF)
|
set(MLX_BUILD_ACCELERATE OFF)
|
||||||
if(${CMAKE_HOST_APPLE})
|
|
||||||
# The blas shipped in macOS SDK is not supported, search homebrew for
|
|
||||||
# openblas instead.
|
|
||||||
set(BLA_VENDOR OpenBLAS)
|
|
||||||
set(LAPACK_ROOT "${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
|
|
||||||
endif()
|
|
||||||
# Search and link with lapack.
|
|
||||||
find_package(LAPACK REQUIRED)
|
|
||||||
if (NOT LAPACK_FOUND)
|
|
||||||
message(FATAL_ERROR "Must have LAPACK installed")
|
|
||||||
endif()
|
|
||||||
find_path(LAPACK_INCLUDE_DIRS lapacke.h
|
|
||||||
/usr/include
|
|
||||||
/usr/local/include
|
|
||||||
/usr/local/opt/openblas/include)
|
|
||||||
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
|
|
||||||
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
|
|
||||||
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
|
|
||||||
target_link_libraries(mlx ${LAPACK_LIBRARIES})
|
|
||||||
# List blas after lapack otherwise we may accidentally incldue an old version
|
|
||||||
# of lapack.h from the include dirs of blas.
|
|
||||||
find_package(BLAS REQUIRED)
|
|
||||||
if (NOT BLAS_FOUND)
|
|
||||||
message(FATAL_ERROR "Must have BLAS installed")
|
|
||||||
endif()
|
|
||||||
# TODO find a cleaner way to do this
|
|
||||||
find_path(BLAS_INCLUDE_DIRS cblas.h
|
|
||||||
/usr/include
|
|
||||||
/usr/local/include
|
|
||||||
$ENV{BLAS_HOME}/include)
|
|
||||||
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
|
|
||||||
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
|
|
||||||
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
|
|
||||||
target_link_libraries(mlx ${BLAS_LIBRARIES})
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
|
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
|
||||||
|
@ -153,6 +153,8 @@ should point to the path to the built metal library.
|
|||||||
- OFF
|
- OFF
|
||||||
* - MLX_BUILD_METAL
|
* - MLX_BUILD_METAL
|
||||||
- ON
|
- ON
|
||||||
|
* - MLX_BUILD_CPU
|
||||||
|
- ON
|
||||||
* - MLX_BUILD_PYTHON_BINDINGS
|
* - MLX_BUILD_PYTHON_BINDINGS
|
||||||
- OFF
|
- OFF
|
||||||
* - MLX_METAL_DEBUG
|
* - MLX_METAL_DEBUG
|
||||||
@ -179,10 +181,28 @@ should point to the path to the built metal library.
|
|||||||
|
|
||||||
xcrun -sdk macosx --show-sdk-version
|
xcrun -sdk macosx --show-sdk-version
|
||||||
|
|
||||||
|
Binary Size Minimization
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
To produce a smaller binary use the CMake flags `CMAKE_BUILD_TYPE=MinSizeRel`
|
||||||
|
and `BUILD_SHARED_LIBS=ON`.
|
||||||
|
|
||||||
|
The MLX CMake build has several additional options to make smaller binaries.
|
||||||
|
For example, if you don't need the CPU backend or support for safetensors and
|
||||||
|
GGUF, you can do:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
cmake .. \
|
||||||
|
-DCMAKE_BUILD_TYPE=MinSizeRel \
|
||||||
|
-DBUILD_SHARED_LIBS=ON \
|
||||||
|
-DMLX_BUILD_CPU=ON \
|
||||||
|
-DMLX_BUILD_SAFETENSORS=OFF \
|
||||||
|
-DMLX_BUILD_GGUF=OFF
|
||||||
|
```
|
||||||
|
|
||||||
Troubleshooting
|
Troubleshooting
|
||||||
^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
|
||||||
Metal not found
|
Metal not found
|
||||||
~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
@ -19,11 +19,16 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h
|
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h
|
||||||
)
|
)
|
||||||
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
|
if (MLX_BUILD_CPU)
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
|
||||||
|
else()
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu)
|
||||||
|
endif()
|
||||||
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
||||||
if (MLX_BUILD_ACCELERATE)
|
if (MLX_BUILD_ACCELERATE)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
|
||||||
else()
|
elseif(MLX_BUILD_CPU)
|
||||||
target_sources(
|
target_sources(
|
||||||
mlx
|
mlx
|
||||||
PRIVATE
|
PRIVATE
|
||||||
|
@ -37,6 +37,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
|
||||||
|
347
mlx/backend/common/common.cpp
Normal file
347
mlx/backend/common/common.cpp
Normal file
@ -0,0 +1,347 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void AsStrided::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
|
||||||
|
auto& in = inputs[0];
|
||||||
|
|
||||||
|
if (!in.flags().row_contiguous) {
|
||||||
|
// Just ensuring that inputs[0] came from the ops which would ensure the
|
||||||
|
// input is row contiguous.
|
||||||
|
throw std::runtime_error(
|
||||||
|
"AsStrided must be used with row contiguous arrays only.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the flags given the shape and strides
|
||||||
|
bool row_contiguous = true, col_contiguous = true;
|
||||||
|
size_t r = 1, c = 1;
|
||||||
|
for (int i = strides_.size() - 1, j = 0; i >= 0; i--, j++) {
|
||||||
|
row_contiguous &= (r == strides_[i]) || (shape_[i] == 1);
|
||||||
|
col_contiguous &= (c == strides_[j]) || (shape_[j] == 1);
|
||||||
|
r *= shape_[i];
|
||||||
|
c *= shape_[j];
|
||||||
|
}
|
||||||
|
auto flags = in.flags();
|
||||||
|
// TODO: Compute the contiguous flag in a better way cause now we are
|
||||||
|
// unnecessarily strict.
|
||||||
|
flags.contiguous = row_contiguous || col_contiguous;
|
||||||
|
flags.row_contiguous = row_contiguous;
|
||||||
|
flags.col_contiguous = col_contiguous;
|
||||||
|
|
||||||
|
// There is no easy way to compute the actual data size so we use out.size().
|
||||||
|
// The contiguous flag will almost certainly not be set so no code should
|
||||||
|
// rely on data_size anyway.
|
||||||
|
size_t data_size = out.size();
|
||||||
|
|
||||||
|
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
if (out.size() == 0) {
|
||||||
|
out.set_data(nullptr);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
std::vector<size_t> strides(out.ndim(), 0);
|
||||||
|
int diff = out.ndim() - in.ndim();
|
||||||
|
for (int i = in.ndim() - 1; i >= 0; --i) {
|
||||||
|
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
|
||||||
|
}
|
||||||
|
auto flags = in.flags();
|
||||||
|
if (out.size() > in.size()) {
|
||||||
|
flags.row_contiguous = flags.col_contiguous = false;
|
||||||
|
}
|
||||||
|
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
||||||
|
}
|
||||||
|
|
||||||
|
void Copy::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
out.copy_shared_buffer(inputs[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CustomVJP::eval(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
assert(inputs.size() > outputs.size());
|
||||||
|
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
|
||||||
|
i++, j++) {
|
||||||
|
outputs[i].copy_shared_buffer(inputs[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Depends::eval(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
assert(inputs.size() > outputs.size());
|
||||||
|
for (int i = 0; i < outputs.size(); i++) {
|
||||||
|
outputs[i].copy_shared_buffer(inputs[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
|
double numel = 1;
|
||||||
|
for (auto ax : axes_) {
|
||||||
|
numel *= inputs[0].shape(ax);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (inverted_) {
|
||||||
|
numel = 1.0 / numel;
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (out.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
*out.data<bool>() = static_cast<bool>(numel);
|
||||||
|
break;
|
||||||
|
case uint8:
|
||||||
|
*out.data<uint8_t>() = static_cast<uint8_t>(numel);
|
||||||
|
break;
|
||||||
|
case uint16:
|
||||||
|
*out.data<uint16_t>() = static_cast<uint16_t>(numel);
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
*out.data<uint32_t>() = static_cast<uint32_t>(numel);
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
*out.data<uint64_t>() = static_cast<uint64_t>(numel);
|
||||||
|
break;
|
||||||
|
case int8:
|
||||||
|
*out.data<int8_t>() = static_cast<int8_t>(numel);
|
||||||
|
break;
|
||||||
|
case int16:
|
||||||
|
*out.data<int16_t>() = static_cast<int16_t>(numel);
|
||||||
|
break;
|
||||||
|
case int32:
|
||||||
|
*out.data<int32_t>() = static_cast<int32_t>(numel);
|
||||||
|
break;
|
||||||
|
case int64:
|
||||||
|
*out.data<int64_t>() = static_cast<int64_t>(numel);
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
*out.data<float16_t>() = static_cast<float16_t>(numel);
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
*out.data<float>() = static_cast<float>(numel);
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
*out.data<bfloat16_t>() = static_cast<bfloat16_t>(numel);
|
||||||
|
break;
|
||||||
|
case complex64:
|
||||||
|
*out.data<complex64_t>() = static_cast<complex64_t>(numel);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
|
||||||
|
const array& in,
|
||||||
|
const array& out) {
|
||||||
|
// Special case for empty arrays or row contiguous arrays
|
||||||
|
if (in.size() == 0 || in.flags().row_contiguous) {
|
||||||
|
return {false, out.strides()};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Special case for scalars
|
||||||
|
if (in.ndim() == 0) {
|
||||||
|
std::vector<size_t> out_strides(out.ndim(), 0);
|
||||||
|
return {false, out_strides};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Firstly let's collapse all the contiguous dimensions of the input
|
||||||
|
auto [shape, _strides] = collapse_contiguous_dims(in);
|
||||||
|
auto& strides = _strides[0];
|
||||||
|
|
||||||
|
// If shapes fit exactly in the contiguous dims then no copy is necessary so
|
||||||
|
// let's check.
|
||||||
|
std::vector<size_t> out_strides;
|
||||||
|
bool copy_necessary = false;
|
||||||
|
int j = 0;
|
||||||
|
for (int i = 0; i < out.ndim(); i++) {
|
||||||
|
int N = out.shape(i);
|
||||||
|
if (j < shape.size() && shape[j] % N == 0) {
|
||||||
|
shape[j] /= N;
|
||||||
|
out_strides.push_back(shape[j] * strides[j]);
|
||||||
|
j += (shape[j] == 1);
|
||||||
|
} else if (N == 1) {
|
||||||
|
// i > 0 because otherwise j < shape.size() && shape[j] % 1 == 0
|
||||||
|
out_strides.push_back(out_strides.back());
|
||||||
|
} else {
|
||||||
|
copy_necessary = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {copy_necessary, out_strides};
|
||||||
|
}
|
||||||
|
|
||||||
|
void Reshape::shared_buffer_reshape(
|
||||||
|
const array& in,
|
||||||
|
const std::vector<size_t>& out_strides,
|
||||||
|
array& out) {
|
||||||
|
auto flags = in.flags();
|
||||||
|
if (flags.row_contiguous) {
|
||||||
|
// For row contiguous reshapes:
|
||||||
|
// - Shallow copy the buffer
|
||||||
|
// - If reshaping into a vector (all singleton dimensions except one) it
|
||||||
|
// becomes col contiguous again.
|
||||||
|
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
|
||||||
|
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
|
||||||
|
}
|
||||||
|
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
||||||
|
}
|
||||||
|
|
||||||
|
void Split::eval(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
|
||||||
|
auto& in = inputs[0];
|
||||||
|
|
||||||
|
auto compute_new_flags = [](const auto& shape,
|
||||||
|
const auto& strides,
|
||||||
|
size_t in_data_size,
|
||||||
|
auto flags) {
|
||||||
|
size_t data_size = 1;
|
||||||
|
size_t f_stride = 1;
|
||||||
|
size_t b_stride = 1;
|
||||||
|
flags.row_contiguous = true;
|
||||||
|
flags.col_contiguous = true;
|
||||||
|
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
|
||||||
|
flags.col_contiguous &= strides[i] == f_stride || shape[i] == 1;
|
||||||
|
flags.row_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
|
||||||
|
f_stride *= shape[i];
|
||||||
|
b_stride *= shape[ri];
|
||||||
|
if (strides[i] > 0) {
|
||||||
|
data_size *= shape[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (data_size == 1) {
|
||||||
|
// Broadcasted scalar array is contiguous.
|
||||||
|
flags.contiguous = true;
|
||||||
|
} else if (data_size == in_data_size) {
|
||||||
|
// Means we sliced a broadcasted dimension so leave the "no holes" flag
|
||||||
|
// alone.
|
||||||
|
} else {
|
||||||
|
// We sliced something. So either we are row or col contiguous or we
|
||||||
|
// punched a hole.
|
||||||
|
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::pair<decltype(flags), size_t>{flags, data_size};
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<int> indices(1, 0);
|
||||||
|
indices.insert(indices.end(), indices_.begin(), indices_.end());
|
||||||
|
for (int i = 0; i < indices.size(); i++) {
|
||||||
|
size_t offset = indices[i] * in.strides()[axis_];
|
||||||
|
auto [new_flags, data_size] = compute_new_flags(
|
||||||
|
outputs[i].shape(), in.strides(), in.data_size(), in.flags());
|
||||||
|
outputs[i].copy_shared_buffer(
|
||||||
|
in, in.strides(), new_flags, data_size, offset);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<bool, int64_t, std::vector<int64_t>> Slice::prepare_slice(
|
||||||
|
const array& in) {
|
||||||
|
int64_t data_offset = 0;
|
||||||
|
bool copy_needed = false;
|
||||||
|
std::vector<int64_t> inp_strides(in.ndim(), 0);
|
||||||
|
for (int i = 0; i < in.ndim(); ++i) {
|
||||||
|
data_offset += start_indices_[i] * in.strides()[i];
|
||||||
|
inp_strides[i] = in.strides()[i] * strides_[i];
|
||||||
|
|
||||||
|
copy_needed |= strides_[i] < 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_tuple(copy_needed, data_offset, inp_strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Slice::shared_buffer_slice(
|
||||||
|
const array& in,
|
||||||
|
const std::vector<size_t>& out_strides,
|
||||||
|
size_t data_offset,
|
||||||
|
array& out) {
|
||||||
|
// Compute row/col contiguity
|
||||||
|
auto [data_size, is_row_contiguous, is_col_contiguous] =
|
||||||
|
check_contiguity(out.shape(), out_strides);
|
||||||
|
|
||||||
|
auto flags = in.flags();
|
||||||
|
flags.row_contiguous = is_row_contiguous;
|
||||||
|
flags.col_contiguous = is_col_contiguous;
|
||||||
|
|
||||||
|
if (data_size == 1) {
|
||||||
|
// Broadcasted scalar array is contiguous.
|
||||||
|
flags.contiguous = true;
|
||||||
|
} else if (data_size == in.data_size()) {
|
||||||
|
// Means we sliced a broadcasted dimension so leave the "no holes" flag
|
||||||
|
// alone.
|
||||||
|
} else {
|
||||||
|
// We sliced something. So either we are row or col contiguous or we
|
||||||
|
// punched a hole.
|
||||||
|
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
|
||||||
|
}
|
||||||
|
|
||||||
|
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
|
||||||
|
const array& in) {
|
||||||
|
int64_t data_offset = 0;
|
||||||
|
std::vector<int64_t> inp_strides(in.ndim(), 0);
|
||||||
|
for (int i = 0; i < in.ndim(); ++i) {
|
||||||
|
data_offset += start_indices_[i] * in.strides()[i];
|
||||||
|
inp_strides[i] = in.strides()[i] * strides_[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_tuple(data_offset, inp_strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
out.copy_shared_buffer(inputs[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
std::vector<size_t> out_strides(out.ndim());
|
||||||
|
auto& in = inputs[0];
|
||||||
|
for (int ax = 0; ax < axes_.size(); ++ax) {
|
||||||
|
out_strides[ax] = in.strides()[axes_[ax]];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Conditions for {row/col}_contiguous
|
||||||
|
// - array must be contiguous (no gaps)
|
||||||
|
// - underlying buffer size should have the same size as the array
|
||||||
|
// - cumulative product of shapes is equal to the strides (we can ignore axes
|
||||||
|
// with size == 1)
|
||||||
|
// - in the forward direction (column contiguous)
|
||||||
|
// - in the reverse direction (row contiguous)
|
||||||
|
// - vectors are both row and col contiguous (hence if both row/col are
|
||||||
|
// true, they stay true)
|
||||||
|
auto flags = in.flags();
|
||||||
|
if (flags.contiguous && in.data_size() == in.size()) {
|
||||||
|
size_t f_stride = 1;
|
||||||
|
size_t b_stride = 1;
|
||||||
|
flags.col_contiguous = true;
|
||||||
|
flags.row_contiguous = true;
|
||||||
|
for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) {
|
||||||
|
flags.col_contiguous &= (out_strides[i] == f_stride || out.shape(i) == 1);
|
||||||
|
f_stride *= out.shape(i);
|
||||||
|
flags.row_contiguous &=
|
||||||
|
(out_strides[ri] == b_stride || out.shape(ri) == 1);
|
||||||
|
b_stride *= out.shape(ri);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/backend/common/copy.h"
|
#include "mlx/backend/common/copy.h"
|
||||||
#include "mlx/linalg.h"
|
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
#ifdef ACCELERATE_NEW_LAPACK
|
#ifdef ACCELERATE_NEW_LAPACK
|
||||||
@ -93,12 +92,4 @@ void Inverse::eval(const std::vector<array>& inputs, array& output) {
|
|||||||
inverse_impl(inputs[0], output);
|
inverse_impl(inputs[0], output);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<std::vector<array>, std::vector<int>> Inverse::vmap(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
const std::vector<int>& axes) {
|
|
||||||
auto ax = axes[0] >= 0 ? 0 : -1;
|
|
||||||
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
|
|
||||||
return {{linalg::inv(a, stream())}, {ax}};
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
@ -113,61 +113,6 @@ void AsType::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
copy(in, out, ctype);
|
copy(in, out, ctype);
|
||||||
}
|
}
|
||||||
|
|
||||||
void AsStrided::eval(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
|
|
||||||
auto& in = inputs[0];
|
|
||||||
|
|
||||||
if (!in.flags().row_contiguous) {
|
|
||||||
// Just ensuring that inputs[0] came from the ops which would ensure the
|
|
||||||
// input is row contiguous.
|
|
||||||
throw std::runtime_error(
|
|
||||||
"AsStrided must be used with row contiguous arrays only.");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute the flags given the shape and strides
|
|
||||||
bool row_contiguous = true, col_contiguous = true;
|
|
||||||
size_t r = 1, c = 1;
|
|
||||||
for (int i = strides_.size() - 1, j = 0; i >= 0; i--, j++) {
|
|
||||||
row_contiguous &= (r == strides_[i]) || (shape_[i] == 1);
|
|
||||||
col_contiguous &= (c == strides_[j]) || (shape_[j] == 1);
|
|
||||||
r *= shape_[i];
|
|
||||||
c *= shape_[j];
|
|
||||||
}
|
|
||||||
auto flags = in.flags();
|
|
||||||
// TODO: Compute the contiguous flag in a better way cause now we are
|
|
||||||
// unnecessarily strict.
|
|
||||||
flags.contiguous = row_contiguous || col_contiguous;
|
|
||||||
flags.row_contiguous = row_contiguous;
|
|
||||||
flags.col_contiguous = col_contiguous;
|
|
||||||
|
|
||||||
// There is no easy way to compute the actual data size so we use out.size().
|
|
||||||
// The contiguous flag will almost certainly not be set so no code should
|
|
||||||
// rely on data_size anyway.
|
|
||||||
size_t data_size = out.size();
|
|
||||||
|
|
||||||
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
const auto& in = inputs[0];
|
|
||||||
if (out.size() == 0) {
|
|
||||||
out.set_data(nullptr);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
std::vector<size_t> strides(out.ndim(), 0);
|
|
||||||
int diff = out.ndim() - in.ndim();
|
|
||||||
for (int i = in.ndim() - 1; i >= 0; --i) {
|
|
||||||
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
|
|
||||||
}
|
|
||||||
auto flags = in.flags();
|
|
||||||
if (out.size() > in.size()) {
|
|
||||||
flags.row_contiguous = flags.col_contiguous = false;
|
|
||||||
}
|
|
||||||
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
|
||||||
}
|
|
||||||
|
|
||||||
void Ceil::eval(const std::vector<array>& inputs, array& out) {
|
void Ceil::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
@ -214,11 +159,6 @@ void Conjugate::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Copy::eval(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
out.copy_shared_buffer(inputs[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Cos::eval(const std::vector<array>& inputs, array& out) {
|
void Cos::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
@ -243,81 +183,6 @@ void Cosh::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void CustomVJP::eval(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
std::vector<array>& outputs) {
|
|
||||||
assert(inputs.size() > outputs.size());
|
|
||||||
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
|
|
||||||
i++, j++) {
|
|
||||||
outputs[i].copy_shared_buffer(inputs[j]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Depends::eval(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
std::vector<array>& outputs) {
|
|
||||||
assert(inputs.size() > outputs.size());
|
|
||||||
for (int i = 0; i < outputs.size(); i++) {
|
|
||||||
outputs[i].copy_shared_buffer(inputs[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
|
||||||
|
|
||||||
double numel = 1;
|
|
||||||
for (auto ax : axes_) {
|
|
||||||
numel *= inputs[0].shape(ax);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (inverted_) {
|
|
||||||
numel = 1.0 / numel;
|
|
||||||
}
|
|
||||||
|
|
||||||
switch (out.dtype()) {
|
|
||||||
case bool_:
|
|
||||||
*out.data<bool>() = static_cast<bool>(numel);
|
|
||||||
break;
|
|
||||||
case uint8:
|
|
||||||
*out.data<uint8_t>() = static_cast<uint8_t>(numel);
|
|
||||||
break;
|
|
||||||
case uint16:
|
|
||||||
*out.data<uint16_t>() = static_cast<uint16_t>(numel);
|
|
||||||
break;
|
|
||||||
case uint32:
|
|
||||||
*out.data<uint32_t>() = static_cast<uint32_t>(numel);
|
|
||||||
break;
|
|
||||||
case uint64:
|
|
||||||
*out.data<uint64_t>() = static_cast<uint64_t>(numel);
|
|
||||||
break;
|
|
||||||
case int8:
|
|
||||||
*out.data<int8_t>() = static_cast<int8_t>(numel);
|
|
||||||
break;
|
|
||||||
case int16:
|
|
||||||
*out.data<int16_t>() = static_cast<int16_t>(numel);
|
|
||||||
break;
|
|
||||||
case int32:
|
|
||||||
*out.data<int32_t>() = static_cast<int32_t>(numel);
|
|
||||||
break;
|
|
||||||
case int64:
|
|
||||||
*out.data<int64_t>() = static_cast<int64_t>(numel);
|
|
||||||
break;
|
|
||||||
case float16:
|
|
||||||
*out.data<float16_t>() = static_cast<float16_t>(numel);
|
|
||||||
break;
|
|
||||||
case float32:
|
|
||||||
*out.data<float>() = static_cast<float>(numel);
|
|
||||||
break;
|
|
||||||
case bfloat16:
|
|
||||||
*out.data<bfloat16_t>() = static_cast<bfloat16_t>(numel);
|
|
||||||
break;
|
|
||||||
case complex64:
|
|
||||||
*out.data<complex64_t>() = static_cast<complex64_t>(numel);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Erf::eval(const std::vector<array>& inputs, array& out) {
|
void Erf::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
@ -547,63 +412,6 @@ void RandomBits::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
|
|
||||||
const array& in,
|
|
||||||
const array& out) {
|
|
||||||
// Special case for empty arrays or row contiguous arrays
|
|
||||||
if (in.size() == 0 || in.flags().row_contiguous) {
|
|
||||||
return {false, out.strides()};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Special case for scalars
|
|
||||||
if (in.ndim() == 0) {
|
|
||||||
std::vector<size_t> out_strides(out.ndim(), 0);
|
|
||||||
return {false, out_strides};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Firstly let's collapse all the contiguous dimensions of the input
|
|
||||||
auto [shape, _strides] = collapse_contiguous_dims(in);
|
|
||||||
auto& strides = _strides[0];
|
|
||||||
|
|
||||||
// If shapes fit exactly in the contiguous dims then no copy is necessary so
|
|
||||||
// let's check.
|
|
||||||
std::vector<size_t> out_strides;
|
|
||||||
bool copy_necessary = false;
|
|
||||||
int j = 0;
|
|
||||||
for (int i = 0; i < out.ndim(); i++) {
|
|
||||||
int N = out.shape(i);
|
|
||||||
if (j < shape.size() && shape[j] % N == 0) {
|
|
||||||
shape[j] /= N;
|
|
||||||
out_strides.push_back(shape[j] * strides[j]);
|
|
||||||
j += (shape[j] == 1);
|
|
||||||
} else if (N == 1) {
|
|
||||||
// i > 0 because otherwise j < shape.size() && shape[j] % 1 == 0
|
|
||||||
out_strides.push_back(out_strides.back());
|
|
||||||
} else {
|
|
||||||
copy_necessary = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return {copy_necessary, out_strides};
|
|
||||||
}
|
|
||||||
|
|
||||||
void Reshape::shared_buffer_reshape(
|
|
||||||
const array& in,
|
|
||||||
const std::vector<size_t>& out_strides,
|
|
||||||
array& out) {
|
|
||||||
auto flags = in.flags();
|
|
||||||
if (flags.row_contiguous) {
|
|
||||||
// For row contiguous reshapes:
|
|
||||||
// - Shallow copy the buffer
|
|
||||||
// - If reshaping into a vector (all singleton dimensions except one) it
|
|
||||||
// becomes col contiguous again.
|
|
||||||
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
|
|
||||||
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
|
|
||||||
}
|
|
||||||
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
|
||||||
}
|
|
||||||
|
|
||||||
void Reshape::eval(const std::vector<array>& inputs, array& out) {
|
void Reshape::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
@ -674,49 +482,6 @@ void Sinh::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<bool, int64_t, std::vector<int64_t>> Slice::prepare_slice(
|
|
||||||
const array& in) {
|
|
||||||
int64_t data_offset = 0;
|
|
||||||
bool copy_needed = false;
|
|
||||||
std::vector<int64_t> inp_strides(in.ndim(), 0);
|
|
||||||
for (int i = 0; i < in.ndim(); ++i) {
|
|
||||||
data_offset += start_indices_[i] * in.strides()[i];
|
|
||||||
inp_strides[i] = in.strides()[i] * strides_[i];
|
|
||||||
|
|
||||||
copy_needed |= strides_[i] < 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
return std::make_tuple(copy_needed, data_offset, inp_strides);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Slice::shared_buffer_slice(
|
|
||||||
const array& in,
|
|
||||||
const std::vector<size_t>& out_strides,
|
|
||||||
size_t data_offset,
|
|
||||||
array& out) {
|
|
||||||
// Compute row/col contiguity
|
|
||||||
auto [data_size, is_row_contiguous, is_col_contiguous] =
|
|
||||||
check_contiguity(out.shape(), out_strides);
|
|
||||||
|
|
||||||
auto flags = in.flags();
|
|
||||||
flags.row_contiguous = is_row_contiguous;
|
|
||||||
flags.col_contiguous = is_col_contiguous;
|
|
||||||
|
|
||||||
if (data_size == 1) {
|
|
||||||
// Broadcasted scalar array is contiguous.
|
|
||||||
flags.contiguous = true;
|
|
||||||
} else if (data_size == in.data_size()) {
|
|
||||||
// Means we sliced a broadcasted dimension so leave the "no holes" flag
|
|
||||||
// alone.
|
|
||||||
} else {
|
|
||||||
// We sliced something. So either we are row or col contiguous or we
|
|
||||||
// punched a hole.
|
|
||||||
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
|
|
||||||
}
|
|
||||||
|
|
||||||
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Slice::eval(const std::vector<array>& inputs, array& out) {
|
void Slice::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
if (out.size() == 0) {
|
if (out.size() == 0) {
|
||||||
@ -748,18 +513,6 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
|
|
||||||
const array& in) {
|
|
||||||
int64_t data_offset = 0;
|
|
||||||
std::vector<int64_t> inp_strides(in.ndim(), 0);
|
|
||||||
for (int i = 0; i < in.ndim(); ++i) {
|
|
||||||
data_offset += start_indices_[i] * in.strides()[i];
|
|
||||||
inp_strides[i] = in.strides()[i] * strides_[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
return std::make_tuple(data_offset, inp_strides);
|
|
||||||
}
|
|
||||||
|
|
||||||
void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
|
void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
if (out.size() == 0) {
|
if (out.size() == 0) {
|
||||||
@ -797,58 +550,6 @@ void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
/* CopyType ctype = */ CopyType::GeneralGeneral);
|
/* CopyType ctype = */ CopyType::GeneralGeneral);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Split::eval(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
std::vector<array>& outputs) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
|
|
||||||
auto& in = inputs[0];
|
|
||||||
|
|
||||||
auto compute_new_flags = [](const auto& shape,
|
|
||||||
const auto& strides,
|
|
||||||
size_t in_data_size,
|
|
||||||
auto flags) {
|
|
||||||
size_t data_size = 1;
|
|
||||||
size_t f_stride = 1;
|
|
||||||
size_t b_stride = 1;
|
|
||||||
flags.row_contiguous = true;
|
|
||||||
flags.col_contiguous = true;
|
|
||||||
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
|
|
||||||
flags.col_contiguous &= strides[i] == f_stride || shape[i] == 1;
|
|
||||||
flags.row_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
|
|
||||||
f_stride *= shape[i];
|
|
||||||
b_stride *= shape[ri];
|
|
||||||
if (strides[i] > 0) {
|
|
||||||
data_size *= shape[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (data_size == 1) {
|
|
||||||
// Broadcasted scalar array is contiguous.
|
|
||||||
flags.contiguous = true;
|
|
||||||
} else if (data_size == in_data_size) {
|
|
||||||
// Means we sliced a broadcasted dimension so leave the "no holes" flag
|
|
||||||
// alone.
|
|
||||||
} else {
|
|
||||||
// We sliced something. So either we are row or col contiguous or we
|
|
||||||
// punched a hole.
|
|
||||||
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
|
|
||||||
}
|
|
||||||
|
|
||||||
return std::pair<decltype(flags), size_t>{flags, data_size};
|
|
||||||
};
|
|
||||||
|
|
||||||
std::vector<int> indices(1, 0);
|
|
||||||
indices.insert(indices.end(), indices_.begin(), indices_.end());
|
|
||||||
for (int i = 0; i < indices.size(); i++) {
|
|
||||||
size_t offset = indices[i] * in.strides()[axis_];
|
|
||||||
auto [new_flags, data_size] = compute_new_flags(
|
|
||||||
outputs[i].shape(), in.strides(), in.data_size(), in.flags());
|
|
||||||
outputs[i].copy_shared_buffer(
|
|
||||||
in, in.strides(), new_flags, data_size, offset);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Square::eval(const std::vector<array>& inputs, array& out) {
|
void Square::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
@ -865,11 +566,6 @@ void Sqrt::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
out.copy_shared_buffer(inputs[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Tan::eval(const std::vector<array>& inputs, array& out) {
|
void Tan::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
@ -894,38 +590,4 @@ void Tanh::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
std::vector<size_t> out_strides(out.ndim());
|
|
||||||
auto& in = inputs[0];
|
|
||||||
for (int ax = 0; ax < axes_.size(); ++ax) {
|
|
||||||
out_strides[ax] = in.strides()[axes_[ax]];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Conditions for {row/col}_contiguous
|
|
||||||
// - array must be contiguous (no gaps)
|
|
||||||
// - underlying buffer size should have the same size as the array
|
|
||||||
// - cumulative product of shapes is equal to the strides (we can ignore axes
|
|
||||||
// with size == 1)
|
|
||||||
// - in the forward direction (column contiguous)
|
|
||||||
// - in the reverse direction (row contiguous)
|
|
||||||
// - vectors are both row and col contiguous (hence if both row/col are
|
|
||||||
// true, they stay true)
|
|
||||||
auto flags = in.flags();
|
|
||||||
if (flags.contiguous && in.data_size() == in.size()) {
|
|
||||||
size_t f_stride = 1;
|
|
||||||
size_t b_stride = 1;
|
|
||||||
flags.col_contiguous = true;
|
|
||||||
flags.row_contiguous = true;
|
|
||||||
for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) {
|
|
||||||
flags.col_contiguous &= (out_strides[i] == f_stride || out.shape(i) == 1);
|
|
||||||
f_stride *= out.shape(i);
|
|
||||||
flags.row_contiguous &=
|
|
||||||
(out_strides[ri] == b_stride || out.shape(ri) == 1);
|
|
||||||
b_stride *= out.shape(ri);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/backend/common/copy.h"
|
#include "mlx/backend/common/copy.h"
|
||||||
#include "mlx/backend/common/lapack_helper.h"
|
#include "mlx/backend/common/lapack_helper.h"
|
||||||
#include "mlx/linalg.h"
|
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
@ -145,12 +144,4 @@ void SVD::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
|
|||||||
svd_impl(inputs[0], outputs[0], outputs[1], outputs[2]);
|
svd_impl(inputs[0], outputs[0], outputs[1], outputs[2]);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<std::vector<array>, std::vector<int>> SVD::vmap(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
const std::vector<int>& axes) {
|
|
||||||
auto ax = axes[0] >= 0 ? 0 : -1;
|
|
||||||
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
|
|
||||||
return {{linalg::svd(a, stream())}, {ax, ax, ax}};
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
9
mlx/backend/no_cpu/CMakeLists.txt
Normal file
9
mlx/backend/no_cpu/CMakeLists.txt
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
target_sources(
|
||||||
|
mlx
|
||||||
|
PRIVATE
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/../common/load.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/../common/common.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/../common/compiled_nocpu.cpp
|
||||||
|
)
|
108
mlx/backend/no_cpu/primitives.cpp
Normal file
108
mlx/backend/no_cpu/primitives.cpp
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#define NO_CPU_MULTI(func) \
|
||||||
|
void func::eval_cpu( \
|
||||||
|
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||||
|
throw std::runtime_error(#func " has no CPU implementation."); \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define NO_CPU(func) \
|
||||||
|
void func::eval_cpu(const std::vector<array>& inputs, array& out) { \
|
||||||
|
throw std::runtime_error(#func " has no CPU implementation."); \
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
NO_CPU(Abs)
|
||||||
|
NO_CPU(Add)
|
||||||
|
NO_CPU(AddMM)
|
||||||
|
NO_CPU(Arange)
|
||||||
|
NO_CPU(ArcCos)
|
||||||
|
NO_CPU(ArcCosh)
|
||||||
|
NO_CPU(ArcSin)
|
||||||
|
NO_CPU(ArcSinh)
|
||||||
|
NO_CPU(ArcTan)
|
||||||
|
NO_CPU(ArcTan2)
|
||||||
|
NO_CPU(ArcTanh)
|
||||||
|
NO_CPU(ArgPartition)
|
||||||
|
NO_CPU(ArgReduce)
|
||||||
|
NO_CPU(ArgSort)
|
||||||
|
NO_CPU(AsType)
|
||||||
|
NO_CPU(AsStrided)
|
||||||
|
NO_CPU(BitwiseBinary)
|
||||||
|
NO_CPU(BlockMaskedMM)
|
||||||
|
NO_CPU(BlockSparseMM)
|
||||||
|
NO_CPU(Broadcast)
|
||||||
|
NO_CPU(Ceil)
|
||||||
|
NO_CPU(Concatenate)
|
||||||
|
NO_CPU(Conjugate)
|
||||||
|
NO_CPU(Convolution)
|
||||||
|
NO_CPU(Copy)
|
||||||
|
NO_CPU(Cos)
|
||||||
|
NO_CPU(Cosh)
|
||||||
|
NO_CPU_MULTI(CustomVJP)
|
||||||
|
NO_CPU_MULTI(Depends)
|
||||||
|
NO_CPU(Divide)
|
||||||
|
NO_CPU_MULTI(DivMod)
|
||||||
|
NO_CPU(NumberOfElements)
|
||||||
|
NO_CPU(Remainder)
|
||||||
|
NO_CPU(Equal)
|
||||||
|
NO_CPU(Erf)
|
||||||
|
NO_CPU(ErfInv)
|
||||||
|
NO_CPU(Exp)
|
||||||
|
NO_CPU(Expm1)
|
||||||
|
NO_CPU(FFT)
|
||||||
|
NO_CPU(Floor)
|
||||||
|
NO_CPU(Full)
|
||||||
|
NO_CPU(Gather)
|
||||||
|
NO_CPU(Greater)
|
||||||
|
NO_CPU(GreaterEqual)
|
||||||
|
NO_CPU(Less)
|
||||||
|
NO_CPU(LessEqual)
|
||||||
|
NO_CPU(Load)
|
||||||
|
NO_CPU(Log)
|
||||||
|
NO_CPU(Log1p)
|
||||||
|
NO_CPU(LogicalNot)
|
||||||
|
NO_CPU(LogicalAnd)
|
||||||
|
NO_CPU(LogicalOr)
|
||||||
|
NO_CPU(LogAddExp)
|
||||||
|
NO_CPU(Matmul)
|
||||||
|
NO_CPU(Maximum)
|
||||||
|
NO_CPU(Minimum)
|
||||||
|
NO_CPU(Multiply)
|
||||||
|
NO_CPU(Negative)
|
||||||
|
NO_CPU(NotEqual)
|
||||||
|
NO_CPU(Pad)
|
||||||
|
NO_CPU(Partition)
|
||||||
|
NO_CPU(Power)
|
||||||
|
NO_CPU_MULTI(QRF)
|
||||||
|
NO_CPU(QuantizedMatmul)
|
||||||
|
NO_CPU(RandomBits)
|
||||||
|
NO_CPU(Reduce)
|
||||||
|
NO_CPU(Reshape)
|
||||||
|
NO_CPU(Round)
|
||||||
|
NO_CPU(Scan)
|
||||||
|
NO_CPU(Scatter)
|
||||||
|
NO_CPU(Select)
|
||||||
|
NO_CPU(Sigmoid)
|
||||||
|
NO_CPU(Sign)
|
||||||
|
NO_CPU(Sin)
|
||||||
|
NO_CPU(Sinh)
|
||||||
|
NO_CPU(Slice)
|
||||||
|
NO_CPU(SliceUpdate)
|
||||||
|
NO_CPU(Softmax)
|
||||||
|
NO_CPU(Sort)
|
||||||
|
NO_CPU_MULTI(Split)
|
||||||
|
NO_CPU(Square)
|
||||||
|
NO_CPU(Sqrt)
|
||||||
|
NO_CPU(StopGradient)
|
||||||
|
NO_CPU(Subtract)
|
||||||
|
NO_CPU_MULTI(SVD)
|
||||||
|
NO_CPU(Tan)
|
||||||
|
NO_CPU(Tanh)
|
||||||
|
NO_CPU(Transpose)
|
||||||
|
NO_CPU(Inverse)
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -8,6 +8,7 @@
|
|||||||
|
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/fft.h"
|
#include "mlx/fft.h"
|
||||||
|
#include "mlx/linalg.h"
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
@ -3578,4 +3579,20 @@ bool NumberOfElements::is_equivalent(const Primitive& other) const {
|
|||||||
dtype_ == n_other.dtype_;
|
dtype_ == n_other.dtype_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<array>, std::vector<int>> SVD::vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
auto ax = axes[0] >= 0 ? 0 : -1;
|
||||||
|
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
|
||||||
|
return {{linalg::svd(a, stream())}, {ax, ax, ax}};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<array>, std::vector<int>> Inverse::vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
auto ax = axes[0] >= 0 ? 0 : -1;
|
||||||
|
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
|
||||||
|
return {{linalg::inv(a, stream())}, {ax}};
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
Loading…
Reference in New Issue
Block a user