Putting things together

Training the model and putting all the above things together.

Overview

Causal Variational Autoencoder consists of a model and guide stochastic function where the model uses the decoder network to produce the image and the guide uses the encoder network to produce the latent representation that could have produced the image. In the training mode, the learn-able weights of the encoder and decoder are fine-tuned through backpropagation by using the Adam optimizer.

In Training mode, we use both the model and the guide. The model and the decoder work in conjunction whereas the guide and the encoder work together. In the model, we sample all the nodes of the DAG conditioned on the labels and observed images. In the guide, we only sample the unobserved attributes of both actor and the reactor like their strength, attack, and defense by making use of the conditional probability distribution we got using gRain.

Neural Network Architecture

The architecture selected to learn the distribution of the images is a variational autoencoder. The autoencoder architecture consists of an encoder that converts an image to latent space and a decoder which converts the latent space to an image. Since we are working with images we need convolution layers in the encoder to capture the distribution of the data, and using convolution neural networks for images is a well-established concept at this point in time. The decoder model takes in an encoded representation of the image and also the observed labels to generate an image back.

Encoder

import torch
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, z_dim, hidden_dim=1024, num_labels=17):
        super().__init__()
        self.cnn = get_cnn_encoder(image_channels=3) # Currently this returns only for 1024 hidden dimensions. Need to change that
        # setup the two linear transformations used
        self.fc21 = nn.Linear(hidden_dim+num_labels, z_dim)
        self.fc22 = nn.Linear(hidden_dim+num_labels, z_dim)
        # setup the non-linearities
        self.softplus = nn.Softplus()

    def forward(self, x,y):
        '''
        Forward module of the encoder
        '''
        # compute the hidden units
        hidden = self.cnn(x)
        hidden = self.softplus(hidden) # This should return a [1, 1024] vector.
        # then return a mean vector and a (positive) square root covariance

        # each of size batch_size x z_dim
        hidden = torch.cat([hidden, y], dim=-1)
        z_loc = self.fc21(hidden)
        z_scale = torch.exp(self.fc22(hidden))
        return z_loc, z_scale

In the encoder architecture, we use multiple blocks of convolution, batch normalization, and ReLU activation units. Each block works with certain image channels as we keep on increasing the channel depth and decreasing the image height and width. We represent a 400*400 image with 3 channels as a tensor with 1024 elements. The get_cnn_encoder function can be very simple or complex and can be suited according to the user needs. Apart from that, the rest can be very similar to what's been given above.

Decoder

import torch
import torch.nn as nn

class Decoder(nn.Module):
    def __init__(self, z_dim, hidden_dim, num_labels=17):
        super().__init__()
        self.cnn_decoder = get_seq_decoder(hidden_dim, 3) # image_channels is 3
        # setup the two linear transformations used
        self.fc1 = nn.Linear(z_dim+num_labels, hidden_dim)
        #self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        #self.fc21 = nn.Linear(hidden_dim, 400)
        # setup the non-linearities
        self.softplus = nn.Softplus()

    def forward(self, z, y):
        # define the forward computation on the latent z
        # first compute the hidden units
        concat_z = torch.cat([z, y], dim=-1)
        hidden = self.softplus(self.fc1(concat_z))
        #hidden = self.softplus(self.fc2(hidden))
        # return the parameter for the output Bernoulli
        # each is of size batch_size x 784
        loc_img = self.cnn_decoder(hidden)
        return loc_img

In the decoder architecture, we start with the 1024 element tensor and we apply convolution transpose to decrease the channel size and increase the height and width of the resulting image back to the size of the original image. The kernel and stride are chosen so that the image shapes are consistent.

VAE

We have a separate model and guide in training as we observe data during training. In our training model, we implement our DAG as mentioned in the example in the DAG section.

pageProgramming the Directed Causal Graphical Model

The only modification we do to the sample statements in the model function is that we condition the sample statements on the training data. One of the model statement is mentioned below

actor = pyro.sample("actor", 
    dist.OneHotCategorical(self.cpts["character"]), obs=actorObs)

Here, we sample the actor from a categorical distribution, one hot encoded with a certain prior probability as mentioned in the CPT section and condition with the observed label, using the obs argument.

pageSpecifying the Non-Decoder/Encoder Parameters of the Model

In the guide function, we sample unobserved nodes of the DAG. In our example, we don't observe the strength, attack, and defense attributes of the actor and the reactor in the image. Hence, we use the guide function and learn their posterior distributions. We compute the conditional probability of the unobserved nodes, indexed by the values of their parents and children nodes. For actor_strength, the parents of the node are actor and actor type and the children is the action. These 3 entities are observed in our training data.

