assertAlmostEqual en la prueba unitaria de Python para colecciones de flotantes


81

El método assertAlmostEqual (x, y) en el marco de pruebas unitarias de Python prueba si xy yson aproximadamente iguales asumiendo que son flotantes.

El problema assertAlmostEqual()es que solo funciona con flotadores. Estoy buscando un método assertAlmostEqual()que funcione en listas de flotantes, conjuntos de flotantes, diccionarios de flotantes, tuplas de flotantes, listas de tuplas de flotantes, conjuntos de listas de flotantes, etc.

Por ejemplo, vamos x = 0.1234567890, y = 0.1234567891. xy yson casi iguales porque coinciden en todos y cada uno de los dígitos excepto en el último. Por tanto self.assertAlmostEqual(x, y)es Trueporque assertAlmostEqual()funciona para flotadores.

Estoy buscando uno más genérico assertAlmostEquals()que también evalúe las siguientes llamadas a True:

  • self.assertAlmostEqual_generic([x, x, x], [y, y, y]).
  • self.assertAlmostEqual_generic({1: x, 2: x, 3: x}, {1: y, 2: y, 3: y}).
  • self.assertAlmostEqual_generic([(x,x)], [(y,y)]).

¿Existe tal método o tengo que implementarlo yo mismo?

Aclaraciones:

  • assertAlmostEquals()tiene un parámetro opcional llamado placesy los números se comparan calculando la diferencia redondeada al número de decimales places. Por defecto places=7, por self.assertAlmostEqual(0.5, 0.4)lo tanto, es Falso mientras que self.assertAlmostEqual(0.12345678, 0.12345679)es Verdadero. Mi especulativo assertAlmostEqual_generic()debería tener la misma funcionalidad.

  • Dos listas se consideran casi iguales si tienen números casi iguales en exactamente el mismo orden. formalmente, for i in range(n): self.assertAlmostEqual(list1[i], list2[i]).

  • De manera similar, dos conjuntos se consideran casi iguales si se pueden convertir en listas casi iguales (asignando un orden a cada conjunto).

  • De manera similar, dos diccionarios se consideran casi iguales si el conjunto de claves de cada diccionario es casi igual al conjunto de claves del otro diccionario, y para cada par de claves casi iguales hay un valor correspondiente casi igual.

  • En general: considero dos colecciones casi iguales si son iguales, excepto por algunos flotadores correspondientes que son casi iguales entre sí. En otras palabras, realmente me gustaría comparar objetos pero con una precisión baja (personalizada) al comparar flotantes en el camino.


¿Cuál es el punto de usar floatclaves en el diccionario? Dado que no puede estar seguro de obtener exactamente el mismo flotador, nunca encontrará sus artículos mediante la búsqueda. Y si no utiliza la función de búsqueda, ¿por qué no utilizar una lista de tuplas en lugar de un diccionario? El mismo argumento se aplica a los conjuntos.
máximo

Solo un enlace a la fuente de assertAlmostEqual.
djvg

Respuestas:


71

si no le importa usar NumPy (que viene con su Python (x, y)), es posible que desee ver el np.testingmódulo que define, entre otros, una assert_almost_equalfunción.

La firma es np.testing.assert_almost_equal(actual, desired, decimal=7, err_msg='', verbose=True)

>>> x = 1.000001
>>> y = 1.000002
>>> np.testing.assert_almost_equal(x, y)
AssertionError: 
Arrays are not almost equal to 7 decimals
ACTUAL: 1.000001
DESIRED: 1.000002
>>> np.testing.assert_almost_equal(x, y, 5)
>>> np.testing.assert_almost_equal([x, x, x], [y, y, y], 5)
>>> np.testing.assert_almost_equal((x, x, x), (y, y, y), 5)

4
Eso está cerca, pero numpy.testinglos métodos casi iguales funcionan solo en números, matrices, tuplas y listas. No funcionan con diccionarios, conjuntos y colecciones de colecciones.
snakile

