Etiquetas en línea en Matplotlib


100

En Matplotlib, no es demasiado difícil hacer una leyenda ( example_legend(), a continuación), pero creo que es mejor estilo poner etiquetas directamente en las curvas que se trazan (como en example_inline(), a continuación). Esto puede ser muy complicado, porque tengo que especificar las coordenadas a mano y, si vuelvo a formatear el gráfico, probablemente tenga que reposicionar las etiquetas. ¿Hay alguna forma de generar etiquetas automáticamente en curvas en Matplotlib? Puntos extra por poder orientar el texto en un ángulo correspondiente al ángulo de la curva.

import numpy as np
import matplotlib.pyplot as plt

def example_legend():
    plt.clf()
    x = np.linspace(0, 1, 101)
    y1 = np.sin(x * np.pi / 2)
    y2 = np.cos(x * np.pi / 2)
    plt.plot(x, y1, label='sin')
    plt.plot(x, y2, label='cos')
    plt.legend()

Figura con leyenda

def example_inline():
    plt.clf()
    x = np.linspace(0, 1, 101)
    y1 = np.sin(x * np.pi / 2)
    y2 = np.cos(x * np.pi / 2)
    plt.plot(x, y1, label='sin')
    plt.plot(x, y2, label='cos')
    plt.text(0.08, 0.2, 'sin')
    plt.text(0.9, 0.2, 'cos')

Figura con etiquetas en línea

Respuestas:


28

Buena pregunta, hace un tiempo experimenté un poco con esto, pero no lo he usado mucho porque todavía no es a prueba de balas. Dividí el área de la parcela en una cuadrícula de 32x32 y calculé un 'campo potencial' para la mejor posición de una etiqueta para cada línea de acuerdo con las siguientes reglas:

  • el espacio en blanco es un buen lugar para una etiqueta
  • La etiqueta debe estar cerca de la línea correspondiente
  • La etiqueta debe estar alejada de las otras líneas

El código era algo como esto:

import matplotlib.pyplot as plt
import numpy as np
from scipy import ndimage


def my_legend(axis = None):

    if axis == None:
        axis = plt.gca()

    N = 32
    Nlines = len(axis.lines)
    print Nlines

    xmin, xmax = axis.get_xlim()
    ymin, ymax = axis.get_ylim()

    # the 'point of presence' matrix
    pop = np.zeros((Nlines, N, N), dtype=np.float)    

    for l in range(Nlines):
        # get xy data and scale it to the NxN squares
        xy = axis.lines[l].get_xydata()
        xy = (xy - [xmin,ymin]) / ([xmax-xmin, ymax-ymin]) * N
        xy = xy.astype(np.int32)
        # mask stuff outside plot        
        mask = (xy[:,0] >= 0) & (xy[:,0] < N) & (xy[:,1] >= 0) & (xy[:,1] < N)
        xy = xy[mask]
        # add to pop
        for p in xy:
            pop[l][tuple(p)] = 1.0

    # find whitespace, nice place for labels
    ws = 1.0 - (np.sum(pop, axis=0) > 0) * 1.0 
    # don't use the borders
    ws[:,0]   = 0
    ws[:,N-1] = 0
    ws[0,:]   = 0  
    ws[N-1,:] = 0  

    # blur the pop's
    for l in range(Nlines):
        pop[l] = ndimage.gaussian_filter(pop[l], sigma=N/5)

    for l in range(Nlines):
        # positive weights for current line, negative weight for others....
        w = -0.3 * np.ones(Nlines, dtype=np.float)
        w[l] = 0.5

        # calculate a field         
        p = ws + np.sum(w[:, np.newaxis, np.newaxis] * pop, axis=0)
        plt.figure()
        plt.imshow(p, interpolation='nearest')
        plt.title(axis.lines[l].get_label())

        pos = np.argmax(p)  # note, argmax flattens the array first 
        best_x, best_y =  (pos / N, pos % N) 
        x = xmin + (xmax-xmin) * best_x / N       
        y = ymin + (ymax-ymin) * best_y / N       


        axis.text(x, y, axis.lines[l].get_label(), 
                  horizontalalignment='center',
                  verticalalignment='center')


plt.close('all')

