2017-01-29 38 views

Odpowiedz

21

Tak, jest to możliwe. Po prostu twórz obiekty samodzielnie, np.

import torch.utils.data as data_utils 

train = data_utils.TensorDataset(features, targets) 
train_loader = data_utils.DataLoader(train, batch_size=50, shuffle=True) 

gdzie features i targets są tensory. features musi być 2-D, tj. Macierz, gdzie każda linia reprezentuje jedną próbkę szkoleniową, a targets może być 1-D lub 2-D, w zależności od tego, czy próbujesz przewidzieć skalar lub wektor.

Nadzieję, że pomaga!


EDIT: odpowiedź na pytanie @ Sarthak za

Zasadniczo tak. Jeśli utworzyć obiekt typu TensorData, wówczas konstruktor bada, czy pierwsze wymiary tensora funkcji (co jest faktycznie zwane data_tensor) oraz tensor docelową (tzw target_tensor) mają taką samą długość:

assert data_tensor.size(0) == target_tensor.size(0) 

jednak jeśli chcesz później wprowadzić te dane do sieci neuronowej, musisz zachować ostrożność. Podczas gdy warstwy splotu działają na dane takie jak twoje, (myślę), wszystkie inne typy warstw oczekują danych w formie macierzy. Tak więc, jeśli napotkasz taki problem, to prostym rozwiązaniem byłoby przekonwertowanie twojego zestawu danych 4D (podanego jako pewnego rodzaju tensor, np. FloatTensor) do macierzy przy użyciu metody view. Dla Państwa 5000xnxnx3 zbiorze, będzie to wyglądać tak:

2d_dataset = 4d_dataset.view(5000, -1) 

(. Wartość -1 mówi PyTorch aby dowiedzieć się długość drugim wymiarze automatycznie)

+0

Mam funkcje 3D: 2D dla obrazu i jeden dodatkowy wymiar dla kanałów kolorów. Czy nadal będzie działać, jeśli przekażę funkcje jako 5000xnxnx3. 5000 to liczba punktów danych nxnx3 to rozmiar obrazu – Sarthak

+0

Zasadniczo tak, ale sprawdź edycję mojej odpowiedzi. – pho7

+0

Zestaw danych 4d można przekazać jako funkcje, dlatego nie ma potrzeby stosowania instrukcji widoku. – Sarthak

5

Można łatwo zrobić to być rozszerzenie klasy data.Dataset . Zgodnie z API, wszystko co musisz zrobić, to zaimplementować dwie funkcje: __getitem__ i __len__.

Następnie można zawinąć zestaw danych za pomocą DataLoadera, jak pokazano w interfejsie API oraz w odpowiedzi @ pho7.

Myślę, że klasa ImageFolder jest odniesieniem. Zobacz kod here.