Pregunta Distinguir el sobreajuste frente a la buena predicción


Estas son preguntas sobre cómo calcular y reducir el sobreajuste en el aprendizaje automático. Creo que muchas personas nuevas en el aprendizaje automático tendrán las mismas preguntas, así que traté de ser claro con mis ejemplos y preguntas con la esperanza de que las respuestas aquí puedan ayudar a otros.

Tengo una muestra muy pequeña de textos y estoy tratando de predecir los valores asociados con ellos. He utilizado sklearn para calcular tf-idf, e insertarlos en un modelo de regresión para la predicción. Esto me da 26 muestras con 6323 funciones, no mucho ... Lo sé:

>> count_vectorizer = CountVectorizer(min_n=1, max_n=1)
>> term_freq = count_vectorizer.fit_transform(texts)
>> transformer = TfidfTransformer()
>> X = transformer.fit_transform(term_freq) 
>> print X.shape

(26, 6323)

Insertar esas 26 muestras de características 6323 (X) y puntajes asociados (y), en un LinearRegression modelo, da buenas predicciones. Estos se obtienen utilizando la validación cruzada de dejar uno fuera, desde cross_validation.LeaveOneOut(X.shape[0], indices=True) :

using ngrams (n=1):
     human  machine  points-off  %error
      8.67    8.27    0.40       1.98
      8.00    7.33    0.67       3.34
      ...     ...     ...        ...
      5.00    6.61    1.61       8.06
      9.00    7.50    1.50       7.50
mean: 7.59    7.64    1.29       6.47
std : 1.94    0.56    1.38       6.91

¡Bastante bueno! Usando ngrams (n = 300) en lugar de unigrams (n = 1), se producen resultados similares, lo que obviamente no es correcto. No se producen 300 palabras en ninguno de los textos, por lo que la predicción debería fallar, pero no es así:

using ngrams (n=300):
      human  machine  points-off  %error
       8.67    7.55    1.12       5.60
       8.00    7.57    0.43       2.13
       ...     ...     ...        ...
mean:  7.59    7.59    1.52       7.59
std :  1.94    0.08    1.32       6.61

Pregunta 1:


32
2017-09-03 19:32


origen


Respuestas:


¿Cómo se suele decir que el modelo es demasiado ajustado?

Una regla práctica útil es que puede estar sobreajustado cuando el rendimiento de su modelo en su propio conjunto de entrenamiento es mucho mejor que en su conjunto de validación retenido o en una configuración de validación cruzada. Sin embargo, eso no es todo lo que hay que hacer.

La entrada del blog a la que he vinculado describe un procedimiento para probar el sobreajuste: trazar el conjunto de entrenamiento y validar el conjunto de errores como una función del tamaño del conjunto de entrenamiento. Si muestran una brecha estable en el extremo derecho de la trama, probablemente estés sobreajustado.

¿Cuál es la mejor manera de prevenir el ajuste excesivo (en esta situación) para asegurarse de que los resultados de predicción son buenos o no?

Usar una conjunto de prueba retenida. Solo haga la evaluación en este conjunto cuando haya terminado completamente con la selección del modelo (ajuste de hiperparámetros); no entrene en él, no lo use en la validación (cruzada). El puntaje que obtiene en el conjunto de prueba es la evaluación final del modelo. Esto debería mostrar si accidentalmente sobrepasaste los conjuntos de validación.

[Las conferencias de aprendizaje automático a veces se configuran como una competencia, donde el conjunto de pruebas no se entrega a los investigadores hasta después han entregado su modelo final a los organizadores. Mientras tanto, pueden usar el conjunto de entrenamiento a su gusto, p. probando modelos usando validación cruzada. Kaggle hace algo similar.]

Si LeaveOneOut se usa la validación cruzada, ¿cómo es posible que el modelo se sobrepase con buenos resultados?

Debido a que puede ajustar el modelo tanto como desee en esta configuración de validación cruzada, hasta que se comporte casi perfectamente en CV.

Como un ejemplo extremo, suponga que ha implementado un estimador que es esencialmente un generador de números aleatorios. Puede seguir probando semillas aleatorias hasta que llegue a un "modelo" que produce un error muy bajo en la validación cruzada, pero eso no le da al modelo correcto. Significa que no has superado la validación cruzada.

Ver también este interesante warstory.


34
2017-09-03 22:06