Genere un mapa de calor en MatPlotLib usando un conjunto de datos de dispersión


187

Tengo un conjunto de puntos de datos X, Y (aproximadamente 10k) que son fáciles de trazar como un diagrama de dispersión pero que me gustaría representar como un mapa de calor.

Miré a través de los ejemplos en MatPlotLib y parece que todos ya comienzan con valores de celdas de mapa de calor para generar la imagen.

¿Existe algún método que convierta un grupo de x, y, todos diferentes, en un mapa de calor (donde las zonas con mayor frecuencia de x, y serían "más cálidas")?


Respuestas:


182

Si no quieres hexágonos, puedes usar la histogram2dfunción de numpy :

import numpy as np
import numpy.random
import matplotlib.pyplot as plt

# Generate some test data
x = np.random.randn(8873)
y = np.random.randn(8873)

heatmap, xedges, yedges = np.histogram2d(x, y, bins=50)
extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]

plt.clf()
plt.imshow(heatmap.T, extent=extent, origin='lower')
plt.show()

Esto hace un mapa de calor de 50x50. Si quiere, digamos, 512x384, puede bins=(512, 384)llamar histogram2d.

Ejemplo: Ejemplo de mapa de calor Matplotlib


1
No me refiero a ser un idiota, pero ¿cómo puede realmente tener esta salida en un archivo PNG / PDF en lugar de mostrarla solo en una sesión interactiva de IPython? Estoy tratando de obtener esto como una especie de axesinstancia normal , donde puedo agregar un título, etiquetas de eje, etc. y luego hacer lo normal savefig()como lo haría para cualquier otro gráfico de matplotlib típico.
gotgenes

3
@gotgenes: ¿no plt.savefig('filename.png')funciona? Si desea obtener una instancia de ejes, use la interfaz orientada a objetos de Matplotlib:fig = plt.figure() ax = fig.gca() ax.imshow(...) fig.savefig(...)
ptomato

1
De hecho, gracias! Supongo que no entiendo completamente que imshow()está en la misma categoría de funciones que scatter(). Honestamente, no entiendo por qué imshow()convierte una matriz 2D de flotadores en bloques de color apropiado, mientras que sí entiendo lo que scatter()se supone que debe hacer con dicha matriz.
gotgenes

14
Una advertencia sobre el uso de imshow para trazar un histograma 2D de valores x / y como este: de forma predeterminada, imshow traza el origen en la esquina superior izquierda y transpone la imagen. Lo que haría para obtener la misma orientación que un diagrama de dispersión esplt.imshow(heatmap.T, extent=extent, origin = 'lower')
Jamie

77
Para aquellos que quieran hacer una barra de colores logarítmica, vea esta pregunta stackoverflow.com/questions/17201172/… y simplemente haga lo siguientefrom matplotlib.colors import LogNorm plt.imshow(heatmap, norm=LogNorm()) plt.colorbar()
tommy.carstensen

109

En el léxico Matplotlib , creo que quieres un diagrama de hexbin .

Si no está familiarizado con este tipo de diagrama, es solo un histograma bivariado en el que el plano xy está teselado por una cuadrícula regular de hexágonos.

Entonces, desde un histograma, puede contar el número de puntos que caen en cada hexágono, discretizar la región de trazado como un conjunto de ventanas , asignar cada punto a una de estas ventanas; finalmente, asigne las ventanas a una matriz de colores y obtendrá un diagrama hexbin.

Aunque se usa con menos frecuencia que, por ejemplo, círculos o cuadrados, los hexágonos son una mejor opción para la geometría del contenedor binning es intuitivo:

  • los hexágonos tienen simetría del vecino más cercano (por ejemplo, los contenedores cuadrados no, por ejemplo, la distancia desde un punto en el borde de un cuadrado hasta un punto dentro de ese cuadrado no es igual en todas partes)

  • El hexágono es el n-polígono más alto que proporciona la teselación del plano regular (es decir, puede modelar de manera segura el piso de su cocina con mosaicos de forma hexagonal porque no tendrá ningún espacio vacío entre los mosaicos cuando haya terminado, lo cual no es cierto para todos los demás superiores-n, n> = 7, polígonos).

