XGBoostのfeature_importantの計算方法

概要

xgboostのfeature_importantについて。公式サイトに詳しく書かれれておらず、検索してもなかなかヒットしなかったので、メモ。

環境

macOS Mojave 10.14.5
python 3.7.4
xgboost 0.90

feature_importantの分類

feature_importantの計算方法は大きく分けて3つに分類できる。

weight

全ツリーに対するその特徴量の出現回数。わかりやすい。

gain

全ツリーに対するその特徴量で分岐した時の目的関数の増減。xgboost の原理を勉強しないと理解しにくい。

cover

全ツリーに対するその特徴量で分岐した時の目的関数の二次勾配の総和。xgboostの原理を理解していないと理解しにくい。しかし、目的関数がsquared errorの場合は二次勾配が1になるので、二次勾配の総和はサンプル数の総和と等しくなる。

具体的な計算方法

xgboostのスクリプトを追っていくことで計算方法を確認する。

feature_importances_関数

feature_importances_関数はpython/site-packages/xgboost/sklearn.pyにある。ここでは、get_score関数から得られた値を規格化し、その値を割合として返している。open

def feature_importances_(self):
    """
    Feature importances property
                                                                                          
    .. note:: Feature importance is defined only for tree boosters
                                                                                          
        Feature importance is only defined when the decision tree model is chosen as base
        learner (`booster=gbtree`). It is not defined for other base learner types, such
        as linear learners (`booster=gblinear`).
                                                                                          
    Returns
    -------
    feature_importances_ : array of shape ``[n_features]``
                                                                                          
    """
    if getattr(self, 'booster', None) is not None and self.booster != 'gbtree':
        raise AttributeError('Feature importance is not defined for Booster type {}'
                             .format(self.booster))
    b = self.get_booster()
    score = b.get_score(importance_type=self.importance_type)
    all_features = [score.get(f, 0.) for f in b.feature_names]
    all_features = np.array(all_features, dtype=np.float32)
    return all_features / all_features.sum()
close

get_score関数

get_score関数はpython/site-packages/xgboost/core.pyにある。ここでは、作成された木を全探索して情報を取得している。open

def get_score(self, fmap='', importance_type='weight'):
    """Get feature importance of each feature.
    Importance type can be defined as:
                                                                                              
    * 'weight': the number of times a feature is used to split the data across all trees.
    * 'gain': the average gain across all splits the feature is used in.
    * 'cover': the average coverage across all splits the feature is used in.
    * 'total_gain': the total gain across all splits the feature is used in.
    * 'total_cover': the total coverage across all splits the feature is used in.
                                                                                              
    .. note:: Feature importance is defined only for tree boosters
                                                                                              
        Feature importance is only defined when the decision tree model is chosen as base
        learner (`booster=gbtree`). It is not defined for other base learner types, such
        as linear learners (`booster=gblinear`).
                                                                                              
    Parameters
    ----------
    fmap: str (optional)
       The name of feature map file.
    importance_type: str, default 'weight'
        One of the importance types defined above.
    """
    if getattr(self, 'booster', None) is not None and self.booster not in {'gbtree', 'dart'}:
        raise ValueError('Feature importance is not defined for Booster type {}'
                         .format(self.booster))
                                                                                              
    allowed_importance_types = ['weight', 'gain', 'cover', 'total_gain', 'total_cover']
    if importance_type not in allowed_importance_types:
        msg = ("importance_type mismatch, got '{}', expected one of " +
               repr(allowed_importance_types))
        raise ValueError(msg.format(importance_type))
                                                                                              
    # if it's weight, then omap stores the number of missing values
    if importance_type == 'weight':
        # do a simpler tree dump to save time
        trees = self.get_dump(fmap, with_stats=False)
                                                                                              
        fmap = {}
        for tree in trees:
            for line in tree.split('\n'):
                # look for the opening square bracket
                arr = line.split('[')
                # if no opening bracket (leaf node), ignore this line
                if len(arr) == 1:
                    continue
                                                                                              
                # extract feature name from string between []
                fid = arr[1].split(']')[0].split('<')[0]
                                                                                              
                if fid not in fmap:
                    # if the feature hasn't been seen yet
                    fmap[fid] = 1
                else:
                    fmap[fid] += 1
                                                                                              
        return fmap
                                                                                              
    average_over_splits = True
    if importance_type == 'total_gain':
        importance_type = 'gain'
        average_over_splits = False
    elif importance_type == 'total_cover':
        importance_type = 'cover'
        average_over_splits = False
                                                                                              
    trees = self.get_dump(fmap, with_stats=True)
                                                                                              
    importance_type += '='
    fmap = {}
    gmap = {}
    for tree in trees:
        for line in tree.split('\n'):
            # look for the opening square bracket
            arr = line.split('[')
            # if no opening bracket (leaf node), ignore this line
            if len(arr) == 1:
                continue
                                                                                              
            # look for the closing bracket, extract only info within that bracket
            fid = arr[1].split(']')
                                                                                              
            # extract gain or cover from string after closing bracket
            g = float(fid[1].split(importance_type)[1].split(',')[0])
                                                                                              
            # extract feature name from string before closing bracket
            fid = fid[0].split('<')[0]
                                                                                              
            if fid not in fmap:
                # if the feature hasn't been seen yet
                fmap[fid] = 1
                gmap[fid] = g
            else:
                fmap[fid] += 1
                gmap[fid] += g
                                                                                              
    # calculate average value (gain/cover) for each feature
    if average_over_splits:
        for fid in gmap:
            gmap[fid] = gmap[fid] / fmap[fid]
                                                                                              
    return gmap
