The team led by Professor Xiaokang Yang have proposed the first continual predictive learning framework

In recent years, deep predictive learning has been applied for many long-term decision-making tasks including industrial manufacturing and auto pilot. To deal with more realistic, non-stationary physical environments, the team led by the Prof. Yang improves existing continual learning method and propose a new continual predictive learning framework. Directed by Professor Xiaokang Yang and Assistant Professor Yunbo Wang, the work “Continual Predictive Learning from Videos” has been accepted by the CVPR 2022 conference and selected as oral presentation (top 5%).


Predictive learning is an unsupervised learning technique to build a world model of the environment by learning the consequences from historical observations, sequences of actions, and corresponding future observation frames. The standard predictive learning setup is assumed to operate the model in a stationary environment with relatively fixed physical dynamics. However, the assumption of stationarity does not always hold in more realistic scenarios, such as in the settings of continual learning (CL), where the model is learned through tasks that arrive sequentially. For example, in robotics (see Fig. 1), world models often serve as the representation learners of modelbased control systems, while the agent may be subjected to non-stationary environments in different training periods. Under these circumstances, it is not practical to maintain a single model for each environment or each task, nor is it practical to collect data from all environments at all times. A primary finding of this paper is that most existing predictive networks cannot perform well when trained in non-stationary environments, suffering from a phenomenon known as catastrophic forgetting.


图片 1.png

Figure 1. The new problem of continual predictive learning and the general framework of our approach at test time


We formalize this problem setup as continual predictive learning, in which the world model is trained in timevarying environments (i.e., “tasks” in the context of continual learning) with non-stationary physical dynamics. The model is expected to handle both newer tasks and older ones after the entire training phase. To this end, we propose a novel continual predictive learning (CPL) approach, which composes three main components: the Mixture World Model, the Predictive Experience Replay Scheme, and the Predictive Experience Replay Strategy. The overall architecture is shown in Figure 2.


图片 2.png

Figure 2. The overall network architecture of CPL


• Mixture world model: A new recurrent network that captures multi-modal visual dynamics in task-specific latent subspaces. Unlike existing world models [9,18], the learned priors are in forms of mixture-of-Gaussians to overcome dynamics shift.

• Predictive experience replay: A new rehearsal-based training scheme that combats the forgetting within the world model and is efficient in memory usage.

• Non-parametric task inference: Instead of using any parametric task inference model that may introduce extra forgetting issues, we use a trial-and-error strategy over the task label set to determine the present task. Before making predictions, we then perform several steps of self-supervised test-time adaptation to recall the pre-learned knowledge of the inferred task.


We quantitatively and qualitatively evaluate CPL on the following two real-world datasets: Robonet and KTH. We adopt SSIM and PSNR from previous literature to evaluate the prediction results. We run the continual learning procedure 10 times and report the mean results and standard deviations in the two metrics.

Results on the KTH dataset is shown in Figure 3. In this dataset, the learning sequence is set as (boxing -> handclapping -> handwaving -> walking -> jogging-> running). When training after the last task “running”, we test the model on the first task “boxing” and show results in Figure 3. Compared with other methods, the proposed model “CPL-full” can generate clear video sequences which are consistent with the ground truth.


图片 3.gif

Figure 3. Results on the KTH dataset. (The last one is our method)


To test whether our method can be applied to the robotic control system, we collect videos from the meta world environment and test our method, results are shown in Figure 4. The learning sequence in this dataset is set as (hammer -> assembly -> sweep). When training after the last task “sweep”, we test the model on all tasks and show results in Figure 4. Compared with other methods, the baseline model suffers from the catastrophic forgetting problem. On the other hand, the proposed model “CPL-full” can generate clear video sequences which are consistent with the ground truth.


图片 4.gif

Figure 4. Results on the meta world dataset. (The last column is our method)


In this paper, we explored a new research problem of continual predictive learning, which is meaningful in realistic application scenarios such as vision-based robot control. It is also a challenging problem mainly because of the coexistence of the covariate, dynamics, and target shift. We thus proposed the approach named CPL to overcome the forgetting of multi-modal distributions in inputs, dynamics, and targets.


Paper link:https://arxiv.org/abs/2204.05624

Project link:https://github.com/jc043/CPL


[ 2022-05-07 ]