( Matplotlib usa el término gráfico de hexbin ; también lo hacen (AFAIK) todas las bibliotecas de trazado para R ; todavía no sé si este es el término generalmente aceptado para gráficos de este tipo, aunque sospecho que es probable dado que hexbin es corto para el agrupamiento hexagonal , que describe el paso esencial para preparar los datos para su visualización).


from matplotlib import pyplot as PLT
from matplotlib import cm as CM
from matplotlib import mlab as ML
import numpy as NP

n = 1e5
x = y = NP.linspace(-5, 5, 100)
X, Y = NP.meshgrid(x, y)
Z1 = ML.bivariate_normal(X, Y, 2, 2, 0, 0)
Z2 = ML.bivariate_normal(X, Y, 4, 1, 1, 1)
ZD = Z2 - Z1
x = X.ravel()
y = Y.ravel()
z = ZD.ravel()
gridsize=30
PLT.subplot(111)

# if 'bins=None', then color of each hexagon corresponds directly to its count
# 'C' is optional--it maps values to x-y coordinates; if 'C' is None (default) then 
# the result is a pure 2D histogram 

PLT.hexbin(x, y, C=z, gridsize=gridsize, cmap=CM.jet, bins=None)
PLT.axis([x.min(), x.max(), y.min(), y.max()])

cb = PLT.colorbar()
cb.set_label('mean value')
PLT.show()   

ingrese la descripción de la imagen aquí


¿Qué significa que "los hexágonos tienen simetría de vecino más cercano"? Usted dice que "la distancia desde un punto en el borde de un cuadrado y un punto dentro de ese cuadrado no es igual en todas partes", ¿pero la distancia a qué?
Jaan

9
Para un hexágono, la distancia del centro a un vértice que une dos lados también es más larga que del centro a la mitad de un lado, solo la relación es menor (2 / sqrt (3) ≈ 1.15 para el hexágono vs. sqrt (2) ≈ 1.41 por cuadrado). La única forma donde la distancia desde el centro a cada punto en el borde es igual es el círculo.
Jaan

55
@Jaan Para un hexágono, cada vecino está a la misma distancia. No hay ningún problema con 8 barrios o 4 barrios. Sin vecinos diagonales, solo un tipo de vecino.
isarandi

@doug ¿Cómo se elige el gridsize=parámetro? Me gustaría elegirlo así, para que los hexágonos se toquen sin superponerse. Noté que gridsize=100produciría hexágonos más pequeños, pero ¿cómo elegir el valor adecuado?
Alexander Cska

40

Editar: Para una mejor aproximación de la respuesta de Alejandro, vea a continuación.

Sé que esta es una vieja pregunta, pero quería agregar algo al comentario de Alejandro: si desea una buena imagen suavizada sin usar py-sphviewer, puede usar np.histogram2dy aplicar un filtro gaussiano (desde scipy.ndimage.filters) al mapa de calor:

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from scipy.ndimage.filters import gaussian_filter


def myplot(x, y, s, bins=1000):
    heatmap, xedges, yedges = np.histogram2d(x, y, bins=bins)
    heatmap = gaussian_filter(heatmap, sigma=s)

    extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
    return heatmap.T, extent


fig, axs = plt.subplots(2, 2)

# Generate some test data
x = np.random.randn(1000)
y = np.random.randn(1000)

sigmas = [0, 16, 32, 64]

for ax, s in zip(axs.flatten(), sigmas):
    if s == 0:
        ax.plot(x, y, 'k.', markersize=5)
        ax.set_title("Scatter plot")
    else:
        img, extent = myplot(x, y, s)
        ax.imshow(img, extent=extent, origin='lower', cmap=cm.jet)
        ax.set_title("Smoothing with  $\sigma$ = %d" % s)

plt.show()

Produce:

Imágenes de salida

