Introduction
If you found yourself in a strange situation, where you want your Neural Network to do several things at once — don’t worry, you are just having a Multi-Task Learning (MTL) problem. In this article, I will discuss the challenges of MTL, make a survey on effective solutions to them, and propose minor improvements of my own to the readers.
Traditionally, the development of Multi-Task Learning was aimed to improve the generalization of multiple task predictors by jointly training them, while allowing some sort of knowledge transfer and between them (Caruana, 1997). If you, for example, train a surface normal prediction model and depth prediction model together, they will definitely share mutually-benefitial features together (Eigen et al. 2015). This motivation is clearly inspired by natural intelligence — living creatures in an remarkable way can easily learn a task by leveraging the knowledge from other tasks. A broader generalization of this idea is called Lifelong Learning, in which different tasks are not even learned simultaneously.
For engineers, there is more to it than just knowledge sharing between tasks. The reasons why we want to do Multi-Task Learning are:
To optimize multiple objectives at once. For instance, in GANs, it is shown in various tasks that often incorporating additonal loss functions can yield much better results (Isola et al. 2017; Wang et al. 2018). A regularization term can also be considered as additional objective.
To reduce the cost of running multiple models. In mobile apps, we want to perform more intelligent tasks with less hardware requirements. How can we further speed up the 5 models that are already optimized both in size and speed? Oh, yeah, we can merge all of them into a single Multi-Task model!
To improve the accuracy on each task. We hope that the tasks will share mutually-benefitial features. This is an active area of research, that is currently limited, unfortunately, to our understanding of neural representation.
The motto of this article is: simple as instant noodle, i.e. I will only describe methods that are easy to implement and effective as heck! This article serves me as a lecture note as well, so here you will find more in-depth theoretical stuffs (that normally only the full papers have) than a typical survey will do. For a more comprehensive survey with a bird-eye-view on the whole field of MTL and more focused on the mutually-benefitial sharing aspect of Multi-Task Learning, it is recommended to read Ruder’s (2017) paper.
1. Too many losses? MOO to the rescue!
The methods of Multi-Objective Optimization (MOO) can help you learn multiple objectives better (here and after we will use the terms objective and task interchangeably). In this section, I will discuss the challenges of learning multiple objectives, and describe a State-of-the-Art solution to it.
1.1. Forming the formal formulation
Consider the input space \(\mathcal{X}\) and collection of task-specific output spaces \(\{ \mathcal{Y}^t \} _ {t \in [ T ]}\), where \(T\) is the number of tasks. Without loss of generality, we consider a Multi-Task Neural Network
$$ \begin{equation} f(x, \theta) = \left( f^1(x; \theta^{sh}, \theta^1), \dots, f^T(x; \theta^{sh}, \theta^T) \right)\,, \tag{1.1.1} \label{eq:mtnn} \end{equation} $$where \( \theta^{sh} \) are network parameters shared between tasks, and \( \theta^t \) are task-specific parameters. Task-specific outputs \( f^t(x; \theta^{sh}, \theta^t) \) maps the inputs from \( \mathcal{X} \) to task-specific outputs \( \mathcal{Y}^t \). In Multi-Task Learning literature, the following summation formulation of the problem often yields:
$$ \begin{equation} \begin{split} \text{minimize} \enspace \sum _ {t=1}^T {\lambda^t \hat{\mathcal{L}}^t(\theta^{sh}, \theta^{t})} \quad\quad \text{w.r.t.} \enspace \theta^{sh}, \theta^{1}, \dots, \theta^{T} \,, \end{split} \tag{1.1.2} \label{eq:mtloss} \end{equation} $$where \(\hat{\mathcal{L}}^t(\cdot)\) is an empirical task-specific loss for \(t\)-th task defined as the average loss across the whole dataset \(\hat{\mathcal{L}} (\theta^{sh}, \theta^{t}) \triangleq \frac{1}{N} \sum_i \mathcal{L} ( f^t(x_i; \theta^{sh}, \theta^{t}), y_i^t )\), where \(y_i^t \in \mathcal{Y}^t\) is the ground truth of the \(t\)-th task that corresponds to \(i\)-th sample in the dataset of \(N\) samples.
1.2. The \(\lambda_t\) Balancing Problem
The obvious question from a first glance at \eqref{eq:mtloss} is: how to set the weight coefficient \(\lambda_t\) for \(t\)-th task? Usually, setting \(\lambda_t\) to \(1\) is not a good idea: for different tasks, the magnitude of loss functions, as well as the magnitudes of gradients, might be very different. In an unbalanced setting, the magnitude of the gradients of one task might be so large that it makes the gradients from other tasks insignificant — i.e. the model will only learn one task while ignoring the other tasks. Even the brute-force approach (e.g. grid search) may not find optimal values of \(\lambda_t\) since they pre-sets the values at the beginning of training, while optimal values may change over time.
If we allow \(\lambda_t\) to change dynamically during training, which is a desirable behaviour, additional challenges occurs. A basic justification is that in this setting, it is not possible to define global optimality for optimization objective \eqref{eq:mtloss}. Consider two sets of solutions \(\theta\) and \(\bar{\theta}\) such that
$$ \begin{equation} \mathcal{L}^{t_1} (\theta^{sh}, \theta^{t_1}) < \mathcal{L}^{t_1} (\bar{\theta}^{sh}, \bar{\theta}^{t_1}) \quad\text{and}\quad \mathcal{L}^{t_2} (\theta^{sh}, \theta^{t_2}) > \mathcal{L}^{t_2} (\bar{\theta}^{sh}, \bar{\theta}^{t_2}) \tag{1.2.1} \label{eq:mtwtf} \end{equation} $$for some tasks \( t_1 \) and \( t_2 \). In other words, solution \( \theta \) is better for task \( t_1 \) whereas \( \bar{\theta} \) is better for \( t_2 \). It is not possible to compare them without explicit pairwise importance of tasks, which is typically not available.
1.3. Instant Noodle in case of Multiple Losses is MOO
Recent works attacks this problem by presenting a heuristic, according to which the coefficients \(\lambda_t\) are chosen: Chen et al. (2017) manipulates them in such a way that the gradients are approximately normalized; Kendall et al. (2018) models the network output’s homoscedastic uncertainty with a probabilistic model. These approaches are further discussed in Appendix A.1 and Appendix A.2. However, heuristic are too unreliable — there is no guarantee that the chosen weights will be of any good. A true instant noodle approach should be reliable. That’s where the latest paper Sener and Koltun (2018) presented on NeurIPS this year comes to rescue. This paper is very theory-heavy, so I will expose it just enough to give you a glimpse of the core idea without delving too deep into the rigorous theoretical stuffs.
Instead of optimizing the summation objective \eqref{eq:mtloss}, the idea is to look at the MTL problem from the perspective of multi-objective optimization: optimizing a collection of possibly conflicting objectives. The MTL objective is then specified using a vector-valued loss \({L}\):
$$ \begin{equation} {L}(\theta^{sh}, \theta^1,\ldots,\theta^T) = \left( \hat{\mathcal{L}}^1(\theta^{sh},\theta^1), \ldots, \hat{\mathcal{L}}^T(\theta^{sh},\theta^T) \right)^\intercal \tag{1.3.1} \label{eq:vecloss} \end{equation} $$The main motivation to this formulation is the conflict \eqref{eq:mtwtf}. This vector objective will not have a strong order minimum, but we can still talk about a weaker sort of minimality — the Pareto optimality.
Definition (Pareto optimality). A solution \(\theta\) dominates a solution \(\bar{\theta}\) if \(\hat{\mathcal{L}}^t(\theta^{sh},\theta^t) \leq \hat{\mathcal{L}}^t(\bar{\theta}^{sh},\bar{\theta}^t)\) for all tasks \(t\) and \(L(\theta^{sh}, \theta^1,\ldots,\theta^T) \neq L(\bar{\theta}^{sh}, \bar{\theta}^1,\ldots,\bar{\theta}^T)\); A solution \(\,\theta^\star\) is called Pareto optimal if there exists no solution \(\,\theta\) that dominates \(\,\theta^\star\).
The multi-objective optimization can be solved to local minimality (in a Pareto sense) via Multiple Gradient Descent Algorithm (MGDA), thoroughly studied by Désidéri (2012). This algorithm leverages the Karush–Kuhn–Tucker (KKT) conditions which are neccessary for optimality.
Intuitively, the KKT conditions generalizes the notion of stationarity for Pareto-dominance formulation. It describes the situation, where the gradients of task-specific parameters \(\nabla_{\theta^T} \hat{\mathcal{L}} (\theta^t, \theta^{sh})\) are all \(0\), and we can find a convex combination where \(\nabla_{\theta^{sh}} \hat{\mathcal{L}} (\theta^t, \theta^{sh})\) cancels each other. In this case, the KKT conditions for both shared and task-specific parameters are follows:
Karush–Kuhn–Tucker (KKT) conditions
- There exists \(\lambda^1 \dots \lambda^T\) such that \(\sum _ {t=1}^T \lambda^t = 1\) and the convex combination of gradients with respect to shared parameters \(\sum _ {t=1}^T \lambda^t \nabla _ {\theta^{sh}} \hat{\mathcal{L}}^t(\theta^{sh},\theta^t) = 0\).
- For all tasks \(t\), the gradients with respect to task-specific parameters \(\nabla _ {\theta^t} \hat{\mathcal{L}} (\theta^{sh}, \theta^{t}) = 0\).
The solutions satisfying these conditions are also called a Pareto stationary point. It is worth noting that although every Pareto optimal point is Pareto stationary, the reverse may not be true. Now, we formulate the optimization problem for coefficients \(\lambda^1, \ldots, \lambda^T\) as follows:
$$ \begin{equation} \begin{split} \text{minimize} \quad & \left\| \sum_{t=1}^T {\lambda^t \nabla_{\theta^{sh}} \hat{\mathcal{L}}^t (\theta^{sh},\theta^t)} \right\| _ 2^2 \\ \text{subject to} \quad & \sum_{t=1}^T \lambda^t = 1, \enspace \lambda^t \ge 0 \end{split} \tag{1.3.2} \label{eq:lambdaopt} \end{equation} $$Denoting \(p^t = \nabla _ {\theta^{sh}} \hat{\mathcal{L}}^t (\theta^{sh},\theta^t)\), this optimization problem with respect to \(\lambda^t\) is equivalent to finding a minimum-norm point in the convex hull of the set of input points \(p^t\). This problem arises naturally in computational geometry: it is equivalent to finding the closest point within a convex hull to a given query point. Basically, \eqref{eq:lambdaopt} is a convex quadratic problem with linear constraints. If you are like me, chances are you’re also sick of the non-convex optimization problems appearing every day of your career! Having a convex problem popping out of nowhere like this is nothing short of joy. The Frank–Wolfe solver was used as a most suitable convex optimization algorithm in this case, just because you have an analytical solution in the case of 2 tasks (more in the paper by Sener and Koltun). The following theorem highlights the nice properties of this optimization problem:
Theorem (Désidéri). If \( \lambda^1, \dots, \lambda^T \) is the solution of \eqref{eq:lambdaopt}, either of the following is true:
- \( \sum_{t=1}^T \lambda^t \nabla_{\theta^{sh}} \hat{\mathcal{L}}^t (\theta^{sh},\theta^t) = 0 \) and the resulting \( \lambda^1, \ldots, \lambda^T \) satisfies the KKT conditions.
- \( \sum_{t=1}^T \lambda^t \nabla_{\theta^{sh}} \hat{\mathcal{L}}^t (\theta^{sh},\theta^t) \) is a descent direction that decreases all objectives.
The gist of the approach is clear — the resulting MTL algorithm is to apply gradient descent on the task-specific parameters \(\{ \theta^t \} _ {t=1}^T\), followed by solving \eqref{eq:lambdaopt} and applying the solution \(\sum_{t=1}^T \lambda^t \nabla_{\theta^{sh}}\) as a gradient update to shared parameter \(\theta^{sh}\). This algorithm will work for almost any neural network that you can build — the definition in \eqref{eq:mtnn} is very broad.
It is easy to notice that in this case, we need to compute \(\nabla _ {\theta^{sh}}\) for each task \(t\), which requires a backward pass over the shared parameters for each task. Hence, the resulting gradient computation would be the forward pass followed by \(T\) backward passes. This significantly increases our expected training time.
To address that, the authors (Sener and Koltun, 2018) also provided a clever upper-bound formulation of \(\eqref{eq:lambdaopt}\) for encoder-decoder architectures that doesn’t require to calculate \(\nabla _ {\theta^{sh}}\) for every task. Also, the Frank–Wolfe solver used to optimize \(\eqref{eq:lambdaopt}\) requires an efficient algorithm for the line search (a very common subroutine in convex optimization methods). These two problems involve rigorous proofs, so I will omit it here to keep the simplicity (i.e. noodleness) of this article. Advanced readers might want to read the paper (Sener and Koltun, 2018) for more information.
1.4. Remarks and Modifications
This new method outperforms Chen et al. (2017) and Kendall et al. (2018) consistently with a large margin! Heck, it even outperforms the single-task classifier in most of the benchmarks! Absolute insanity! This is by far the most tasty instant noodle in this survey!
Question: Can I use this algorithm effectively with other network architectures?
The upper-bound formulation of \eqref{eq:lambdaopt} by the authors, although it is designed for encoder-decoder architecture, can be generalized to a tree-like structures that will be described in Section 3.3. The extended proof is provided in Appendix B.1.
Question: What if the objectives \(\hat{\mathcal{L}} (\theta^{sh}, \theta^t)\) are not equally important?
This algorithm will not preserve the Pareto Optimality in case your objectives are not equal in importance, i.e. a collection of constraints \(\|\lambda^{t_1}\nabla _ {\theta^{sh}} \hat{\mathcal{L}}^{t_1}(\theta^{sh},\theta^{t_1})\| \ge \|\lambda^{t_2} \nabla _ {\theta^{sh}} \hat{\mathcal{L}}^{t_2}(\theta^{sh},\theta^{t_2})\|\) is added for tasks \(t_1\) and \(t_2\). This is a convex constraint, and a combination of it is also convex. So, we can still use the Frank–Wolfe solver here.
The difference will be in situation when \(\sum _ {t=1}^T \lambda^t \nabla _ {\theta^{sh}} \hat{\mathcal{L}}^t(\theta^{sh},\theta^t) = 0\), i.e. when the zero point is inside the convex hull of directions. The importance constraint in this case means that we should ignore the less important objective and minimize \(\eqref{eq:lambdaopt}\) with respect to the other objectives, as preserving Pareto optimality will be impossible in this case. In Section 2.3, however, I will show the case when this is actually a desired behaviour — when we combine this Instant Noodle with another Instant Noodle to create an ultimate Instant Noodle.
Question: What if my network has a tree-like structure (i.e. multiple branches and sharing at multiple locations)?
The beauty of MGDA is you can generalize it to any structure that you can think of. The only thing that you need to know is what tasks are sharing what parameters. In the illustration below, \(\theta_i\) contributes to tasks \(t_1\), \(t_2\), and \(t_k\), so the MGDA gradients “flows” from those tasks to the shared parameters \(\theta_i\).
Let’s denote \(\mathcal{O}_{\mathcal{F}}(i)\) to be the set of task indices that the weights \(\theta_i\) of the model \(\mathcal{F}\) contributes to (i.e. the automatic differenciation diagram on task \(t\) includes \(\theta_i\)). The most general form of the MGDA algorithm will look like this:
$$ \begin{equation*} \min_{\alpha^t} \left\{ \left\| \sum_{t \in \mathcal{O}_{\mathcal{F}}(i)} \alpha^t \nabla_{\theta_i} \hat{\mathcal{L}}^t (\theta) \right\|_ 2^2 \enspace\middle|\enspace \sum_{t \in \mathcal{O}_{\mathcal{F}}(i)} \alpha^t = 1, \enspace \alpha^t \ge 0 \right\} \end{equation*} $$This allows us to use MGDA for any Multi-Task Learning architecture that you can think of.
2. Forgot something? Hallucinate it!
In this section, I will describe the problem of catastrophic forgetting that occurs when the tasks you are trying to learn are very different so you don’t have ground truth labels for each tasks for every input (or, in case of Unsupervised/GANs/Reinforcement — you can’t evaluate the model for all its actions). Then, I will describe the ways to overcome it that were proposed on WACV 2018 and discuss how it can be improved in the industry setting with abundant resources.
2.1. Interference of Tasks and Forgetting Effect
Consider a Multi-Task Learning setting as in figure \((a)\), where a CNN have to learn the Action (presented in orange color) of the image, and the Caption (presented in blue) of the image. The data is incomplete — the Caption ground truth is not available for Action data, and vice-versa. Thus, the training pipeline as illustrated in figure \((b)\) will alternate between Action tasks and Caption tasks, i.e. the model on each training step will only have data from either of the tasks. Obviously, it makes the training of summation objectives \eqref{eq:mtloss} or vector objectives \eqref{eq:lambdaopt} impossible.
A naive way to get around it is to ignore the losses of other tasks while training on a sample of one task (Kokkinos, 2016; there is also a video of the talk). More specifically, on training step \(s\) where only the inputs \(x^t\) and ground truths \(y^t\) of task \(t\) is available, we will set \(\mathcal{L}^k(\cdot) = 0\) for all \(k \ne t\), i.e. zeroing out gradients of the tasks without ground truth. Kokkinos (2016) also suggests to not use a fixed batch size, but rather accumulate gradients separately for task-specific parameters \(\theta^t\) and shared parameters \(\theta^{sh}\), and do the gradient step once the number of samples exceeds certain threshold (individual for each \(\theta^{sh}\) and \(\theta^t\)).
Unfortunately, there is a well-known issue with this simple method. When we train either branch with a dataset, the knowledge of the network of the other tasks might be forgotten. It is because during training, the optimization path of the \(\theta^{sh}\) can be different for each task. Accumulating batches for each parts of the networks, as in (Kokkinos, 2016), then do the gradient updates according to Section 1.3 might do the trick, but remember that the variance of data presented to the shared part of the network is greater than all of the task-specific parts combined (this is an open problem, thoroughly described by Kokkinos, 2016), so we still need some kind of augmentation.
2.2. Instant Noodle is Your Previous Self!
On WACV 2018, a very simple approach is proposed (Kim et al. 2018). The idea is, if you don’t have ground truths for other tasks — just make sure that the model’s output on other branches is the same as previously. On each training step, where you have input samples \(x^t\) and only ground truths for \(t\)-th task \(y^t\), instead of setting \(\lambda^k(\cdot) = 0\) for \(k \ne t\) as in the naive approach above, you need to enforce the outputs to be similar to your previous outputs on tasks \(k\) using a Knowledge Distillation loss.
For example, in the setting in previous subsection, when the model is feeded with Caption data, it also tries to be similar to its previous self with respect to its outputs on Action branch, as illustrated in figure \((c)\); when the model is feeded with Action data, it also tries to be similar to its previous self with respect to the outputs on the Caption branch, as illustrated in figure \((d)\).
Knowledge Distillation is a family of techniques, first proposed by Hinton (2015), to make a model to learn from other model as well while training on a specific dataset. In case of classification, consider a Teacher Network (a pre-trained network) and a Student Network (the one to be trained) that have \( \text{Softmax}(\cdot) \) as the output layer, and outputs \( y = (y_1, \ldots, y_n) \) and \( \hat{y} = (\hat{y}_1, \ldots, \hat{y}_n) \) respectively, where \( n \) is number of classes. The Knowledge Distillation loss that applied to the Student Network for preserving activation of the Student Network is defined as follows:
$$ \begin{equation} \mathcal{L} _ {\text{distill}}(y, \hat{y}) = -\sum _ {k=1}^n {y' _ k \log \hat{y}' _ k}\,, \quad y' _ k = \frac{y_k ^ {1/T}}{ \sum_k y_k ^ {1/T}} \tag{2.2.1} \label{eq:kd} \end{equation} $$where \( T \) is called temperature — the parameter that makes \( \text{Softmax}(\cdot) \) activations softer. In a sense, the Distillation loss above is almost the same as the crossentropy loss used for classification, but is softer. This makes it ideal for our MTL setting — we want our outputs to other tasks to be similar to a learned model, but not the same. Demanding the same outputs might prevent our model to learn. One can construct a Distillation Loss for other kind of loss functions as well.
2.3. New requirement from boss: 5 nets should be 5x faster, till tomorrow!
A hell of an unreasonable requirement, but we all will get there at some low point of our life. What should we do if we already have a bunch of trained networks? How to merge them to one network?
A more general idea to Section 2.2 is to distill the knowledge from a collection of single-task networks, each already learned on task \(k\) for all tasks \( k \ne t \), while training on task \( t \), as illustrated below. This way, we can pretend that the label is there when we actually don’t have it.
One can even go a step further — to ellaborate the more aggressive knowledge transfer techniques that distills hidden representations, such as FitNets (Romero, 2015), to train the Multi-Task model faster (however, I won’t recommend more constrained distillation methods, such as Yim et al. 2017). It can be helpful when one needs to perform a Neural Architecture Search for the most efficient MTL architecture. Simple, yet effective. A true Instant Noodle!
Note that this approach can be simply combined with the approach described in Section 1.3, with a minor modification described in Section 1.4. For each pair of task (objective) and its distillation counterpart, we require that the gradient direction should be more more biased towards the real objective. Strictly speaking, we denote \( \nabla_{\theta^{sh}} \hat{\mathcal{L}}^t \) as a more important objective than \( \nabla_{\theta^{sh}} \hat{\mathcal{L}}^t_{\text{distill}} \).
3. No Instant Noodle Architecture yet :(
Unfortunately, I can’t think of any multi-task architecture that can be used everywhere, i.e. a Truly Instant Noodle. In this section, I will instead discuss the pros and cons of commonly used architectures for Multi-Task Learning (especially in Computer Vision tasks). Since they are not Instant Noodles, I will not delve deep into details in this section.
Existing architectures of MTL can be classified according to how they share parameters between tasks, as shown in the figure below (Meyerson & Miikkulainen, 2018). The common trait between them is that they all have some sort of shared hierarchies.
3.1. Shared encoder, task-specific decoders
This is the most straight-forward approach, as shown in Fig. \((3.a)\), and is the most natural architecture that one can come up with, dated back from Caruana (1997). In Multi-Task Learning literature, this approach is also referred to as Hard Parameter Sharing. Sometimes, this approach is extended to task-specific encoders (Luong et al., 2016). This is the most widely used architecture as well (sort of an okay-ish instant noodle).
Pros of this family:
- Dead simple — simple to implement, simple to train, simple to debug. Lots of tutorials are available as well: for TensorFlow (here), for Keras (here & here), for PyTorch (here).
- Well-studied — a huge body of literature has accumulated ever since Caruana (1997), both theoretically (Kendall et al. 2018; Chen et al. 2017; Sener & Koltun, 2018) and practically (Ranjan et al. 2016; Wu et el. 2015; Jaderberg et al. 2017).
- The fastest from all — it shares everything possible, so the inference time will be not much different than executing a single network.
Cons of this family:
- Not flexible — forcing all tasks to share a common encoder is dumb. Some tasks are more similar than other, so logically a depth prediction and surface normal prediction should share more parameters with each other, than with a object detection task.
- Pretending to share — as highlighted by Liu & Huang (2018), these kind of architectures just collects all the features together into a common layer, instead of learning shared parameters (weights) across different tasks.
- Fight for resources — as a consequence, the tasks often fight with each other for resources (e.g. convolution kernels) within a layer. If the tasks are closely related, it’s ok, but otherwise this architecture is very inconvenient. This makes the issue of negative transfer (i.e. one task can corrupt useful features of other tasks) more probable.
3.2. A body for each task
This family of architectures is also referred as Soft Parameter Sharing in literature, the core idea is shown in Fig. \((3.b)\) — each task has its own layer of task-specific parameters at each shared depth. They then define a mechanism to share knowledge (i.e. parameters) between tasks at each shared depth (i.e. sharing between columns).
The most instant noodley approach is Cross-Stich Networks (Mirsa et al. 2016), illustrated in Fig. \((a)\). It allows the model to determine in what way the task-specific columns should combine knowledge from other columns, by learning a linear combination of the output of previous layers. Use this if you need a noodley.
A generalization of Cross-Stitch Networks are Sluice Networks (Ruder et al. 2017). It combines elements of hard paramenter sharing, cross-stitch networks, as well as other good stuffs to create a task hierarchy, as illustrated in Fig. \((b)\). Use this if you’re feeling naughty.
Another interesting yet extremely simple column-based approach is Progressive Networks (Rusu et al. 2016), illustrated in Fig. \((c)\). This is arguably another breed — it is intended to solve a more general problem to MTL, the Learning Without Forgetting (LWF) problem. The tasks are learned gradually, one-by-one. This works best when you have learned a task, and want to learn similar tasks quickly. This is a very specific noodley.
A recent work (He et al. 2018), presented on NeurIPS 2018, allows one to merge fully-trained networks, by leveraging results from Network Pruning.
Pros of this family:
- Explicit sharing mechanism — the tasks decides for themselves what to keep and what to share at each pre-defined level, so it will have less problems like fighting for resources or pretending to share.
Cons of this family:
- Soooo sloooooow, soooo faaat — the architecture is very bulky (a whole network for each task), so the approach is impractical. Current trend in tech requires lighter and faster networks for On-Device AI.
- Huge variety, no silver bullet — there are a huge variety in this family of networks. None of them seems much supperior to the others, so choosing the right architecture for specific need might be tricky.
- Not end-to-end — this family of networks usually requires the task-specific columns (at least some of them) to be already pre-trained.
3.3. Branching at custom depth
This approach is based on the shared encoder one, discussed in Section 3.1, with a small modification — instead of having all task-specific encoders branching from the main body (the shared part) at a fixed layer, each of them now are detaching from different layers, as shown in Fig. \((3.c)\).
In my personal experience, I choose the branching place of each task experimentally — I just run a bunch of experiments over the weekends on a Huge Ass cluster with a bunch of GPUs to decide the best performing yet most compact one. Basically, a brute force, which is very inefficient.
A more promising way of finding efficient architectures is to dynamically figure out where to branch out from the main body. On CVPR 2017, an approach was proposed by Lu et al. (2017) that starts from a fully-shared architecture, then dynamically splits the layers out greedily according to task affinity heuristics (that should correlate with task similarity). Basically, it is a Neural Architecture Search (NAS) algorithm. This approach has many drawbacks as well (very hard to choose hyperparameters, the architecture may not be optimal at all, the affinity metric is questionable, etc. — just my opinion), but is still an interesting direction of research.
Pros of this family:
- Dead simple and well-studied — the theoretical and practical stuffs for architectures in Section 3.1 works here as well, so it has all pros described previously.
- Still fast AF — not as fast as the family of architectures in Section 3.1, but still faster than everything else. In this family of architecture, you still share as much as you can between tasks.
- Ideal case is ideal — different tasks tends to share bottom features and diverge on the deeper layers (He et al. 2018). If branching is done ideally, combined with ideas from the family of networks in Section 3.2, there shouldn’t be any fighting for resources or pretending to share problems as in Section 3.1.
Cons of this family:
- No one dares to do it — not everyone have a luxury of going for a full brute-force as me. Dynamic approaches based on heuristics (Lu et al. 2017) are very unreliable. If done incorrectly, this family of architectures can inherit all drawbacks of all families of MTL nets combined!!!
3.4. Beyond sharing, beyond hierarchies, beyond this world
This family, schematically illustrated in Fig. \((3.d)\), makes an observation that the tasks can share all parameters in the main body, except batch normalization scaling factors (Bilen and Vedaldi, 2017). Basically, the tasks share the whole network, and the only task-specific parameters are Instance Normalization parameters. On ICLR last year, Meyerson & Miikkulainen (2018) quickly escalated this idea a step further by allowing the weights themselves to be freely permutted. The idea of changing the orders of layers by itself is not new (Veit et al. 2016), but learning the best permutation of weights across different tasks is very creative.
Pros of this family:
- Lightweight — they share every penny that they can, so the resulting model will have almost as much parameter as one single-task network.
- Just WOW — sharing every layer, and even with permutted order, is very counter-intuitive. It makes you wonder “what are features? what is knowledge? what is life?” You can even use it to hook up some girls!
Cons of this family:
- Still slooooooow — as in Section 3.2, you still have to propagate through the whole network for each task. If you don’t intend to execute all tasks, just want to save some space, this is not a cons at all.
- Still vulnerable — this family can still be vulnerable to fights for resources or pretending to share problems as in Section 3.1.
3.5. Remarks
I just want to make a quick note that the Instant Noodles in Section 1 and Section 2 can be applied to any of the architectures above, with a limitation that the upper-bound approximation of \eqref{eq:lambdaopt} may not apply to architectures with no explicit enconder/decoder. A true Instant Noodle Architecture in the future might utilize Neural Architecture Search (NAS) for MTL might be very promising in the future, as the industry is moving towards smaller and faster models.
Okay, that’s all cool, but how should I Multi-Task?
Just use the Instant Noodles described above. More specifically, combine the technique of Multi-Gradient Descent in Section 1 together with gradients normalization described in Appendix A.1. Then, if you have problems with missing data, use the stochastic gradient update or the self-distillation idea that was analyzed in detail in Section 2. Choosing the right architecture might be tricky, so you will need to account for all the pros and cons in Section 3. Finally, subscribe and comment to this blog for more cool stuffs!
Appendix A: other noodles that ain’t the yummiest noodle
In this section, I will describe the other approaches that I had experience with, but won’t recommend them for others. On their own, they are quite good and convenient, just not the best out there (comparing to methods described above).
A.1. GradNorm: Gradients Normalization for Adaptive \(\lambda_t\) Balancing
Description. This approach (Chen et al. 2017), presented on ICML 2018, attempts to regularize the magnitude of gradients to roughly the same scale during a backward pass. The motivation is simple: the raw value of the loss components of \eqref{eq:mtloss} does not reflect how much your model “cares” about that component; e.g. an \(L_1\) loss can report arbitrarily large loss values based on loss scale; the gradient’s magnitude is what actually matters.
At each training step \(s\), the average gradient norm \(\bar{g}_{\theta}(s)\) is chosen as a common scale for gradients. The relative inverse learning rate of task \(t\), \(r_t(s)\) is used to balance the gradients — the higher the value of it, the higher the gradient magnitudes should be for task \(t\) in order to encourage the task to train more quickly. The desired gradient magnitude for each task \(t\) is therefore:
$$ \begin{equation} g^t _ \theta (s) \mapsto \bar{g} _ \theta (s) \times r_t(s)^\alpha \tag{A.1.1} \label{eq:gradnorm1} \end{equation} $$where \(\alpha\) is an additional hyperparameter. It sets the strength of the restoring force which pulls tasks back to a common training rate. If the tasks are very different leading to dramatically different learning dynamics between tasks, the higher value of \(\alpha\) should be used. At each training step \(s\), we then encourage the gradients to be closer to the desired magnitude:
$$ \begin{equation} \mathcal{L} _ {\text{grad}} (s, \lambda_1, \dots, \lambda_T) = \sum _ {t=1}^T {\left\vert g^t _ \theta (s) - \bar{g} _ \theta (s) \times r_t(s)^\alpha \right\vert} _ {1} \tag{A.1.2} \label{eq:gradnorm2} \end{equation} $$The loss \eqref{eq:gradnorm2} is then differentiated only w.r.t. \(\lambda_t\) and then updated via the standard update in backpropagation algorithm. In other words, the weight coefficients \(\lambda\) are used to manipulate the gradient’s norms and move it towards the desired target \eqref{eq:gradnorm1}.
Comments. This approach requires choosing an additional hyperparameter \(\alpha\) that you have to guess. This pisses me off because I don’t have the insane intuition on hyperparameter tuning as the Kagglers do. Furthermore, it introduces another loss function \(\mathcal{L} _ {\text{grad}}\) that regularizes the gradient magnitudes. Lol, good job bro — you’ve just introduced another hyperparameter to optimize your existing hyperparameters, and introduced another loss to the sum to simplify the bunch of losses that you have (sarcasm). Somehow it works, but not as good as the instant noodle above.
A.2. Using uncertainties of losses \(\mathcal{L}^t(\cdot)\) to balance \(\lambda_t\)
Description. On CVPR 2018, another approach was proposed by Kendall et al. (2018) that models the network output’s homoscedastic uncertainty with a probabilistic model. We will use the notion \(\eqref{eq:mtnn}\) in Section 1.1. For single-task, we model the network output uncertainty with a density function \(p\left( y \vert f(x, \theta) \right)\) (how the true answer is likely to be \(y\), given network’s response). In the case of multiple network outputs \(y^1, \dots y^T\), we obtain the following multi-task likelihood:
$$ \begin{equation} \tag{A.2.1} \label{eq:mtlikelihood} p \left( y^1, \ldots, y^T \vert f(x, \theta) \right) = p\left(y^1 \vert f^1(x, \theta)\right) \ldots p\left(y^T \vert f^T(x, \theta)\right) \to \max \end{equation} $$Instead of balancing the weights of loss functions as in \eqref{eq:mtloss}, we can now require the likelihood \eqref{eq:mtlikelihood} to be maximal, i.e. we have a maximal likelihood inference problem, when the objective is to minimize \(-\log p(y^1, \ldots, y^T \vert f(x, \theta))\) with respect to \(\theta\). The trick now is to construct such a likelihood \(p(y^t \vert f^t(x,\theta))\) for each task, so that it will contain a loss \(\mathcal{L}^t(\cdot)\) term. This way, we will be able to create a bridge between the maximum likelihood \eqref{eq:mtlikelihood} and the summation loss \eqref{eq:mtloss}. The \(\log(\cdot)\) will also convert multiplications to summation, which will basically bring the maximum likelihood to the summation form.
As an example to this dark magic approach, consider a multi-regression regression where your objective is to optimize the loss \(\mathcal{L}^t(\theta) = \| y^t - f^t(x, \theta) \|^2\) for all tasks \(t \in \{1 \ldots T\}\). The likelihood is defined artificially as a Gaussian with mean given by model’s output and deviation given by a noise factor \(\sigma\):
$$ \begin{equation} p\left(y \vert f(x, \theta)\right) = \mathcal{N}(f(x, \theta), \sigma^2) \tag{A.2.2} \label{eq:gausslike} \end{equation} $$The noise scalar \( \sigma \) is observed during training, i.e. it is a trainable parameter. In essence, it is the parameter that captures the uncertainty. So, our objective now is to maximize \(\text{\eqref{eq:mtlikelihood}}\) with respect to \( \sigma \) as well. After careful computations, our log likelihood will take the following form:
$$ \begin{equation} -\log p \left( y^1, \ldots, y^T \vert f(x, \theta) \right) \propto \underbrace{\sum_{t=1}^T {\frac{1}{2\sigma _ 1^2} \| y^t - f^t(x, \theta) \|^2}} _ {\text{the same as}\, \sum_{t=1}^T \lambda^t\mathcal{L}^t(\theta)} + \underbrace{\log \prod_{t=1}^T \sigma_t} _ {\text{regularization}} \tag{A.3.3} \label{eq:l2ll} \end{equation} $$which is the same as the summation loss \eqref{eq:mtloss}, where we assign \(\lambda_t = \frac{1}{2}\sigma_t^{-2}\), plus the regularization term that discourages \(\sigma_t\) to increase too much (effectively ignoring the data).
Comments. What kind of black magic is this? Basically, to optimize a loss function, you will need to construct a whole distribution, the logarithm of which will give you the loss function multiplied by a learnable scalar \(\sigma\), and make sure that this distribution is physically meaningful! Or, get rid of the notion of “loss function” at all and just make hypothesis about the form of \(\mathcal{L}(\cdot)\) uncertainty. This is too much pain in the ass for a lazy engineer. There is no guarantee that the density you constructed is correct either.