2016-01-26 14 views
6

Korzystanie z funkcji StratifiedKFold sklearn, czy ktoś może mi pomóc zrozumieć błąd tutaj?StratifiedKFold: IndexError: zbyt wiele indeksów dla tablicy

Domyślam się, że ma to coś wspólnego z moją tablicą wejściową etykiet. Zauważam, że kiedy je drukuję (pierwsze 16 w tym przykładzie) indeksowanie wynosi od 0 do 15, ale nad tekstem nadrukowane jest dodatkowe 0 Nie spodziewałem się. Może jestem po prostu Pythonem noobem, ale to wygląda dziwnie.

Ktoś tu widzi ten głupek?

Dokumentacja: http://scikit-learn.org...StratifiedKFold.html

Kod:

import nltk 
import sklearn 

print('The nltk version is {}.'.format(nltk.__version__)) 
print('The scikit-learn version is {}.'.format(sklearn.__version__)) 

print type(skew_gendata_targets.values), skew_gendata_targets.values.shape 
print skew_gendata_targets.head(16) 

skew_sfold10 = cross_validation.StratifiedKFold(skew_gendata_targets.values, n_folds=10, shuffle=True, random_state=20160121) 

Wynik

The nltk version is 3.1. 
The scikit-learn version is 0.17. 
<type 'numpy.ndarray'> (500L, 1L) 
    0 
0 0 
1 0 
2 0 
3 0 
4 0 
5 0 
6 0 
7 0 
8 0 
9 0 
10 0 
11 0 
12 0 
13 0 
14 1 
15 0 
--------------------------------------------------------------------------- 
IndexError        Traceback (most recent call last) 
<ipython-input-373-653b6010b806> in <module>() 
     8 print skew_gendata_targets.head(16) 
     9 
---> 10 skew_sfold10 = cross_validation.StratifiedKFold(skew_gendata_targets.values, n_folds=10, shuffle=True, random_state=20160121) 
    11 
    12 #print '\nSkewed Generated Dataset (', len(skew_gendata_data), ')' 

d:\Program Files\Anaconda2\lib\site-packages\sklearn\cross_validation.pyc in __init__(self, y, n_folds, shuffle, random_state) 
    531   for test_fold_idx, per_label_splits in enumerate(zip(*per_label_cvs)): 
    532    for label, (_, test_split) in zip(unique_labels, per_label_splits): 
--> 533     label_test_folds = test_folds[y == label] 
    534     # the test split can be too big because we used 
    535     # KFold(max(c, self.n_folds), self.n_folds) instead of 

IndexError: too many indices for array 

Odpowiedz

11

Sprawdź kształt skew_gendata_targets.values. Zobaczysz, że nie jest to tablica 1d (kształt (500,)), tak jak oczekiwał StratifiedKFold, ale raczej tablica (500,1). SKlearn traktuje je osobno, zamiast je zmuszać, aby były takie same. Daj mi znać, jeśli to pomoże

+0

Ten wydruk znajduje się na wyjściu w pytaniu: typ wydruku (skew_gendata_targets.values), skew_gendata_targets.values.shape, to tablica (500,1) numpy. Jestem ćpunem matlabowym wrzuconym do dołu pytonów i nie znam różnicy pomiędzy matrycą 500x1 a matrycą 500xnada/tablicą/rzeczą. Przynajmniej w świecie matlabów nie ma różnicy. –

+2

Tak, to niefortunne i nieco mylące. Różnica jest ważna podczas wykonywania operacji takich jak "*". W jednym przypadku Panda/numpy dokona mnożenia elementarnego, podczas gdy wykona mnożenie macierzy na drugim. Mam nadzieję, że operacja StratifiedKFold zadziałała po przymuszeniu go do macierzy (500,). – Brian

+1

Widzę, zmiana kształtu matricies jest czymś, co Matlaber może zrozumieć, wydaje się, że to naprawiło: np.reshape (skew_gendata_targets.values, [500,]), dzięki! –