Mam serwer z wieloma procesorami graficznymi i chcę w pełni z nich korzystać podczas wnioskowania modelu wewnątrz aplikacji java. Domyślnie tensorflow zajmuje wszystkie dostępne GPU, ale używa tylko pierwszego.Jądro Multi-GPU Tensorflow Java
mogę myśleć o trzech opcji rozwiązania tego problemu:
Ogranicz widoczność urządzenia na poziomie procesu, a mianowicie za pomocą
CUDA_VISIBLE_DEVICES
zmienną środowiskową.Wymagałoby to uruchomienia kilku wystąpień aplikacji Java i rozpowszechniania ruchu między nimi. Nie ten kuszący pomysł.
uruchomienie kilku sesji wewnątrz jednej aplikacji i spróbować przypisać jedno urządzenie do każdego z nich poprzez
ConfigProto
:public class DistributedPredictor { private Predictor[] nested; private int[] counters; // ... public DistributedPredictor(String modelPath, int numDevices, int numThreadsPerDevice) { nested = new Predictor[numDevices]; counters = new int[numDevices]; for (int i = 0; i < nested.length; i++) { nested[i] = new Predictor(modelPath, i, numDevices, numThreadsPerDevice); } } public Prediction predict(Data data) { int i = acquirePredictorIndex(); Prediction result = nested[i].predict(data); releasePredictorIndex(i); return result; } private synchronized int acquirePredictorIndex() { int i = argmin(counters); counters[i] += 1; return i; } private synchronized void releasePredictorIndex(int i) { counters[i] -= 1; } } public class Predictor { private Session session; public Predictor(String modelPath, int deviceIdx, int numDevices, int numThreadsPerDevice) { GPUOptions gpuOptions = GPUOptions.newBuilder() .setVisibleDeviceList("" + deviceIdx) .setAllowGrowth(true) .build(); ConfigProto config = ConfigProto.newBuilder() .setGpuOptions(gpuOptions) .setInterOpParallelismThreads(numDevices * numThreadsPerDevice) .build(); byte[] graphDef = Files.readAllBytes(Paths.get(modelPath)); Graph graph = new Graph(); graph.importGraphDef(graphDef); this.session = new Session(graph, config.toByteArray()); } public Prediction predict(Data data) { // ... } }
Takie podejście wydaje się działać dobrze na pierwszy rzut oka. Jednak sesje sporadycznie ignorują opcję
setVisibleDeviceList
i wszystkie idą na pierwsze urządzenie powodujące awarię z powodu braku pamięci.Zbuduj model w stylu wielopiętrowym w pythonie, używając specyfikacji
tf.device()
. Po stronie Java podaj różne wieże wewnątrz współdzielonej sesji.Czuje się uciążliwy i idiomatycznie zły dla mnie.
UPDATE: jako @ash zaproponowano tam jeszcze innej opcji:
przypisać odpowiednie urządzenie dla każdej operacji istniejącego wykresu poprzez zmianę określenia (
graphDef
).Aby to zrobić, można dostosować kod Metoda 2:
public class Predictor { private Session session; public Predictor(String modelPath, int deviceIdx, int numDevices, int numThreadsPerDevice) { byte[] graphDef = Files.readAllBytes(Paths.get(modelPath)); graphDef = setGraphDefDevice(graphDef, deviceIdx) Graph graph = new Graph(); graph.importGraphDef(graphDef); ConfigProto config = ConfigProto.newBuilder() .setAllowSoftPlacement(true) .build(); this.session = new Session(graph, config.toByteArray()); } private static byte[] setGraphDefDevice(byte[] graphDef, int deviceIdx) throws InvalidProtocolBufferException { String deviceString = String.format("/gpu:%d", deviceIdx); GraphDef.Builder builder = GraphDef.parseFrom(graphDef).toBuilder(); for (int i = 0; i < builder.getNodeCount(); i++) { builder.getNodeBuilder(i).setDevice(deviceString); } return builder.build().toByteArray(); } public Prediction predict(Data data) { // ... } }
Podobnie jak w innych wymienionych podejść, ten nie uwolnij mnie od ręcznie dystrybucji danych między urządzeniami. Ale przynajmniej działa stabilnie i jest porównywalnie łatwa do wdrożenia. Ogólnie rzecz biorąc, wygląda to jak (prawie) normalna technika.
Czy istnieje elegancki sposób na wykonanie tak podstawowej rzeczy za pomocą interfejsu API tensorflow java? Wszelkie pomysły będą mile widziane.