diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index b01eeec7e..13ce88a62 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -1,6 +1,5 @@ // Copyright © 2023-2024 Apple Inc. #include -#include #include "mlx/backend/common/compiled.h" #include "mlx/backend/common/utils.h" diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 255b0307c..7161a39b2 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4676,6 +4676,11 @@ array segmented_mm( throw std::invalid_argument(msg.str()); } + if (!issubdtype(segments.dtype(), integer)) { + throw std::invalid_argument( + "[segmented_mm] Got segments with invalid dtype. Segments must be integral."); + } + a = astype(a, out_type, s); b = astype(b, out_type, s); segments = astype(segments, uint32, s);