2017-08-22 71 views
5

Potrzebuję wykonać integrację numeryczną w 6D w pythonie. Ponieważ funkcja scipy.integrate.nquad działa wolno, obecnie próbuję przyspieszyć działanie, definiując integand jako scipy.LowLevelCallable z Numba.W jaki sposób można wdrożyć wywołanie C z Numba w celu wydajnej integracji z NQAD?

I był w stanie w tym 1D z scipy.integrate.quad przez replikację przykład podany here:

import numpy as np 
from numba import cfunc 
from scipy import integrate 

def integrand(t): 
    return np.exp(-t)/t**2 

nb_integrand = cfunc("float64(float64)")(integrand) 

# regular integration 
%timeit integrate.quad(integrand, 1, np.inf) 

10000 pętle, najlepiej od 3: 128 mikrosekundy na pętli

# integration with compiled function 
%timeit integrate.quad(nb_integrand.ctypes, 1, np.inf) 

100000 pętli, najlepiej z 3: 7.08 μs na pętlę

Kiedy chcę to zrobić teraz z nquadem, dokumentacja nquad mówi:

If the user desires improved integration performance, then f may be a scipy.LowLevelCallable with one of the signatures:

double func(int n, double *xx) 
double func(int n, double *xx, void *user_data) 

where n is the number of extra parameters and args is an array of doubles of the additional parameters, the xx array contains the coordinates. The user_data is the data contained in the scipy.LowLevelCallable.

Ale następujący kod daje mi błąd:

from numba import cfunc 
import ctypes 

def func(n_arg,x): 
    xe = x[0] 
    xh = x[1] 
    return np.sin(2*np.pi*xe)*np.sin(2*np.pi*xh) 

nb_func = cfunc("float64(int64,CPointer(float64))")(func) 

integrate.nquad(nb_func.ctypes, [[0,1],[0,1]], full_output=True) 

błędzie: Quad: Pierwszy argument jest wskaźnikiem funkcji ctypes z nieprawidłowym podpisem

Czy to możliwe, aby skompilować funkcję z Numba że może być używany z nquad bezpośrednio w kodzie i bez definiowania funkcji w pliku zewnętrznym?

Dziękuję bardzo z góry!

Odpowiedz

3

Owijanie funkcji w scipy.LowLevelCallable sprawia nquad szczęśliwa:

si.nquad(sp.LowLevelCallable(nb_func.ctypes), [[0,1],[0,1]], full_output=True) 
# (-2.3958561404687756e-19, 7.002641250699693e-15, {'neval': 1323})