Causal Variational Autoencoder consists of a model and guide stochastic function where the model uses the decoder network to produce the image and the guide uses the encoder network to produce the latent representation that could have produced the image. In the training mode, the learn-able weights of the encoder and decoder are fine-tuned through backpropagation by using the Adam optimizer.
In Training mode, we use both the model and the guide. The model and the decoder work in conjunction whereas the guide and the encoder work together. In the model, we sample all the nodes of the DAG conditioned on the labels and observed images. In the guide, we only sample the unobserved attributes of both actor and the reactor like their strength, attack, and defense by making use of the conditional probability distribution we got using gRain.
Neural Network Architecture
The architecture selected to learn the distribution of the images is a variational autoencoder. The autoencoder architecture consists of an encoder that converts an image to latent space and a decoder which converts the latent space to an image. Since we are working with images we need convolution layers in the encoder to capture the distribution of the data, and using convolution neural networks for images is a well-established concept at this point in time. The decoder model takes in an encoded representation of the image and also the observed labels to generate an image back.
Encoder
import torchimport torch.nn as nnclassEncoder(nn.Module):def__init__(self,z_dim,hidden_dim=1024,num_labels=17):super().__init__() self.cnn =get_cnn_encoder(image_channels=3)# Currently this returns only for 1024 hidden dimensions. Need to change that# setup the two linear transformations used self.fc21 = nn.Linear(hidden_dim+num_labels, z_dim) self.fc22 = nn.Linear(hidden_dim+num_labels, z_dim)# setup the non-linearities self.softplus = nn.Softplus()defforward(self,x,y):''' Forward module of the encoder '''# compute the hidden units hidden = self.cnn(x) hidden = self.softplus(hidden)# This should return a [1, 1024] vector.# then return a mean vector and a (positive) square root covariance# each of size batch_size x z_dim hidden = torch.cat([hidden, y], dim=-1) z_loc = self.fc21(hidden) z_scale = torch.exp(self.fc22(hidden))return z_loc, z_scale
In the encoder architecture, we use multiple blocks of convolution, batch normalization, and ReLU activation units. Each block works with certain image channels as we keep on increasing the channel depth and decreasing the image height and width. We represent a 400*400 image with 3 channels as a tensor with 1024 elements. The get_cnn_encoder function can be very simple or complex and can be suited according to the user needs. Apart from that, the rest can be very similar to what's been given above.
Decoder
import torchimport torch.nn as nnclassDecoder(nn.Module):def__init__(self,z_dim,hidden_dim,num_labels=17):super().__init__() self.cnn_decoder =get_seq_decoder(hidden_dim, 3)# image_channels is 3# setup the two linear transformations used self.fc1 = nn.Linear(z_dim+num_labels, hidden_dim)#self.fc2 = nn.Linear(hidden_dim, hidden_dim)#self.fc21 = nn.Linear(hidden_dim, 400)# setup the non-linearities self.softplus = nn.Softplus()defforward(self,z,y):# define the forward computation on the latent z# first compute the hidden units concat_z = torch.cat([z, y], dim=-1) hidden = self.softplus(self.fc1(concat_z))#hidden = self.softplus(self.fc2(hidden))# return the parameter for the output Bernoulli# each is of size batch_size x 784 loc_img = self.cnn_decoder(hidden)return loc_img
In the decoder architecture, we start with the 1024 element tensor and we apply convolution transpose to decrease the channel size and increase the height and width of the resulting image back to the size of the original image. The kernel and stride are chosen so that the image shapes are consistent.
VAE
We have a separate model and guide in training as we observe data during training. In our training model, we implement our DAG as mentioned in the example in the DAG section.
The only modification we do to the sample statements in the model function is that we condition the sample statements on the training data. One of the model statement is mentioned below
actor = pyro.sample("actor", dist.OneHotCategorical(self.cpts["character"]), obs=actorObs)
Here, we sample the actor from a categorical distribution, one hot encoded with a certain prior probability as mentioned in the CPT section and condition with the observed label, using the obsargument.
In the guide function, we sample unobserved nodes of the DAG. In our example, we don't observe the strength, attack, and defense attributes of the actor and the reactor in the image. Hence, we use the guide function and learn their posterior distributions. We compute the conditional probability of the unobserved nodes, indexed by the values of their parents and children nodes. For actor_strength, the parents of the node are actor and actor type and the children is the action. These 3 entities are observed in our training data.
import torchimport torch.nn as nnimport pyroimport pyro.distributions as distfrom.encoder import Encoderfrom.decoder import Decoderfrom utils.utils import return_cpts, return_values, return_inverse_cptsclassVAE(nn.Module):def__init__(self,z_dim=128,hidden_dim=1024,use_cuda=False,num_labels=17):super().__init__() self.output_size = num_labels# create the encoder and decoder networks self.encoder =Encoder(z_dim, hidden_dim, num_labels) self.decoder =Decoder(z_dim, hidden_dim, num_labels)# 3 channel image. self.values =return_values() self.cpts =return_cpts() self.inverse_cpts =return_inverse_cpts()if use_cuda:# calling cuda() here will put all the parameters of# the encoder and decoder networks into gpu memory self.cuda() self.use_cuda = use_cuda self.z_dim = z_dim# define the model p(x|z)p(z)defmodel(self,x,y,actorObs,reactorObs,actor_typeObs,reactor_typeObs,actionObs,reactionObs):# register PyTorch module `decoder` with Pyro pyro.module("decoder", self.decoder) options =dict(dtype=x.dtype, device=x.device)with pyro.plate("data", x.shape[0]):# setup hyperparameters for prior p(z)# decode the latent code z# The label y is supervised, sample from the# constant prior, otherwise, observe the value (i.e. score it against the constant prior)''' Causal Model '''''' The below should basically be a concatenation of actor's action and reactor's reaction. ''' actor = pyro.sample("actor", dist.OneHotCategorical(self.cpts["character"]), obs=actorObs).cuda() act_idx = actor[...,:].nonzero()[:,1].cuda() reactor = pyro.sample("reactor", dist.OneHotCategorical(self.cpts["character"]), obs=reactorObs).cuda() rct_idx = reactor[...,:].nonzero()[:,1].cuda()# To choose the type of Satyr or Golem (type 1, 2 or 3. This translates to different image of that character.) actor_type = pyro.sample("actor_type", dist.OneHotCategorical(self.cpts["type"][act_idx]), obs=actor_typeObs).cuda() act_typ_idx = actor_type[...,:].nonzero()[:,1].cuda() reactor_type = pyro.sample("reactor_type", dist.OneHotCategorical(self.cpts["type"][rct_idx]), obs=reactor_typeObs).cuda() rct_typ_idx = reactor_type[...,:].nonzero()[:,1].cuda()# To choose the strength, defense and attack based on the character and its type. Either Low or High actor_strength = pyro.sample("actor_strength", dist.Categorical(self.cpts["strength"][act_idx, act_typ_idx])).cuda() actor_defense = pyro.sample("actor_defense", dist.Categorical(self.cpts["defense"][act_idx, act_typ_idx])).cuda() actor_attack = pyro.sample("actor_attack", dist.Categorical(self.cpts["attack"][act_idx, act_typ_idx])).cuda()# To choose the character's(actor, who starts the fight) action based on the strength, defense and attack capabilities actor_action = pyro.sample("actor_action", dist.OneHotCategorical(self.cpts["action"][actor_strength, actor_defense, actor_attack]), obs=actionObs).cuda()# Converting onehot categorical to categorical value sampled_actor_action = actor_action[...,:].nonzero()[:,1].cuda()# To choose the other character's strength, defense and attack based on the character and its type reactor_strength = pyro.sample("reactor_strength", dist.Categorical(self.cpts["strength"][rct_idx, rct_typ_idx])).cuda() reactor_defense = pyro.sample("reactor_defense", dist.Categorical(self.cpts["defense"][rct_idx, rct_typ_idx])).cuda() reactor_attack = pyro.sample("reactor_attack", dist.Categorical(self.cpts["attack"][rct_idx, rct_typ_idx])).cuda()# To choose the character's (reactor, who reacts to the actor's action in a duel) reaction based on its own strength, defense , attack and the other character's action. reactor_reaction = pyro.sample("reactor_reaction", dist.OneHotCategorical(self.cpts["reaction"][reactor_strength, reactor_defense, reactor_attack, sampled_actor_action]), obs=reactionObs).cuda()#Modiying actor/reactor type tensor sizes to match the original num_labels.#actor_type = modify_type_tensor(actor_type, act_idx)#reactor_type = modify_type_tensor(reactor_type, rct_idx) ys = torch.cat([actor, actor_type, actor_action, reactor, reactor_type, reactor_reaction], dim=-1).cuda()''' Basically, the following should be a concatenation of actor's action and reactor's reaction ''' z_loc = torch.zeros(x.shape[0], self.z_dim, dtype=x.dtype, device=x.device) z_scale = torch.ones(x.shape[0], self.z_dim, dtype=x.dtype, device=x.device)# sample from prior (value will be sampled by guide when computing the ELBO) z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))''' Important:: decoder produces an image by giving latent input and sampled labels conditioned on observed data ''' loc_img = self.decoder.forward(z,ys)# score against actual images. Since its 3 channel image, we do to_event(3) pyro.sample("obs", dist.Bernoulli(loc_img).to_event(3), obs=x)# return the loc so we can visualize it laterreturn loc_img# define the guide (i.e. variational distribution) q(z|x)defguide(self,x,y,actorObs,reactorObs,actor_typeObs,reactor_typeObs,actionObs,reactionObs):# register PyTorch module `encoder` with Pyro pyro.module("encoder", self.encoder)with pyro.plate("data", x.shape[0]): act_idx = actorObs[...,:].nonzero()[:,1].cuda() rct_idx = reactorObs[...,:].nonzero()[:,1].cuda() act_typ_idx = actor_typeObs[...,:].nonzero()[:,1].cuda() rct_typ_idx = reactor_typeObs[...,:].nonzero()[:,1].cuda() action, reaction = torch.nonzero(actionObs)[:,1].cuda(), torch.nonzero(reactionObs)[:,1].cuda()# use the encoder to get the parameters used to define q(z|x) actor_strength = pyro.sample("actor_strength", dist.Categorical(self.inverse_cpts["action_strength"][action, act_typ_idx, act_idx])).cuda() actor_defense = pyro.sample("actor_defense", dist.Categorical(self.inverse_cpts["action_defense"][action, act_typ_idx, act_idx])).cuda() actor_attack = pyro.sample("actor_attack", dist.Categorical(self.inverse_cpts["action_attack"][action, act_typ_idx, act_idx])).cuda() reactor_strength = pyro.sample("reactor_strength", dist.Categorical(self.inverse_cpts["reaction_strength"][reaction, rct_typ_idx, rct_idx])).cuda() reactor_defense = pyro.sample("reactor_defense", dist.Categorical(self.inverse_cpts["reaction_defense"][reaction, rct_typ_idx, rct_idx])).cuda() reactor_attack = pyro.sample("reactor_attack", dist.Categorical(self.inverse_cpts["reaction_attack"][reaction, rct_typ_idx, rct_idx])).cuda() z_loc, z_scale = self.encoder.forward(x,y)# y -> action and reaction# sample the latent code z pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
Training Mode
We use a stochastic variational inference algorithm with an exponentially decaying learning rate as a way to learn our network parameters.
To do ELBO gradient and do an update step, we use the step function from SVI
# do ELBO gradient and accumulate lossepoch_loss += svi.step(x,y, actor,reactor, actor_type,reactor_type, action, reaction)# The function signature should be similar to model and guide
and to evaluate loss we use evaluate loss function,
# compute ELBO estimate and accumulate loss test_loss += svi.evaluate_loss(x,y, actor,reactor, actor_type,reactor_type, action, reaction)# The function signature should be similar to model and guide
For the full code and instructions, please refer to the following GitHub repo