(Caffe)基本类DataReader、QueuePair、B
2016-04-07 本文已影响290人
沤江一流
本文从CSDN上转移过来:
http://blog.csdn.net/mounty_fsc/article/details/51088361
1 简介
QueuePair与Body
是DataReader
的内部类。一个DataReader
对应一个任务,一个Body生成一个线程来读取数据库(如examples/mnist/mnist_train_lmdb
)。QueuePair
为前面两者之间的衔接、通信。
2 源代码
/**
* @brief Reads data from a source to queues available to data layers.
* A single reading thread is created per source, even if multiple solvers
* are running in parallel, e.g. for multi-GPU training. This makes sure
* databases are read sequentially, and that each solver accesses a different
* subset of the database. Data is distributed to solvers in a round-robin
* way to keep parallel training deterministic.
*/
class DataReader {
public:
...
protected:
// Queue pairs are shared between a body and its readers
class QueuePair {
public:
explicit QueuePair(int size);
~QueuePair();
BlockingQueue<Datum*> free_;
BlockingQueue<Datum*> full_;
};
// A single body is created per source
class Body : public InternalThread {
public:
...
protected:
void InternalThreadEntry();
void read_one(db::Cursor* cursor, QueuePair* qp);
const LayerParameter param_;
BlockingQueue<shared_ptr<QueuePair> > new_queue_pairs_;
...
};
...
const shared_ptr<QueuePair> queue_pair_;
shared_ptr<Body> body_;
static map<const string, boost::weak_ptr<DataReader::Body> > bodies_;
};
2 类QueuePair
DataReader::QueuePair::QueuePair(int size) {
// Initialize the free queue with requested number of datums
for (int i = 0; i < size; ++i) {
free_.push(new Datum());
}
}
说明:
- 一个
QueuePair
对应一个任务队列,从数据库(如examples/mnist/mnist_train_lmdb
)中读取size
个样本 -
BlockingQueue
为一个线程安全的队列容器,其模板类型可能是Datum
,Batch
等。此处装的是Datum
。 -
BlockingQueue<Datum*> free_为Datum
队列,均为新new
出来的,没有包含原始数据(图像)信息 -
BlockingQueue<Datum*> full_
为从数据库读取信息后的队列,包含了原始数据(图像)信息 -
Datum
为一个样本单元,关于Datum
的定义,参见caffe.proto
文件,一般来说,Datum
对应于一张图像(及其label
)
3 类Body
DataReader::Body::Body(const LayerParameter& param)
: param_(param),
new_queue_pairs_() {
StartInternalThread();
}
说明:
-
Body
类继承了InternalThread
(详见博文)。在构造函数了开启这个线程 -
Body
类重载了DataReader::Body::InternalThreadEntry()
函数,从数据库读取数据的操作在该函数中实现,见本文第5节
4 类DataReader
DataReader
类的构造函数如下:
map<const string, weak_ptr<DataReader::Body> > DataReader::bodies_;
static boost::mutex bodies_mutex_;
DataReader::DataReader(const LayerParameter& param)
: queue_pair_(new QueuePair( //
param.data_param().prefetch() * param.data_param().batch_size())) {
// Get or create a body
boost::mutex::scoped_lock lock(bodies_mutex_);
string key = source_key(param);
weak_ptr<Body>& weak = bodies_[key];
body_ = weak.lock();
if (!body_) {
body_.reset(new Body(param));
bodies_[key] = weak_ptr<Body>(body_);
}
body_->new_queue_pairs_.push(queue_pair_);
}
说明:
- 一个数据库只可能有
Body
对象,如examples/mnist/mnist_train_lmdb
不管在任何线程的任何DataReader
对象中,都只会有一个Body
对象,因为bodies_
是静态的: - 所以有,一个
Body
的对象也可以有多个DataReader
对象 - 此外有,一个
DataReader
对象可以有多个Body
对象,即map<string,weak_ptr<Body>> bodies_
- 由代码5,6行及16行可知,每一个DataReader对应一个读的任务,即从数据库(如examples/mnist/mnist_train_lmdb)中读取param.data_param().prefetch() * param.data_param().batch_size()(LeNet5中默认为4×64)个样本
- 由此可见,一个DataReader为一个任务,通过QueuePair(也对应于该任务)“通知”Body某个数据库中读去N个样本
- 由代码13行可知,某个数据库(如examples/mnist/mnist_train_lmdb)对应的Body若不存在,将新建一个Body来处理该数据库,也可以理解成新建一个唯一对应于该数据库的线程来处理该数据可。
5 函数DataReader::Body::InternalThreadEntry
void DataReader::Body::InternalThreadEntry() {
...
vector<shared_ptr<QueuePair> > qps;
try {
...
// To ensure deterministic runs, only start running once all solvers
// are ready. But solvers need to peek on one item during initialization,
// so read one item, then wait for the next solver.
for (int i = 0; i < solver_count; ++i) {
shared_ptr<QueuePair> qp(new_queue_pairs_.pop());
read_one(cursor.get(), qp.get());
qps.push_back(qp);
}
// Main loop
while (!must_stop()) {
for (int i = 0; i < solver_count; ++i) {
read_one(cursor.get(), qps[i].get());
}
...
}
} catch (boost::thread_interrupted&) {
// Interrupted exception is expected on shutdown
}
}
说明:
-
read_one()
从QueuePair的free_
中取出<font color="red">一个</font>Datum
,从数据库读入数据至Datum
,然后放入full_
中 - 由第4节16行可知,一个新的任务(
DataReader
)到来时,将把一个命令队列(QueuePair
)放入到某个数据库(Body
)的缓冲命令队列中(new_queue_pairs_
) - 9到13行从每个solver的任务中读取一个
Datum
,在15到18行从数据库中循环读出数据 - <u>该线程何时停止呢?</u>