network

2017-05-04  本文已影响214人  陈继科

和 layer 一样,这也是深度学习框架的重要数据结构。

类型 名称 意义

int n;//
int batch;//
int *seen;
float epoch;//
int subdivisions;
float momentum;//
float decay;//
layer *layers;//
float *output;
learning_rate_policy policy;//

float learning_rate;//
float gamma;//
float scale;//
float power;//
int time_steps;
int step;//
int max_batches;
float *scales;
int   *steps;
int num_steps;
int burn_in;

int adam;//
float B1;
float B2;
float eps;

int inputs;
int outputs;
int truths;
int notruth;
int h, w, c;
int max_crop;
int min_crop;
int center;
float angle;
float aspect;
float exposure;
float saturation;
float hue;

int gpu_index;
tree *hierarchy;



float *input;
float *truth;
float *delta;
float *workspace;
int train;
int index;
float *cost;

#ifdef GPU
float *input_gpu;
float *truth_gpu;
float *delta_gpu;
float *output_gpu;
#endif

这里面提供了许多重要函数
float get_current_rate(network net);
int get_current_batch(network net);
void free_network(network net);
void compare_networks(network n1, network n2, data d);
char *get_layer_string(LAYER_TYPE a);

network make_network(int n);
void forward_network(network net);
void backward_network(network net);
void update_network(network net);

float train_network(network net, data d);
float train_network_sgd(network net, data d, int n);
float train_network_datum(network net);

matrix network_predict_data(network net, data test);
float *network_predict(network net, float *input);
float network_accuracy(network net, data d);
float *network_accuracies(network net, data d, int n);
float network_accuracy_multi(network net, data d, int n);
void top_predictions(network net, int n, int *index);
image get_network_image(network net);
image get_network_image_layer(network net, int i);
layer get_network_output_layer(network net);
int get_predicted_class_network(network net);
void print_network(network net);
void visualize_network(network net);
int resize_network(network *net, int w, int h);
void set_batch_network(network *net, int b);
network load_network(char *cfg, char *weights, int clear);
load_args get_base_args(network net);
void calc_network_cost(network net);
重点先看下面三个,有助于代码理解

forward_network(...)##

前向

void forward_network(network net)
{
    int i;
    for(i = 0; i < net.n; ++i){
        net.index = i;
        layer l = net.layers[i];
        if(l.delta){
            fill_cpu(l.outputs * l.batch, 0, l.delta, 1);
        }
        l.forward(l, net);
        net.input = l.output;
        if(l.truth) {
            net.truth = l.output;
        }
    }
    calc_network_cost(net);
}

backward_network(...)##

BP 算梯度

void backward_network(network net)
{
    int i;
    network orig = net;
    for(i = net.n-1; i >= 0; --i){
        layer l = net.layers[i];
        if(l.stopbackward) break;
        if(i == 0){
            net = orig;
        }else{
            layer prev = net.layers[i-1];
            net.input = prev.output;
            net.delta = prev.delta;
        }
        net.index = i;
        l.backward(l, net);
    }
}

update_network(...)##

更新 parameters

void update_network(network net)
{
    int i;
    int update_batch = net.batch*net.subdivisions;
    float rate = get_current_rate(net);
    for(i = 0; i < net.n; ++i){
        layer l = net.layers[i];
        if(l.update){
            l.update(l, update_batch, rate*l.learning_rate_scale, net.momentum, net.decay);
        }
    }
}
上一篇下一篇

猜你喜欢

热点阅读