#ifndef _MNIST_DATABASE_H #define _MNIST_DATABASE_H #include "string" #include "iostream" #include "fstream" #include "vector" int ReverseInt(int i) { unsigned char ch1, ch2, ch3, ch4; ch1 = i & 255; ch2 = (i >> 8) & 255; ch3 = (i >> 16) & 255; ch4 = (i >> 24) & 255; return((int)ch1 << 24) + ((int)ch2 << 16) + ((int)ch3 << 8) + ch4; } class mnist_database { public: mnist_database(std::string dir); virtual ~mnist_database(){} const std::vector > &train_images(); const std::vector > &test_images(); const std::vector &train_labels(); const std::vector &test_labels(); void image_dimension(int &rows, int &cols); private: void read_mnist_images(std::ifstream &fs, std::vector > &images); void read_mnist_labels(std::ifstream &fs, std::vector &labels); private: std::vector > train_images_, test_images_; std::vector train_labels_, test_labels_; }; mnist_database::mnist_database(std::string dir) { std::string file = dir + "/t10k-images.idx3-ubyte"; std::ifstream infile(file, std::ios::binary); if (!infile) throw std::runtime_error("[mnist_database] Database is not found."); read_mnist_images(infile, test_images_); infile.close(); file = dir + "/t10k-labels.idx1-ubyte"; infile.open(file, std::ios::binary); if (!infile) throw std::runtime_error("[mnist_database] Database is not found."); read_mnist_labels(infile, test_labels_); infile.close(); file = dir + "/train-images.idx3-ubyte"; infile.open(file, std::ios::binary); if (!infile) throw std::runtime_error("[mnist_database] Database is not found."); read_mnist_images(infile, train_images_); infile.close(); file = dir + "/train-labels.idx1-ubyte"; infile.open(file, std::ios::binary); if (!infile) throw std::runtime_error("[mnist_database] Database is not found."); read_mnist_labels(infile, train_labels_); infile.close(); } const std::vector > &mnist_database::train_images() { return train_images_; } const std::vector > &mnist_database::test_images() { return test_images_; } const std::vector &mnist_database::train_labels() { return train_labels_; } const std::vector &mnist_database::test_labels() { return test_labels_; } void mnist_database::image_dimension(int &rows, int &cols) { rows = cols = 28; return; } void mnist_database::read_mnist_images(std::ifstream &fs, std::vector > &images) { int magic_number = 0; int number_of_images = 0; int n_rows = 0; int n_cols = 0; unsigned char label; fs.read((char*)&magic_number, sizeof(magic_number)); fs.read((char*)&number_of_images, sizeof(number_of_images)); fs.read((char*)&n_rows, sizeof(n_rows)); fs.read((char*)&n_cols, sizeof(n_cols)); magic_number = ReverseInt(magic_number); number_of_images = ReverseInt(number_of_images); n_rows = ReverseInt(n_rows); n_cols = ReverseInt(n_cols); //std::cout << "magic number = " << magic_number << std::endl; //std::cout << "number of images = " << number_of_images << std::endl; //std::cout << "rows = " << n_rows << std::endl; //std::cout << "cols = " << n_cols << std::endl; std::vector tp; for (int i = 0; i < number_of_images; i++) { tp.clear(); for (int r = 0; r < n_rows; r++) { for (int c = 0; c < n_cols; c++) { unsigned char image = 0; fs.read((char*)&image, sizeof(image)); tp.push_back(image); } } images.push_back(tp); } return; } void mnist_database::read_mnist_labels(std::ifstream &fs, std::vector &labels) { int magic_number = 0; int number_of_images = 0; fs.read((char*)&magic_number, sizeof(magic_number)); fs.read((char*)&number_of_images, sizeof(number_of_images)); magic_number = ReverseInt(magic_number); number_of_images = ReverseInt(number_of_images); //std::cout << "magic number = " << magic_number << std::endl; //std::cout << "number of images = " << number_of_images << std::endl; for (int i = 0; i < number_of_images; i++) { unsigned char label = 0; fs.read((char*)&label, sizeof(label)); labels.push_back((double)label); } return; } #endif // _MNIST_DATABASE_H