El diagrama de dispersión ys = 16 trazados uno encima del otro para Agape Gal'lo (haga clic para ver mejor):

Encima del otro


Una diferencia que noté con mi enfoque de filtro gaussiano y el enfoque de Alejandro fue que su método muestra estructuras locales mucho mejores que las mías. Por lo tanto, implementé un método vecino más cercano simple a nivel de píxel. Este método calcula para cada píxel la suma inversa de las distancias deln puntos más cercanos en los datos. Este método tiene una alta resolución bastante costoso computacionalmente y creo que hay una forma más rápida, así que avíseme si tiene alguna mejora.

Actualización: como sospechaba, hay un método mucho más rápido con Scipy's scipy.cKDTree. Vea la respuesta de Gabriel para la implementación.

De todos modos, aquí está mi código:

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm


def data_coord2view_coord(p, vlen, pmin, pmax):
    dp = pmax - pmin
    dv = (p - pmin) / dp * vlen
    return dv


def nearest_neighbours(xs, ys, reso, n_neighbours):
    im = np.zeros([reso, reso])
    extent = [np.min(xs), np.max(xs), np.min(ys), np.max(ys)]

    xv = data_coord2view_coord(xs, reso, extent[0], extent[1])
    yv = data_coord2view_coord(ys, reso, extent[2], extent[3])
    for x in range(reso):
        for y in range(reso):
            xp = (xv - x)
            yp = (yv - y)

            d = np.sqrt(xp**2 + yp**2)

            im[y][x] = 1 / np.sum(d[np.argpartition(d.ravel(), n_neighbours)[:n_neighbours]])

    return im, extent


n = 1000
xs = np.random.randn(n)
ys = np.random.randn(n)
resolution = 250

fig, axes = plt.subplots(2, 2)

for ax, neighbours in zip(axes.flatten(), [0, 16, 32, 64]):
    if neighbours == 0:
        ax.plot(xs, ys, 'k.', markersize=2)
        ax.set_aspect('equal')
        ax.set_title("Scatter Plot")
    else:
        im, extent = nearest_neighbours(xs, ys, resolution, neighbours)
        ax.imshow(im, origin='lower', extent=extent, cmap=cm.jet)
        ax.set_title("Smoothing over %d neighbours" % neighbours)
        ax.set_xlim(extent[0], extent[1])
        ax.set_ylim(extent[2], extent[3])
plt.show()

Resultado:

Alisado vecino más cercano


1
Me gusta esto. Graph es tan bueno como la respuesta de Alejandro, pero no se requieren paquetes nuevos.
Nathan Clement

Muy agradable ! Pero genera un desplazamiento con este método. Puede ver esto comparando un gráfico de dispersión normal con el de color. ¿Podría agregar algo para corregirlo? ¿O simplemente para mover el gráfico por los valores x e y?
Agape Gal'lo

1
Ágape Gal'lo, ¿qué quieres decir con desplazamiento? Si los traza uno encima del otro, coinciden (vea la edición de mi publicación). Quizás te desanimes porque el ancho de la dispersión no coincide exactamente con los otros tres.
Jurgy

¡Muchas gracias por trazar el gráfico solo para mí! Entendí mi error: había modificado la "extensión" para definir los límites x e y. Ahora entiendo que modificó el origen del gráfico. Entonces, tengo una última pregunta: ¿cómo puedo expandir los límites del gráfico, incluso para el área donde no hay datos existentes? Por ejemplo, entre -5 y +5 para x e y.
Ágape Gal'lo

1
Digamos que desea que el eje x vaya de -5 a 5 y el eje y de -3 a 4; en la myplotfunción, añadir el rangeparámetro a np.histogram2d: np.histogram2d(x, y, bins=bins, range=[[-5, 5], [-3, 4]])y en el bucle para establecer la x e y lim del eje: ax.set_xlim([-5, 5]) ax.set_ylim([-3, 4]). Además, de forma predeterminada, imshowmantiene la relación de aspecto idéntica a la relación de sus ejes (en mi ejemplo, una relación de 10: 7), pero si desea que coincida con su ventana de trazado, agregue el parámetro aspect='auto'a imshow.
Jurgy

