重写lore estimator
2018-08-12 本文已影响15人
Helen_Cat
import inspect
import logging
import warnings
import threading
import lore.env
from lore.util import timed
lore.env.require(
)
from lore.estimators.xgboost import Base
from sklearn.ensemble import RandomForestRegressor
logger=logging.getLogger(__name__)
class RFRegression(Base,RandomForestRegressor):
def __init__(self,
base_estimator,
n_estimators=10,
estimator_params=tuple(),
bootstrap=False,
oob_score=False,
n_jobs=1,
random_state=None,
verbose=0,
warm_start=False,
criterion = 'gini',
max_depth = None,
min_samples_split = 2,
min_sample_leaf = 1,
min_weight_fraction_leaf = 0,
max_features = "auto",
max_leaf_nodes = None,
min_impurity_decrease = 0.,
min_impurity_split = None,
class_weight = None,
**kwargs
):
kwargs=locals()
kwargs.pop('self')
kwargs.pop('__class__',None)
kwargs=dict(kwargs,**(kwargs.pop('kwargs',{})))
if 'random_state' not in kwargs and 'seed' in kwargs:
kwargs['random_state']=kwargs.pop('seed')
if 'n_jobs' not in kwargs and 'nthread' in kwargs:
kwargs['n_jobs']=kwargs.pop('nthread')
super(RFRegression,self).__init__(
base_estimator,
n_estimators=n_estimators,
estimator_params=estimator_params,
bootstrap=bootstrap,
oob_score=oob_score,
n_jobs=n_jobs,
random_state=random_state,
verbose=verbose,
warm_start=warm_start,
criterion = criterion,
max_depth = max_depth,
min_samples_split = min_samples_split,
min_sample_leaf = min_sample_leaf,
min_weight_fraction_leaf = min_weight_fraction_leaf,
max_features = max_features,
max_leaf_nodes = max_leaf_nodes,
min_impurity_decrease = min_impurity_decrease,
min_impurity_split = min_impurity_split,
class_weight = class_weight
)