gctl_ai/data/MNIST/mnist_database.h

153 lines
4.5 KiB
C
Raw Normal View History

2024-09-10 20:15:33 +08:00
#ifndef _MNIST_DATABASE_H
#define _MNIST_DATABASE_H
#include "string"
#include "iostream"
#include "fstream"
#include "vector"
int ReverseInt(int i)
{
unsigned char ch1, ch2, ch3, ch4;
ch1 = i & 255;
ch2 = (i >> 8) & 255;
ch3 = (i >> 16) & 255;
ch4 = (i >> 24) & 255;
return((int)ch1 << 24) + ((int)ch2 << 16) + ((int)ch3 << 8) + ch4;
}
class mnist_database
{
public:
mnist_database(std::string dir);
virtual ~mnist_database(){}
const std::vector<std::vector<double> > &train_images();
const std::vector<std::vector<double> > &test_images();
const std::vector<double> &train_labels();
const std::vector<double> &test_labels();
void image_dimension(int &rows, int &cols);
private:
void read_mnist_images(std::ifstream &fs, std::vector<std::vector<double> > &images);
void read_mnist_labels(std::ifstream &fs, std::vector<double> &labels);
private:
std::vector<std::vector<double> > train_images_, test_images_;
std::vector<double> train_labels_, test_labels_;
};
mnist_database::mnist_database(std::string dir)
{
std::string file = dir + "/t10k-images.idx3-ubyte";
std::ifstream infile(file, std::ios::binary);
if (!infile) throw std::runtime_error("[mnist_database] Database is not found.");
read_mnist_images(infile, test_images_);
infile.close();
file = dir + "/t10k-labels.idx1-ubyte";
infile.open(file, std::ios::binary);
if (!infile) throw std::runtime_error("[mnist_database] Database is not found.");
read_mnist_labels(infile, test_labels_);
infile.close();
file = dir + "/train-images.idx3-ubyte";
infile.open(file, std::ios::binary);
if (!infile) throw std::runtime_error("[mnist_database] Database is not found.");
read_mnist_images(infile, train_images_);
infile.close();
file = dir + "/train-labels.idx1-ubyte";
infile.open(file, std::ios::binary);
if (!infile) throw std::runtime_error("[mnist_database] Database is not found.");
read_mnist_labels(infile, train_labels_);
infile.close();
}
const std::vector<std::vector<double> > &mnist_database::train_images()
{
return train_images_;
}
const std::vector<std::vector<double> > &mnist_database::test_images()
{
return test_images_;
}
const std::vector<double> &mnist_database::train_labels()
{
return train_labels_;
}
const std::vector<double> &mnist_database::test_labels()
{
return test_labels_;
}
void mnist_database::image_dimension(int &rows, int &cols)
{
rows = cols = 28;
return;
}
void mnist_database::read_mnist_images(std::ifstream &fs, std::vector<std::vector<double> > &images)
{
int magic_number = 0;
int number_of_images = 0;
int n_rows = 0;
int n_cols = 0;
unsigned char label;
fs.read((char*)&magic_number, sizeof(magic_number));
fs.read((char*)&number_of_images, sizeof(number_of_images));
fs.read((char*)&n_rows, sizeof(n_rows));
fs.read((char*)&n_cols, sizeof(n_cols));
magic_number = ReverseInt(magic_number);
number_of_images = ReverseInt(number_of_images);
n_rows = ReverseInt(n_rows);
n_cols = ReverseInt(n_cols);
//std::cout << "magic number = " << magic_number << std::endl;
//std::cout << "number of images = " << number_of_images << std::endl;
//std::cout << "rows = " << n_rows << std::endl;
//std::cout << "cols = " << n_cols << std::endl;
std::vector<double> tp;
for (int i = 0; i < number_of_images; i++)
{
tp.clear();
for (int r = 0; r < n_rows; r++)
{
for (int c = 0; c < n_cols; c++)
{
unsigned char image = 0;
fs.read((char*)&image, sizeof(image));
tp.push_back(image);
}
}
images.push_back(tp);
}
return;
}
void mnist_database::read_mnist_labels(std::ifstream &fs, std::vector<double> &labels)
{
int magic_number = 0;
int number_of_images = 0;
fs.read((char*)&magic_number, sizeof(magic_number));
fs.read((char*)&number_of_images, sizeof(number_of_images));
magic_number = ReverseInt(magic_number);
number_of_images = ReverseInt(number_of_images);
//std::cout << "magic number = " << magic_number << std::endl;
//std::cout << "number of images = " << number_of_images << std::endl;
for (int i = 0; i < number_of_images; i++)
{
unsigned char label = 0;
fs.read((char*)&label, sizeof(label));
labels.push_back((double)label);
}
return;
}
#endif // _MNIST_DATABASE_H