31

En lugar de usar np.hist2d, que en general produce histogramas bastante feos, me gustaría reciclar py-sphviewer , un paquete de python para representar simulaciones de partículas utilizando un núcleo de suavizado adaptativo y que se puede instalar fácilmente desde pip (consulte la documentación de la página web). Considere el siguiente código, que se basa en el ejemplo:

import numpy as np
import numpy.random
import matplotlib.pyplot as plt
import sphviewer as sph

def myplot(x, y, nb=32, xsize=500, ysize=500):   
    xmin = np.min(x)
    xmax = np.max(x)
    ymin = np.min(y)
    ymax = np.max(y)

    x0 = (xmin+xmax)/2.
    y0 = (ymin+ymax)/2.

    pos = np.zeros([3, len(x)])
    pos[0,:] = x
    pos[1,:] = y
    w = np.ones(len(x))

    P = sph.Particles(pos, w, nb=nb)
    S = sph.Scene(P)
    S.update_camera(r='infinity', x=x0, y=y0, z=0, 
                    xsize=xsize, ysize=ysize)
    R = sph.Render(S)
    R.set_logscale()
    img = R.get_image()
    extent = R.get_extent()
    for i, j in zip(xrange(4), [x0,x0,y0,y0]):
        extent[i] += j
    print extent
    return img, extent

fig = plt.figure(1, figsize=(10,10))
ax1 = fig.add_subplot(221)
ax2 = fig.add_subplot(222)
ax3 = fig.add_subplot(223)
ax4 = fig.add_subplot(224)


# Generate some test data
x = np.random.randn(1000)
y = np.random.randn(1000)

#Plotting a regular scatter plot
ax1.plot(x,y,'k.', markersize=5)
ax1.set_xlim(-3,3)
ax1.set_ylim(-3,3)

heatmap_16, extent_16 = myplot(x,y, nb=16)
heatmap_32, extent_32 = myplot(x,y, nb=32)
heatmap_64, extent_64 = myplot(x,y, nb=64)

ax2.imshow(heatmap_16, extent=extent_16, origin='lower', aspect='auto')
ax2.set_title("Smoothing over 16 neighbors")

ax3.imshow(heatmap_32, extent=extent_32, origin='lower', aspect='auto')
ax3.set_title("Smoothing over 32 neighbors")

#Make the heatmap using a smoothing over 64 neighbors
ax4.imshow(heatmap_64, extent=extent_64, origin='lower', aspect='auto')
ax4.set_title("Smoothing over 64 neighbors")

plt.show()

que produce la siguiente imagen:

ingrese la descripción de la imagen aquí

Como puede ver, las imágenes se ven muy bien y podemos identificar diferentes subestructuras en ellas. Estas imágenes se construyen distribuyendo un peso dado para cada punto dentro de un determinado dominio, definido por la longitud de suavizado, que a su vez viene dada por la distancia al vecino nb más cercano (he elegido 16, 32 y 64 para los ejemplos). Por lo tanto, las regiones de mayor densidad generalmente se extienden sobre regiones más pequeñas en comparación con las regiones de menor densidad.

La función myplot es solo una función muy simple que he escrito para dar los datos x, y a py-sphviewer para hacer la magia.


2
Un comentario para cualquiera que intente instalar py-sphviewer en OSX: tuve muchas dificultades, consulte: github.com/alejandrobll/py-sphviewer/issues/3
Sam Finnigan

Lástima que no funciona con python3. Se instala, pero luego se bloquea cuando intentas usarlo ...
Fábio Dias

1
@Fabio Dias, la última versión (1.1.x) ahora funciona con Python 3.
Alejandro

29

Si está utilizando 1.2.x

import numpy as np
import matplotlib.pyplot as plt

x = np.random.randn(100000)
y = np.random.randn(100000)
plt.hist2d(x,y,bins=100)
plt.show()

