Utilizo Tensorflow, pero estoy escribiendo documentación para usuarios que generalmente varía según los marcos de aprendizaje profundo .
Cuando trabajo con conjuntos de datos que no se ajustan al sistema de archivos local (TB +), tomo muestras de datos de un almacén de datos remoto y escribo muestras localmente en un tfrecords
formato estándar de Tensorflow .
Durante la primera época de entrenamiento solo habré muestreado algunos valores, por lo tanto, una época de datos locales es muy pequeña, entreno con ella. En la época 2 vuelvo a examinar qué archivos de datos han sido producidos por mis subprocesos de muestreo (ahora más) y me entreno en el conjunto ampliado de archivos de datos locales para la próxima época. Repita el proceso cada época. De esta manera, construyo un caché local de muestras y puedo expulsar muestras antiguas a medida que lleno el almacenamiento local. El caché de muestras locales crece aproximadamente en el momento en que el modelo necesita más la varianza (hacia la última parte del entrenamiento).
En Python / Tensorflow es crucial que no deserialice los datos en el proceso de bucle de entrenamiento de Python porque Python GIL no puede admitir las velocidades de transferencia de datos (300-600 MB / seg, los datos son científicos sin comprimir) y, por lo tanto, el rendimiento de la GPU sufre cuando Python GIL no puede atender el ciclo de entrenamiento rápido.
Escribir las muestras en un tfrecords
archivo desde subprocesos (multiprocesamiento de python) permite que los nativos de tensorflow TFRecordsDataset
realicen deserialización fuera de Python y, por lo tanto, evitamos los problemas de Python GIL, y puedo saturar una GPU con altas tasas de datos de E / S.
Me gustaría saber cómo abordaría este problema en Pytorch. Estoy escribiendo sobre la estrategia de muestreo que se está utilizando y quiero proporcionar recomendaciones específicas a los usuarios de Tensorflow y PyTorch, pero no conozco el ecosistema de preprocesamiento de PyTorch lo suficientemente bien como para escribir con suficiente detalle.
Nota al margen: la única solución puramente basada en Python para admitir estas velocidades de transferencia de datos puede venir en Python 3.8 con memoria compartida y multiprocesamiento del Sistema V, pero aún no lo he intentado ya que el soporte no es suficiente (pronto será ) Las soluciones de multiprocesamiento existentes no son suficientes porque requieren deserialización en el proceso del ciclo de capacitación y, por lo tanto, bloquean el GIL durante la deserialización a altas tasas de E / S.
DataLoader
como en mi respuesta.