Inference Mode
How to perform inference on causal generative model
Last updated
How to perform inference on causal generative model
Last updated
In the inference mode, we use the trained decoder network in conjugation with the latent node to generate an image for various probabilistic queries. In inference mode, instead of doing inference using MCMC or HMC, we pre-compute the posterior distributions using analytic methods using gRain package in R. This is not possible in all cases.
def inference_model(self, cpts):
# Here cpts refer to the pre-computed intervention/condition distribution.
actor = pyro.sample("actor", dist.OneHotCategorical(cpts["character"])).cuda()
reactor = pyro.sample("reactor", dist.OneHotCategorical(cpts["character"])).cuda()
# To choose the type of Satyr or Golem (type 1, 2 or 3. This translates to different image of that character.)
actor_type = pyro.sample("actor_type", dist.OneHotCategorical(cpts["type"])).cuda()
reactor_type = pyro.sample("reactor_type", dist.OneHotCategorical(cpts["type"])).cuda()
# To choose the strength, defense and attack based on the character and its type. Either Low or High
actor_strength = pyro.sample("actor_strength", dist.Categorical(cpts["strength"])).cuda()
actor_defense = pyro.sample("actor_defense", dist.Categorical(cpts["defense"])).cuda()
actor_attack = pyro.sample("actor_attack", dist.Categorical(cpts["attack"])).cuda()
# To choose the character's(actor, who starts the fight) action based on the strength, defense and attack capabilities
actor_action = pyro.sample("actor_action", dist.OneHotCategorical(cpts["action"])).cuda()
# Converting onehot categorical to categorical value
# To choose the other character's strength, defense and attack based on the character and its type
reactor_strength = pyro.sample("reactor_strength", dist.Categorical(cpts["strength"])).cuda()
reactor_defense = pyro.sample("reactor_defense", dist.Categorical(cpts["defense"])).cuda()
reactor_attack = pyro.sample("reactor_attack", dist.Categorical(cpts["attack"])).cuda()
# To choose the character's (reactor, who reacts to the actor's action in a duel) reaction based on its own strength, defense , attack and the other character's action.
reactor_reaction = pyro.sample("reactor_reaction", dist.OneHotCategorical(cpts["reaction"])).cuda()
#Modiying actor/reactor type tensor sizes to match the original num_labels.
#actor_type = modify_type_tensor(actor_type, act_idx)
#reactor_type = modify_type_tensor(reactor_type, rct_idx)
ys = torch.cat([actor, actor_type, actor_action, reactor, reactor_type, reactor_reaction], dim=-1).cuda()
z_loc = torch.zeros(1,self.z_dim,dtype=torch.float32).cuda()
z_scale = torch.ones(1, self.z_dim, dtype=torch.float32).cuda()
z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
loc_img = self.decoder.forward(z,ys)
# score against actual images
#print(f"actor is {actor},reactor is {reactor}, actor_type is {actor_type}, reactor_type is {reactor_type},actor_strength is {actor_strength}, actor_defense is {actor_defense},actor_attack is {actor_attack}, actor_action is {actor_action}, sampled_actor_action is {sampled_actor_action}, reactor_strength is {reactor_strength}, reactor_attack is {reactor_attack}, reactor_defense is {reactor_defense},reactor_reaction is {reactor_reaction}, ys is {ys}")
# return the loc so we can visualize it later
model_attrs = {
"actor": actor,
"actor_type": actor_type,
"action": actor_action,
"reactor": reactor,
"reactor_type": reactor_type,
"reaction": reactor_reaction,
"actor_attack": actor_attack,
"actor_strength": actor_strength,
"actor_defense": actor_defense,
"reactor_attack": reactor_attack,
"reactor_strength": reactor_strength,
"reactor_defense": reactor_defense,
"ys": ys
}
return loc_img, model_attrs