gaussian_2d_heat_map


17

Seaborn ahora tiene la función plot conjunta que debería funcionar bien aquí:

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

# Generate some test data
x = np.random.randn(8873)
y = np.random.randn(8873)

sns.jointplot(x=x, y=y, kind='hex')
plt.show()

imagen de demostración


Simple, bonito y analíticamente útil.
ryanjdillon

@wordsforthewise, ¿cómo se puede hacer que los datos de 600k sean legibles visualmente con esto? (cómo cambiar el tamaño)
nrmb

No estoy muy seguro de lo que quieres decir; tal vez sea mejor que hagas una pregunta por separado y la vincules aquí. ¿Quieres decir cambiar el tamaño de todo el higo? Primero haga la figura con fig = plt.figure(figsize=(12, 12)), luego obtenga el eje actual con ax=plt.gca(), luego agregue el argumento ax=axa la jointplotfunción.
wordsforthewise

@wordsforthewise podría responder esta pregunta: stackoverflow.com/questions/50997662/… gracias
ebrahimi

4

y la pregunta inicial fue ... ¿cómo convertir los valores de dispersión en valores de cuadrícula, ¿verdad? histogram2dsí cuenta la frecuencia por celda, sin embargo, si tiene otros datos por celda además de la frecuencia, necesitará un trabajo adicional para hacer.

x = data_x # between -10 and 4, log-gamma of an svc
y = data_y # between -4 and 11, log-C of an svc
z = data_z #between 0 and 0.78, f1-values from a difficult dataset

Entonces, tengo un conjunto de datos con resultados Z para las coordenadas X e Y. Sin embargo, estaba calculando algunos puntos fuera del área de interés (grandes brechas) y montones de puntos en una pequeña área de interés.

Sí, aquí se vuelve más difícil pero también más divertido. Algunas bibliotecas (lo siento):

from matplotlib import pyplot as plt
from matplotlib import cm
import numpy as np
from scipy.interpolate import griddata

Pyplot es mi motor gráfico hoy, cm es una gama de mapas de colores con algunas opciones interesantes. numpy para los cálculos y datos de cuadrícula para adjuntar valores a una cuadrícula fija.

El último es importante, especialmente porque la frecuencia de los puntos xy no se distribuye por igual en mis datos. Primero, comencemos con algunos límites que se ajustan a mis datos y un tamaño de cuadrícula arbitrario. Los datos originales tienen puntos de datos también fuera de esos límites x e y.

#determine grid boundaries
gridsize = 500
x_min = -8
x_max = 2.5
y_min = -2
y_max = 7

Así que hemos definido una cuadrícula con 500 píxeles entre los valores mínimo y máximo de x e y.

En mis datos, hay muchos más de los 500 valores disponibles en el área de alto interés; Considerando que en el área de bajo interés, ni siquiera hay 200 valores en la cuadrícula total; entre los límites gráficos de x_miny x_maxhay aún menos.

Entonces, para obtener una buena imagen, la tarea es obtener un promedio de los valores de alto interés y llenar los vacíos en otros lugares.

Defino mi grilla ahora. Para cada par xx-yy, quiero tener un color.

xx = np.linspace(x_min, x_max, gridsize) # array of x values
yy = np.linspace(y_min, y_max, gridsize) # array of y values
grid = np.array(np.meshgrid(xx, yy.T))
grid = grid.reshape(2, grid.shape[1]*grid.shape[2]).T

¿Por qué la forma extraña? scipy.griddata quiere una forma de (n, D).

Griddata calcula un valor por punto en la cuadrícula, por un método predefinido. Elijo "más cercano": los puntos de cuadrícula vacíos se rellenarán con valores del vecino más cercano. Esto parece que las áreas con menos información tienen celdas más grandes (incluso si no es el caso). Se podría elegir interpolar "lineal", luego las áreas con menos información se ven menos nítidas. Cuestión de gustos, de verdad.

