Predicción probabilística forestal aleatoria vs voto mayoritario


10

Scikit learn parece utilizar la predicción probabilística en lugar del voto mayoritario para la técnica de agregación del modelo sin una explicación de por qué (1.9.2.1. Bosques aleatorios).

¿Hay una explicación clara de por qué? Además, ¿hay un buen artículo o artículo de revisión para las diversas técnicas de agregación de modelos que se pueden usar para el ensacado de Random Forest?

¡Gracias!

Respuestas:


10

Estas preguntas siempre se responden mejor mirando el código, si habla Python con fluidez.

RandomForestClassifier.predict, Al menos en la versión actual 0.16.1, predice la clase con la más alta estimación de la probabilidad, según lo dado por predict_proba. ( esta línea )

La documentación para predict_probadice:

Las probabilidades de clase predichas de una muestra de entrada se calculan como las probabilidades de clase predichas medias de los árboles en el bosque. La probabilidad de clase de un solo árbol es la fracción de muestras de la misma clase en una hoja.

La diferencia con el método original probablemente sea solo para que las predictpredicciones sean consistentes predict_proba. El resultado a veces se llama "votación blanda", en lugar del voto mayoritario "duro" utilizado en el documento original de Breiman. En la búsqueda rápida no pude encontrar una comparación adecuada del rendimiento de los dos métodos, pero ambos parecen bastante razonables en esta situación.

La predictdocumentación es, en el mejor de los casos, bastante engañosa; He enviado una solicitud de extracción para solucionarlo.

Si quieres hacer predicciones de voto mayoritario, aquí tienes una función para hacerlo. Llámalo como predict_majvote(clf, X)más que como clf.predict(X). (Basado en predict_proba; solo ligeramente probado, pero creo que debería funcionar).

from scipy.stats import mode
from sklearn.ensemble.forest import _partition_estimators, _parallel_helper
from sklearn.tree._tree import DTYPE
from sklearn.externals.joblib import Parallel, delayed
from sklearn.utils import check_array
from sklearn.utils.validation import check_is_fitted

def predict_majvote(forest, X):
    """Predict class for X.

    Uses majority voting, rather than the soft voting scheme
    used by RandomForestClassifier.predict.

    Parameters
    ----------
    X : array-like or sparse matrix of shape = [n_samples, n_features]
        The input samples. Internally, it will be converted to
        ``dtype=np.float32`` and if a sparse matrix is provided
        to a sparse ``csr_matrix``.
    Returns
    -------
    y : array of shape = [n_samples] or [n_samples, n_outputs]
        The predicted classes.
    """
    check_is_fitted(forest, 'n_outputs_')

    # Check data
    X = check_array(X, dtype=DTYPE, accept_sparse="csr")

    # Assign chunk of trees to jobs
    n_jobs, n_trees, starts = _partition_estimators(forest.n_estimators,
                                                    forest.n_jobs)

    # Parallel loop
    all_preds = Parallel(n_jobs=n_jobs, verbose=forest.verbose,
                         backend="threading")(
        delayed(_parallel_helper)(e, 'predict', X, check_input=False)
        for e in forest.estimators_)

    # Reduce
    modes, counts = mode(all_preds, axis=0)

    if forest.n_outputs_ == 1:
        return forest.classes_.take(modes[0], axis=0)
    else:
        n_samples = all_preds[0].shape[0]
        preds = np.zeros((n_samples, forest.n_outputs_),
                         dtype=forest.classes_.dtype)
        for k in range(forest.n_outputs_):
            preds[:, k] = forest.classes_[k].take(modes[:, k], axis=0)
        return preds

En el caso sintético tonto que probé, las predicciones coincidían siempre con el predictmétodo.


Gran respuesta, Dougal! Gracias por tomarse el tiempo para explicar esto cuidadosamente. Considere también ir a apilar el desbordamiento y responder a esta pregunta allí .
user1745038

1
También hay un documento, aquí , que aborda la predicción probabilística.
user1745038
Al usar nuestro sitio, usted reconoce que ha leído y comprende nuestra Política de Cookies y Política de Privacidad.
Licensed under cc by-sa 3.0 with attribution required.