W zeszłym tygodniu zadawałem pytania pokrewne dotyczące tego stosu, próbując wyizolować rzeczy, których nie rozumiałem, używając dekoratora @jit z Numba w Pythonie. Uderzam jednak w ścianę, więc po prostu napiszę cały problem.Odległość między segmentami za pomocą numba jit, Python
Problemem jest obliczenie minimalnej odległości między parami dużej liczby segmentów. Segmenty są reprezentowane przez ich początek i punkty końcowe w 3D. Matematycznie, każdy segment jest sparametryzowany jako [AB] = A + (B-A) * s, gdzie s w [0,1], i A i B są początkiem i punktami końcowymi segmentu. W przypadku dwóch takich segmentów można obliczyć minimalną odległość, a formułę podano here.
Już zdemaskowałem ten problem na innym thread, a otrzymana odpowiedź dotyczyła zastąpienia podwójnych pętli mojego kodu przez wektoryzację problemu, co jednak mogłoby spowodować problemy z pamięcią dla dużych zestawów segmentów. Dlatego postanowiłem trzymać się pętli i używać zamiast tego numby jit.
Ponieważ rozwiązanie tego problemu wymaga dużej ilości produktów dotowych, a produkt numpy's dotępny jest not supported by numba, zacząłem od wdrożenia mojego własnego produktu w postaci kropek 3D.
import numpy as np
from numba import jit, autojit, double, float64, float32, void, int32
def my_dot(a,b):
res = a[0]*b[0]+a[1]*b[1]+a[2]*b[2]
return res
dot_jit = jit(double(double[:], double[:]))(my_dot) #I know, it's not of much use here.
Funkcja obliczenie minimalnej odległości wszystkich par w segmentach N przyjmuje jako dane wejściowe tablicę NX6 współrzędnych (6)
def compute_stuff(array_to_compute):
N = len(array_to_compute)
con_mat = np.zeros((N,N))
for i in range(N):
for j in range(i+1,N):
p0 = array_to_compute[i,0:3]
p1 = array_to_compute[i,3:6]
q0 = array_to_compute[j,0:3]
q1 = array_to_compute[j,3:6]
s = (dot_jit((p1-p0),(q1-q0))*dot_jit((q1-q0),(p0-q0)) - dot_jit((q1-q0),(q1-q0))*dot_jit((p1-p0),(p0-q0)))/(dot_jit((p1-p0),(p1-p0))*dot_jit((q1-q0),(q1-q0)) - dot_jit((p1-p0),(q1-q0))**2)
t = (dot_jit((p1-p0),(p1-p0))*dot_jit((q1-q0),(p0-q0)) -dot_jit((p1-p0),(q1-q0))*dot_jit((p1-p0),(p0-q0)))/(dot_jit((p1-p0),(p1-p0))*dot_jit((q1-q0),(q1-q0)) - dot_jit((p1-p0),(q1-q0))**2)
con_mat[i,j] = np.sum((p0+(p1-p0)*s-(q0+(q1-q0)*t))**2)
return con_mat
fast_compute_stuff = jit(double[:,:](double[:,:]))(compute_stuff)
więc compute_stuff (Arg) wykonuje jako argument 2D np.array (double [:,:]), wykonuje kilka obsługiwanych przez numba (?) operacji i zwraca inne 2D np.array (double [:,:]). Jednak w przypadku każdej pętli otrzymuję 134 i 123 ms. Czy możesz rzucić trochę światła na to, dlaczego nie mogę przyspieszyć mojej funkcji? Wszelkie opinie będą mile widziane.
To jest * bardzo * mało prawdopodobne, że będziesz w stanie pokonać 'np.dot' używając kompilatora JIT firmy numba. 'np.dot' jest po prostu cienkim opakowaniem, które wywołuje funkcje BLAS' * gemm/* gemv', które są mocno zoptymalizowane i często wielowątkowe. Najlepszym wyjściem jest prawdopodobnie upewnienie się, że numpy jest połączona z najszybszą wielowątkową biblioteką BLAS, którą możesz zdobyć (prawdopodobnie albo MKL firmy Intel, albo OpenBLAS). –
Problem nie bije np.dot, problem polega na tym, że jeśli kompilator jit uruchamia się w wywołaniu np.dot, to nie może wywnioskować swojego typu zwracanego, a następnie nie przyspieszy mojej całej funkcji (i btw, dot_jit I kodowany jest szybszy niż np.dot dla produktów 3d skalarnych wektorowych) – Mathusalem
Czy wyposażyłeś swój oryginalny kod w linię? Podejrzewam, że większość czasu spędzasz wewnątrz 'n.dot' tak, więc nie powinieneś oczekiwać dużej wydajności, ponieważ JIT odciąga koszty od zagnieżdżonych pętli' for'. –