mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Address comments
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <fmt/format.h>
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
@@ -4676,6 +4676,11 @@ array segmented_mm(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (!issubdtype(segments.dtype(), integer)) {
|
||||
throw std::invalid_argument(
|
||||
"[segmented_mm] Got segments with invalid dtype. Segments must be integral.");
|
||||
}
|
||||
|
||||
a = astype(a, out_type, s);
|
||||
b = astype(b, out_type, s);
|
||||
segments = astype(segments, uint32, s);
|
||||
|
||||
Reference in New Issue
Block a user