2017-08-03 78 views
11

Załóżmy, że mam scipy.sparse.csr_matrix reprezentujących wartości poniżejscipy Rzadki cumSum

[[0 0 1 2 0 3 0 4] 
[1 0 0 2 0 3 4 0]] 

Chcę obliczyć skumulowaną sumę wartości niezerowych w miejscu, które mogłyby zmienić tablicę:

[[0 0 1 3 0 6 0 10] 
[1 0 0 3 0 6 10 0]] 

Rzeczywiste wartości to nie 1, 2, 3, ...

Liczba niezerowych wartości w każdym rzędzie jest mało prawdopodobna.

Jak zrobić to szybko?

Aktualny Program:

import scipy.sparse 
import numpy as np 

# sparse data 
a = scipy.sparse.csr_matrix(
    [[0,0,1,2,0,3,0,4], 
    [1,0,0,2,0,3,4,0]], 
    dtype=int) 

# method 
indptr = a.indptr 
data = a.data 
for i in range(a.shape[0]): 
    st = indptr[i] 
    en = indptr[i + 1] 
    np.cumsum(data[st:en], out=data[st:en]) 

# print result 
print(a.todense()) 

Wynik:

[[ 0 0 1 3 0 6 0 10] 
[ 1 0 0 3 0 6 10 0]] 
+0

dla kodu pracy, należy dodawać do https://codereview.stackexchange.com/ – Alexander

+6

Istnieje wiele bardziej 'NumPy/scipy' na oczy TAK niż na CR. Szybkie pytania na temat działającego kodu są cały czas odbierane przez SO, zwłaszcza, że ​​pakiety kodu są nieco wyspecjalizowane. – hpaulj

+6

'@r xu', to, co pokazujesz, wygląda dobrze. Zastosowanie wiersza 'cumsum 'po wierszu jest naprawdę tylko drogą do zrobienia. A twoje użycie 'out' jest sprytne. Istnieje też "itptr", który może trochę poprawić prędkość. – hpaulj

Odpowiedz

1

Jak ow ten sposób zamiast

a = np.array([[0,0,1,2,0,3,0,4], 
       [1,0,0,2,0,3,4,0]], dtype=int) 

b = a.copy() 
b[b > 0] = 1 
z = np.cumsum(a,axis=1) 
print(z*b) 

Daje

array([[ 0, 0, 1, 3, 0, 6, 0, 10], 
    [ 1, 0, 0, 3, 0, 6, 10, 0]]) 

Doing rzadki

def sparse(a): 
    a = scipy.sparse.csr_matrix(a) 

    indptr = a.indptr 
    data = a.data 
    for i in range(a.shape[0]): 
     st = indptr[i] 
     en = indptr[i + 1] 
     np.cumsum(data[st:en], out=data[st:en]) 


In[1]: %timeit sparse(a) 
10000 loops, best of 3: 167 µs per loop 

Korzystanie mnożenie

def mult(a): 
    b = a.copy() 
    b[b > 0] = 1 
    z = np.cumsum(a, axis=1) 
    z * b 

In[2]: %timeit mult(a) 
100000 loops, best of 3: 5.93 µs per loop 
+0

Sparse to ten sam kod, który podałeś, a mnożenie to kod, który podałem – DJK

+0

To oznacza ... metodę działającą na rzadkiej macierzy zamiast gęstej macierzy. Twoja funkcja "mult" działa obecnie na gęstej macierzy. –