如何从使用scikit-learn创建的决策树中获取区间限制?

huangapple go评论80阅读模式
英文:

How to obtain the interval limits from a decision tree with scikit-learn?

问题

你可以通过以下方式获取类似所需的分割点数组:

import numpy as np

# 你的树节点分割描述
tree_structure = "|--- feature_0 <= 0.08\n|   |--- class: 0\n|--- feature_0 > 0.08\n|   |--- feature_0 <= 8.50\n|   |   |--- feature_0 <= 1.50\n|   |   |   |--- class: 1\n|   |   |--- feature_0 > 1.50\n|   |   |   |--- class: 1\n|   |--- feature_0 > 8.50\n|   |   |--- feature_0 <= 60.25\n|   |   |   |--- class: 0\n|   |   |--- feature_0 > 60.25\n|   |   |   |--- class: 0"

# 从树结构中提取分割点限制
limits = []
lines = tree_structure.split("\n")
for line in lines:
    if "<=" in line or ">" in line:
        limit = float(line.split()[-1])
        limits.append(limit)

# 添加负无穷和正无穷
limits.insert(0, -np.inf)
limits.append(np.inf)

# 打印结果
print(limits)

这段代码会从你的决策树结构描述中提取出分割点的限制,并添加负无穷和正无穷作为数组的第一个和最后一个元素,得到你所需的分割点数组。

英文:

Say I am using the titanic dataset, with the variable age only:

import pandas as pd

data = pd.read_csv(&#39;https://www.openml.org/data/get_csv/16826755/phpMYEkMl&#39;)[[&quot;age&quot;, &quot;survived&quot;]]
data = data.replace(&#39;?&#39;, np.nan)
data = data.fillna(0)
print(data)

the result:

         age  survived
0         29         1
1     0.9167         1
2          2         0
3         30         0
4         25         0
...      ...       ...
1304    14.5         0
1305       0         0
1306    26.5         0
1307      27         0
1308      29         0

[1309 rows x 2 columns]

Now I train a decision tree to predict survival from age:

from sklearn.tree import DecisionTreeClassifier
tree_model = DecisionTreeClassifier(max_depth=3)
tree_model.fit(data[&#39;age&#39;].to_frame(),data[&quot;survived&quot;])

And if I print the structure of the tree:

from sklearn import tree
print(tree.export_text(tree_model))

I obtain:

|--- feature_0 &lt;= 0.08
|   |--- class: 0
|--- feature_0 &gt;  0.08
|   |--- feature_0 &lt;= 8.50
|   |   |--- feature_0 &lt;= 1.50
|   |   |   |--- class: 1
|   |   |--- feature_0 &gt;  1.50
|   |   |   |--- class: 1
|   |--- feature_0 &gt;  8.50
|   |   |--- feature_0 &lt;= 60.25
|   |   |   |--- class: 0
|   |   |--- feature_0 &gt;  60.25
|   |   |   |--- class: 0

These means that the final division for every node is:

0-0.08 ; 0.08-1.50; 1.50-8.50 ; 8.50-60; >60

My question is, how can I capture those limits in an array that looks like this:

[-np.inf, 0.08, 1.5, 8.5, 60, np.inf]

Thank you!

答案1

得分: 3

决策分类器在这种情况下 `tree_model` 具有名为 `tree_` 的属性允许访问底层属性

print(tree_model.tree_.threshold)

array([ 0.08335, -2.     ,  8.5    ,  1.5    , -2.     , -2.     ,
       60.25   , -2.     , -2.     ])

print(tree_model.tree_.feature)

array([ 0, -2,  0,  0, -2, -2,  0, -2, -2], dtype=int64)

`feature``threshold` 数组仅适用于分裂节点因此这些数组中的叶节点的值是任意的

要获取特征的分割/阈值可以使用 `feature` 数组过滤阈值
``` python
threshold = tree_model.tree_.threshold
feature = tree_model.tree_.feature
feature_threshold = threshold[feature == 0]
thresholds = sorted(feature_threshold)
print(thresholds)

[0.08335000276565552, 1.5, 8.5, 60.25]

要使用 np.inf,您需要自行添加。

thresholds = [-np.inf] + thresholds + [np.inf]
print(thresholds)

[-inf, 0.08335000276565552, 1.5, 8.5, 60.25, inf]

参考:理解决策树结构


<details>
<summary>英文:</summary>

The decision classifier, in this case `tree_model` has an attribute called `tree_` which allows access to low level attributes.

``` python
print(tree_model.tree_.threshold)

array([ 0.08335, -2.     ,  8.5    ,  1.5    , -2.     , -2.     ,
       60.25   , -2.     , -2.     ])
print(tree_model.tree_.feature)

array([ 0, -2,  0,  0, -2, -2,  0, -2, -2], dtype=int64)

The arrays feature and threshold only apply to split nodes. The values for leaf nodes in these arrays are therefore arbitrary.

To get the division/threshold of a feature, you can filter the threshold using the feature array.

threshold = tree_model.tree_.threshold
feature = tree_model.tree_.feature
feature_threshold = threshold[feature == 0]
thresholds = sorted(feature_threshold)
print(thresholds)

[0.08335000276565552, 1.5, 8.5, 60.25]

To have np.inf, you need to add it yourself.

thresholds = [-np.inf] + thresholds + [np.inf]
print(thresholds)

[-inf, 0.08335000276565552, 1.5, 8.5, 60.25, inf]

Reference: Understanding the decision tree structure.

huangapple
  • 本文由 发表于 2023年3月7日 22:47:26
  • 转载请务必保留本文链接:https://go.coder-hub.com/75663472.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定