mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
2 Commits
27232db1ba
...
b862d842e1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b862d842e1 | ||
|
|
f7a400951a |
@@ -186,7 +186,7 @@ Boolean masks follow NumPy semantics:
|
||||
.. code-block:: shell
|
||||
|
||||
>>> a = mx.arange(1000).reshape(10, 10, 10)
|
||||
>>> a[mx.random.randn(10, 10) > 0.0] = 0 # valid: mask covers axes 0 and 1
|
||||
>>> a[mx.random.normal((10, 10)) > 0.0] = 0 # valid: mask covers axes 0 and 1
|
||||
|
||||
The mask of shape ``(10, 10)`` applies to the first two axes, so ``a[mask]``
|
||||
selects the 1-D slices ``a[i, j, :]`` where ``mask[i, j]`` is ``True``.
|
||||
|
||||
@@ -338,28 +338,40 @@ std::pair<std::string, bool> subgraph_to_key(cudaGraph_t graph) {
|
||||
}
|
||||
cudaGraphNodeType type;
|
||||
CHECK_CUDA_ERROR(cudaGraphNodeGetType(node, &type));
|
||||
if (type == cudaGraphNodeTypeGraph) {
|
||||
// Try to be updatable for a structure like graph -> graph -> kernel
|
||||
cudaGraph_t child;
|
||||
CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child));
|
||||
auto [subkey, sub_is_updatable] = subgraph_to_key(child);
|
||||
is_updatable &= sub_is_updatable;
|
||||
key += subkey;
|
||||
} else if (type == cudaGraphNodeTypeMemset) {
|
||||
key += "M";
|
||||
} else if (type != cudaGraphNodeTypeKernel) {
|
||||
is_updatable = false;
|
||||
} else {
|
||||
cudaLaunchAttributeValue cluster_dim;
|
||||
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
|
||||
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
|
||||
// Only allow dim.x to be greater than 1
|
||||
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
|
||||
is_updatable = false;
|
||||
} else {
|
||||
key += "K";
|
||||
key += std::to_string(cluster_dim.clusterDim.x);
|
||||
switch (type) {
|
||||
case cudaGraphNodeTypeGraph: {
|
||||
// Try to be updatable for a structure like graph -> graph -> kernel
|
||||
cudaGraph_t child;
|
||||
CHECK_CUDA_ERROR(cudaGraphChildGraphNodeGetGraph(node, &child));
|
||||
auto [subkey, sub_is_updatable] = subgraph_to_key(child);
|
||||
is_updatable &= sub_is_updatable;
|
||||
key += subkey;
|
||||
break;
|
||||
}
|
||||
case cudaGraphNodeTypeMemset:
|
||||
key += "M";
|
||||
break;
|
||||
case cudaGraphNodeTypeKernel: {
|
||||
cudaLaunchAttributeValue cluster_dim;
|
||||
CHECK_CUDA_ERROR(cudaGraphKernelNodeGetAttribute(
|
||||
node, cudaLaunchAttributeClusterDimension, &cluster_dim));
|
||||
// Only allow dim.x to be greater than 1
|
||||
if (cluster_dim.clusterDim.y > 1 || cluster_dim.clusterDim.z > 1) {
|
||||
is_updatable = false;
|
||||
} else {
|
||||
key += "K";
|
||||
key += std::to_string(cluster_dim.clusterDim.x);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case cudaGraphNodeTypeWaitEvent:
|
||||
key += "W";
|
||||
break;
|
||||
case cudaGraphNodeTypeEventRecord:
|
||||
key += "R";
|
||||
break;
|
||||
default:
|
||||
is_updatable = false;
|
||||
}
|
||||
}
|
||||
key += ")";
|
||||
|
||||
Reference in New Issue
Block a user