x = np.linspace(0, 1, 101)
y1 = np.sin(x * np.pi / 2)
y2 = np.cos(x * np.pi / 2)
y3 = x * x
plt.plot(x, y1, 'b', label='blue')
plt.plot(x, y2, 'r', label='red')
plt.plot(x, y3, 'g', label='green')
my_legend()
plt.show()

Y la trama resultante: ingrese la descripción de la imagen aquí


Muy agradable. Sin embargo, tengo un ejemplo que no funciona del todo: plt.plot(x2, 3*x2**2, label="3x*x"); plt.plot(x2, 2*x2**2, label="2x*x"); plt.plot(x2, 0.5*x2**2, label="0.5x*x"); plt.plot(x2, -1*x2**2, label="-x*x"); plt.plot(x2, -2.5*x2**2, label="-2.5*x*x"); my_legend();esto coloca una de las etiquetas en la esquina superior izquierda. ¿Alguna idea sobre cómo solucionar este problema? Parece que el problema puede ser que las líneas estén demasiado juntas.
egpbos

Lo siento, lo olvidé x2 = np.linspace(0,0.5,100).
egpbos

¿Hay alguna forma de usar esto sin scipy? En mi sistema actual es una molestia instalarlo.
AnnanFay

Esto no me funciona en Python 3.6.4, Matplotlib 2.1.2 y Scipy 1.0.0. Después de actualizar el printcomando, se ejecuta y crea 4 gráficos, 3 de los cuales parecen ser un galimatías pixelado (probablemente algo relacionado con el 32x32), y el cuarto con etiquetas en lugares extraños.
Y Davis

84

Actualización: el usuario cphyc ha creado amablemente un repositorio de Github para el código de esta respuesta (ver aquí ) y ha empaquetado el código en un paquete que puede instalarse usando pip install matplotlib-label-lines.


Bonita imagen:

etiquetado de parcelas semiautomático

En matplotlibque es bastante fácil de parcelas etiqueta de contorno (ya sea de forma automática o manualmente mediante la colocación de etiquetas con clics del ratón). ¡No parece (todavía) haber ninguna capacidad equivalente para etiquetar series de datos de esta manera! Puede haber alguna razón semántica para no incluir esta característica que me falta.

Independientemente, he escrito el siguiente módulo que admite cualquier etiquetado de trazado semiautomático. Requiere solo numpyun par de funciones de la mathbiblioteca estándar .

Descripción

El comportamiento predeterminado de la labelLinesfunción es espaciar las etiquetas de manera uniforme a lo largo del xeje (colocándolas automáticamente en el yvalor correcto, por supuesto). Si lo desea, puede simplemente pasar una matriz de las coordenadas x de cada una de las etiquetas. Incluso puede modificar la ubicación de una etiqueta (como se muestra en el gráfico inferior derecho) y espaciar el resto de manera uniforme si lo desea.

Además, la label_linesfunción no tiene en cuenta las líneas que no han tenido una etiqueta asignada en el plotcomando (o más exactamente si la etiqueta contiene '_line').

Los argumentos de palabra clave pasados labelLineso labelLinese pasan a la textllamada de función (algunos argumentos de palabra clave se establecen si el código de llamada elige no especificar).

Cuestiones

  • Los cuadros delimitadores de anotaciones a veces interfieren de manera no deseada con otras curvas. Como se muestra en las anotaciones 1y 10en el gráfico superior izquierdo. Ni siquiera estoy seguro de que esto pueda evitarse.
  • En su lugar, sería bueno especificar una yposición a veces.
  • Sigue siendo un proceso iterativo obtener anotaciones en la ubicación correcta.
  • Solo funciona cuando los xvalores de -axis son floats

Gotchas

  • De forma predeterminada, la labelLinesfunción asume que todas las series de datos abarcan el rango especificado por los límites del eje. Eche un vistazo a la curva azul en el gráfico superior izquierdo de la bonita imagen. Si solo hubiera datos disponibles para el xrango 0.5, 1entonces no podríamos colocar una etiqueta en la ubicación deseada (que es un poco menos que 0.2). Consulte esta pregunta para ver un ejemplo particularmente desagradable. En este momento, el código no identifica inteligentemente este escenario y reorganiza las etiquetas, sin embargo, hay una solución razonable. La función labelLines toma el xvalsargumento; una lista de xvalores especificados por el usuario en lugar de la distribución lineal predeterminada a lo ancho. Para que el usuario pueda decidir quéx-valores que se utilizarán para la ubicación de la etiqueta de cada serie de datos.

