fix extension (#1740)

This commit is contained in:
Awni Hannun 2025-01-02 16:16:16 -08:00 committed by GitHub
parent 491fa95b1f
commit a64a8dfe45
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -58,9 +58,9 @@ mx::array axpby(
// Construct the array as the output of the Axpby primitive // Construct the array as the output of the Axpby primitive
// with the broadcasted and upcasted arrays as inputs // with the broadcasted and upcasted arrays as inputs
return mx::array( return mx::array(
/* const std::vector<int>& shape = */ out_shape, /* const mx::Shape& shape = */ out_shape,
/* mx::Dtype dtype = */ out_dtype, /* mx::Dtype dtype = */ out_dtype,
/* std::unique_ptr<mx::Primitive> primitive = */ /* std::shared_ptr<mx::Primitive> primitive = */
std::make_shared<Axpby>(to_stream(s), alpha, beta), std::make_shared<Axpby>(to_stream(s), alpha, beta),
/* const std::vector<mx::array>& inputs = */ broadcasted_inputs); /* const std::vector<mx::array>& inputs = */ broadcasted_inputs);
} }
@ -279,7 +279,7 @@ void Axpby::eval_gpu(
if (!contiguous_kernel) { if (!contiguous_kernel) {
compute_encoder.set_vector_bytes(x.shape(), 5); compute_encoder.set_vector_bytes(x.shape(), 5);
compute_encoder.set_vector_bytes(x.strides(), 6); compute_encoder.set_vector_bytes(x.strides(), 6);
compute_encoder.set_bytes(y.strides(), 7); compute_encoder.set_vector_bytes(y.strides(), 7);
compute_encoder.set_bytes(ndim, 8); compute_encoder.set_bytes(ndim, 8);
} }