153 lines
4.5 KiB
C
153 lines
4.5 KiB
C
|
#ifndef _MNIST_DATABASE_H
|
||
|
#define _MNIST_DATABASE_H
|
||
|
|
||
|
#include "string"
|
||
|
#include "iostream"
|
||
|
#include "fstream"
|
||
|
#include "vector"
|
||
|
|
||
|
int ReverseInt(int i)
|
||
|
{
|
||
|
unsigned char ch1, ch2, ch3, ch4;
|
||
|
ch1 = i & 255;
|
||
|
ch2 = (i >> 8) & 255;
|
||
|
ch3 = (i >> 16) & 255;
|
||
|
ch4 = (i >> 24) & 255;
|
||
|
return((int)ch1 << 24) + ((int)ch2 << 16) + ((int)ch3 << 8) + ch4;
|
||
|
}
|
||
|
|
||
|
class mnist_database
|
||
|
{
|
||
|
public:
|
||
|
mnist_database(std::string dir);
|
||
|
virtual ~mnist_database(){}
|
||
|
|
||
|
const std::vector<std::vector<double> > &train_images();
|
||
|
const std::vector<std::vector<double> > &test_images();
|
||
|
const std::vector<double> &train_labels();
|
||
|
const std::vector<double> &test_labels();
|
||
|
void image_dimension(int &rows, int &cols);
|
||
|
|
||
|
private:
|
||
|
void read_mnist_images(std::ifstream &fs, std::vector<std::vector<double> > &images);
|
||
|
void read_mnist_labels(std::ifstream &fs, std::vector<double> &labels);
|
||
|
|
||
|
private:
|
||
|
std::vector<std::vector<double> > train_images_, test_images_;
|
||
|
std::vector<double> train_labels_, test_labels_;
|
||
|
};
|
||
|
|
||
|
mnist_database::mnist_database(std::string dir)
|
||
|
{
|
||
|
std::string file = dir + "/t10k-images.idx3-ubyte";
|
||
|
std::ifstream infile(file, std::ios::binary);
|
||
|
if (!infile) throw std::runtime_error("[mnist_database] Database is not found.");
|
||
|
read_mnist_images(infile, test_images_);
|
||
|
infile.close();
|
||
|
|
||
|
file = dir + "/t10k-labels.idx1-ubyte";
|
||
|
infile.open(file, std::ios::binary);
|
||
|
if (!infile) throw std::runtime_error("[mnist_database] Database is not found.");
|
||
|
read_mnist_labels(infile, test_labels_);
|
||
|
infile.close();
|
||
|
|
||
|
file = dir + "/train-images.idx3-ubyte";
|
||
|
infile.open(file, std::ios::binary);
|
||
|
if (!infile) throw std::runtime_error("[mnist_database] Database is not found.");
|
||
|
read_mnist_images(infile, train_images_);
|
||
|
infile.close();
|
||
|
|
||
|
file = dir + "/train-labels.idx1-ubyte";
|
||
|
infile.open(file, std::ios::binary);
|
||
|
if (!infile) throw std::runtime_error("[mnist_database] Database is not found.");
|
||
|
read_mnist_labels(infile, train_labels_);
|
||
|
infile.close();
|
||
|
}
|
||
|
|
||
|
const std::vector<std::vector<double> > &mnist_database::train_images()
|
||
|
{
|
||
|
return train_images_;
|
||
|
}
|
||
|
|
||
|
const std::vector<std::vector<double> > &mnist_database::test_images()
|
||
|
{
|
||
|
return test_images_;
|
||
|
}
|
||
|
|
||
|
const std::vector<double> &mnist_database::train_labels()
|
||
|
{
|
||
|
return train_labels_;
|
||
|
}
|
||
|
|
||
|
const std::vector<double> &mnist_database::test_labels()
|
||
|
{
|
||
|
return test_labels_;
|
||
|
}
|
||
|
|
||
|
void mnist_database::image_dimension(int &rows, int &cols)
|
||
|
{
|
||
|
rows = cols = 28;
|
||
|
return;
|
||
|
}
|
||
|
|
||
|
void mnist_database::read_mnist_images(std::ifstream &fs, std::vector<std::vector<double> > &images)
|
||
|
{
|
||
|
int magic_number = 0;
|
||
|
int number_of_images = 0;
|
||
|
int n_rows = 0;
|
||
|
int n_cols = 0;
|
||
|
unsigned char label;
|
||
|
fs.read((char*)&magic_number, sizeof(magic_number));
|
||
|
fs.read((char*)&number_of_images, sizeof(number_of_images));
|
||
|
fs.read((char*)&n_rows, sizeof(n_rows));
|
||
|
fs.read((char*)&n_cols, sizeof(n_cols));
|
||
|
magic_number = ReverseInt(magic_number);
|
||
|
number_of_images = ReverseInt(number_of_images);
|
||
|
n_rows = ReverseInt(n_rows);
|
||
|
n_cols = ReverseInt(n_cols);
|
||
|
|
||
|
//std::cout << "magic number = " << magic_number << std::endl;
|
||
|
//std::cout << "number of images = " << number_of_images << std::endl;
|
||
|
//std::cout << "rows = " << n_rows << std::endl;
|
||
|
//std::cout << "cols = " << n_cols << std::endl;
|
||
|
|
||
|
std::vector<double> tp;
|
||
|
for (int i = 0; i < number_of_images; i++)
|
||
|
{
|
||
|
tp.clear();
|
||
|
for (int r = 0; r < n_rows; r++)
|
||
|
{
|
||
|
for (int c = 0; c < n_cols; c++)
|
||
|
{
|
||
|
unsigned char image = 0;
|
||
|
fs.read((char*)&image, sizeof(image));
|
||
|
tp.push_back(image);
|
||
|
}
|
||
|
}
|
||
|
images.push_back(tp);
|
||
|
}
|
||
|
return;
|
||
|
}
|
||
|
|
||
|
void mnist_database::read_mnist_labels(std::ifstream &fs, std::vector<double> &labels)
|
||
|
{
|
||
|
int magic_number = 0;
|
||
|
int number_of_images = 0;
|
||
|
fs.read((char*)&magic_number, sizeof(magic_number));
|
||
|
fs.read((char*)&number_of_images, sizeof(number_of_images));
|
||
|
magic_number = ReverseInt(magic_number);
|
||
|
number_of_images = ReverseInt(number_of_images);
|
||
|
|
||
|
//std::cout << "magic number = " << magic_number << std::endl;
|
||
|
//std::cout << "number of images = " << number_of_images << std::endl;
|
||
|
|
||
|
for (int i = 0; i < number_of_images; i++)
|
||
|
{
|
||
|
unsigned char label = 0;
|
||
|
fs.read((char*)&label, sizeof(label));
|
||
|
labels.push_back((double)label);
|
||
|
}
|
||
|
return;
|
||
|
}
|
||
|
|
||
|
#endif // _MNIST_DATABASE_H
|