Además, creo que esta es la primera respuesta para completar el objetivo adicional de alinear las etiquetas con la curva en la que se encuentran. :)

label_lines.py:

from math import atan2,degrees
import numpy as np

#Label line with line2D label data
def labelLine(line,x,label=None,align=True,**kwargs):

    ax = line.axes
    xdata = line.get_xdata()
    ydata = line.get_ydata()

    if (x < xdata[0]) or (x > xdata[-1]):
        print('x label location is outside data range!')
        return

    #Find corresponding y co-ordinate and angle of the line
    ip = 1
    for i in range(len(xdata)):
        if x < xdata[i]:
            ip = i
            break

    y = ydata[ip-1] + (ydata[ip]-ydata[ip-1])*(x-xdata[ip-1])/(xdata[ip]-xdata[ip-1])

    if not label:
        label = line.get_label()

    if align:
        #Compute the slope
        dx = xdata[ip] - xdata[ip-1]
        dy = ydata[ip] - ydata[ip-1]
        ang = degrees(atan2(dy,dx))

        #Transform to screen co-ordinates
        pt = np.array([x,y]).reshape((1,2))
        trans_angle = ax.transData.transform_angles(np.array((ang,)),pt)[0]

    else:
        trans_angle = 0

    #Set a bunch of keyword arguments
    if 'color' not in kwargs:
        kwargs['color'] = line.get_color()

    if ('horizontalalignment' not in kwargs) and ('ha' not in kwargs):
        kwargs['ha'] = 'center'

    if ('verticalalignment' not in kwargs) and ('va' not in kwargs):
        kwargs['va'] = 'center'

    if 'backgroundcolor' not in kwargs:
        kwargs['backgroundcolor'] = ax.get_facecolor()

    if 'clip_on' not in kwargs:
        kwargs['clip_on'] = True

    if 'zorder' not in kwargs:
        kwargs['zorder'] = 2.5

    ax.text(x,y,label,rotation=trans_angle,**kwargs)

def labelLines(lines,align=True,xvals=None,**kwargs):

    ax = lines[0].axes
    labLines = []
    labels = []

    #Take only the lines which have labels other than the default ones
    for line in lines:
        label = line.get_label()
        if "_line" not in label:
            labLines.append(line)
            labels.append(label)

    if xvals is None:
        xmin,xmax = ax.get_xlim()
        xvals = np.linspace(xmin,xmax,len(labLines)+2)[1:-1]

    for line,x,label in zip(labLines,xvals,labels):
        labelLine(line,x,label,align,**kwargs)

Pruebe el código para generar la bonita imagen de arriba:

from matplotlib import pyplot as plt
from scipy.stats import loglaplace,chi2

from labellines import *

X = np.linspace(0,1,500)
A = [1,2,5,10,20]
funcs = [np.arctan,np.sin,loglaplace(4).pdf,chi2(5).pdf]

plt.subplot(221)
for a in A:
    plt.plot(X,np.arctan(a*X),label=str(a))

labelLines(plt.gca().get_lines(),zorder=2.5)

plt.subplot(222)
for a in A:
    plt.plot(X,np.sin(a*X),label=str(a))

labelLines(plt.gca().get_lines(),align=False,fontsize=14)

plt.subplot(223)
for a in A:
    plt.plot(X,loglaplace(4).pdf(a*X),label=str(a))

xvals = [0.8,0.55,0.22,0.104,0.045]
labelLines(plt.gca().get_lines(),align=False,xvals=xvals,color='k')

plt.subplot(224)
for a in A:
    plt.plot(X,chi2(5).pdf(a*X),label=str(a))

lines = plt.gca().get_lines()
l1=lines[-1]
labelLine(l1,0.6,label=r'$Re=${}'.format(l1.get_label()),ha='left',va='bottom',align = False)
labelLines(lines[:-1],align=False)

plt.show()

1
@blujay Me alegro de que hayas podido adaptarlo a tus necesidades. Agregaré esa restricción como un problema.
NauticalMile

