¿Cómo extraer las reglas de decisión de scikit-learn decision-tree?


157

¿Puedo extraer las reglas de decisión subyacentes (o 'rutas de decisión') de un árbol entrenado en un árbol de decisión como una lista de texto?

Algo como:

if A>0.4 then if B<0.2 then if C>0.8 then class='X'

Gracias por tu ayuda.



¿Alguna vez encontró una respuesta a este problema? Tengo que exportar las reglas del árbol de decisiones en un formato de paso de datos SAS que es casi exactamente como lo tiene en la lista.
Zelazny7

1
Puede usar el paquete sklearn-porter para exportar y transpilar árboles de decisión (también bosque aleatorio y árboles potenciados) a C, Java, JavaScript y otros.
Darius

Respuestas:


139

Creo que esta respuesta es más correcta que las otras respuestas aquí:

from sklearn.tree import _tree

def tree_to_code(tree, feature_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    print "def tree({}):".format(", ".join(feature_names))

    def recurse(node, depth):
        indent = "  " * depth
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            print "{}if {} <= {}:".format(indent, name, threshold)
            recurse(tree_.children_left[node], depth + 1)
            print "{}else:  # if {} > {}".format(indent, name, threshold)
            recurse(tree_.children_right[node], depth + 1)
        else:
            print "{}return {}".format(indent, tree_.value[node])

    recurse(0, 1)

Esto imprime una función Python válida. Aquí hay un ejemplo de salida para un árbol que está tratando de devolver su entrada, un número entre 0 y 10.

def tree(f0):
  if f0 <= 6.0:
    if f0 <= 1.5:
      return [[ 0.]]
    else:  # if f0 > 1.5
      if f0 <= 4.5:
        if f0 <= 3.5:
          return [[ 3.]]
        else:  # if f0 > 3.5
          return [[ 4.]]
      else:  # if f0 > 4.5
        return [[ 5.]]
  else:  # if f0 > 6.0
    if f0 <= 8.5:
      if f0 <= 7.5:
        return [[ 7.]]
      else:  # if f0 > 7.5
        return [[ 8.]]
    else:  # if f0 > 8.5
      return [[ 9.]]

Aquí hay algunos obstáculos que veo en otras respuestas:

  1. Usar tree_.threshold == -2para decidir si un nodo es una hoja no es una buena idea. ¿Qué pasa si es un nodo de decisión real con un umbral de -2? En cambio, debe mirar tree.featureo tree.children_*.
  2. La línea se features = [feature_names[i] for i in tree_.feature]bloquea con mi versión de sklearn, porque algunos valores de tree.tree_.featureson -2 (específicamente para los nodos hoja).
  3. No es necesario tener múltiples sentencias if en la función recursiva, solo una está bien.

1
Este código funciona muy bien para mí. Sin embargo, tengo más de 500 nombres de funciones, por lo que el código de salida es casi imposible de entender para un humano. ¿Hay alguna forma de permitirme ingresar solo los feature_names que me interesan en la función?
user3768495

1
Estoy de acuerdo con el comentario anterior. IIUC, print "{}return {}".format(indent, tree_.value[node])debe cambiarse a print "{}return {}".format(indent, np.argmax(tree_.value[node][0]))para que la función devuelva el índice de clase.
soupault

1
@paulkernfeld Ah, sí, veo que puedes hacer un bucle RandomForestClassifier.estimators_, pero no pude averiguar cómo combinar los resultados de los estimadores.
Nathan Lloyd el

66
No pude hacer que esto funcionara en python 3, los bits _tree no parecen funcionar nunca y TREE_UNDEFINED no estaba definido. Este enlace me ayudó. Si bien el código exportado no se puede ejecutar directamente en Python, es similar a C y bastante fácil de traducir a otros idiomas: web.archive.org/web/20171005203850/http://www.kdnuggets.com/…
Josiah

1
@Josiah, agregue () a las declaraciones de impresión para que funcione en python3. ej. print "bla"=>print("bla")
Nir

48

Creé mi propia función para extraer las reglas de los árboles de decisión creados por sklearn:

import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier

# dummy data:
df = pd.DataFrame({'col1':[0,1,2,3],'col2':[3,4,5,6],'dv':[0,1,0,1]})

# create decision tree
dt = DecisionTreeClassifier(max_depth=5, min_samples_leaf=1)
dt.fit(df.ix[:,:2], df.dv)

Esta función comienza primero con los nodos (identificados por -1 en las matrices secundarias) y luego encuentra recursivamente a los padres. A esto lo llamo el "linaje" de un nodo. En el camino, tomo los valores que necesito para crear la lógica SAS if / then / else:

def get_lineage(tree, feature_names):
     left      = tree.tree_.children_left
     right     = tree.tree_.children_right
     threshold = tree.tree_.threshold
     features  = [feature_names[i] for i in tree.tree_.feature]

     # get ids of child nodes
     idx = np.argwhere(left == -1)[:,0]     

     def recurse(left, right, child, lineage=None):          
          if lineage is None:
               lineage = [child]
          if child in left:
               parent = np.where(left == child)[0].item()
               split = 'l'
          else:
               parent = np.where(right == child)[0].item()
               split = 'r'

          lineage.append((parent, split, threshold[parent], features[parent]))

          if parent == 0:
               lineage.reverse()
               return lineage
          else:
               return recurse(left, right, parent, lineage)

     for child in idx:
          for node in recurse(left, right, child):
               print node

Los conjuntos de tuplas a continuación contienen todo lo que necesito para crear sentencias SAS if / then / else. No me gusta usar dobloques en SAS, por eso creo lógica que describe la ruta completa de un nodo. El número entero único después de las tuplas es la ID del nodo terminal en una ruta. Todas las tuplas anteriores se combinan para crear ese nodo.

In [1]: get_lineage(dt, df.columns)
(0, 'l', 0.5, 'col1')
1
(0, 'r', 0.5, 'col1')
(2, 'l', 4.5, 'col2')
3
(0, 'r', 0.5, 'col1')
(2, 'r', 4.5, 'col2')
(4, 'l', 2.5, 'col1')
5
(0, 'r', 0.5, 'col1')
(2, 'r', 4.5, 'col2')
(4, 'r', 2.5, 'col1')
6

Salida GraphViz del árbol de ejemplo


Es este tipo de árbol es correcto, ya que col1 está llegando de nuevo uno es col1 <= 0.50000 y uno col1 <= 2.5000 en caso afirmativo, ¿es cualquier tipo de recursividad instalacciones se utiliza en la biblioteca
jayant Singh

la rama derecha tendría registros en medio (0.5, 2.5]. Los árboles están hechos con particiones recursivas. No hay nada que impida que una variable se seleccione varias veces.
Zelazny7

bien puede explicar la parte recursividad lo que sucede xactly porque yo he utilizado en mi código y el resultado similar se observa
jayant Singh

38

Modifiqué el código enviado por Zelazny7 para imprimir un pseudocódigo:

def get_code(tree, feature_names):
        left      = tree.tree_.children_left
        right     = tree.tree_.children_right
        threshold = tree.tree_.threshold
        features  = [feature_names[i] for i in tree.tree_.feature]
        value = tree.tree_.value

        def recurse(left, right, threshold, features, node):
                if (threshold[node] != -2):
                        print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
                        if left[node] != -1:
                                recurse (left, right, threshold, features,left[node])
                        print "} else {"
                        if right[node] != -1:
                                recurse (left, right, threshold, features,right[node])
                        print "}"
                else:
                        print "return " + str(value[node])

        recurse(left, right, threshold, features, 0)

si llama get_code(dt, df.columns)al mismo ejemplo obtendrá:

if ( col1 <= 0.5 ) {
return [[ 1.  0.]]
} else {
if ( col2 <= 4.5 ) {
return [[ 0.  1.]]
} else {
if ( col1 <= 2.5 ) {
return [[ 1.  0.]]
} else {
return [[ 0.  1.]]
}
}
}

1
¿Puede decir qué significa exactamente [[1. 0.]] en la declaración de retorno en la salida anterior No soy un tipo de Python, pero estoy trabajando en el mismo tipo de cosas. Por lo tanto, será bueno para mí si prueba algunos detalles para que sea más fácil para mí.
Subhradip Bose

1
@ user3156186 Significa que hay un objeto en la clase '0' y cero objetos en la clase '1'
Daniele

1
@Daniele, ¿sabes cómo se ordenan las clases? Supongo que es alfanumérico, pero no he encontrado confirmación en ningún lado.
IanS

¡Gracias! Para el escenario de caso límite donde el valor umbral es en realidad -2, es posible que necesitemos cambiar (threshold[node] != -2)a ( left[node] != -1)(similar al método siguiente para obtener identificadores de nodos secundarios)
tlingf

@Daniele, ¿alguna idea de cómo hacer que su función "get_code" "devuelva" un valor y no "imprimirlo", porque necesito enviarlo a otra función?
RoyaumeIX

17

Scikit Learn introdujo un nuevo método delicioso llamado export_texten la versión 0.21 (mayo de 2019) para extraer las reglas de un árbol. Documentación aquí . Ya no es necesario crear una función personalizada.

Una vez que haya ajustado su modelo, solo necesita dos líneas de código. Primero, importa export_text:

from sklearn.tree.export import export_text

Segundo, crea un objeto que contendrá tus reglas. Para que las reglas se vean más legibles, use el feature_namesargumento y pase una lista de los nombres de sus características. Por ejemplo, si se llama a su modelo modely sus características se nombran en un marco de datos llamado X_train, puede crear un objeto llamado tree_rules:

tree_rules = export_text(model, feature_names=list(X_train))

Luego simplemente imprima o guarde tree_rules. Su salida se verá así:

|--- Age <= 0.63
|   |--- EstimatedSalary <= 0.61
|   |   |--- Age <= -0.16
|   |   |   |--- class: 0
|   |   |--- Age >  -0.16
|   |   |   |--- EstimatedSalary <= -0.06
|   |   |   |   |--- class: 0
|   |   |   |--- EstimatedSalary >  -0.06
|   |   |   |   |--- EstimatedSalary <= 0.40
|   |   |   |   |   |--- EstimatedSalary <= 0.03
|   |   |   |   |   |   |--- class: 1

14

Hay un nuevo DecisionTreeClassifiermétodo decision_path, en el 0.18.0 versión . Los desarrolladores proporcionan un tutorial extenso (bien documentado) .

La primera sección de código en el tutorial que imprime la estructura de árbol parece estar bien. Sin embargo, modifiqué el código en la segunda sección para interrogar una muestra. Mis cambios denotados con# <--

Editar Los cambios marcados # <--en el código a continuación se han actualizado en el enlace de recorrido después de que se señalaron los errores en las solicitudes de extracción # 8653 y # 10951 . Es mucho más fácil seguirlo ahora.

sample_id = 0
node_index = node_indicator.indices[node_indicator.indptr[sample_id]:
                                    node_indicator.indptr[sample_id + 1]]

print('Rules used to predict sample %s: ' % sample_id)
for node_id in node_index:

    if leave_id[sample_id] == node_id:  # <-- changed != to ==
        #continue # <-- comment out
        print("leaf node {} reached, no decision here".format(leave_id[sample_id])) # <--

    else: # < -- added else to iterate through decision nodes
        if (X_test[sample_id, feature[node_id]] <= threshold[node_id]):
            threshold_sign = "<="
        else:
            threshold_sign = ">"

        print("decision id node %s : (X[%s, %s] (= %s) %s %s)"
              % (node_id,
                 sample_id,
                 feature[node_id],
                 X_test[sample_id, feature[node_id]], # <-- changed i to sample_id
                 threshold_sign,
                 threshold[node_id]))

Rules used to predict sample 0: 
decision id node 0 : (X[0, 3] (= 2.4) > 0.800000011921)
decision id node 2 : (X[0, 2] (= 5.1) > 4.94999980927)
leaf node 4 reached, no decision here

Cambie sample_idpara ver las rutas de decisión para otras muestras. No he preguntado a los desarrolladores acerca de estos cambios, simplemente me pareció más intuitivo al trabajar con el ejemplo.


tu mi amigo eres una leyenda! ¿Alguna idea de cómo trazar el árbol de decisión para esa muestra específica? se agradece mucha ayuda

1
Gracias Victor, probablemente sea mejor hacer esto como una pregunta separada ya que los requisitos de trazado pueden ser específicos para las necesidades de un usuario. Probablemente obtendrá una buena respuesta si proporciona una idea de cómo desea que se vea la salida.
Kevin

Hola Kevin, creé la pregunta stackoverflow.com/questions/48888893/…

sería tan amable de echar un vistazo a: stackoverflow.com/questions/52654280/…
Alexander Chervov

¿Puede explicar la parte llamada node_index, sin obtener esa parte? ¿Qué hace?
Anindya Sankar Dey

12
from StringIO import StringIO
out = StringIO()
out = tree.export_graphviz(clf, out_file=out)
print out.getvalue()

Puedes ver un árbol de dígrafo. Entonces, clf.tree_.featurey clf.tree_.valueson una matriz de nodos que dividen la función y la matriz de valores de nodos respectivamente. Puede consultar más detalles de esta fuente de github .


1
Sí, sé cómo dibujar el árbol, pero necesito la versión más textual: las reglas. algo como: orange.biolab.si/docs/latest/reference/rst/…
Dror Hilman

4

Solo porque todos fueron muy útiles, solo agregaré una modificación a Zelazny7 y las hermosas soluciones de Daniele. Este es para Python 2.7, con pestañas para hacerlo más legible:

def get_code(tree, feature_names, tabdepth=0):
    left      = tree.tree_.children_left
    right     = tree.tree_.children_right
    threshold = tree.tree_.threshold
    features  = [feature_names[i] for i in tree.tree_.feature]
    value = tree.tree_.value

    def recurse(left, right, threshold, features, node, tabdepth=0):
            if (threshold[node] != -2):
                    print '\t' * tabdepth,
                    print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
                    if left[node] != -1:
                            recurse (left, right, threshold, features,left[node], tabdepth+1)
                    print '\t' * tabdepth,
                    print "} else {"
                    if right[node] != -1:
                            recurse (left, right, threshold, features,right[node], tabdepth+1)
                    print '\t' * tabdepth,
                    print "}"
            else:
                    print '\t' * tabdepth,
                    print "return " + str(value[node])

    recurse(left, right, threshold, features, 0)

3

Los códigos a continuación son mi enfoque bajo anaconda python 2.7 más un nombre de paquete "pydot-ng" para hacer un archivo PDF con reglas de decisión. Espero que sea útil.

from sklearn import tree

clf = tree.DecisionTreeClassifier(max_leaf_nodes=n)
clf_ = clf.fit(X, data_y)

feature_names = X.columns
class_name = clf_.classes_.astype(int).astype(str)

def output_pdf(clf_, name):
    from sklearn import tree
    from sklearn.externals.six import StringIO
    import pydot_ng as pydot
    dot_data = StringIO()
    tree.export_graphviz(clf_, out_file=dot_data,
                         feature_names=feature_names,
                         class_names=class_name,
                         filled=True, rounded=True,
                         special_characters=True,
                          node_ids=1,)
    graph = pydot.graph_from_dot_data(dot_data.getvalue())
    graph.write_pdf("%s.pdf"%name)

output_pdf(clf_, name='filename%s'%n)

un gráfico de árbol que se muestra aquí


3

He estado pasando por esto, pero necesitaba que las reglas se escribieran en este formato

if A>0.4 then if B<0.2 then if C>0.8 then class='X' 

Así que adapté la respuesta de @paulkernfeld (gracias) que puedes personalizar según tus necesidades

def tree_to_code(tree, feature_names, Y):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    pathto=dict()

    global k
    k = 0
    def recurse(node, depth, parent):
        global k
        indent = "  " * depth

        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            s= "{} <= {} ".format( name, threshold, node )
            if node == 0:
                pathto[node]=s
            else:
                pathto[node]=pathto[parent]+' & ' +s

            recurse(tree_.children_left[node], depth + 1, node)
            s="{} > {}".format( name, threshold)
            if node == 0:
                pathto[node]=s
            else:
                pathto[node]=pathto[parent]+' & ' +s
            recurse(tree_.children_right[node], depth + 1, node)
        else:
            k=k+1
            print(k,')',pathto[parent], tree_.value[node])
    recurse(0, 1, 0)

3

Aquí hay una manera de traducir todo el árbol en una sola expresión de Python (no necesariamente legible para humanos) utilizando la biblioteca SKompiler :

from skompiler import skompile
skompile(dtree.predict).to('python/code')

3

Esto se basa en la respuesta de @paulkernfeld. Si tiene un marco de datos X con sus características y un marco de datos de destino y con sus resones y desea hacerse una idea de qué valor de y terminó en qué nodo (y también hormiga para trazarlo en consecuencia), puede hacer lo siguiente:

    def tree_to_code(tree, feature_names):
        from sklearn.tree import _tree
        codelines = []
        codelines.append('def get_cat(X_tmp):\n')
        codelines.append('   catout = []\n')
        codelines.append('   for codelines in range(0,X_tmp.shape[0]):\n')
        codelines.append('      Xin = X_tmp.iloc[codelines]\n')
        tree_ = tree.tree_
        feature_name = [
            feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
            for i in tree_.feature
        ]
        #print "def tree({}):".format(", ".join(feature_names))

        def recurse(node, depth):
            indent = "      " * depth
            if tree_.feature[node] != _tree.TREE_UNDEFINED:
                name = feature_name[node]
                threshold = tree_.threshold[node]
                codelines.append ('{}if Xin["{}"] <= {}:\n'.format(indent, name, threshold))
                recurse(tree_.children_left[node], depth + 1)
                codelines.append( '{}else:  # if Xin["{}"] > {}\n'.format(indent, name, threshold))
                recurse(tree_.children_right[node], depth + 1)
            else:
                codelines.append( '{}mycat = {}\n'.format(indent, node))

        recurse(0, 1)
        codelines.append('      catout.append(mycat)\n')
        codelines.append('   return pd.DataFrame(catout,index=X_tmp.index,columns=["category"])\n')
        codelines.append('node_ids = get_cat(X)\n')
        return codelines
    mycode = tree_to_code(clf,X.columns.values)

    # now execute the function and obtain the dataframe with all nodes
    exec(''.join(mycode))
    node_ids = [int(x[0]) for x in node_ids.values]
    node_ids2 = pd.DataFrame(node_ids)

    print('make plot')
    import matplotlib.cm as cm
    colors = cm.rainbow(np.linspace(0, 1, 1+max( list(set(node_ids)))))
    #plt.figure(figsize=cm2inch(24, 21))
    for i in list(set(node_ids)):
        plt.plot(y[node_ids2.values==i],'o',color=colors[i], label=str(i))  
    mytitle = ['y colored by node']
    plt.title(mytitle ,fontsize=14)
    plt.xlabel('my xlabel')
    plt.ylabel(tagname)
    plt.xticks(rotation=70)       
    plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.00), shadow=True, ncol=9)
    plt.tight_layout()
    plt.show()
    plt.close 

