Representation learning constructs low-dimensional representations to summarize essential features of high-dimensional data like images and text. Ideally, such a representation should capture non-spurious features of the data in an efficient way. It shall also be disentangled so that we can freely manipulate each of its dimensions. However, these desiderata are often intuitively defined and challenging to quantify or enforce.
In this work, we take on a causal perspective of representation learning. We show how desiderata of representation learning can be formalized using counterfactual notions, which then enables algorithms that target efficient, non-spurious, and disentangled representations of data. We discuss the theoretical underpinnings of the algorithm and illustrate its empirical performance in both supervised and unsupervised representation learning.
This is joint work with Michael Jordan.
[*] https://arxiv.org/abs/2109.03795
|