Programming the Directed Causal Graphical Model

What is a DAG ? How is a causal model represented in a DAG ? How to implement it and make inference in pyro ?

As mentioned in previous chapters, the relationship between the different entities in the image is described via a DAG and probabilistic queries can be answered by running inference algorithms on the DAG. Below, we take an example and elucidate how to go about answering a few causal queries in pyro before we present the DAG for our use case.

Sample DAG and answering causal queries using pyro

A survey was taken among people and the following data was captured. A DAG indicating the probabilistic relationship between these variables is drawn above

The conditional independence is indirectly encoded in the DAG. We know that Age and Sex are variables that can't be affected by any other variable, at least from a survey point of view. Education is affected by Age and Sex. Education decides what your occupation is and in which city you live. The Mode of travel is affected by Occupation and Residence. The joint probability is given by factorizing the DAG.

Conditional Probability Tables

Initial conditional probabilities can be computed from data or a prior can be assumed.

A_alias = ['adult','old','young']
S_alias = ['F','M']
E_alias = ['high','uni']
O_alias = ['emp','self']
R_alias = ['big','small']
T_alias = ['car','other','train']


A_prob = torch.tensor([0.36,0.16,0.48])
S_prob = torch.tensor([0.55, 0.45])

E_prob = torch.tensor([[[0.64, 0.36], [0.84, 0.16], [0.16, 0.84]],
                     [[0.72, 0.28], [0.89, 0.11], [0.81, 0.19]]])
O_prob = torch.tensor([[0.98, 0.02], [0.97, 0.03]])
R_prob = torch.tensor([[0.72, 0.28], [0.94, 0.06]])
T_prob = torch.tensor([[[0.71, 0.14, 0.15], [0.68, 0.16, 0.16]],
                     [[0.55, 0.08, 0.37], [0.73, 0.25, 0.02]]])

Let's say these probabilities were computed from data. It is very easy to interpret. For the variable age, the probability of being an adult is 0.36, being old is 0.16 and being young is 0.48. For variable R, it's conditioned on the values of E. If Education is high school, then the 0th index is accessed and the probability of residence being in a big city is 0.72 whereas, if Education is University, the probability of being in a big city is 0.94.

These probabilities reflect the current state of the relationships between these entities and can be obtained via data.

Model

All probabilistic programs are built up by composing primitive stochastic functions and deterministic computation. In pyro, these models are defined as a function.

def model():
    A = pyro.sample("A", dist.Categorical(probs=A_prob))
    S = pyro.sample("S", dist.Categorical(probs=S_prob))
    E = pyro.sample("E", dist.Categorical(probs=E_prob[S][A]))
    O = pyro.sample("O", dist.Categorical(probs=O_prob[E]))
    R = pyro.sample("R", dist.Categorical(probs=R_prob[E]))
    T = pyro.sample("T", dist.Categorical(probs=T_prob[R][O]))
    return{'A': A,'S': S,'E': E,'O': O,'R': R,'T': T}

We encode the DAG in the above function. The variable name that pyro looks out for is within the sample statement. During inference, a program transformation takes place and this function gets called multiple times. Hence, we get a trace, the samples, of all the pyro variables during inference and those probabilities change depending on the evidence that is provided.

Guide

If there are any learn-able parameters, we need another stochastic function named guide to help learn these parameters. Inference algorithms in pyro, such as stochastic variational inference, use the guide functions as approximate posterior distributions. Guide functions must satisfy two criteria to be valid approximations of the model. One, all the unobserved sample statements that appear in the model must appear in the guide. Second, the guide has the same signature as that of the model, i.e. it takes the same arguments. Although, for this mock example, there aren't any learn-able parameters and hence no guide function is needed but for the causal image generation model we need a guide function as we have learn-able weights in our neural network.

Condition Queries

Let's take an example. You observe a person with a university degree. What is your prediction of this person's means of travel?. To answer a query like this, we use a condition statement and condition on the evidence. The evidence here is that the person has a university degree. (E = Uni)

conditioned_model_uni_degree = pyro.condition(model, data={'E':torch.tensor(1)})

We give the value of E = Uni as tensor(1) as it is indexed in position 1.After running an inference algorithm, importance sampling in this case

You can run HMC ( Hamiltonian Monte Carlo) for more accurate posterior computation

T_posterior = Importance(conditioned_model_uni_degree, num_samples=5000).run()
T_marginal = EmpiricalMarginal(T_posterior,"T")
T_samples = [T_marginal().item() for _ in range(5000)]
T_unique, T_counts = np.unique(T_samples, return_counts=True)

plt.bar(T_unique, T_counts, align='center', alpha=0.5)
plt.xticks(T_unique, T_alias)
plt.ylabel('Posterior Probability')
plt.xlabel('T')
plt.title('P(T | E = Uni) - Importance Sampling')

When we condition the value of education as Uni, then it should be taken that there are no other possible outcomes other than someone who is educated in a university.

Intervention Queries

Let's see the difference in the intervention distribution for the same query. In intervention, the effect of the parents of an intervening node is negated/cutoff. Hence, age and sex won't affect education anymore.

intervention_model_uni_degree = pyro.do(model, data={'E':torch.tensor(1)})
T_posterior = Importance(intervention_model_uni_degree, num_samples=5000).run()
T_marginal = EmpiricalMarginal(T_posterior,"T")
T_samples = [T_marginal().item() for _ in range(5000)]
T_unique, T_counts = np.unique(T_samples, return_counts=True)

plt.bar(T_unique, T_counts/5000, align='center', alpha=0.5)
plt.xticks(T_unique, T_alias)
plt.ylabel('Posterior Probability')
plt.xlabel('T')
plt.title('P(T | do(E = Uni)) - Importance Sampling')

The intervention distribution is slightly different than conditional distribution.

For more details and examples, please refer to the following GitHub repo

Last updated