no es la versión más elegante pero hace el trabajo ...


1
Este es un buen enfoque cuando desea devolver las líneas de código en lugar de simplemente imprimirlas.
Hajar Homayouni

3

Este es el código que necesitas

He modificado el código que más me gustó para sangrar en un jupyter notebook python 3 correctamente

import numpy as np
from sklearn.tree import _tree

def tree_to_code(tree, feature_names):
    tree_ = tree.tree_
    feature_name = [feature_names[i] 
                    if i != _tree.TREE_UNDEFINED else "undefined!" 
                    for i in tree_.feature]
    print("def tree({}):".format(", ".join(feature_names)))

    def recurse(node, depth):
        indent = "    " * depth
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            print("{}if {} <= {}:".format(indent, name, threshold))
            recurse(tree_.children_left[node], depth + 1)
            print("{}else:  # if {} > {}".format(indent, name, threshold))
            recurse(tree_.children_right[node], depth + 1)
        else:
            print("{}return {}".format(indent, np.argmax(tree_.value[node])))

    recurse(0, 1)

2

Aquí hay una función, imprimir reglas de un árbol de decisión de scikit-learn en python 3 y con compensaciones para bloques condicionales para hacer que la estructura sea más legible:

def print_decision_tree(tree, feature_names=None, offset_unit='    '):
    '''Plots textual representation of rules of a decision tree
    tree: scikit-learn representation of tree
    feature_names: list of feature names. They are set to f1,f2,f3,... if not specified
    offset_unit: a string of offset of the conditional block'''

    left      = tree.tree_.children_left
    right     = tree.tree_.children_right
    threshold = tree.tree_.threshold
    value = tree.tree_.value
    if feature_names is None:
        features  = ['f%d'%i for i in tree.tree_.feature]
    else:
        features  = [feature_names[i] for i in tree.tree_.feature]        

    def recurse(left, right, threshold, features, node, depth=0):
            offset = offset_unit*depth
            if (threshold[node] != -2):
                    print(offset+"if ( " + features[node] + " <= " + str(threshold[node]) + " ) {")
                    if left[node] != -1:
                            recurse (left, right, threshold, features,left[node],depth+1)
                    print(offset+"} else {")
                    if right[node] != -1:
                            recurse (left, right, threshold, features,right[node],depth+1)
                    print(offset+"}")
            else:
                    print(offset+"return " + str(value[node]))

    recurse(left, right, threshold, features, 0,0)