close

作成された木の確認方法

具体的に作成された木の定義を確認したいときは以下のようにする。

trees = best_model.get_booster().get_dump('', with_stats=True)
print(trees[0])

出力例

0:[f5<0.755322099] yes=1,no=2,missing=1,gain=2173.84375,cover=2977
    1:[f8<-0.585839391] yes=3,no=4,missing=3,gain=47.46875,cover=2246
        3:[f3<0.0519423634] yes=7,no=8,missing=7,gain=94.8320312,cover=345
            7:[f23<-3.19843745] yes=13,no=14,missing=13,gain=39.0761719,cover=81
                13:leaf=-0.0423907638,cover=9
                14:[f3<-0.00912035629] yes=21,no=22,missing=21,gain=1.49291992,cover=72
                    21:leaf=-0.0760522038,cover=66
                    22:leaf=-0.0463605523,cover=6
            8:leaf=-0.0875011981,cover=264
        4:leaf=-0.0899999961,cover=1901
    2:[f5<1.22153115] yes=5,no=6,missing=5,gain=144.21875,cover=731
        5:[f23<0.321514517] yes=9,no=10,missing=9,gain=74.71875,cover=531
            9:[f20<-1.14524913] yes=15,no=16,missing=15,gain=39.2675781,cover=222
                15:leaf=-0.0813433379,cover=46
                16:[f31<2.09428215] yes=23,no=24,missing=23,gain=22.152832,cover=176
                    23:[f35<-0.367280573] yes=27,no=28,missing=27,gain=12.0170898,cover=169
                        27:[f20<-0.726363778] yes=29,no=30,missing=29,gain=9.59008789,cover=68
                            29:leaf=-0.0391683429,cover=9
                            30:leaf=-0.0619690083,cover=59
                        28:leaf=-0.0712869242,cover=101
                    24:leaf=-0.0376744941,cover=7
            10:[f3<-0.513715267] yes=17,no=18,missing=17,gain=50.3417969,cover=309
                17:[f24<-1.03701448] yes=25,no=26,missing=25,gain=1.18652344,cover=142
                    25:leaf=-0.0647445917,cover=56
                    26:leaf=-0.0772461221,cover=86
                18:leaf=-0.084925279,cover=167
        6:[f21<-0.358709186] yes=11,no=12,missing=11,gain=58.1689453,cover=200
            11:leaf=-0.0578844398,cover=120
            12:[f8<0.016414443] yes=19,no=20,missing=19,gain=9.87451172,cover=80
                19:leaf=-0.0619722642,cover=37
                20:leaf=-0.0793117732,cover=43

get_score関数はこの文字列を探索してgaincoverを計算している。今回は目的関数がsquared errorなので上記のcoverはサンプル数と等しくなる。こうして見るとわかるとおり、木の定義では特徴量の名前が保存されていない。