points = np.array([x, y]).T # because griddata wants it that way
z_grid2 = griddata(points, z, grid, method='nearest')
# you get a 1D vector as result. Reshape to picture format!
z_grid2 = z_grid2.reshape(xx.shape[0], yy.shape[0])

Y saltamos, pasamos a matplotlib para mostrar la trama

fig = plt.figure(1, figsize=(10, 10))
ax1 = fig.add_subplot(111)
ax1.imshow(z_grid2, extent=[x_min, x_max,y_min, y_max,  ],
            origin='lower', cmap=cm.magma)
ax1.set_title("SVC: empty spots filled by nearest neighbours")
ax1.set_xlabel('log gamma')
ax1.set_ylabel('log C')
plt.show()

Alrededor de la parte puntiaguda de la forma de V, ves que hice muchos cálculos durante mi búsqueda del punto óptimo, mientras que las partes menos interesantes en casi todos los demás tienen una resolución más baja.

Mapa de calor de un SVC en alta resolución


¿Puedes mejorar tu respuesta para tener un código completo y ejecutable? Este es un método interesante que ha proporcionado. Estoy tratando de entenderlo mejor en este momento. Tampoco entiendo por qué hay una forma de V. Gracias.
ldmtwo

La forma de V proviene de mis datos. Es el valor f1 para un SVM entrenado: esto va un poco en la teoría de SVM. Si tiene una C alta, incluye todos sus puntos en el cálculo, lo que permite que funcione un rango gamma más amplio. Gamma es la rigidez de la curva que separa lo bueno y lo malo. Esos dos valores tienen que ser dados al SVM (X e Y en mi gráfico); entonces obtienes un resultado (Z en mi gráfico). En la mejor área, con suerte, llegarás a alturas significativas.
Anderas

segundo intento: la forma de V está en mis datos. Es el valor f1 para un SVM: si tiene una C alta, incluye todos sus puntos en el cálculo, permitiendo que funcione un rango gamma más amplio, pero haciendo que el cálculo sea lento. Gamma es la rigidez de la curva que separa lo bueno y lo malo. Esos dos valores tienen que ser dados al SVM (X e Y en mi gráfico); entonces obtienes un resultado (Z en mi gráfico). En el área optimizada obtienes valores altos, en otros lugares valores bajos. Lo que mostré aquí es útil si tiene valores Z para algunos (X, Y) y muchas lagunas en otros lugares. Si tiene puntos de datos (X, Y, Z), puede usar mi código.
Anderas

4

Aquí está el enfoque de vecino más cercano de Jurgy, pero implementado usando scipy.cKDTree . En mis pruebas es aproximadamente 100 veces más rápido.

ingrese la descripción de la imagen aquí

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from scipy.spatial import cKDTree


def data_coord2view_coord(p, resolution, pmin, pmax):
    dp = pmax - pmin
    dv = (p - pmin) / dp * resolution
    return dv


n = 1000
xs = np.random.randn(n)
ys = np.random.randn(n)

resolution = 250

extent = [np.min(xs), np.max(xs), np.min(ys), np.max(ys)]
xv = data_coord2view_coord(xs, resolution, extent[0], extent[1])
yv = data_coord2view_coord(ys, resolution, extent[2], extent[3])


def kNN2DDens(xv, yv, resolution, neighbours, dim=2):
    """
    """
    # Create the tree
    tree = cKDTree(np.array([xv, yv]).T)
    # Find the closest nnmax-1 neighbors (first entry is the point itself)
    grid = np.mgrid[0:resolution, 0:resolution].T.reshape(resolution**2, dim)
    dists = tree.query(grid, neighbours)
    # Inverse of the sum of distances to each grid point.
    inv_sum_dists = 1. / dists[0].sum(1)

    # Reshape
    im = inv_sum_dists.reshape(resolution, resolution)
    return im