One of the guide statement is mentioned below

actor_strength = pyro.sample("actor_strength", 
    dist.Categorical(
        self.inverse_cpts["action_strength"][action, actor_type, actor]
    ))

The full implementation of VAE is coded below.

import torch
import torch.nn as nn
import pyro
import pyro.distributions as dist
from .encoder import Encoder
from .decoder import Decoder
from utils.utils import return_cpts, return_values, return_inverse_cpts

class VAE(nn.Module):
    def __init__(self, z_dim=128, hidden_dim=1024, use_cuda=False, num_labels=17):
        super().__init__()
        self.output_size = num_labels
        # create the encoder and decoder networks
        self.encoder = Encoder(z_dim, hidden_dim, num_labels)
        self.decoder = Decoder(z_dim, hidden_dim, num_labels) # 3 channel image.
        self.values = return_values()
        self.cpts = return_cpts()
        self.inverse_cpts = return_inverse_cpts()


        if use_cuda:
            # calling cuda() here will put all the parameters of
            # the encoder and decoder networks into gpu memory
            self.cuda()
        self.use_cuda = use_cuda
        self.z_dim = z_dim

    # define the model p(x|z)p(z)
    def model(self, x,y, actorObs, reactorObs, actor_typeObs, reactor_typeObs, actionObs, reactionObs):
        # register PyTorch module `decoder` with Pyro
        pyro.module("decoder", self.decoder)
        options = dict(dtype=x.dtype, device=x.device)
        with pyro.plate("data", x.shape[0]):
            # setup hyperparameters for prior p(z)
            
            # decode the latent code z
            # The label y  is supervised, sample from the
            # constant prior, otherwise, observe the value (i.e. score it against the constant prior)
            '''
            Causal Model
            '''

            '''
            The below should basically be a concatenation of actor's action and reactor's reaction.
            '''

            actor = pyro.sample("actor", dist.OneHotCategorical(self.cpts["character"]), obs=actorObs).cuda()
            act_idx = actor[..., :].nonzero()[:, 1].cuda()

            reactor = pyro.sample("reactor", dist.OneHotCategorical(self.cpts["character"]), obs=reactorObs).cuda()
            rct_idx = reactor[..., :].nonzero()[:, 1].cuda()


            # To choose the type of Satyr or Golem (type 1, 2 or 3. This translates to different image of that character.)
            actor_type = pyro.sample("actor_type", dist.OneHotCategorical(self.cpts["type"][act_idx]), obs=actor_typeObs).cuda()
            act_typ_idx = actor_type[..., :].nonzero()[:, 1].cuda()

            reactor_type = pyro.sample("reactor_type", dist.OneHotCategorical(self.cpts["type"][rct_idx]), obs=reactor_typeObs).cuda()
            rct_typ_idx = reactor_type[..., :].nonzero()[:, 1].cuda()


            # To choose the strength, defense and attack based on the character and its type. Either Low or High
            actor_strength = pyro.sample("actor_strength", dist.Categorical(self.cpts["strength"][act_idx, act_typ_idx])).cuda()
            actor_defense = pyro.sample("actor_defense", dist.Categorical(self.cpts["defense"][act_idx, act_typ_idx])).cuda()
            actor_attack = pyro.sample("actor_attack", dist.Categorical(self.cpts["attack"][act_idx, act_typ_idx])).cuda()

            # To choose the character's(actor, who starts the fight) action based on the strength, defense and attack capabilities
            actor_action = pyro.sample("actor_action", dist.OneHotCategorical(self.cpts["action"][actor_strength, actor_defense, actor_attack]), obs=actionObs).cuda()

            # Converting onehot categorical to categorical value
            sampled_actor_action = actor_action[..., :].nonzero()[:, 1].cuda()
            # To choose the other character's strength, defense and attack based on the character and its type
            reactor_strength = pyro.sample("reactor_strength", dist.Categorical(self.cpts["strength"][rct_idx, rct_typ_idx])).cuda()
            reactor_defense = pyro.sample("reactor_defense", dist.Categorical(self.cpts["defense"][rct_idx, rct_typ_idx])).cuda()
            reactor_attack = pyro.sample("reactor_attack", dist.Categorical(self.cpts["attack"][rct_idx, rct_typ_idx])).cuda()

            # To choose the character's (reactor, who reacts to the actor's action in a duel) reaction based on its own strength, defense , attack and the other character's action.
            reactor_reaction = pyro.sample("reactor_reaction", dist.OneHotCategorical(self.cpts["reaction"][reactor_strength, reactor_defense, reactor_attack, sampled_actor_action]), obs=reactionObs).cuda()

            #Modiying actor/reactor type tensor sizes to match the original num_labels.

            #actor_type = modify_type_tensor(actor_type, act_idx)
            #reactor_type = modify_type_tensor(reactor_type, rct_idx)

            ys = torch.cat([actor, actor_type, actor_action, reactor, reactor_type, reactor_reaction], dim=-1).cuda()

            '''
            Basically, the following should be a concatenation of actor's action and reactor's reaction
            '''

            z_loc = torch.zeros(x.shape[0], self.z_dim, dtype=x.dtype, device=x.device)
            z_scale = torch.ones(x.shape[0], self.z_dim, dtype=x.dtype, device=x.device)
            # sample from prior (value will be sampled by guide when computing the ELBO)
            z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
            '''
            Important:: decoder produces an image by giving latent input and 
            sampled labels conditioned on observed data
            '''
            loc_img = self.decoder.forward(z,ys)
            # score against actual images. Since its 3 channel image, we do to_event(3)
            pyro.sample("obs", dist.Bernoulli(loc_img).to_event(3), obs=x)
            # return the loc so we can visualize it later
            return loc_img

    # define the guide (i.e. variational distribution) q(z|x)
    def guide(self, x, y, actorObs, reactorObs, actor_typeObs, reactor_typeObs, actionObs, reactionObs):
        # register PyTorch module `encoder` with Pyro
        pyro.module("encoder", self.encoder)
        with pyro.plate("data", x.shape[0]):
            act_idx = actorObs[..., :].nonzero()[:, 1].cuda()
            rct_idx = reactorObs[..., :].nonzero()[:, 1].cuda()
            act_typ_idx = actor_typeObs[..., :].nonzero()[:, 1].cuda()
            rct_typ_idx = reactor_typeObs[..., :].nonzero()[:, 1].cuda()
            action, reaction = torch.nonzero(actionObs)[:, 1].cuda(), torch.nonzero(reactionObs)[:, 1].cuda()
            # use the encoder to get the parameters used to define q(z|x)
            actor_strength = pyro.sample("actor_strength", dist.Categorical(self.inverse_cpts["action_strength"][action, act_typ_idx, act_idx])).cuda()
            actor_defense = pyro.sample("actor_defense", dist.Categorical(self.inverse_cpts["action_defense"][action, act_typ_idx, act_idx])).cuda()
            actor_attack = pyro.sample("actor_attack", dist.Categorical(self.inverse_cpts["action_attack"][action, act_typ_idx, act_idx])).cuda()

            reactor_strength = pyro.sample("reactor_strength", dist.Categorical(self.inverse_cpts["reaction_strength"][reaction, rct_typ_idx, rct_idx])).cuda()
            reactor_defense = pyro.sample("reactor_defense", dist.Categorical(self.inverse_cpts["reaction_defense"][reaction, rct_typ_idx, rct_idx])).cuda()
            reactor_attack = pyro.sample("reactor_attack", dist.Categorical(self.inverse_cpts["reaction_attack"][reaction, rct_typ_idx, rct_idx])).cuda()

            z_loc, z_scale = self.encoder.forward(x,y) # y -> action and reaction
            # sample the latent code z
            pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))

Training Mode

We use a stochastic variational inference algorithm with an exponentially decaying learning rate as a way to learn our network parameters.

# setup the VAE
vae = VAE(use_cuda=True, num_labels = 17)

# setup the exponential learning rate scheduler
optimizer = torch.optim.Adam
scheduler = pyro.optim.ExponentialLR({'optimizer': optimizer, 
   'optim_args': {'lr': args.learning_rate}, 'gamma': 0.1})


# setup the inference algorithm
elbo = JitTrace_ELBO() if args.jit else Trace_ELBO()
svi = SVI(vae.model, vae.guide, scheduler, loss=elbo)

   

To do ELBO gradient and do an update step, we use the step function from SVI

# do ELBO gradient and accumulate loss
epoch_loss += svi.step(x,y, actor,reactor, 
   actor_type,reactor_type, action, reaction) 
# The function signature should be similar to model and guide

and to evaluate loss we use evaluate loss function,

 # compute ELBO estimate and accumulate loss
 test_loss += svi.evaluate_loss(x,y, actor,reactor, 
   actor_type,reactor_type, action, reaction)
  # The function signature should be similar to model and guide

For the full code and instructions, please refer to the following GitHub repo

Last updated