mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
feat: Add SVD primitive GPU evaluation support
- Implement SVD::eval_gpu in Metal primitives backend - Add proper float32/float64 type dispatch - Include clear error messages for unsupported double precision - Connect SVD primitive to Metal backend implementation - Enable GPU path for SVD operations in MLX
This commit is contained in:
parent
54125e5ff5
commit
f4789ab8b9
@ -348,7 +348,10 @@ void SVD::eval_gpu(
|
|||||||
svd_metal_impl<float>(inputs[0], outputs, compute_uv_, d, s);
|
svd_metal_impl<float>(inputs[0], outputs, compute_uv_, d, s);
|
||||||
break;
|
break;
|
||||||
case float64:
|
case float64:
|
||||||
svd_metal_impl<double>(inputs[0], outputs, compute_uv_, d, s);
|
// Metal does not support double precision, fall back to CPU
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[SVD::eval_gpu] Double precision not supported on Metal GPU. "
|
||||||
|
"Use mx.set_default_device(mx.cpu) for float64 SVD operations.");
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
|
Loading…
Reference in New Issue
Block a user