2

También puede hacerlo más informativo distinguiéndolo a qué clase pertenece o incluso mencionando su valor de salida.

def print_decision_tree(tree, feature_names, offset_unit='    '):    
left      = tree.tree_.children_left
right     = tree.tree_.children_right
threshold = tree.tree_.threshold
value = tree.tree_.value
if feature_names is None:
    features  = ['f%d'%i for i in tree.tree_.feature]
else:
    features  = [feature_names[i] for i in tree.tree_.feature]        

def recurse(left, right, threshold, features, node, depth=0):
        offset = offset_unit*depth
        if (threshold[node] != -2):
                print(offset+"if ( " + features[node] + " <= " + str(threshold[node]) + " ) {")
                if left[node] != -1:
                        recurse (left, right, threshold, features,left[node],depth+1)
                print(offset+"} else {")
                if right[node] != -1:
                        recurse (left, right, threshold, features,right[node],depth+1)
                print(offset+"}")
        else:
                #print(offset,value[node]) 

                #To remove values from node
                temp=str(value[node])
                mid=len(temp)//2
                tempx=[]
                tempy=[]
                cnt=0
                for i in temp:
                    if cnt<=mid:
                        tempx.append(i)
                        cnt+=1
                    else:
                        tempy.append(i)
                        cnt+=1
                val_yes=[]
                val_no=[]
                res=[]
                for j in tempx:
                    if j=="[" or j=="]" or j=="." or j==" ":
                        res.append(j)
                    else:
                        val_no.append(j)
                for j in tempy:
                    if j=="[" or j=="]" or j=="." or j==" ":
                        res.append(j)
                    else:
                        val_yes.append(j)
                val_yes = int("".join(map(str, val_yes)))
                val_no = int("".join(map(str, val_no)))

                if val_yes>val_no:
                    print(offset,'\033[1m',"YES")
                    print('\033[0m')
                elif val_no>val_yes:
                    print(offset,'\033[1m',"NO")
                    print('\033[0m')
                else:
                    print(offset,'\033[1m',"Tie")
                    print('\033[0m')

