gcforest的官方代码详解

2018-05-14  本文已影响0人  wendy_要努力努力再努力

本文采用的是v1.1版本,github地址https://github.com/kingfengji/gcForest
代码主要分为两部分:examples文件夹下是主代码.py和配置文件.json;libs文件夹下是代码中用到的库

主代码的实现

from gcforest.gcforest import GCForest
gc = GCForest(config) # should be a dict
X_train_enc = gc.fit_transform(X_train, y_train)
y_pred = gc.predict(X_test)

lib库的详解

gcforest.py 整个框架的实现
fgnet.py 多粒度部分,FineGrained的实现
cascade/cascade_classifier 级联分类器的实现
datasets/.... 包含一系列数据集的定义
estimator/... 包含决策树在进行评估用到的函数(多种分类器的预估)
layer/... 包含不同的层操作,如连接、池化、滑窗等
utils/.. 包含各种功能函数,譬如计算准确率、win_vote、win_avg、get_windows等

json配置文件的详解

参数介绍


训练的配置,分三类情况:

  1. 采用默认的模型
def get_toy_config():
    config = {}
    ca_config = {}
    ca_config["random_state"] = 0  # 0 or 1
    ca_config["max_layers"] = 100  #最大的层数,layer对应论文中的level
    ca_config["early_stopping_rounds"] = 3  #如果出现某层的三层以内的准确率都没有提升,层中止
    ca_config["n_classes"] = 3      #判别的类别数量
    ca_config["estimators"] = []  
    ca_config["estimators"].append(
            {"n_folds": 5, "type": "XGBClassifier", "n_estimators": 10, "max_depth": 5,
             "objective": "multi:softprob", "silent": True, "nthread": -1, "learning_rate": 0.1} )
    ca_config["estimators"].append({"n_folds": 5, "type": "RandomForestClassifier", "n_estimators": 10, "max_depth": None, "n_jobs": -1})
    ca_config["estimators"].append({"n_folds": 5, "type": "ExtraTreesClassifier", "n_estimators": 10, "max_depth": None, "n_jobs": -1})
    ca_config["estimators"].append({"n_folds": 5, "type": "LogisticRegression"})
    config["cascade"] = ca_config    #共使用了四个基学习器
    return config

支持的基本分类器:
RandomForestClassifier
XGBClassifier
ExtraTreesClassifier
LogisticRegression
SGDClassifier

你可以通过下述方式手动添加任何分类器:

lib/gcforest/estimators/__init__.py
  1. 只有级联(cascade)部分
{
"cascade": {
    "random_state": 0,
    "max_layers": 100,
    "early_stopping_rounds": 3,
    "n_classes": 10,
    "estimators": [
        {"n_folds":5,"type":"XGBClassifier","n_estimators":10,"max_depth":5,"objective":"multi:softprob", "silent":true, "nthread":-1, "learning_rate":0.1},
        {"n_folds":5,"type":"RandomForestClassifier","n_estimators":10,"max_depth":null,"n_jobs":-1},
        {"n_folds":5,"type":"ExtraTreesClassifier","n_estimators":10,"max_depth":null,"n_jobs":-1},
        {"n_folds":5,"type":"LogisticRegression"}
    ]
}
}
  1. “multi fine-grained + cascade” 两部分
    滑动窗口的大小: {[d/16], [d/8], [d/4]},d代表输入特征的数量;
    "look_indexs_cycle": [
    [0, 1],
    [2, 3],
    [4, 5]]
    代表级联多粒度的方式,第一层级联0、1森林的输出,第二层级联2、3森林的输出,第三层级联4、5森林的输出
{
"net":{
"outputs": ["pool1/7x7/ets", "pool1/7x7/rf", "pool1/10x10/ets", "pool1/10x10/rf", "pool1/13x13/ets", "pool1/13x13/rf"],
"layers":[
// win1/7x7
    {
        "type":"FGWinLayer",
        "name":"win1/7x7",
        "bottoms": ["X","y"],
        "tops":["win1/7x7/ets", "win1/7x7/rf"],
        "n_classes": 10,
        "estimators": [
            {"n_folds":3,"type":"ExtraTreesClassifier","n_estimators":20,"max_depth":10,"n_jobs":-1,"min_samples_leaf":10},
            {"n_folds":3,"type":"RandomForestClassifier","n_estimators":20,"max_depth":10,"n_jobs":-1,"min_samples_leaf":10}
        ],
        "stride_x": 2,
        "stride_y": 2,
        "win_x":7,
        "win_y":7
    },
// win1/10x10
    {
        "type":"FGWinLayer",
        "name":"win1/10x10",
        "bottoms": ["X","y"],
        "tops":["win1/10x10/ets", "win1/10x10/rf"],
        "n_classes": 10,
        "estimators": [
            {"n_folds":3,"type":"ExtraTreesClassifier","n_estimators":20,"max_depth":10,"n_jobs":-1,"min_samples_leaf":10},
            {"n_folds":3,"type":"RandomForestClassifier","n_estimators":20,"max_depth":10,"n_jobs":-1,"min_samples_leaf":10}
        ],
        "stride_x": 2,
        "stride_y": 2,
        "win_x":10,
        "win_y":10
    },
// win1/13x13
    {
        "type":"FGWinLayer",
        "name":"win1/13x13",
        "bottoms": ["X","y"],
        "tops":["win1/13x13/ets", "win1/13x13/rf"],
        "n_classes": 10,
        "estimators": [
            {"n_folds":3,"type":"ExtraTreesClassifier","n_estimators":20,"max_depth":10,"n_jobs":-1,"min_samples_leaf":10},
            {"n_folds":3,"type":"RandomForestClassifier","n_estimators":20,"max_depth":10,"n_jobs":-1,"min_samples_leaf":10}
        ],
        "stride_x": 2,
        "stride_y": 2,
        "win_x":13,
        "win_y":13
    },
// pool1
    {
        "type":"FGPoolLayer",
        "name":"pool1",
        "bottoms": ["win1/7x7/ets", "win1/7x7/rf", "win1/10x10/ets", "win1/10x10/rf", "win1/13x13/ets", "win1/13x13/rf"],
        "tops": ["pool1/7x7/ets", "pool1/7x7/rf", "pool1/10x10/ets", "pool1/10x10/rf", "pool1/13x13/ets", "pool1/13x13/rf"],
        "pool_method": "avg",
        "win_x":2,
        "win_y":2
    }
]

},

"cascade": {
    "random_state": 0,
    "max_layers": 100,
    "early_stopping_rounds": 3,
    "look_indexs_cycle": [
        [0, 1],
        [2, 3],
        [4, 5]
    ],
    "n_classes": 10,
    "estimators": [
        {"n_folds":5,"type":"ExtraTreesClassifier","n_estimators":1000,"max_depth":null,"n_jobs":-1},
        {"n_folds":5,"type":"RandomForestClassifier","n_estimators":1000,"max_depth":null,"n_jobs":-1}
    ]
}
}
上一篇下一篇

猜你喜欢

热点阅读