De hecho, pero eso es un comienzo. Además, tiene acceso al código fuente que puede modificar para permitir la comparación de diccionarios, colecciones, etc. np.testing.assert_equalreconoce los diccionarios como argumentos, por ejemplo (incluso si la comparación la realiza un ==que no funcionará para usted).
Pierre GM

Por supuesto, aún tendrá problemas al comparar conjuntos, como mencionó @BrenBarn.
Pierre GM

Tenga en cuenta que la documentación actual de assert_array_almost_equalrecomienda usar assert_allclose, assert_array_almost_equal_nulpo en su assert_array_max_ulplugar.
phunehehe

10

A partir de Python 3.5, puede comparar usando

math.isclose(a, b, rel_tol=1e-9, abs_tol=0.0)

Como se describe en pep-0485 . La implementación debe ser equivalente a

abs(a-b) <= max( rel_tol * max(abs(a), abs(b)), abs_tol )

7
¿Cómo ayuda esto a comparar los contenedores con los flotadores, sobre lo que se estaba preguntando?
máximo

9

Así es como he implementado una is_almost_equal(first, second)función genérica :

Primero, duplique los objetos que necesita comparar ( firsty second), pero no haga una copia exacta: corte los dígitos decimales insignificantes de cualquier flotador que encuentre dentro del objeto.

Ahora que tiene copias de firsty secondpara las cuales desaparecieron los dígitos decimales insignificantes, simplemente compare firsty seconduse el ==operador.

Supongamos que tenemos una cut_insignificant_digits_recursively(obj, places)función que se duplica objpero deja solo los placesdígitos decimales más significativos de cada flotante en el original obj. Aquí hay una implementación funcional de is_almost_equals(first, second, places):

from insignificant_digit_cutter import cut_insignificant_digits_recursively

def is_almost_equal(first, second, places):
    '''returns True if first and second equal. 
    returns true if first and second aren't equal but have exactly the same
    structure and values except for a bunch of floats which are just almost
    equal (floats are almost equal if they're equal when we consider only the
    [places] most significant digits of each).'''
    if first == second: return True
    cut_first = cut_insignificant_digits_recursively(first, places)
    cut_second = cut_insignificant_digits_recursively(second, places)
    return cut_first == cut_second

Y aquí hay una implementación funcional de cut_insignificant_digits_recursively(obj, places):

def cut_insignificant_digits(number, places):
    '''cut the least significant decimal digits of a number, 
    leave only [places] decimal digits'''
    if  type(number) != float: return number
    number_as_str = str(number)
    end_of_number = number_as_str.find('.')+places+1
    if end_of_number > len(number_as_str): return number
    return float(number_as_str[:end_of_number])

def cut_insignificant_digits_lazy(iterable, places):
    for obj in iterable:
        yield cut_insignificant_digits_recursively(obj, places)

def cut_insignificant_digits_recursively(obj, places):
    '''return a copy of obj except that every float loses its least significant 
    decimal digits remaining only [places] decimal digits'''
    t = type(obj)
    if t == float: return cut_insignificant_digits(obj, places)
    if t in (list, tuple, set):
        return t(cut_insignificant_digits_lazy(obj, places))
    if t == dict:
        return {cut_insignificant_digits_recursively(key, places):
                cut_insignificant_digits_recursively(val, places)
                for key,val in obj.items()}
    return obj

El código y sus pruebas unitarias están disponibles aquí: https://github.com/snakile/approximate_comparator . Doy la bienvenida a cualquier mejora y corrección de errores.


En lugar de comparar flotadores, ¿está comparando cuerdas? De acuerdo ... Pero entonces, ¿no sería más fácil establecer un formato común? ¿Me gusta fmt="{{0:{0}f}}".format(decimals)y usa este fmtformato para "encadenar" tus flotadores?
Pierre GM

1
Esto se ve bien, pero un pequeño punto: placesda el número de lugares decimales, no el número de cifras significativas. Por ejemplo, comparar 1024.123y 1023.999con 3 significantes debería devolver igual, pero con 3 decimales no lo son.
Rodney Richardson

