Address comments

This commit is contained in:
Angelos Katharopoulos
2025-07-07 17:03:28 -07:00
parent 8ea5729ee4
commit 1c589298ec
2 changed files with 5 additions and 1 deletions

View File

@@ -1,6 +1,5 @@
// Copyright © 2023-2024 Apple Inc.
#include <fmt/format.h>
#include <iostream>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/utils.h"

View File

@@ -4676,6 +4676,11 @@ array segmented_mm(
throw std::invalid_argument(msg.str());
}
if (!issubdtype(segments.dtype(), integer)) {
throw std::invalid_argument(
"[segmented_mm] Got segments with invalid dtype. Segments must be integral.");
}
a = astype(a, out_type, s);
b = astype(b, out_type, s);
segments = astype(segments, uint32, s);