Inference Mode
How to perform inference on causal generative model
Last updated
How to perform inference on causal generative model
Last updated
def inference_model(self, cpts):
# Here cpts refer to the pre-computed intervention/condition distribution.
actor = pyro.sample("actor", dist.OneHotCategorical(cpts["character"])).cuda()
reactor = pyro.sample("reactor", dist.OneHotCategorical(cpts["character"])).cuda()
# To choose the type of Satyr or Golem (type 1, 2 or 3. This translates to different image of that character.)
actor_type = pyro.sample("actor_type", dist.OneHotCategorical(cpts["type"])).cuda()
reactor_type = pyro.sample("reactor_type", dist.OneHotCategorical(cpts["type"])).cuda()
# To choose the strength, defense and attack based on the character and its type. Either Low or High
actor_strength = pyro.sample("actor_strength", dist.Categorical(cpts["strength"])).cuda()
actor_defense = pyro.sample("actor_defense", dist.Categorical(cpts["defense"])).cuda()
actor_attack = pyro.sample("actor_attack", dist.Categorical(cpts["attack"])).cuda()
# To choose the character's(actor, who starts the fight) action based on the strength, defense and attack capabilities
actor_action = pyro.sample("actor_action", dist.OneHotCategorical(cpts["action"])).cuda()
# Converting onehot categorical to categorical value
# To choose the other character's strength, defense and attack based on the character and its type
reactor_strength = pyro.sample("reactor_strength", dist.Categorical(cpts["strength"])).cuda()
reactor_defense = pyro.sample("reactor_defense", dist.Categorical(cpts["defense"])).cuda()
reactor_attack = pyro.sample("reactor_attack", dist.Categorical(cpts["attack"])).cuda()
# To choose the character's (reactor, who reacts to the actor's action in a duel) reaction based on its own strength, defense , attack and the other character's action.
reactor_reaction = pyro.sample("reactor_reaction", dist.OneHotCategorical(cpts["reaction"])).cuda()
#Modiying actor/reactor type tensor sizes to match the original num_labels.
#actor_type = modify_type_tensor(actor_type, act_idx)
#reactor_type = modify_type_tensor(reactor_type, rct_idx)
ys = torch.cat([actor, actor_type, actor_action, reactor, reactor_type, reactor_reaction], dim=-1).cuda()
z_loc = torch.zeros(1,self.z_dim,dtype=torch.float32).cuda()
z_scale = torch.ones(1, self.z_dim, dtype=torch.float32).cuda()
z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
loc_img = self.decoder.forward(z,ys)
# score against actual images
#print(f"actor is {actor},reactor is {reactor}, actor_type is {actor_type}, reactor_type is {reactor_type},actor_strength is {actor_strength}, actor_defense is {actor_defense},actor_attack is {actor_attack}, actor_action is {actor_action}, sampled_actor_action is {sampled_actor_action}, reactor_strength is {reactor_strength}, reactor_attack is {reactor_attack}, reactor_defense is {reactor_defense},reactor_reaction is {reactor_reaction}, ys is {ys}")
# return the loc so we can visualize it later
model_attrs = {
"actor": actor,
"actor_type": actor_type,
"action": actor_action,
"reactor": reactor,
"reactor_type": reactor_type,
"reaction": reactor_reaction,
"actor_attack": actor_attack,
"actor_strength": actor_strength,
"actor_defense": actor_defense,
"reactor_attack": reactor_attack,
"reactor_strength": reactor_strength,
"reactor_defense": reactor_defense,
"ys": ys
}
return loc_img, model_attrs