2017-04-02 77 views
6

I skonstruowanej sekwencyjnego modelu Keras z 35000 próbek wejściowych i 20 predyktory rozkład klas wyjściowe dane testowe są:Keras wyjście przewidywania wieloklasowego jest ograniczony do jednej grupy

  • Class_0 = 5,6%
  • Class_1 = 7,7%
  • Class_2 = 35,6%
  • Class_3 = 45,7%
  • Class_4 = 5,4%

Po przekształceniu wyjść na binarną macierz klas wykorzystującą (np_utils.to_categoryical) dokładność treningu wynosi około 54%, gdy wykonuję model dopasowania z danymi testowymi (15000 próbek), wszystkie prognozy (100%) zdarzają się w tej samej klasie który jest najwyższym wystąpieniem class_3 w wynikach szkolenia, jaka jest przyczyna tego błędu i nie ma jednej prognozy dla innych klas? jak sprawić, by model był wrażliwy na przewidywanie mniejszej liczby klas i poprawić dokładność, szczególnie jeśli zbieżność w danych treningowych jest niska jak 1 - 3%.

model = Sequential() 
model.add(Dense(40, input_dim=20, activation='relu')) 
model.add(Dropout(0.2)) 
model.add(Dense(10, activation='relu')) 
model.add(Dense(5, activation='softmax')) 
model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy']) 

model.fit(X, Y, epochs=500, verbose=1) 

Odpowiedz

3

Najprostszym sposobem, aby zrównoważyć swoje straty jest:

  1. Zdefiniuj class_weights dict:

    class_weights_dict = {0: 1.0/0_class_freq, 1: 1.0/1_class_freq, ... } 
    

    Gdzie i_class_freq jest częstotliwość i-th klasie.

  2. Zmień swoją funkcję fit do:

    model.fit(X, Y, epochs=500, verbose=1, class_weight=class_weights_dict) 
    

Model określone powyżej powinny być równoważne do modelu z Bayessian ponownemu równoważeniu klas.

0

Jednym ze sposobów rozwiązania jest nadmierne pobieranie próbek w ramach reprezentowanych przykładów klas. tj. Jeśli masz dwie klasy A (66,6%) i B (33,3%), to próbka B dwa razy w porównaniu do A. Aby uczynić ją jeszcze prostszą, możesz zmienić swój zestaw danych, duplikując B raz i tworząc zbiór danych jak A + 2 * B.

Można również zmodyfikować funkcję utraty, która zwiększa wagę, gdy błędnie klasyfikuje nie reprezentowane klasy.

+0

vikasreddy, ty za sugestie. Czy możesz rozwinąć więcej na temat modyfikacji funkcji straty, aby nadać większą wagę, i dla pierwszej części pytania, jakiekolwiek wyjaśnienie, dlaczego sieć ignoruje przewidywanie reszty (54,3%) innych klas w szczególności klasy_2, która ma dużą liczbę obserwacji (35 %)? – Ray

+0

Jednym z powodów, dla których mogę myśleć, jest to, że model nie jest wystarczająco złożony, aby całkowicie zminimalizować funkcję utraty, ponieważ zatrzymuje się w lokalnym minimum, co zdarza się przewidywać, że wszystkie przykłady znajdują się w klasie_3. – vikasreddy

+0

Jako implementacja funkcji ważonej straty patrz https://github.com/fchollet/keras/issues/2115#issuecomment-204060456 – vikasreddy