A Cleverer Trick on top of the Reparametrization Trick
TL;DR - Implicit differentiation can lead to an efficient computation of the gradient of reparametrized samples.
The famous reparametrization trick has been employed in estimating the gradients of samples from probability distributions by replacing an equivalent estimator that is deterministic and a differential transformation of a simple distribution.
The paper expounds the requirements of probability distributions on which the reparametrization trick can be used. For the reparametrization trick, the probability distribution, whose sample gradients are required, must satisfy at least one of the following conditions -
- Has location-scale parametrization
- Has a tractable inverse (cumulative distribution functions) CDF
- Can be expressed as a deterministic differential transformation of other distributions satisfying the above two conditions.
Revisiting the Reparametrization Trick
The reparametrization trick is used mainly in estimating the gradient of an expectation of a differentiable function with respect to the parameters of the distribution $q(x; \theta)$ such that $x \sim q(x; \theta)$. In other words, the trick can be used to compute .
Here, is come complex distribution that satisfies at least one of the conditions mentioned above. Note that the main impediment to computing the gradient of the above expression directly is the non-differentiable step of sampling $x$ from .
Therefore, the trick is to rewrite the argument of the function $f$ as such that it is independent of the parameters of the distribution. In other words, since the sampling procedure is not differentiable, make the sampling procedure independent of the parameters so that the gradient for the sampling is not required. By re-writing the argument, the parameters get transferred to the function $f$ through . When written as , the here is independent of the parameters of the distribution and hence no gradient of with respect to is required.
In general, if a sample can be written as a deterministic differentiable expression where is a sample that is independent of the parameters
The gradient of the above expression can thus, be computed (using chain-rule) as
Now, if the distribution has a location and scale parameters (like the Gaussian distribution) , as then can be a simple translation and scaling of the form
If the distribution $ has a tractable inverse CDF , then can be written as
It is also possible to use both of the above transformations in tandem, justifying the conditions presented above.
However, distributions like Gamma, Beta, Dirichlet distributions or even mixture distributions do not satisfy the above conditions and thus, the reparametrization trick cannot be used. Other techniques, addressing this limitation include approximating the intractable inverse CDF or using score function (gradient of the log likelihood). However, these produce gradients with relatively large variance. Large variance in such estimates affect the convergence of the training algorithm, and therefore, further variance-reduction techniques (like that of control-variates) are required. Often, the variance reduction techniques are problem-specific and cannot be used for a wide range of models.
Implicit Reparametrization
This paper proposes a clever technique for producing low-variance gradients using the reparametrization trick, that is applicable over a large range of probability distributions. Firstly, the difficulty arises from computing the gradient of the expression . For distributions like Gamma, the expression usually follows their inverse CDF which is intractable. Therefore, computing the gradient becomes a huge problem. The task now is to find an efficient way to compute the gradient of the expression even for intractable .
The key insight here is that the parameter-independent sample $x’$ can be written as
Now, we can apply implicit differentiation technique to the above expression as follows-
(Note that represents total gradient and represents gradient with respect to $\theta$.) Therefore, through implicit differentiation, it is possible to find the gradient of the reparametrized samples $x$.
Now, observe that is simply . Since implicit differentiation yields the same result as that of the usual differentiation, the overall results for easier distributions like Gaussian are identical to the usual procedure. Furthermore, note that the above expression is only in terms of which is essentially the CDF of complicated distributions like Gamma distribution. In such cases, numeric differentiation can be used to find the gradients.
In conclusion, using implicit differentiation, a generic method for finding the gradient of the reparametrized expression can be determined. In cases where the CDF is intractable, the gradient can be directly found using numeric differentiation, as opposed to inverting the CDF and then computing the gradient in the usual reparametrization trick.
References
[1] Figurnov, Michael, Shakir Mohamed, and Andriy Mnih. “Implicit Reparameterization Gradients.” arXiv preprint arXiv:1805.08498 (2018).