recurse(left, right, threshold, features, 0,0)

ingrese la descripción de la imagen aquí


2

Aquí está mi enfoque para extraer las reglas de decisión en una forma que se pueda usar directamente en sql, para que los datos se puedan agrupar por nodo. (Basado en los enfoques de los carteles anteriores).

El resultado serán CASEcláusulas posteriores que se pueden copiar en una instrucción sql, ej.

SELECT COALESCE(*CASE WHEN <conditions> THEN > <NodeA>*, > *CASE WHEN <conditions> THEN <NodeB>*, > ....)NodeName,* > FROM <table or view>


import numpy as np

import pickle
feature_names=.............
features  = [feature_names[i] for i in range(len(feature_names))]
clf= pickle.loads(trained_model)
impurity=clf.tree_.impurity
importances = clf.feature_importances_
SqlOut=""

#global Conts
global ContsNode
global Path
#Conts=[]#
ContsNode=[]
Path=[]
global Results
Results=[]

def print_decision_tree(tree, feature_names, offset_unit=''    ''):    
    left      = tree.tree_.children_left
    right     = tree.tree_.children_right
    threshold = tree.tree_.threshold
    value = tree.tree_.value

    if feature_names is None:
        features  = [''f%d''%i for i in tree.tree_.feature]
    else:
        features  = [feature_names[i] for i in tree.tree_.feature]        

    def recurse(left, right, threshold, features, node, depth=0,ParentNode=0,IsElse=0):
        global Conts
        global ContsNode
        global Path
        global Results
        global LeftParents
        LeftParents=[]
        global RightParents
        RightParents=[]
        for i in range(len(left)): # This is just to tell you how to create a list.
            LeftParents.append(-1)
            RightParents.append(-1)
            ContsNode.append("")
            Path.append("")


        for i in range(len(left)): # i is node
            if (left[i]==-1 and right[i]==-1):      
                if LeftParents[i]>=0:
                    if Path[LeftParents[i]]>" ":
                        Path[i]=Path[LeftParents[i]]+" AND " +ContsNode[LeftParents[i]]                                 
                    else:
                        Path[i]=ContsNode[LeftParents[i]]                                   
                if RightParents[i]>=0:
                    if Path[RightParents[i]]>" ":
                        Path[i]=Path[RightParents[i]]+" AND not " +ContsNode[RightParents[i]]                                   
                    else:
                        Path[i]=" not " +ContsNode[RightParents[i]]                     
                Results.append(" case when  " +Path[i]+"  then ''" +"{:4d}".format(i)+ " "+"{:2.2f}".format(impurity[i])+" "+Path[i][0:180]+"''")

            else:       
                if LeftParents[i]>=0:
                    if Path[LeftParents[i]]>" ":
                        Path[i]=Path[LeftParents[i]]+" AND " +ContsNode[LeftParents[i]]                                 
                    else:
                        Path[i]=ContsNode[LeftParents[i]]                                   
                if RightParents[i]>=0:
                    if Path[RightParents[i]]>" ":
                        Path[i]=Path[RightParents[i]]+" AND not " +ContsNode[RightParents[i]]                                   
                    else:
                        Path[i]=" not "+ContsNode[RightParents[i]]                      
                if (left[i]!=-1):
                    LeftParents[left[i]]=i
                if (right[i]!=-1):
                    RightParents[right[i]]=i
                ContsNode[i]=   "( "+ features[i] + " <= " + str(threshold[i])   + " ) "

    recurse(left, right, threshold, features, 0,0,0,0)
