mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	angelos's commit files
This commit is contained in:
		
							
								
								
									
										1071
									
								
								python/src/array.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1071
									
								
								python/src/array.cpp
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										42
									
								
								python/src/device.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								python/src/device.cpp
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,42 @@
 | 
			
		||||
#include <sstream>
 | 
			
		||||
 | 
			
		||||
#include <pybind11/pybind11.h>
 | 
			
		||||
 | 
			
		||||
#include "mlx/device.h"
 | 
			
		||||
#include "mlx/utils.h"
 | 
			
		||||
 | 
			
		||||
namespace py = pybind11;
 | 
			
		||||
using namespace py::literals;
 | 
			
		||||
using namespace mlx::core;
 | 
			
		||||
 | 
			
		||||
void init_device(py::module_& m) {
 | 
			
		||||
  py::enum_<Device::DeviceType>(m, "DeviceType")
 | 
			
		||||
      .value("cpu", Device::DeviceType::cpu)
 | 
			
		||||
      .value("gpu", Device::DeviceType::gpu)
 | 
			
		||||
      .export_values()
 | 
			
		||||
      .def(
 | 
			
		||||
          "__eq__",
 | 
			
		||||
          [](const Device::DeviceType& d1, const Device& d2) {
 | 
			
		||||
            return d1 == d2;
 | 
			
		||||
          },
 | 
			
		||||
          py::prepend());
 | 
			
		||||
 | 
			
		||||
  py::class_<Device>(m, "Device")
 | 
			
		||||
      .def(py::init<Device::DeviceType, int>(), "type"_a, "index"_a = 0)
 | 
			
		||||
      .def_readonly("type", &Device::type)
 | 
			
		||||
      .def(
 | 
			
		||||
          "__repr__",
 | 
			
		||||
          [](const Device& d) {
 | 
			
		||||
            std::ostringstream os;
 | 
			
		||||
            os << d;
 | 
			
		||||
            return os.str();
 | 
			
		||||
          })
 | 
			
		||||
      .def("__eq__", [](const Device& d1, const Device& d2) {
 | 
			
		||||
        return d1 == d2;
 | 
			
		||||
      });
 | 
			
		||||
 | 
			
		||||
  py::implicitly_convertible<Device::DeviceType, Device>();
 | 
			
		||||
 | 
			
		||||
  m.def("default_device", &default_device);
 | 
			
		||||
  m.def("set_default_device", &set_default_device, "device"_a);
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										12
									
								
								python/src/metal.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								python/src/metal.cpp
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,12 @@
 | 
			
		||||
#include <pybind11/pybind11.h>
 | 
			
		||||
 | 
			
		||||
#include "mlx/backend/metal/metal.h"
 | 
			
		||||
 | 
			
		||||
namespace py = pybind11;
 | 
			
		||||
 | 
			
		||||
using namespace mlx::core;
 | 
			
		||||
 | 
			
		||||
void init_metal(py::module_& m) {
 | 
			
		||||
  py::module_ metal = m.def_submodule("metal", "mlx.metal");
 | 
			
		||||
  metal.def("is_available", &metal::is_available);
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										32
									
								
								python/src/stream.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								python/src/stream.cpp
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,32 @@
 | 
			
		||||
#include <sstream>
 | 
			
		||||
 | 
			
		||||
#include <pybind11/pybind11.h>
 | 
			
		||||
 | 
			
		||||
#include "mlx/stream.h"
 | 
			
		||||
#include "mlx/utils.h"
 | 
			
		||||
 | 
			
		||||
namespace py = pybind11;
 | 
			
		||||
using namespace py::literals;
 | 
			
		||||
using namespace mlx::core;
 | 
			
		||||
 | 
			
		||||
void init_stream(py::module_& m) {
 | 
			
		||||
  py::class_<Stream>(m, "Stream")
 | 
			
		||||
      .def(py::init<int, Device>(), "index"_a, "device"_a)
 | 
			
		||||
      .def_readonly("device", &Stream::device)
 | 
			
		||||
      .def(
 | 
			
		||||
          "__repr__",
 | 
			
		||||
          [](const Stream& s) {
 | 
			
		||||
            std::ostringstream os;
 | 
			
		||||
            os << s;
 | 
			
		||||
            return os.str();
 | 
			
		||||
          })
 | 
			
		||||
      .def("__eq__", [](const Stream& s1, const Stream& s2) {
 | 
			
		||||
        return s1 == s2;
 | 
			
		||||
      });
 | 
			
		||||
 | 
			
		||||
  py::implicitly_convertible<Device::DeviceType, Device>();
 | 
			
		||||
 | 
			
		||||
  m.def("default_stream", &default_stream, "device"_a);
 | 
			
		||||
  m.def("set_default_stream", &set_default_stream, "stream"_a);
 | 
			
		||||
  m.def("new_stream", &new_stream, "device"_a);
 | 
			
		||||
}
 | 
			
		||||
		Reference in New Issue
	
	Block a user