Podemos hacer esto de dos maneras simples . El primero es fácil de codificar, fácil de entender y razonablemente rápido. El segundo es un poco más complicado, pero mucho más eficiente para este tamaño de problema que el primer método u otros enfoques mencionados aquí.
Método 1 : rápido y sucio.
Para obtener una sola observación de la distribución de probabilidad de cada fila, simplemente podemos hacer lo siguiente.
# Q is the cumulative distribution of each row.
Q <- t(apply(P,1,cumsum))
# Get a sample with one observation from the distribution of each row.
X <- rowSums(runif(N) > Q) + 1
Esto produce la distribución acumulativa de cada fila de Py luego muestrea una observación de cada distribución. Tenga en cuenta que si podemos reutilizar P entonces podemos calcular Quna vez y guárdelo para su uso posterior. Sin embargo, la pregunta necesita algo que funcione para otroP en cada iteración
Si necesitas múltiples (n) observaciones de cada fila, luego reemplace la última línea con la siguiente.
# Returns an N x n matrix
X <- replicate(n, rowSums(runif(N) > Q)+1)
En general, esta no es una forma extremadamente eficiente de hacerlo, pero sí aprovecha las R
capacidades de vectorización, que generalmente es el principal determinante de la velocidad de ejecución. También es sencillo de entender.
Método 2 : Concatenar los cdfs.
Supongamos que tenemos una función que toma dos vectores, el segundo de los cuales se clasificó en orden monotónicamente no decreciente y encontró el índice en el segundo vector del límite inferior más grande de cada elemento en el primero. Entonces, podríamos usar esta función y un truco ingenioso: simplemente cree la suma acumulativa de los cdf de todas las filas. Esto da un vector monotónicamente creciente con elementos en el rango[0,N].
Aquí está el código.
i <- 0:(N-1)
# Cumulative function of the cdfs of each row of P.
Q <- cumsum(t(P))
# Find the interval and then back adjust
findInterval(runif(N)+i, Q)-i*K+1
Observe lo que hace la última línea, crea variables aleatorias distribuidas en (0,1),(1,2),…,(N−1,N)y luego llama findInterval
para encontrar el índice del límite inferior más grande de cada entrada. Entonces, esto nos dice que el primer elemento de runif(N)+i
se encontrará entre el índice 1 y el índiceK, el segundo se encontrará entre el índice K+1 y 2K, etc., cada uno según la distribución de la fila correspondiente de P. Luego, necesitamos volver a transformar para volver a colocar cada uno de los índices en el rango{1,…,K}.
Debido a que findInterval
es rápido tanto desde el punto de vista algorítmico como de implementación, este método resulta ser extremadamente eficiente.
Un punto de referencia
En mi vieja computadora portátil (MacBook Pro, 2.66 GHz, 8GB RAM), probé esto con N=10000 y K=100 y generando 5000 muestras de tamaño N, exactamente como se sugiere en la pregunta actualizada, para un total de 50 millones de variantes aleatorias.
El código para el Método 1 tardó casi exactamente 15 minutos en ejecutarse, o alrededor de 55,000 variantes aleatorias por segundo. El código para el Método 2 tardó aproximadamente cuatro minutos y medio en ejecutarse, o alrededor de 183 mil variantes aleatorias por segundo.
Aquí está el código por el bien de la reproducibilidad. (Tenga en cuenta que, como se indica en un comentario,Q se recalcula para cada una de las 5000 iteraciones para simular la situación del OP).
# Benchmark code
N <- 10000
K <- 100
set.seed(17)
P <- matrix(runif(N*K),N,K)
P <- P / rowSums(P)
method.one <- function(P)
{
Q <- t(apply(P,1,cumsum))
X <- rowSums(runif(nrow(P)) > Q) + 1
}
method.two <- function(P)
{
n <- nrow(P)
i <- 0:(n-1)
Q <- cumsum(t(P))
findInterval(runif(n)+i, Q)-i*ncol(P)+1
}
Aquí está la salida.
# Method 1: Timing
> system.time(replicate(5e3, method.one(P)))
user system elapsed
691.693 195.812 899.246
# Method 2: Timing
> system.time(replicate(5e3, method.two(P)))
user system elapsed
182.325 82.430 273.021
Postdata : Al observar el código findInterval
, podemos ver que realiza algunas verificaciones en la entrada para ver si hay NA
entradas o si el segundo argumento no está ordenado. Por lo tanto, si quisiéramos exprimir más el rendimiento de esto, podríamos crear nuestra propia versión modificada findInterval
que elimine estas comprobaciones que son innecesarias en nuestro caso.