ROC promedio para validación cruzada repetida 10 veces con estimaciones de probabilidad


15

Estoy planeando usar validación cruzada estratificada repetida (10 veces) en aproximadamente 10,000 casos usando el algoritmo de aprendizaje automático. Cada vez que la repetición se realizará con diferentes semillas al azar.

En este proceso, creo 10 instancias de estimaciones de probabilidad para cada caso. 1 instancia de estimación de probabilidad para cada una de las 10 repeticiones de la validación cruzada 10 veces

¿Puedo promediar 10 probabilidades para cada caso y luego crear una nueva curva ROC promedio (que represente los resultados de 10 CV repetidos), que se puede comparar con otras curvas ROC mediante comparaciones pareadas?

Respuestas:


13

Según su descripción, parece tener mucho sentido: no solo puede calcular la curva ROC media, sino también la varianza a su alrededor para generar intervalos de confianza. Debería darle la idea de cuán estable es su modelo.

Por ejemplo, así:

ingrese la descripción de la imagen aquí

Aquí pongo curvas ROC individuales, así como la curva media y los intervalos de confianza. Hay áreas donde las curvas están de acuerdo, por lo que tenemos menos varianza, y hay áreas en las que no están de acuerdo.

Para CV repetido, puede repetirlo varias veces y obtener el promedio total en todos los pliegues individuales:

ingrese la descripción de la imagen aquí

Es bastante similar a la imagen anterior, pero ofrece estimaciones más estables (es decir, confiables) de la media y la varianza.

Aquí está el código para obtener la trama:

import matplotlib.pyplot as plt
import numpy as np
from scipy import interp

from sklearn.datasets import make_classification
from sklearn.cross_validation import KFold
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_curve

X, y = make_classification(n_samples=500, random_state=100, flip_y=0.3)

kf = KFold(n=len(y), n_folds=10)

tprs = []
base_fpr = np.linspace(0, 1, 101)

plt.figure(figsize=(5, 5))

for i, (train, test) in enumerate(kf):
    model = LogisticRegression().fit(X[train], y[train])
    y_score = model.predict_proba(X[test])
    fpr, tpr, _ = roc_curve(y[test], y_score[:, 1])

    plt.plot(fpr, tpr, 'b', alpha=0.15)
    tpr = interp(base_fpr, fpr, tpr)
    tpr[0] = 0.0
    tprs.append(tpr)

tprs = np.array(tprs)
mean_tprs = tprs.mean(axis=0)
std = tprs.std(axis=0)

tprs_upper = np.minimum(mean_tprs + std, 1)
tprs_lower = mean_tprs - std


plt.plot(base_fpr, mean_tprs, 'b')
plt.fill_between(base_fpr, tprs_lower, tprs_upper, color='grey', alpha=0.3)

plt.plot([0, 1], [0, 1],'r--')
plt.xlim([-0.01, 1.01])
plt.ylim([-0.01, 1.01])
plt.ylabel('True Positive Rate')
plt.xlabel('False Positive Rate')
plt.axes().set_aspect('equal', 'datalim')
plt.show()

Para CV repetido:

idx = np.arange(0, len(y))

for j in np.random.randint(0, high=10000, size=10):
    np.random.shuffle(idx)
    kf = KFold(n=len(y), n_folds=10, random_state=j)

    for i, (train, test) in enumerate(kf):
        model = LogisticRegression().fit(X[idx][train], y[idx][train])
        y_score = model.predict_proba(X[idx][test])
        fpr, tpr, _ = roc_curve(y[idx][test], y_score[:, 1])

        plt.plot(fpr, tpr, 'b', alpha=0.05)
        tpr = interp(base_fpr, fpr, tpr)
        tpr[0] = 0.0
        tprs.append(tpr)

Fuente de inspiración: http://scikit-learn.org/stable/auto_examples/model_selection/plot_roc_crossval.html


3

No es correcto promediar las probabilidades porque eso no representaría las predicciones que está tratando de validar e implica contaminación en las muestras de validación.

Tenga en cuenta que se pueden requerir 100 repeticiones de validación cruzada 10 veces para lograr una precisión adecuada. O use el arranque de optimismo de Efron-Gong que requiere menos iteraciones para la misma precisión (ver, por ejemplo rms, validatefunciones del paquete R ).

C


¿Podría por favor dar más detalles sobre por qué el promedio no es correcto?
DataD'oh

Ya indicado Debe validar la medida que utilizará en el campo.
Frank Harrell
Al usar nuestro sitio, usted reconoce que ha leído y comprende nuestra Política de Cookies y Política de Privacidad.
Licensed under cc by-sa 3.0 with attribution required.