1
@Liza Lea mi Gotcha que acabo de agregar para explicar por qué está sucediendo esto. Para su caso (supongo que es como el de esta pregunta ) a menos que desee crear manualmente una lista de xvals, es posible que desee modificar labelLinesun poco el código: cambie el código bajo el if xvals is None:alcance para crear una lista basada en otros criterios. Podría comenzar conxvals = [(np.min(l.get_xdata())+np.max(l.get_xdata()))/2 for l in lines]
NauticalMile

1
@Liza Sin embargo, tu gráfico me intriga. El problema es que sus datos no están distribuidos de manera uniforme en la gráfica y tiene muchas curvas que están casi una encima de la otra. Con mi solución, podría ser muy difícil diferenciar las etiquetas en muchos casos. Creo que la mejor solución es tener bloques de etiquetas apiladas en diferentes partes vacías de su parcela. Vea este gráfico para ver un ejemplo con dos bloques de etiquetas apiladas (un bloque con 1 etiqueta y otro bloque con 4). Implementar esto sería bastante trabajo de campo, podría hacerlo en algún momento en el futuro.
NauticalMile

1
Nota: desde Matplotlib 2.0, .get_axes()y .get_axis_bgcolor()han quedado obsoletos. Reemplace con .axesy .get_facecolor(), resp.
Jiāgěng

1
Otra cosa asombrosa labellineses que las propiedades se relacionan con él plt.texto se ax.textaplican a él. Lo que significa que puede configurar fontsizey bboxparámetros en la labelLines()función.
tionichm

52

La respuesta de @Jan Kuiken es ciertamente bien pensada y completa, pero hay algunas advertencias:

  • no funciona en todos los casos
  • requiere una buena cantidad de código adicional
  • puede variar considerablemente de una parcela a la siguiente

Un enfoque mucho más simple es anotar el último punto de cada gráfico. El punto también se puede encerrar en un círculo para enfatizarlo. Esto se puede lograr con una línea adicional:

from matplotlib import pyplot as plt

for i, (x, y) in enumerate(samples):
    plt.plot(x, y)
    plt.text(x[-1], y[-1], 'sample {i}'.format(i=i))

Una variante sería utilizar ax.annotate.


1
+1! Parece una solución sencilla y agradable. Perdón por la pereza, pero ¿cómo se vería esto? ¿Estaría el texto dentro de la gráfica o encima del eje y derecho?
rocarvaj

1
@rocarvaj Depende de otras configuraciones. Es posible que las etiquetas sobresalgan fuera del cuadro de trazado. Dos formas de evitar este comportamiento son: 1) usar un índice diferente de -1, 2) establecer límites de eje apropiados para dejar espacio para las etiquetas.
Ioannis Filippidis

1
También se convierte en un lío, si los gráficos se concentran en algún valor de y, los puntos finales se vuelven demasiado cercanos para que el texto se vea bien
LazyCat

@LazyCat: Eso es cierto. Para solucionar esto, uno puede hacer que las anotaciones se puedan arrastrar. Supongo que es un poco doloroso, pero funcionaría.
PlacidLush

1

Un enfoque más simple como el que hace Ioannis Filippidis:

import matplotlib.pyplot as plt
import numpy as np

# evenly sampled time at 200ms intervals
tMin=-1 ;tMax=10
t = np.arange(tMin, tMax, 0.1)

# red dashes, blue points default
plt.plot(t, 22*t, 'r--', t, t**2, 'b')

factor=3/4 ;offset=20  # text position in view  
textPosition=[(tMax+tMin)*factor,22*(tMax+tMin)*factor]
plt.text(textPosition[0],textPosition[1]+offset,'22  t',color='red',fontsize=20)
textPosition=[(tMax+tMin)*factor,((tMax+tMin)*factor)**2+20]
plt.text(textPosition[0],textPosition[1]+offset, 't^2', bbox=dict(facecolor='blue', alpha=0.5),fontsize=20)
plt.show()

código python 3 en sageCell

Al usar nuestro sitio, usted reconoce que ha leído y comprende nuestra Política de Cookies y Política de Privacidad.
Licensed under cc by-sa 3.0 with attribution required.