mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
4 Commits
ibv-backen
...
937ce79660
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
937ce79660 | ||
|
|
208f5441a7 | ||
|
|
b862d842e1 | ||
|
|
f7a400951a |
@@ -273,7 +273,7 @@ target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
|
||||
if(MLX_BUILD_PYTHON_BINDINGS)
|
||||
message(STATUS "Building Python bindings.")
|
||||
find_package(
|
||||
Python 3.8
|
||||
Python 3.10
|
||||
COMPONENTS Interpreter Development.Module
|
||||
REQUIRED)
|
||||
execute_process(
|
||||
|
||||
@@ -186,7 +186,7 @@ Boolean masks follow NumPy semantics:
|
||||
.. code-block:: shell
|
||||
|
||||
>>> a = mx.arange(1000).reshape(10, 10, 10)
|
||||
>>> a[mx.random.randn(10, 10) > 0.0] = 0 # valid: mask covers axes 0 and 1
|
||||
>>> a[mx.random.normal((10, 10)) > 0.0] = 0 # valid: mask covers axes 0 and 1
|
||||
|
||||
The mask of shape ``(10, 10)`` applies to the first two axes, so ``a[mask]``
|
||||
selects the 1-D slices ``a[i, j, :]`` where ``mask[i, j]`` is ``True``.
|
||||
|
||||
@@ -3,5 +3,9 @@
|
||||
#include "mlx/backend/cpu/simd/base_simd.h"
|
||||
|
||||
#ifdef MLX_USE_ACCELERATE
|
||||
#if defined(__x86_64__)
|
||||
// the accelerate_simd implementation require neon -- use base implementation
|
||||
#else
|
||||
#include "mlx/backend/cpu/simd/accelerate_simd.h"
|
||||
#endif
|
||||
#endif
|
||||
|
||||
@@ -338,28 +338,40 @@ std::pair<std::string, bool> subgraph_to_key(cudaGraph_t graph) {
|
||||
}
|
||||
cudaGraphNodeType type;
|
||||
CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));
|
||||
if (type == cudaGraphNodeTypeGraph) {
|
||||
// Try to be updatable for a structure like graph -> graph -> kernel
|
||||
cudaGraph_t child;
|
||||
CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child));
|
||||
auto [subkey, sub_is_updatable] = subgraph_to_key(child);
|
||||
is_updatable &= sub_is_updatable;
|
||||
key += subkey;
|
||||
} else if (type == cudaGraphNodeTypeMemset) {
|
||||
key += "M";
|
||||
} else if (type != cudaGraphNodeTypeKernel) {
|
||||
is_updatable = false;
|
||||
} else {
|
||||
cudaLaunchAttributeValue cluster_dim;
|
||||
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
|
||||
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
|
||||
// Only allow dim.x to be greater than 1
|
||||
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
|
||||
is_updatable = false;
|
||||
} else {
|
||||
key += "K";
|
||||
key += std::to_string(cluster_dim.clusterDim.x);
|
||||
switch (type) {
|
||||
case cudaGraphNodeTypeGraph: {
|
||||
// Try to be updatable for a structure like graph -> graph -> kernel
|
||||
cudaGraph_t child;
|
||||
CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child));
|
||||
auto [subkey, sub_is_updatable] = subgraph_to_key(child);
|
||||
is_updatable &= sub_is_updatable;
|
||||
key += subkey;
|
||||
break;
|
||||
}
|
||||
case cudaGraphNodeTypeMemset:
|
||||
key += "M";
|
||||
break;
|
||||
case cudaGraphNodeTypeKernel: {
|
||||
cudaLaunchAttributeValue cluster_dim;
|
||||
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
|
||||
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
|
||||
// Only allow dim.x to be greater than 1
|
||||
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
|
||||
is_updatable = false;
|
||||
} else {
|
||||
key += "K";
|
||||
key += std::to_string(cluster_dim.clusterDim.x);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case cudaGraphNodeTypeWaitEvent:
|
||||
key += "W";
|
||||
break;
|
||||
case cudaGraphNodeTypeEventRecord:
|
||||
key += "R";
|
||||
break;
|
||||
default:
|
||||
is_updatable = false;
|
||||
}
|
||||
}
|
||||
key += ")";
|
||||
|
||||
Reference in New Issue
Block a user