print_decision_tree(clf,features)
SqlOut=""
for i in range(len(Results)): 
    SqlOut=SqlOut+Results[i]+ " end,"+chr(13)+chr(10)

1

Ahora puede usar export_text.

from sklearn.tree import export_text

r = export_text(loan_tree, feature_names=(list(X_train.columns)))
print(r)

Un ejemplo completo de [sklearn] [1]

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_text
iris = load_iris()
X = iris['data']
y = iris['target']
decision_tree = DecisionTreeClassifier(random_state=0, max_depth=2)
decision_tree = decision_tree.fit(X, y)
r = export_text(decision_tree, feature_names=iris['feature_names'])
print(r)

0

Se modificó el código de Zelazny7 para obtener SQL del árbol de decisión.

# SQL from decision tree

def get_lineage(tree, feature_names):
     left      = tree.tree_.children_left
     right     = tree.tree_.children_right
     threshold = tree.tree_.threshold
     features  = [feature_names[i] for i in tree.tree_.feature]
     le='<='               
     g ='>'
     # get ids of child nodes
     idx = np.argwhere(left == -1)[:,0]     

     def recurse(left, right, child, lineage=None):          
          if lineage is None:
               lineage = [child]
          if child in left:
               parent = np.where(left == child)[0].item()
               split = 'l'
          else:
               parent = np.where(right == child)[0].item()
               split = 'r'
          lineage.append((parent, split, threshold[parent], features[parent]))
          if parent == 0:
               lineage.reverse()
               return lineage
          else:
               return recurse(left, right, parent, lineage)
     print 'case '
     for j,child in enumerate(idx):
        clause=' when '
        for node in recurse(left, right, child):
            if len(str(node))<3:
                continue
            i=node
            if i[1]=='l':  sign=le 
            else: sign=g
            clause=clause+i[3]+sign+str(i[2])+' and '
        clause=clause[:-4]+' then '+str(j)
        print clause
     print 'else 99 end as clusters'

