2016-07-15 17 views
6

Kiedy obliczyć momenty trzeciego rzędu z matrycy X z N wierszy i n kolumnach Zwykle używam einsum:Alternatywy NumPy einsum

M3 = sp.einsum('ij,ik,il->jkl',X,X,X) /N 

Działa to zwykle w porządku, ale teraz pracuję z większymi wartościami, mianowicie n = 120 i N = 100000 i einsum zwraca następujący błąd:

ValueError: iterator is too large

alternatywą robić pętle zagnieżdżone 3 jest unfeasable, więc zastanawiam się, czy istnieje jakaś alternatywa.

Odpowiedz

4

Należy zauważyć, że obliczenia tego będzie trzeba zrobić co najmniej ~ n × N = 173 miliardów operacji (nie biorąc pod uwagę symetrię), więc będzie to powolny chyba numpy ma dostęp do GPU czy coś. Na nowoczesnym komputerze z procesorem ~ 3 GHz oczekuje się, że całe obliczenie zajmie około 60 sekund, przy założeniu braku przyspieszenia SIMD/równoległego.


Do testowania zacznijmy N = 1000. Będziemy wykorzystywać to do sprawdzenia poprawności i wykonanie:

#!/usr/bin/env python3 

import numpy 
import time 

numpy.random.seed(0) 

n = 120 
N = 1000 
X = numpy.random.random((N, n)) 

start_time = time.time() 

M3 = numpy.einsum('ij,ik,il->jkl', X, X, X) 

end_time = time.time() 

print('check:', M3[2,4,6], '= 125.401852515?') 
print('check:', M3[4,2,6], '= 125.401852515?') 
print('check:', M3[6,4,2], '= 125.401852515?') 
print('check:', numpy.sum(M3), '= 218028826.631?') 
print('total time =', end_time - start_time) 

To trwa około 8 sekund. To jest linia podstawowa.

Zacznijmy z 3 zagnieżdżonych pętli jako alternatywę:

M3 = numpy.zeros((n, n, n)) 
for j in range(n): 
    for k in range(n): 
     for l in range(n): 
      M3[j,k,l] = numpy.sum(X[:,j] * X[:,k] * X[:,l]) 
# ~27 seconds 

To trwa około pół minuty, nie jest dobre! Jednym z powodów jest fakt, że są to właściwie cztery zagnieżdżone pętle: numpy.sum można również uznać za pętlę.

Zauważmy, że suma może być przekształcony w iloczynu skalarnego usunąć ten 4th pętlę:

M3 = numpy.zeros((n, n, n)) 
for j in range(n): 
    for k in range(n): 
     for l in range(n): 
      M3[j,k,l] = X[:,j] * X[:,k] @ X[:,l] 
# 14 seconds 

teraz znacznie lepiej, ale wciąż powolne. Ale zauważamy, że produkt kropki można zmienić na mnożenie macierzy w celu usunięcia jednej pętli:

M3 = numpy.zeros((n, n, n)) 
for j in range(n): 
    for k in range(n): 
     M3[j,k] = X[:,j] * X[:,k] @ X 
# ~0.5 seconds 

Huh? Teraz jest to o wiele bardziej wydajne niż einsum! Możemy również sprawdzić, czy odpowiedź rzeczywiście powinna być poprawna.

Czy możemy pójść dalej? Tak! Możemy wyeliminować pętlę k przez:

M3 = numpy.zeros((n, n, n)) 
for j in range(n): 
    Y = numpy.repeat(X[:,j], n).reshape((N, n)) 
    M3[j] = (Y * X).T @ X 
# ~0.3 seconds 

możemy również użyć Broadcasting (tj a * [b,c] == [a*b, a*c] dla każdego rzędu X), aby uniknąć wykonując numpy.repeat (dzięki @Divakar):

M3 = numpy.zeros((n, n, n)) 
for j in range(n): 
    Y = X[:,j].reshape((N, 1)) 
    ## or, equivalently: 
    # Y = X[:, numpy.newaxis, j] 
    M3[j] = (Y * X).T @ X 
# ~0.16 seconds 

Jeśli skalowanie to do N = 100000 program ma zająć 16 sekund, co mieści się w teoretycznym limicie, więc wyeliminowanie j może nie pomóc zbyt wiele (ale to może utrudnić zrozumienie kodu). Moglibyśmy zaakceptować to jako ostateczne rozwiązanie.


Uwaga: Jeśli używasz Python 2, a @ b jest równoważna a.dot(b).

+0

świetna odpowiedź, dziękuję! –

+0

Świetny pomysł naprawdę. Jeśli mogę dodać tu trochę transmisji, możemy uniknąć tworzenia "Y" i bezpośrednio uzyskać wynik iteratywny: '(X [:, None, j] * X) .T @ X'. To powinno nam zwiększyć wydajność. – Divakar

+0

@Divakar: Dzięki! Zaktualizowano. – kennytm