1
@pir, la licencia no está definida. Vea la respuesta de snalile en este número en el que dice que no tiene tiempo para elegir / agregar una licencia, pero otorga permisos de uso / modificación. Gracias por compartir esto, por cierto.
Jérôme

1
@RodneyRichardson, sí, estos son lugares decimales, como en assertAlmostEqual : "Tenga en cuenta que estos métodos redondean los valores al número dado de lugares decimales (es decir, como la función round ()) y dígitos no significativos".
Jérôme

2
@ Jérôme, gracias por el comentario. Acabo de agregar una licencia del MIT.
Snakile

5

Si no le importa usar el numpypaquete, entonces numpy.testingtiene el assert_array_almost_equalmétodo.

Esto funciona para array_likeobjetos, por lo que está bien para matrices, listas y tuplas de flotantes, pero no funciona para conjuntos y diccionarios.

La documentación está aquí .


4

No existe tal método, tendría que hacerlo usted mismo.

Para listas y tuplas, la definición es obvia, pero tenga en cuenta que los otros casos que menciona no son obvios, por lo que no es de extrañar que no se proporcione dicha función. Por ejemplo, es {1.00001: 1.00002}casi igual a {1.00002: 1.00001}? Manejar tales casos requiere tomar una decisión sobre si la cercanía depende de claves, valores o ambos. En el caso de conjuntos, es poco probable que encuentre una definición significativa, ya que los conjuntos no están ordenados, por lo que no existe la noción de elementos "correspondientes".


BrenBarn: He añadido aclaraciones a la pregunta. La respuesta a su pregunta es que {1.00001: 1.00002}casi es igual {1.00002: 1.00001}si y solo si 1.00001 casi es igual a 1.00002. De forma predeterminada, no son casi iguales (porque la precisión predeterminada es de 7 lugares decimales), pero para un valor lo suficientemente pequeño places, son casi iguales.
snakile

1
@BrenBarn: En mi opinión, el uso de claves de tipo floatin dict debería desalentarse (y tal vez incluso no permitirse) por razones obvias. La igualdad aproximada de dict debe basarse únicamente en valores; el marco de prueba no necesita preocuparse por el uso incorrecto de las floatclaves for. En el caso de los conjuntos, se pueden ordenar antes de la comparación y se pueden comparar las listas ordenadas.
máximo

2

Es posible que tenga que implementarlo usted mismo, aunque es cierto que la lista y los conjuntos se pueden iterar de la misma manera, los diccionarios son una historia diferente, usted itera sus claves, no sus valores, y el tercer ejemplo me parece un poco ambiguo, ¿quiere decir compare cada valor dentro del conjunto, o cada valor de cada conjunto.

aquí hay un fragmento de código simple.

def almost_equal(value_1, value_2, accuracy = 10**-8):
    return abs(value_1 - value_2) < accuracy

x = [1,2,3,4]
y = [1,2,4,5]
assert all(almost_equal(*values) for values in zip(x, y))

Gracias, la solución es correcta para listas y tuplas, pero no para otros tipos de colecciones (o colecciones anidadas). Vea las aclaraciones que agregué a la pregunta. Espero que mi intención esté clara ahora. Dos conjuntos son casi iguales si se hubieran considerado iguales en un mundo donde los números no se miden con mucha precisión.
snakile

0

Ninguna de estas respuestas me funciona. El siguiente código debería funcionar para colecciones, clases, clases de datos y tuplas con nombre de Python. Puede que haya olvidado algo, pero hasta ahora esto funciona para mí.

import unittest
from collections import namedtuple, OrderedDict
from dataclasses import dataclass
from typing import Any