0

Aparentemente, hace mucho tiempo, alguien ya decidió intentar agregar la siguiente función a las funciones de exportación del árbol de scikit oficial (que básicamente solo admite export_graphviz)

def export_dict(tree, feature_names=None, max_depth=None) :
    """Export a decision tree in dict format.

Aquí está su compromiso completo:

https://github.com/scikit-learn/scikit-learn/blob/79bdc8f711d0af225ed6be9fdb708cea9f98a910/sklearn/tree/export.py

No estoy seguro de lo que sucedió con este comentario. Pero también podría intentar usar esa función.

Creo que esto garantiza una solicitud de documentación seria a las buenas personas de scikit-learn para documentar adecuadamente la sklearn.tree.TreeAPI, que es la estructura de árbol subyacente que se DecisionTreeClassifierexpone como su atributo tree_.


0

Simplemente use la función de sklearn.tree como esta

from sklearn.tree import export_graphviz
    export_graphviz(tree,
                out_file = "tree.dot",
                feature_names = tree.columns) //or just ["petal length", "petal width"]

Y luego busque en la carpeta de su proyecto el archivo tree.dot , copie TODO el contenido y péguelo aquí http://www.webgraphviz.com/ y genere su gráfico :)


0

Gracias por la maravillosa solución de @paulkerfeld. En la parte superior de su solución, para todos aquellos que quieren tener una versión serializada de árboles, sólo tiene que utilizar tree.threshold, tree.children_left, tree.children_right, tree.featurey tree.value. Dado que las hojas no tienen divisiones y, por lo tanto, no tienen nombres de características y elementos secundarios, su marcador de posición en tree.featurey tree.children_***son _tree.TREE_UNDEFINEDy _tree.TREE_LEAF. A cada división se le asigna un índice único por depth first search.
Tenga en cuenta que tree.valuees de forma[n, 1, 1]


0

Aquí hay una función que genera código Python a partir de un árbol de decisión al convertir la salida de export_text:

import string
from sklearn.tree import export_text

def export_py_code(tree, feature_names, max_depth=100, spacing=4):
    if spacing < 2:
        raise ValueError('spacing must be > 1')

    # Clean up feature names (for correctness)
    nums = string.digits
    alnums = string.ascii_letters + nums
    clean = lambda s: ''.join(c if c in alnums else '_' for c in s)
    features = [clean(x) for x in feature_names]
    features = ['_'+x if x[0] in nums else x for x in features if x]
    if len(set(features)) != len(feature_names):
        raise ValueError('invalid feature names')

    # First: export tree to text
    res = export_text(tree, feature_names=features, 
                        max_depth=max_depth,
                        decimals=6,
                        spacing=spacing-1)

    # Second: generate Python code from the text
    skip, dash = ' '*spacing, '-'*(spacing-1)
    code = 'def decision_tree({}):\n'.format(', '.join(features))
    for line in repr(tree).split('\n'):
        code += skip + "# " + line + '\n'
    for line in res.split('\n'):
        line = line.rstrip().replace('|',' ')
        if '<' in line or '>' in line:
            line, val = line.rsplit(maxsplit=1)
            line = line.replace(' ' + dash, 'if')
            line = '{} {:g}:'.format(line, float(val))
        else:
            line = line.replace(' {} class:'.format(dash), 'return')
        code += skip + line + '\n'

    return code

Uso de la muestra:

res = export_py_code(tree, feature_names=names, spacing=4)
print (res)

Salida de muestra:

def decision_tree(f1, f2, f3):
    # DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=3,
    #                        max_features=None, max_leaf_nodes=None,
    #                        min_impurity_decrease=0.0, min_impurity_split=None,
    #                        min_samples_leaf=1, min_samples_split=2,
    #                        min_weight_fraction_leaf=0.0, presort=False,
    #                        random_state=42, splitter='best')
    if f1 <= 12.5:
        if f2 <= 17.5:
            if f1 <= 10.5:
                return 2
            if f1 > 10.5:
                return 3
        if f2 > 17.5:
            if f2 <= 22.5:
                return 1
            if f2 > 22.5:
                return 1
    if f1 > 12.5:
        if f1 <= 17.5:
            if f3 <= 23.5:
                return 2
            if f3 > 23.5:
                return 3
        if f1 > 17.5:
            if f1 <= 25:
                return 1
            if f1 > 25:
                return 2

El ejemplo anterior se genera con names = ['f'+str(j+1) for j in range(NUM_FEATURES)] .

Una característica útil es que puede generar un tamaño de archivo más pequeño con un espacio reducido. Solo establece spacing=2.

Al usar nuestro sitio, usted reconoce que ha leído y comprende nuestra Política de Cookies y Política de Privacidad.
Licensed under cc by-sa 3.0 with attribution required.