Tren estratificado / Test-split en scikit-learn


88

Necesito dividir mis datos en un conjunto de entrenamiento (75%) y un conjunto de prueba (25%). Actualmente hago eso con el siguiente código:

X, Xt, userInfo, userInfo_train = sklearn.cross_validation.train_test_split(X, userInfo)   

Sin embargo, me gustaría estratificar mi conjunto de datos de entrenamiento. ¿Cómo puedo hacer eso? He estado investigando el StratifiedKFoldmétodo, pero no me permite especificar la división 75% / 25% y solo estratificar el conjunto de datos de entrenamiento.

Respuestas:


153

[actualización para 0.17]

Ver los documentos de sklearn.model_selection.train_test_split:

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y,
                                                    stratify=y, 
                                                    test_size=0.25)

[/ actualización para 0.17]

Hay una petición de atracción aquí . Pero puedes simplemente hacertrain, test = next(iter(StratifiedKFold(...))) y usar el tren y probar índices si lo desea.


1
@AndreasMueller ¿Existe una manera fácil de estratificar los datos de regresión?
Jordania

3
@Jordan no se implementa nada en scikit-learn. No conozco una forma estándar. Podríamos usar percentiles.
Andreas Mueller

@AndreasMueller ¿Alguna vez ha visto el comportamiento en el que este método es considerablemente más lento que el StratifiedShuffleSplit? Estaba usando el conjunto de datos MNIST.
snymkpr

@activatedgeek eso parece muy extraño, ya que train_test_split (... stratify =) solo está llamando a StratifiedShuffleSplit y tomando la primera división. No dude en abrir un problema en el rastreador con un ejemplo reproducible.
Andreas Mueller

@AndreasMueller En realidad no abrí un problema porque tengo la fuerte sensación de que estoy haciendo algo mal (aunque son solo 2 líneas). Pero si todavía puedo reproducirlo hoy varias veces, ¡lo haré!
snymkpr

29

TL; DR: Utilice StratifiedShuffleSplit contest_size=0.25

Scikit-learn proporciona dos módulos para la división estratificada:

  1. StratifiedKFold : este módulo es útil como operador directo de validación cruzada de k-fold: ya que configurará conjuntos de n_foldsentrenamiento / prueba de manera que las clases estén igualmente equilibradas en ambos.

Aquí hay un código (directamente de la documentación anterior)

>>> skf = cross_validation.StratifiedKFold(y, n_folds=2) #2-fold cross validation
>>> len(skf)
2
>>> for train_index, test_index in skf:
...    print("TRAIN:", train_index, "TEST:", test_index)
...    X_train, X_test = X[train_index], X[test_index]
...    y_train, y_test = y[train_index], y[test_index]
...    #fit and predict with X_train/test. Use accuracy metrics to check validation performance
  1. StratifiedShuffleSplit : este módulo crea un único conjunto de entrenamiento / prueba con clases igualmente equilibradas (estratificadas). Básicamente, esto es lo que quieres con n_iter=1. Puede mencionar el tamaño de la prueba aquí igual que entrain_test_split

Código:

>>> sss = StratifiedShuffleSplit(y, n_iter=1, test_size=0.5, random_state=0)
>>> len(sss)
1
>>> for train_index, test_index in sss:
...    print("TRAIN:", train_index, "TEST:", test_index)
...    X_train, X_test = X[train_index], X[test_index]
...    y_train, y_test = y[train_index], y[test_index]
>>> # fit and predict with your classifier using the above X/y train/test

5
Tenga en cuenta que a partir de 0.18.x, n_iterdebería ser n_splitspara StratifiedShuffleSplit , y que hay una API ligeramente diferente para ello: scikit-learn.org/stable/modules/generated/…
lollercoaster

2
Si yes una serie Pandas, usey.iloc[train_index], y.iloc[test_index]
Owlright

1
@Owlright Intenté usar un marco de datos de pandas y los índices que devuelve StratifiedShuffleSplit no son los índices del marco de datos. dataframe index: 2,3,5 the first split in sss:[(array([2, 1]), array([0]))]:(
Meghna Natraj

2
@tangy, ¿por qué es esto un bucle for? ¿no es el caso que cuando X_train, X_test = X[train_index], X[test_index]se invoca una línea anula X_trainy X_test? ¿Por qué entonces no solo uno next(sss)?
Bartek Wójcik

13

Aquí hay un ejemplo de datos continuos / de regresión (hasta que se resuelva este problema en GitHub ).

min = np.amin(y)
max = np.amax(y)

# 5 bins may be too few for larger datasets.
bins     = np.linspace(start=min, stop=max, num=5)
y_binned = np.digitize(y, bins, right=True)

X_train, X_test, y_train, y_test = train_test_split(
    X, 
    y, 
    stratify=y_binned
)
  • Donde startes min ystop máximo de su objetivo continuo.
  • Si no lo configura right=True, más o menos hará que su valor máximo sea un contenedor separado y su división siempre fallará porque habrá muy pocas muestras en ese contenedor adicional.


6

Además de la respuesta aceptada por @Andreas Mueller, solo quiero agregar eso como @tangy mencionado anteriormente:

StratifiedShuffleSplit se parece más a train_test_split ( stratify = y) con características adicionales de:

  1. estratificar por defecto
  2. al especificar n_splits , divide repetidamente los datos

0
#train_size is 1 - tst_size - vld_size
tst_size=0.15
vld_size=0.15

X_train_test, X_valid, y_train_test, y_valid = train_test_split(df.drop(y, axis=1), df.y, test_size = vld_size, random_state=13903) 

X_train_test_V=pd.DataFrame(X_train_test)
X_valid=pd.DataFrame(X_valid)

X_train, X_test, y_train, y_test = train_test_split(X_train_test, y_train_test, test_size=tst_size, random_state=13903)

0

Actualizando la respuesta de @tangy desde arriba a la versión actual de scikit-learn: 0.23.2 ( documentación de StratifiedShuffleSplit ).

from sklearn.model_selection import StratifiedShuffleSplit

n_splits = 1  # We only want a single split in this case
sss = StratifiedShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=0)

for train_index, test_index in sss.split(X, y):
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]
Al usar nuestro sitio, usted reconoce que ha leído y comprende nuestra Política de Cookies y Política de Privacidad.
Licensed under cc by-sa 3.0 with attribution required.