tgan.trainer module¶
GAN Models.
-
class
tgan.trainer.
GANTrainer
(model, input_queue)[source]¶ Bases:
tensorpack.train.tower.TowerTrainer
GanTrainer model.
We need to set
tower_func()
because it’s aTowerTrainer
, and onlyTowerTrainer
supports automatic graph creation for inference during training.If we don’t care about inference during training, using
tower_func()
is not needed. Just callingmodel.build_graph()
directly is OK.- Parameters
input_queue (tensorpack.input_source.QueueInput) – Data input.
model (tgan.GAN.GANModelDesc) – Model to train.
-
class
tgan.trainer.
MultiGPUGANTrainer
(nr_gpu, input, model)[source]¶ Bases:
tensorpack.train.tower.TowerTrainer
A replacement of GANTrainer (optimize d and g one by one) with multi-gpu support.
- Parameters
nr_gpu (int) –
input (tensorpack.input_source.QueueInput) – Data input.
model (tgan.GAN.GANModelDesc) – Model to train.
-
class
tgan.trainer.
SeparateGANTrainer
(input, model, d_period=1, g_period=1)[source]¶ Bases:
tensorpack.train.tower.TowerTrainer
A GAN trainer which runs two optimization ops with a certain ratio.
- Parameters
input (tensorpack.input_source.QueueInput) – Data input.
model (tgan.GAN.GANModelDesc) – Model to train.
d_period (int) – period of each d_opt run
g_period (int) – period of each g_opt run