Diagramas de dispersión en Pandas / Pyplot: cómo trazar por categoría


89

Estoy tratando de hacer un diagrama de dispersión simple en pyplot usando un objeto Pandas DataFrame, pero quiero una forma eficiente de trazar dos variables pero que los símbolos estén dictados por una tercera columna (clave). He intentado varias formas usando df.groupby, pero no con éxito. A continuación se muestra un ejemplo de secuencia de comandos df. Esto colorea los marcadores de acuerdo con 'key1', pero me gustaría ver una leyenda con las categorías de 'key1'. Estoy cerca? Gracias.

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)
fig1 = plt.figure(1)
ax1 = fig1.add_subplot(111)
ax1.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8)
plt.show()

Respuestas:


118

Puede usar scatterpara esto, pero eso requiere tener valores numéricos para usted key1, y no tendrá una leyenda, como notó.

Es mejor usarlo plotpara categorías discretas como esta. Por ejemplo:

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
np.random.seed(1974)

# Generate Data
num = 20
x, y = np.random.random((2, num))
labels = np.random.choice(['a', 'b', 'c'], num)
df = pd.DataFrame(dict(x=x, y=y, label=labels))

groups = df.groupby('label')

# Plot
fig, ax = plt.subplots()
ax.margins(0.05) # Optional, just adds 5% padding to the autoscaling
for name, group in groups:
    ax.plot(group.x, group.y, marker='o', linestyle='', ms=12, label=name)
ax.legend()

plt.show()

ingrese la descripción de la imagen aquí

Si desea que las cosas se vean como el pandasestilo predeterminado , simplemente actualice rcParamscon la hoja de estilo pandas y use su generador de color. (También estoy modificando ligeramente la leyenda):

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
np.random.seed(1974)

# Generate Data
num = 20
x, y = np.random.random((2, num))
labels = np.random.choice(['a', 'b', 'c'], num)
df = pd.DataFrame(dict(x=x, y=y, label=labels))

groups = df.groupby('label')

# Plot
plt.rcParams.update(pd.tools.plotting.mpl_stylesheet)
colors = pd.tools.plotting._get_standard_colors(len(groups), color_type='random')

fig, ax = plt.subplots()
ax.set_color_cycle(colors)
ax.margins(0.05)
for name, group in groups:
    ax.plot(group.x, group.y, marker='o', linestyle='', ms=12, label=name)
ax.legend(numpoints=1, loc='upper left')

plt.show()

ingrese la descripción de la imagen aquí


¿Por qué en el ejemplo RGB anterior el símbolo se muestra dos veces en la leyenda? ¿Cómo mostrar solo una vez?
Steve Schulist

1
@SteveSchulist: se usa ax.legend(numpoints=1)para mostrar solo un marcador. Hay dos, como con a Line2D, a menudo hay una línea que conecta los dos marcadores.
Joe Kington

Este código solo funcionó para mí después de agregar plt.hold(True)después del ax.plot()comando. ¿Alguna idea de por qué?
Yuval Atzmon

set_color_cycle() quedó obsoleto en matplotlib 1.5. La hay set_prop_cycle(), ahora.
ale

52

Esto es simple de hacer con Seaborn ( pip install seaborn) como un delineador

sns.scatterplot(x_vars="one", y_vars="two", data=df, hue="key1") :

import seaborn as sns
import pandas as pd
import numpy as np
np.random.seed(1974)

df = pd.DataFrame(
    np.random.normal(10, 1, 30).reshape(10, 3),
    index=pd.date_range('2010-01-01', freq='M', periods=10),
    columns=('one', 'two', 'three'))
df['key1'] = (4, 4, 4, 6, 6, 6, 8, 8, 8, 8)

sns.scatterplot(x="one", y="two", data=df, hue="key1")

ingrese la descripción de la imagen aquí

Aquí está el marco de datos para referencia:

ingrese la descripción de la imagen aquí

Dado que tiene tres columnas variables en sus datos, es posible que desee trazar todas las dimensiones por pares con:

sns.pairplot(vars=["one","two","three"], data=df, hue="key1")

ingrese la descripción de la imagen aquí

https://rasbt.github.io/mlxtend/user_guide/plotting/category_scatter/ es otra opción.


