mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
4 Commits
ibv-backen
...
937ce79660
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
937ce79660 | ||
|
|
208f5441a7 | ||
|
|
b862d842e1 | ||
|
|
f7a400951a |
@@ -273,7 +273,7 @@ target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
|
|||||||
if(MLX_BUILD_PYTHON_BINDINGS)
|
if(MLX_BUILD_PYTHON_BINDINGS)
|
||||||
message(STATUS "Building Python bindings.")
|
message(STATUS "Building Python bindings.")
|
||||||
find_package(
|
find_package(
|
||||||
Python 3.8
|
Python 3.10
|
||||||
COMPONENTS Interpreter Development.Module
|
COMPONENTS Interpreter Development.Module
|
||||||
REQUIRED)
|
REQUIRED)
|
||||||
execute_process(
|
execute_process(
|
||||||
|
|||||||
@@ -186,7 +186,7 @@ Boolean masks follow NumPy semantics:
|
|||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
>>> a = mx.arange(1000).reshape(10, 10, 10)
|
>>> a = mx.arange(1000).reshape(10, 10, 10)
|
||||||
>>> a[mx.random.randn(10, 10) > 0.0] = 0 # valid: mask covers axes 0 and 1
|
>>> a[mx.random.normal((10, 10)) > 0.0] = 0 # valid: mask covers axes 0 and 1
|
||||||
|
|
||||||
The mask of shape ``(10, 10)`` applies to the first two axes, so ``a[mask]``
|
The mask of shape ``(10, 10)`` applies to the first two axes, so ``a[mask]``
|
||||||
selects the 1-D slices ``a[i, j, :]`` where ``mask[i, j]`` is ``True``.
|
selects the 1-D slices ``a[i, j, :]`` where ``mask[i, j]`` is ``True``.
|
||||||
|
|||||||
@@ -3,5 +3,9 @@
|
|||||||
#include "mlx/backend/cpu/simd/base_simd.h"
|
#include "mlx/backend/cpu/simd/base_simd.h"
|
||||||
|
|
||||||
#ifdef MLX_USE_ACCELERATE
|
#ifdef MLX_USE_ACCELERATE
|
||||||
|
#if defined(__x86_64__)
|
||||||
|
// the accelerate_simd implementation require neon -- use base implementation
|
||||||
|
#else
|
||||||
#include "mlx/backend/cpu/simd/accelerate_simd.h"
|
#include "mlx/backend/cpu/simd/accelerate_simd.h"
|
||||||
#endif
|
#endif
|
||||||
|
#endif
|
||||||
|
|||||||
@@ -338,28 +338,40 @@ std::pair<std::string, bool> subgraph_to_key(cudaGraph_t graph) {
|
|||||||
}
|
}
|
||||||
cudaGraphNodeType type;
|
cudaGraphNodeType type;
|
||||||
CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));
|
CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));
|
||||||
if (type == cudaGraphNodeTypeGraph) {
|
switch (type) {
|
||||||
// Try to be updatable for a structure like graph -> graph -> kernel
|
case cudaGraphNodeTypeGraph: {
|
||||||
cudaGraph_t child;
|
// Try to be updatable for a structure like graph -> graph -> kernel
|
||||||
CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child));
|
cudaGraph_t child;
|
||||||
auto [subkey, sub_is_updatable] = subgraph_to_key(child);
|
CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child));
|
||||||
is_updatable &= sub_is_updatable;
|
auto [subkey, sub_is_updatable] = subgraph_to_key(child);
|
||||||
key += subkey;
|
is_updatable &= sub_is_updatable;
|
||||||
} else if (type == cudaGraphNodeTypeMemset) {
|
key += subkey;
|
||||||
key += "M";
|
break;
|
||||||
} else if (type != cudaGraphNodeTypeKernel) {
|
|
||||||
is_updatable = false;
|
|
||||||
} else {
|
|
||||||
cudaLaunchAttributeValue cluster_dim;
|
|
||||||
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
|
|
||||||
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
|
|
||||||
// Only allow dim.x to be greater than 1
|
|
||||||
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
|
|
||||||
is_updatable = false;
|
|
||||||
} else {
|
|
||||||
key += "K";
|
|
||||||
key += std::to_string(cluster_dim.clusterDim.x);
|
|
||||||
}
|
}
|
||||||
|
case cudaGraphNodeTypeMemset:
|
||||||
|
key += "M";
|
||||||
|
break;
|
||||||
|
case cudaGraphNodeTypeKernel: {
|
||||||
|
cudaLaunchAttributeValue cluster_dim;
|
||||||
|
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
|
||||||
|
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
|
||||||
|
// Only allow dim.x to be greater than 1
|
||||||
|
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
|
||||||
|
is_updatable = false;
|
||||||
|
} else {
|
||||||
|
key += "K";
|
||||||
|
key += std::to_string(cluster_dim.clusterDim.x);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case cudaGraphNodeTypeWaitEvent:
|
||||||
|
key += "W";
|
||||||
|
break;
|
||||||
|
case cudaGraphNodeTypeEventRecord:
|
||||||
|
key += "R";
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
is_updatable = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
key += ")";
|
key += ")";
|
||||||
|
|||||||
Reference in New Issue
Block a user