【框架caffe】1:caffe.proto——caffe中用到
2018-09-11 本文已影响24人
yuanCruise
它位于…\src\caffe\proto目录下,在这个文件夹下还有一个.pb.cc和一个.pb.h文件,这两个文件都是由caffe.proto编译而来的。
0. 基于Protobuf标准的解释(Protobuf在caffe中到底是怎么实现的)
首先我们需要编写一个 proto 文件,定义我们程序中需要处理的结构化数据,在 protobuf 的术语中,结构化数据被称为 Message。proto 文件非常类似 java 或者 C 语言的数据定义。
Package caffe;
Message BlobProto{...}
如上所示的这个组合,表达的意思为:定义一个命名空间,在该命名空间下定义一个类BlobProto(每一个Message都会生成一个类)。下面对Message做一个简单的介绍:
Message中的field:
//而且这些field有三种形式:
//1. Required是必须有值的,
//2. optional是可选项,
//3. repeated表示后面单元为相同类型的一组向量。
Message的tag:
//每个message里面的每个field都对应一个tag,分别是1~15或者以上,比如required string number=1;
//这个数字就是用来在生成的二进制文件中搜索查询的标签(怪不得会快呢^_^)。
//关于这个数字,1到15会花费1byte的编码空间,16到2047花费2byte。所以一般建议把那些频繁使用的名字的标签设为1到15之间的值~
Message的enum:
//enum枚举类型调用方法:caffe::BlobProto::枚举类型里面的变量。
然后通过编译这个.proto文件之后就会生成一个.pb.cc和一个.pb.h文件。
protoc -I=$SRC_DIR --cpp_out=$DST_DIR $SRC_DIR/addressbook.proto
这个.pb.cc文件里面会自动生成一些函数:
Set_+field():该函数用来设定值
has_():该函数来检查当前Message中是否有某个field
clear_():该函数用来清理field
mutable_():该函数用来设置string的值
_size():该函数用于获取重复的个数
//还有一些函数
void CopyFrom();
void MergeFrom();
void CopyFrom();
void MergeFrom;
void Clear();
bool IsInitialized() const;
int ByteSize() const;
bool MergePartialFromCodedStream();
void SerializeWithCachedSizes() const;
SerializeWithCachedSizesToArray() const;
int GetCachedSize()
void SharedCtor();
void SharedDtor();
void SetCachedSize() const;
生成这些函数之后要在其他文件中用到数据序列化,反序列化的时候只需要:
#include "caffe.pb.h"
并且可以通过如下所示的操作来定义类对象,这些类对象可以用.pb.cc中的各个函数。
//定义一个类对象(Message BoblProto)
caffe:: BlobProto* blobproto1
blobproto1.函数(对象不是指针类型)
blobproto1->函数(对象是指针类型)
Message的enum:
//enum枚举类型调用方法:caffe::BlobProto::枚举类型里面的变量。
1. caffe.proto的代码框架如下:
2. caffe.proto中的几个重要的Message(数据类型)
Message类别:
属于blob的:BlobProto, BlobProtoVector, Datum。
属于layer的:FillerParameter, LayerParameter,
ArgMaxParameter,TransformationParameter, LossParameter, AccuracyParameter,
ConcatParameter, ContrastiveLossParameter, ConvolutionParameter,
DataParameter, DropoutParameter, DummyDataParameter, EltwiseParameter,
ExpParameter, HDF5DataParameter, HDF5OutputParameter, HingeLossParameter,
ImageDataParameter, InfogainLossParameter, InnerProductParameter,
LRNParameter, MemoryDataParameter, MVNParameter, PoolingParameter,
PowerParameter, PythonParameter, ReLUParameter, SigmoidParameter,
SliceParameter, SoftmaxParameter, TanHParameter, ThresholdParameter等。
属于net的:NetParameter, SolverParameter, SolverState, NetState, NetStateRule,
ParamSpec。
NetParameter弄清楚NetParameter类的组成,也就明白了.Caffemodel的
具体数据构成;
SolverState类记录的是当前迭代状态和参数设置,与.solverstate文件有
关系;
<0> BlobProto
message BlobProto {//blob的属性以及blob中的数据(data\diff)
optional int32 num = 1 [default = 0];
optional int32 channels = 2 [default = 0];
optional int32 height = 3 [default = 0];
optional int32 width = 4 [default = 0];
repeated float data = 5 [packed = true];
repeated float diff = 6 [packed = true];
}
<1> Datum
message Datum {
optional int32 channels = 1;
optional int32 height = 2;
optional int32 width = 3;
optional bytes data = 4;//真实的图像数据,以字节存储(bytes)
optional int32 label = 5;
repeated float float_data = 6;//datum也能存float类型的数据(float)
}
<2> LayerParameter
message LayerParameter {
repeated string bottom = 2; //输入的blob的名字(string)
repeated string top = 3; //输出的blob的名字(string)
optional string name = 4; //层的名字
enum LayerType { //层的枚举(enum,和c++中的enum一样)
NONE = 0;
ACCURACY = 1;
BNLL = 2;
CONCAT = 3;
CONVOLUTION = 4;
DATA = 5;
DROPOUT = 6;
EUCLIDEAN_LOSS = 7;
ELTWISE_PRODUCT = 25;
FLATTEN = 8;
HDF5_DATA = 9;
HDF5_OUTPUT = 10;
HINGE_LOSS = 28;
IM2COL = 11;
IMAGE_DATA = 12;
INFOGAIN_LOSS = 13;
INNER_PRODUCT = 14;
LRN = 15;
MEMORY_DATA = 29;
MULTINOMIAL_LOGISTIC_LOSS = 16;
POOLING = 17;
POWER = 26;
RELU = 18;
SIGMOID = 19;
SIGMOID_CROSS_ENTROPY_LOSS = 27;
SOFTMAX = 20;
SOFTMAX_LOSS = 21;
SPLIT = 22;
TANH = 23;
WINDOW_DATA = 24;
}
optional LayerType type = 5; // 层的类型
repeated BlobProto blobs = 6; //blobs的数值参数
repeated float blobs_lr = 7; //学习速率(repeated),如果你想那个设置一个blob的学习速率,你需要设置所有blob的学习速率。
repeated float weight_decay = 8; //权值衰减(repeated)
// 相对于某一特定层的参数(optional)
optional ConcatParameter concat_param = 9;
optional ConvolutionParameter convolution_param = 10;
optional DataParameter data_param = 11;
optional DropoutParameter dropout_param = 12;
optional HDF5DataParameter hdf5_data_param = 13;
optional HDF5OutputParameter hdf5_output_param = 14;
optional ImageDataParameter image_data_param = 15;
optional InfogainLossParameter infogain_loss_param = 16;
optional InnerProductParameter inner_product_param = 17;
optional LRNParameter lrn_param = 18;
optional MemoryDataParameter memory_data_param = 22;
optional PoolingParameter pooling_param = 19;
optional PowerParameter power_param = 21;
optional WindowDataParameter window_data_param = 20;
optional V0LayerParameter layer = 1;
}
<3> NetParameter
message NetParameter {
optional string name = 1;//网络的名字
repeated LayerParameter layers = 2; //repeated类似于数组
repeated string input = 3;//输入层blob的名字
repeated int32 input_dim = 4;//输入层blob的维度,应该等于(4*#input)
optional bool force_backward = 5 [default = false];//网络是否进行反向传播。如果设置为否,则由网络的结构和学习速率来决定是否进行反向传播。
}
<4> SolverParameter
message SolverParameter {
optional string train_net = 1; // 训练网络的proto file
optional string test_net = 2; // 测试网络的proto file
optional int32 test_iter = 3 [default = 0]; // 每次测试时的迭代次数
optional int32 test_interval = 4 [default = 0]; // 两次测试的间隔迭代次数
optional bool test_compute_loss = 19 [default = false];
optional float base_lr = 5; // 基本学习率
optional int32 display = 6; // 两次显示的间隔迭代次数
optional int32 max_iter = 7; // 最大迭代次数
optional string lr_policy = 8; // 学习速率衰减方式
optional float gamma = 9; // 关于梯度下降的一个参数
optional float power = 10; // 计算学习率的一个参数
optional float momentum = 11; // 动量
optional float weight_decay = 12; // 权值衰减
optional int32 stepsize = 13; // 学习速率的衰减步长
optional int32 snapshot = 14 [default = 0]; // snapshot的间隔
optional string snapshot_prefix = 15; // snapshot的前缀
optional bool snapshot_diff = 16 [default = false]; // 是否对于 diff 进行 snapshot
enum SolverMode {
CPU = 0;
GPU = 1;
}
optional SolverMode solver_mode = 17 [default = GPU]; // solver的模式,默认为GPU
optional int32 device_id = 18 [default = 0]; // GPU的ID
optional int64 random_seed = 20 [default = -1]; // 随机数种子
}