def are_almost_equal(o1: Any, o2: Any, max_abs_ratio_diff: float, max_abs_diff: float) -> bool:
    """
    Compares two objects by recursively walking them trough. Equality is as usual except for floats.
    Floats are compared according to the two measures defined below.

    :param o1: The first object.
    :param o2: The second object.
    :param max_abs_ratio_diff: The maximum allowed absolute value of the difference.
    `abs(1 - (o1 / o2)` and vice-versa if o2 == 0.0. Ignored if < 0.
    :param max_abs_diff: The maximum allowed absolute difference `abs(o1 - o2)`. Ignored if < 0.
    :return: Whether the two objects are almost equal.
    """
    if type(o1) != type(o2):
        return False

    composite_type_passed = False

    if hasattr(o1, '__slots__'):
        if len(o1.__slots__) != len(o2.__slots__):
            return False
        if any(not are_almost_equal(getattr(o1, s1), getattr(o2, s2),
                                    max_abs_ratio_diff, max_abs_diff)
            for s1, s2 in zip(sorted(o1.__slots__), sorted(o2.__slots__))):
            return False
        else:
            composite_type_passed = True

    if hasattr(o1, '__dict__'):
        if len(o1.__dict__) != len(o2.__dict__):
            return False
        if any(not are_almost_equal(k1, k2, max_abs_ratio_diff, max_abs_diff)
            or not are_almost_equal(v1, v2, max_abs_ratio_diff, max_abs_diff)
            for ((k1, v1), (k2, v2))
            in zip(sorted(o1.__dict__.items()), sorted(o2.__dict__.items()))
            if not k1.startswith('__')):  # avoid infinite loops
            return False
        else:
            composite_type_passed = True

    if isinstance(o1, dict):
        if len(o1) != len(o2):
            return False
        if any(not are_almost_equal(k1, k2, max_abs_ratio_diff, max_abs_diff)
            or not are_almost_equal(v1, v2, max_abs_ratio_diff, max_abs_diff)
            for ((k1, v1), (k2, v2)) in zip(sorted(o1.items()), sorted(o2.items()))):
            return False

    elif any(issubclass(o1.__class__, c) for c in (list, tuple, set)):
        if len(o1) != len(o2):
            return False
        if any(not are_almost_equal(v1, v2, max_abs_ratio_diff, max_abs_diff)
            for v1, v2 in zip(o1, o2)):
            return False

    elif isinstance(o1, float):
        if o1 == o2:
            return True
        else:
            if max_abs_ratio_diff > 0:  # if max_abs_ratio_diff < 0, max_abs_ratio_diff is ignored
                if o2 != 0:
                    if abs(1.0 - (o1 / o2)) > max_abs_ratio_diff:
                        return False
                else:  # if both == 0, we already returned True
                    if abs(1.0 - (o2 / o1)) > max_abs_ratio_diff:
                        return False
            if 0 < max_abs_diff < abs(o1 - o2):  # if max_abs_diff < 0, max_abs_diff is ignored
                return False
            return True

    else:
        if not composite_type_passed:
            return o1 == o2

    return True


