Adversarial infill
The problem
I’ve had a lot of time on my hands recently, because my internet has been broken. During that time, I wanted to become familiar with pyTorch, because it’s neat. I’m also very interested in GANs and adversarial learning in general. I always have trouble getting adversarial learning to work, it’s quite fiddly, with entire pages devoted to tips and tricks, including that the GAN objective everyone writes about is not the one they use. So, with that in mind, I wanted to take the complexity down a notch, by writing an “adversarial infiller”, that takes an image with a part missing, and fills in that missing part.
The code
You can check out the code on github.
I use luigi to manage the task structure. It’s also presented little cleaner here than in the repo - that’s just because I played around with a few ideas before settling on this approach. Happy reading!
The network
The generator
The generator takes an image with a hole in it, and tries to fill in the hole:
INPUT
+--------------+
| |
| +----+ | +----+
| |????| | --> | |
| +----+ | +----+
| |
+--------------+
The image is a four channel image. The fourth channel is set to 1
in the missing section, and zero elsewhere.
The network is quite small and simple
The x
is the image, and the missing
is the fourth channel mentioned above.
The discriminator
The disciminator simply takes entire image (both the patch and the context it came from), and predicts whether or not it is real or generated (1
or 0
)
It’s quite simple as well:
Side note on pytorch
I’ve found pytorch to be easier to use than tensorflow. Because the graph is built dynamically, it’s a lot easer to debug and fiddle with. I can add print statements to parts of the code, for example.. Never underestimate the debugging power of a well placed print statement.
One thing that did trip me up: pytorch totally lets you do numpy style assignment like this:
mat[4:6] = [1,2]
however, this will not propate gradients. Instead, you need to concatenate the data:
mat_new = torch.cat([mat[:4], [1,2], mat[6:]], 0)
Training
To train the code, I basically built two optimizers:
And then optimized the binary cross entropy, with the appropriate ones and zeros:
- When training the generator train with all ones. We want the generator to make good samples, so we teach it to trick the discriminator.
- When training the discriminator, train with ones for true cases, and zeros for generated, because we’re teaching it to discriminate between generator and real values.
Results
It works better than I expected! I trained it on a dataset of images of flowers.
Here’s some samples from the start. The system does poorly:
And from the end, at training batch 4000
The main complaint I have with these results is that they’re a little bit blurry.
And sometimes it just goes nuts and gets everything wrong. Here’s one I found around batch 4000:
Conclusion
This approach worked quite well! I’m going to try scaling it up. I don’t have access to a big GPU, which limits my ability to do these sorts of experiments, but maybe it’s time for a new GPU :). Thanks for reading!