fig, axes = plt.subplots(2, 2, figsize=(15, 15))
for ax, neighbours in zip(axes.flatten(), [0, 16, 32, 63]):

    if neighbours == 0:
        ax.plot(xs, ys, 'k.', markersize=5)
        ax.set_aspect('equal')
        ax.set_title("Scatter Plot")
    else:

        im = kNN2DDens(xv, yv, resolution, neighbours)

        ax.imshow(im, origin='lower', extent=extent, cmap=cm.Blues)
        ax.set_title("Smoothing over %d neighbours" % neighbours)
        ax.set_xlim(extent[0], extent[1])
        ax.set_ylim(extent[2], extent[3])

plt.savefig('new.png', dpi=150, bbox_inches='tight')

1
Sabía que mi implementación era muy ineficiente, pero no sabía sobre cKDTree. ¡Bien hecho! Te haré referencia en mi respuesta.
Jurgy

2

Haga una matriz bidimensional que corresponda a las celdas en su imagen final, llamada say heatmap_cells y ejemplifique como todos los ceros.

Elija dos factores de escala que definan la diferencia entre cada elemento de la matriz en unidades reales, para cada dimensión, digamos x_scaleyy_scale . Elija estos de modo que todos sus puntos de datos se encuentren dentro de los límites de la matriz de mapas de calor.

Para cada punto de datos sin procesar con x_valuey y_value:

heatmap_cells[floor(x_value/x_scale),floor(y_value/y_scale)]+=1


1

ingrese la descripción de la imagen aquí

Aquí hay uno que hice en un conjunto de 1 millón de puntos con 3 categorías (color rojo, verde y azul). Aquí hay un enlace al repositorio si desea probar la función. Repo de Github

histplot(
    X,
    Y,
    labels,
    bins=2000,
    range=((-3,3),(-3,3)),
    normalize_each_label=True,
    colors = [
        [1,0,0],
        [0,1,0],
        [0,0,1]],
    gain=50)

0

Muy similar a la respuesta de @ Piti , pero usando 1 llamada en lugar de 2 para generar los puntos:

import numpy as np
import matplotlib.pyplot as plt

pts = 1000000
mean = [0.0, 0.0]
cov = [[1.0,0.0],[0.0,1.0]]

x,y = np.random.multivariate_normal(mean, cov, pts).T
plt.hist2d(x, y, bins=50, cmap=plt.cm.jet)
plt.show()

Salida:

2d_gaussian_heatmap


0

Me temo que llego un poco tarde a la fiesta, pero tuve una pregunta similar hace un tiempo. La respuesta aceptada (por @ptomato) me ayudó, pero también me gustaría publicar esto en caso de que sea útil para alguien.


''' I wanted to create a heatmap resembling a football pitch which would show the different actions performed '''

import numpy as np
import matplotlib.pyplot as plt
import random

#fixing random state for reproducibility
np.random.seed(1234324)

fig = plt.figure(12)
ax1 = fig.add_subplot(121)
ax2 = fig.add_subplot(122)

#Ratio of the pitch with respect to UEFA standards 
hmap= np.full((6, 10), 0)
#print(hmap)

xlist = np.random.uniform(low=0.0, high=100.0, size=(20))
ylist = np.random.uniform(low=0.0, high =100.0, size =(20))

#UEFA Pitch Standards are 105m x 68m
xlist = (xlist/100)*10.5
ylist = (ylist/100)*6.5

ax1.scatter(xlist,ylist)

#int of the co-ordinates to populate the array
xlist_int = xlist.astype (int)
ylist_int = ylist.astype (int)

#print(xlist_int, ylist_int)

for i, j in zip(xlist_int, ylist_int):
    #this populates the array according to the x,y co-ordinate values it encounters 
    hmap[j][i]= hmap[j][i] + 1   

#Reversing the rows is necessary 
hmap = hmap[::-1]

#print(hmap)
im = ax2.imshow(hmap)

Aquí está el resultado. ingrese la descripción de la imagen aquí

Al usar nuestro sitio, usted reconoce que ha leído y comprende nuestra Política de Cookies y Política de Privacidad.
Licensed under cc by-sa 3.0 with attribution required.