class EqualityTest(unittest.TestCase):

    def test_floats(self) -> None:
        o1 = ('hi', 3, 3.4)
        o2 = ('hi', 3, 3.400001)
        self.assertTrue(are_almost_equal(o1, o2, 0.0001, 0.0001))
        self.assertFalse(are_almost_equal(o1, o2, 0.00000001, 0.00000001))

    def test_ratio_only(self):
        o1 = ['hey', 10000, 123.12]
        o2 = ['hey', 10000, 123.80]
        self.assertTrue(are_almost_equal(o1, o2, 0.01, -1))
        self.assertFalse(are_almost_equal(o1, o2, 0.001, -1))

    def test_diff_only(self):
        o1 = ['hey', 10000, 1234567890.12]
        o2 = ['hey', 10000, 1234567890.80]
        self.assertTrue(are_almost_equal(o1, o2, -1, 1))
        self.assertFalse(are_almost_equal(o1, o2, -1, 0.1))

    def test_both_ignored(self):
        o1 = ['hey', 10000, 1234567890.12]
        o2 = ['hey', 10000, 0.80]
        o3 = ['hi', 10000, 0.80]
        self.assertTrue(are_almost_equal(o1, o2, -1, -1))
        self.assertFalse(are_almost_equal(o1, o3, -1, -1))

    def test_different_lengths(self):
        o1 = ['hey', 1234567890.12, 10000]
        o2 = ['hey', 1234567890.80]
        self.assertFalse(are_almost_equal(o1, o2, 1, 1))

    def test_classes(self):
        class A:
            d = 12.3

            def __init__(self, a, b, c):
                self.a = a
                self.b = b
                self.c = c

        o1 = A(2.34, 'str', {1: 'hey', 345.23: [123, 'hi', 890.12]})
        o2 = A(2.34, 'str', {1: 'hey', 345.231: [123, 'hi', 890.121]})
        self.assertTrue(are_almost_equal(o1, o2, 0.1, 0.1))
        self.assertFalse(are_almost_equal(o1, o2, 0.0001, 0.0001))

        o2.hello = 'hello'
        self.assertFalse(are_almost_equal(o1, o2, -1, -1))

    def test_namedtuples(self):
        B = namedtuple('B', ['x', 'y'])
        o1 = B(3.3, 4.4)
        o2 = B(3.4, 4.5)
        self.assertTrue(are_almost_equal(o1, o2, 0.2, 0.2))
        self.assertFalse(are_almost_equal(o1, o2, 0.001, 0.001))

    def test_classes_with_slots(self):
        class C(object):
            __slots__ = ['a', 'b']

            def __init__(self, a, b):
                self.a = a
                self.b = b

        o1 = C(3.3, 4.4)
        o2 = C(3.4, 4.5)
        self.assertTrue(are_almost_equal(o1, o2, 0.3, 0.3))
        self.assertFalse(are_almost_equal(o1, o2, -1, 0.01))

    def test_dataclasses(self):
        @dataclass
        class D:
            s: str
            i: int
            f: float

        @dataclass
        class E:
            f2: float
            f4: str
            d: D

        o1 = E(12.3, 'hi', D('hello', 34, 20.01))
        o2 = E(12.1, 'hi', D('hello', 34, 20.0))
        self.assertTrue(are_almost_equal(o1, o2, -1, 0.4))
        self.assertFalse(are_almost_equal(o1, o2, -1, 0.001))

        o3 = E(12.1, 'hi', D('ciao', 34, 20.0))
        self.assertFalse(are_almost_equal(o2, o3, -1, -1))

    def test_ordereddict(self):
        o1 = OrderedDict({1: 'hey', 345.23: [123, 'hi', 890.12]})
        o2 = OrderedDict({1: 'hey', 345.23: [123, 'hi', 890.0]})
        self.assertTrue(are_almost_equal(o1, o2, 0.01, -1))
        self.assertFalse(are_almost_equal(o1, o2, 0.0001, -1))

0

Todavía lo usaría self.assertEqual()porque sigue siendo el más informativo cuando la mierda golpea el ventilador. Puede hacerlo redondeando, por ejemplo.

self.assertEqual(round_tuple((13.949999999999999, 1.121212), 2), (13.95, 1.12))

donde round_tupleesta

def round_tuple(t: tuple, ndigits: int) -> tuple:
    return tuple(round(e, ndigits=ndigits) for e in t)

def round_list(l: list, ndigits: int) -> list:
    return [round(e, ndigits=ndigits) for e in l]

De acuerdo con los documentos de Python (ver https://stackoverflow.com/a/41407651/1031191 ) puede salirse con la suya con problemas de redondeo como 13.94999999, porque 13.94999999 == 13.95es True.


-1

Un enfoque alternativo es convertir sus datos en una forma comparable, por ejemplo, convirtiendo cada flotante en una cadena con precisión fija.

def comparable(data):
    """Converts `data` to a comparable structure by converting any floats to a string with fixed precision."""
    if isinstance(data, (int, str)):
        return data
    if isinstance(data, float):
        return '{:.4f}'.format(data)
    if isinstance(data, list):
        return [comparable(el) for el in data]
    if isinstance(data, tuple):
        return tuple([comparable(el) for el in data])
    if isinstance(data, dict):
        return {k: comparable(v) for k, v in data.items()}

Entonces tú puedes:

self.assertEquals(comparable(value1), comparable(value2))
Al usar nuestro sitio, usted reconoce que ha leído y comprende nuestra Política de Cookies y Política de Privacidad.
Licensed under cc by-sa 3.0 with attribution required.