¿Cómo funciona Python numpy.where ()?


94

Estoy jugando numpyy buscando en la documentación y me he encontrado con algo de magia. A saber, estoy hablando de numpy.where():

>>> x = np.arange(9.).reshape(3, 3)
>>> np.where( x > 5 )
(array([2, 2, 2]), array([0, 1, 2]))

¿Cómo logran internamente que puedas pasar algo parecido x > 5a un método? Supongo que tiene algo que ver, __gt__pero estoy buscando una explicación detallada.

Respuestas:


75

¿Cómo logran internamente que puedas pasar algo como x> 5 en un método?

La respuesta corta es que no lo hacen.

Cualquier tipo de operación lógica en una matriz numérica devuelve una matriz booleana. (es decir __gt__, __lt__etc., todos devuelven matrices booleanas donde la condición dada es verdadera).

P.ej

x = np.arange(9).reshape(3,3)
print x > 5

rinde:

array([[False, False, False],
       [False, False, False],
       [ True,  True,  True]], dtype=bool)

Esta es la misma razón por la que algo como if x > 5:genera un ValueError si xes una matriz numpy. Es una matriz de valores verdaderos / falsos, no un solo valor.

Además, las matrices numpy se pueden indexar mediante matrices booleanas. Por ejemplo , x[x>5]rendimientos [6 7 8], en este caso.

Honestamente, es bastante raro que realmente lo necesite, numpy.wherepero solo devuelve los índices donde está una matriz booleana True. Por lo general, puede hacer lo que necesite con una indexación booleana simple.


10
Solo para señalar que numpy.wheretienen 2 'modos operativos', primero devuelve el indices, dónde condition is Truey si los parámetros opcionales xy yestán presentes (¡la misma forma que condition, o ampliable a dicha forma!), Devolverá valores de xcuando de condition is Trueotra manera y. Por lo tanto, esto lo hace wheremás versátil y permite que se use con más frecuencia. Gracias
comer

1
También puede haber sobrecarga en algunos casos utilizando la __getitem__sintaxis de []over numpy.whereo de numpy.take. Dado __getitem__que también debe admitir el corte, hay algunos gastos generales. He visto diferencias de velocidad notables al trabajar con las estructuras de datos de Python Pandas e indexar lógicamente columnas muy grandes. En esos casos, si no necesita cortar, entonces takey en whererealidad son mejores.
Ely

24

Respuesta anterior es un poco confuso. Le da las UBICACIONES (todas) de donde su declaración es verdadera.

entonces:

>>> a = np.arange(100)
>>> np.where(a > 30)
(array([31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
       48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
       65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
       82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98,
       99]),)
>>> np.where(a == 90)
(array([90]),)

a = a*40
>>> np.where(a > 1000)
(array([26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42,
       43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
       60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,
       77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93,
       94, 95, 96, 97, 98, 99]),)
>>> a[25]
1000
>>> a[26]
1040

Lo uso como alternativa a list.index (), pero también tiene muchos otros usos. Nunca lo he usado con matrices 2D.

http://docs.scipy.org/doc/numpy/reference/generated/numpy.where.html

Nueva respuesta Parece que la persona preguntaba algo más fundamental.

La pregunta era cómo USTED podría implementar algo que le permita a una función (como dónde) saber qué se solicitó.

Primero tenga en cuenta que llamar a cualquiera de los operadores de comparación es algo interesante.

a > 1000
array([False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True`,  True,  True,  True,  True,  True,  True,  True,  True,  True], dtype=bool)`

Esto se hace sobrecargando el método "__gt__". Por ejemplo:

>>> class demo(object):
    def __gt__(self, item):
        print item


>>> a = demo()
>>> a > 4
4

Como puede ver, "a> 4" era un código válido.

Puede obtener una lista completa y documentación de todas las funciones sobrecargadas aquí: http://docs.python.org/reference/datamodel.html

Algo que es increíble es lo sencillo que es hacer esto. TODAS las operaciones en Python se realizan de esa manera. Decir a> b es equivalente a a. gt (b)!


3
Sin embargo, esta sobrecarga del operador de comparación no parece funcionar bien con expresiones lógicas más complejas; por ejemplo, no puedo hacerlo np.where(a > 30 and a < 50)o np.where(30 < a < 50)porque termina tratando de evaluar el Y lógico de dos matrices de valores booleanos, lo cual no tiene sentido. ¿Hay alguna forma de escribir tal condición np.where?
davidA

@meowsqueaknp.where((a > 30) & (a < 50))
tibalt

¿Por qué np.where () devuelve una lista en su ejemplo?
Andreas Yankopolus

0

np.wheredevuelve una tupla de longitud igual a la dimensión del ndarray numérico en el que se llama (en otras palabras ndim) y cada elemento de la tupla es un ndarray numérico de índices de todos esos valores en el ndarray inicial para el que la condición es Verdadera. (No confunda la dimensión con la forma)

Por ejemplo:

x=np.arange(9).reshape(3,3)
print(x)
array([[0, 1, 2],
      [3, 4, 5],
      [6, 7, 8]])
y = np.where(x>4)
print(y)
array([1, 2, 2, 2], dtype=int64), array([2, 0, 1, 2], dtype=int64))


y es una tupla de longitud 2 porque x.ndimes 2. El primer elemento de la tupla contiene números de fila de todos los elementos mayores que 4 y el segundo elemento contiene números de columna de todos los elementos mayores que 4. Como puede ver, [1,2,2 , 2] corresponde a los números de fila de 5,6,7,8 y [2,0,1,2] corresponde a los números de columna de 5,6,7,8 Tenga en cuenta que el ndarray se atraviesa a lo largo de la primera dimensión (por filas ).

Similar,

x=np.arange(27).reshape(3,3,3)
np.where(x>4)


devolverá una tupla de longitud 3 porque x tiene 3 dimensiones.

Pero espera, ¡hay más en np.where!

cuando se agregan dos argumentos adicionales np.where; hará una operación de reemplazo para todas esas combinaciones de filas y columnas por pares que se obtienen mediante la tupla anterior.

x=np.arange(9).reshape(3,3)
y = np.where(x>4, 1, 0)
print(y)
array([[0, 0, 0],
   [0, 0, 1],
   [1, 1, 1]])
Al usar nuestro sitio, usted reconoce que ha leído y comprende nuestra Política de Cookies y Política de Privacidad.
Licensed under cc by-sa 3.0 with attribution required.