✍️
Causal Scene Generation
  • Introduction and Motivation
  • Background
    • Programming the Directed Causal Graphical Model
  • Tutorial
    • Introduction to the Causal Image Generation Model
    • Building the Image Generation Directed Graph
    • Specifying the Non-Decoder/Encoder Parameters of the Model
    • Game Character Dataset
    • Putting things together
      • Inference Mode
    • Results
Powered by GitBook
On this page
  1. Tutorial
  2. Putting things together

Inference Mode

How to perform inference on causal generative model

PreviousPutting things togetherNextResults

Last updated 4 years ago

Was this helpful?

CtrlK
  • Overview
  • Inference Model

Was this helpful?

Overview

Inference Mode

In the inference mode, we use the trained decoder network in conjugation with the latent node to generate an image for various probabilistic queries. In inference mode, instead of doing inference using MCMC or HMC, we pre-compute the posterior distributions using analytic methods using gRain package in R. This is not possible in all cases.

Inference Model

def inference_model(self, cpts):
# Here cpts refer to the pre-computed intervention/condition distribution.
actor = pyro.sample("actor", dist.OneHotCategorical(cpts["character"])).cuda()

reactor = pyro.sample("reactor", dist.OneHotCategorical(cpts["character"])).cuda()

# To choose the type of Satyr or Golem (type 1, 2 or 3. This translates to different image of that character.)
actor_type = pyro.sample("actor_type", dist.OneHotCategorical(cpts["type"])).cuda()

reactor_type = pyro.sample("reactor_type", dist.OneHotCategorical(cpts["type"])).cuda()

# To choose the strength, defense and attack based on the character and its type. Either Low or High
actor_strength = pyro.sample("actor_strength", dist.Categorical(cpts["strength"])).cuda()
actor_defense = pyro.sample("actor_defense", dist.Categorical(cpts["defense"])).cuda()
actor_attack = pyro.sample("actor_attack", dist.Categorical(cpts["attack"])).cuda()

# To choose the character's(actor, who starts the fight) action based on the strength, defense and attack capabilities
actor_action = pyro.sample("actor_action", dist.OneHotCategorical(cpts["action"])).cuda()

# Converting onehot categorical to categorical value
# To choose the other character's strength, defense and attack based on the character and its type
reactor_strength = pyro.sample("reactor_strength", dist.Categorical(cpts["strength"])).cuda()
reactor_defense = pyro.sample("reactor_defense", dist.Categorical(cpts["defense"])).cuda()
reactor_attack = pyro.sample("reactor_attack", dist.Categorical(cpts["attack"])).cuda()

# To choose the character's (reactor, who reacts to the actor's action in a duel) reaction based on its own strength, defense , attack and the other character's action.
reactor_reaction = pyro.sample("reactor_reaction", dist.OneHotCategorical(cpts["reaction"])).cuda()

#Modiying actor/reactor type tensor sizes to match the original num_labels.

#actor_type = modify_type_tensor(actor_type, act_idx)
#reactor_type = modify_type_tensor(reactor_type, rct_idx)

ys = torch.cat([actor, actor_type, actor_action, reactor, reactor_type, reactor_reaction], dim=-1).cuda()

z_loc = torch.zeros(1,self.z_dim,dtype=torch.float32).cuda()
z_scale = torch.ones(1, self.z_dim, dtype=torch.float32).cuda()
z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))

loc_img = self.decoder.forward(z,ys)
        # score against actual images

        #print(f"actor is {actor},reactor is {reactor}, actor_type is {actor_type}, reactor_type is {reactor_type},actor_strength is {actor_strength}, actor_defense is {actor_defense},actor_attack is {actor_attack}, actor_action is {actor_action}, sampled_actor_action is {sampled_actor_action}, reactor_strength is {reactor_strength}, reactor_attack is {reactor_attack}, reactor_defense is {reactor_defense},reactor_reaction is {reactor_reaction}, ys is {ys}")
        # return the loc so we can visualize it later
model_attrs = {
    "actor": actor,
    "actor_type": actor_type,
    "action": actor_action,
    "reactor": reactor,
    "reactor_type": reactor_type,
    "reaction": reactor_reaction,
    "actor_attack": actor_attack,
    "actor_strength": actor_strength,
    "actor_defense": actor_defense,
    "reactor_attack": reactor_attack,
    "reactor_strength": reactor_strength,
    "reactor_defense": reactor_defense,
    "ys": ys
}
return loc_img, model_attrs