19

Con plt.scatter, solo puedo pensar en uno: usar un artista proxy:

df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)
fig1 = plt.figure(1)
ax1 = fig1.add_subplot(111)
x=ax1.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8)

ccm=x.get_cmap()
circles=[Line2D(range(1), range(1), color='w', marker='o', markersize=10, markerfacecolor=item) for item in ccm((array([4,6,8])-4.0)/4)]
leg = plt.legend(circles, ['4','6','8'], loc = "center left", bbox_to_anchor = (1, 0.5), numpoints = 1)

Y el resultado es:

ingrese la descripción de la imagen aquí


10

Puede usar df.plot.scatter y pasar una matriz al argumento c = que define el color de cada punto:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)
colors = np.where(df["key1"]==4,'r','-')
colors[df["key1"]==6] = 'g'
colors[df["key1"]==8] = 'b'
print(colors)
df.plot.scatter(x="one",y="two",c=colors)
plt.show()

ingrese la descripción de la imagen aquí


4

También puede probar Altair o ggpot, que se centran en visualizaciones declarativas.

import numpy as np
import pandas as pd
np.random.seed(1974)

# Generate Data
num = 20
x, y = np.random.random((2, num))
labels = np.random.choice(['a', 'b', 'c'], num)
df = pd.DataFrame(dict(x=x, y=y, label=labels))

Código de Altair

from altair import Chart
c = Chart(df)
c.mark_circle().encode(x='x', y='y', color='label')

ingrese la descripción de la imagen aquí

código ggplot

from ggplot import *
ggplot(aes(x='x', y='y', color='label'), data=df) +\
geom_point(size=50) +\
theme_bw()

ingrese la descripción de la imagen aquí


3

Desde matplotlib 3.1 en adelante puede usar .legend_elements(). Se muestra un ejemplo en Creación automática de leyendas . La ventaja es que se puede utilizar una única llamada dispersa.

En este caso:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), 
                  index = pd.date_range('2010-01-01', freq = 'M', periods = 10), 
                  columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)


fig, ax = plt.subplots()
sc = ax.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8)
ax.legend(*sc.legend_elements())
plt.show()

ingrese la descripción de la imagen aquí

En caso de que las claves no se dieran directamente como números, se vería como

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), 
                  index = pd.date_range('2010-01-01', freq = 'M', periods = 10), 
                  columns = ('one', 'two', 'three'))
df['key1'] = list("AAABBBCCCC")

labels, index = np.unique(df["key1"], return_inverse=True)

fig, ax = plt.subplots()
sc = ax.scatter(df['one'], df['two'], marker = 'o', c = index, alpha = 0.8)
ax.legend(sc.legend_elements()[0], labels)
plt.show()

ingrese la descripción de la imagen aquí


Recibí un error que decía que el objeto 'PathCollection' no tiene el atributo 'legends_elements'. Mi código es el siguiente. fig, ax = plt.subplots(1, 1, figsize = (4,4)) scat = ax.scatter(rand_jitter(important_dataframe["workout_type_int"], jitter = 0.04), important_dataframe["distance"], c = color_list, marker = 'o', alpha = 0.9) print(scat.legends_elements()) #ax.legend(*scat.legend_elements())
Nandish Patel

1
@NandishPatel Verifique la primera oración de esta respuesta. También asegúrese de no confundir legends_elementsy legend_elements.
ImportanceOfBeingErnest

Si, gracias. Eso fue un error tipográfico (leyendas / leyenda). Estaba trabajando en algo desde las últimas 6 horas, así que no se me ocurrió la versión de Matplotlib. Pensé que estaba usando el último. Estaba confundido de que la documentación dice que existe tal método, pero el código estaba dando un error. Gracias de nuevo. Ahora puedo dormir.
Nandish Patel


1

seaborn tiene una función de envoltura scatterplotque lo hace de manera más eficiente.

sns.scatterplot(data = df, x = 'one', y = 'two', data =  'key1'])
Al usar nuestro sitio, usted reconoce que ha leído y comprende nuestra Política de Cookies y Política de Privacidad.
Licensed under cc by-sa 3.0 with attribution required.