Commit e296f393 authored by Annabi Louis's avatar Annabi Louis
Browse files

Update model_training.ipynb

parent 83e7d7f4
...@@ -20,7 +20,6 @@ ...@@ -20,7 +20,6 @@
"# You should have received a copy of the GNU General Public License\n", "# You should have received a copy of the GNU General Public License\n",
"# along with this program. If not, see <https://www.gnu.org/licenses/>.\n", "# along with this program. If not, see <https://www.gnu.org/licenses/>.\n",
"\n", "\n",
"test\n",
"\n", "\n",
"import numpy as np\n", "import numpy as np\n",
"from tqdm import tqdm_notebook as tqdm\n", "from tqdm import tqdm_notebook as tqdm\n",
......
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# Copyright (C) 2021 Louis Annabi # Copyright (C) 2021 Louis Annabi
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by # it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or # the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version. # (at your option) any later version.
# #
# This program is distributed in the hope that it will be useful, # This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of # but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details. # GNU General Public License for more details.
# #
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
test
import numpy as np import numpy as np
from tqdm import tqdm_notebook as tqdm from tqdm import tqdm_notebook as tqdm
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
import torch import torch
from torch import nn from torch import nn
import pickle as pk import pickle as pk
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## 1. The RNN model ## 1. The RNN model
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
class RNN(nn.Module): class RNN(nn.Module):
def __init__(self, states_dim, causes_dim, output_dim, factor_dim, tau): def __init__(self, states_dim, causes_dim, output_dim, factor_dim, tau):
super(RNN, self).__init__() super(RNN, self).__init__()
self.states_dim = states_dim self.states_dim = states_dim
self.causes_dim = causes_dim self.causes_dim = causes_dim
self.output_dim = output_dim self.output_dim = output_dim
self.factor_dim = factor_dim self.factor_dim = factor_dim
# Time constant of the RNN # Time constant of the RNN
self.tau = tau self.tau = tau
# Output weights initialization # Output weights initialization
self.w_o = torch.randn(self.states_dim, self.output_dim) * 5 / self.states_dim self.w_o = torch.randn(self.states_dim, self.output_dim) * 5 / self.states_dim
# Recurrent weights factorization # Recurrent weights factorization
self.w_pd = torch.randn(self.states_dim, self.factor_dim) * 0.2 / np.sqrt(self.factor_dim) self.w_pd = torch.randn(self.states_dim, self.factor_dim) * 0.2 / np.sqrt(self.factor_dim)
self.w_fd = self.w_pd.clone() self.w_fd = self.w_pd.clone()
self.w_cd = torch.nn.Softmax(1)(0.5*torch.randn(self.causes_dim, self.factor_dim))*self.factor_dim self.w_cd = torch.nn.Softmax(1)(0.5*torch.randn(self.causes_dim, self.factor_dim))*self.factor_dim
self.w_pd += torch.randn_like(self.w_pd) / np.sqrt(self.factor_dim) self.w_pd += torch.randn_like(self.w_pd) / np.sqrt(self.factor_dim)
self.w_fd += torch.randn_like(self.w_fd) / np.sqrt(self.factor_dim) self.w_fd += torch.randn_like(self.w_fd) / np.sqrt(self.factor_dim)
# Predictions, states and errors are temporarily stored for batch learning # Predictions, states and errors are temporarily stored for batch learning
# Learning can be performed online, but computations are slower # Learning can be performed online, but computations are slower
self.x_pred = None self.x_pred = None
self.error = None self.error = None
self.h_prior = None self.h_prior = None
self.h_post = None self.h_post = None
self.s = None self.s = None
def forward(self, x, c_init, h_init=0, lr_c=0.2, lr_h=0.2): def forward(self, x, c_init, h_init=0, lr_c=0.2, lr_h=0.2):
""" """
Pass through the network : forward (prediction) and backward (inference) passes are Pass through the network : forward (prediction) and backward (inference) passes are
performed at the same time. Online learning could be performed here, but to improve performed at the same time. Online learning could be performed here, but to improve
computations speed, we use the seq_len as a batch dimension in a separate function. computations speed, we use the seq_len as a batch dimension in a separate function.
Parameters : Parameters :
- x : target sequences, Tensor of shape (seq_len, batch_size, output_dim) - x : target sequences, Tensor of shape (seq_len, batch_size, output_dim)
- c_init : causes of the sequences, Tensor of shape (batch_size, causes_dim) - c_init : causes of the sequences, Tensor of shape (batch_size, causes_dim)
- h_init : states of the sequences, Tensor of shape (batch_size, states_dim) - h_init : states of the sequences, Tensor of shape (batch_size, states_dim)
- lr_c : learning rate associated with the hidden causes, double - lr_c : learning rate associated with the hidden causes, double
- le_h : learning rate associated with the hidden state, double - le_h : learning rate associated with the hidden state, double
""" """
seq_len, batch_size, _ = x.shape seq_len, batch_size, _ = x.shape
# Temporary storing of the predictions, states and errors # Temporary storing of the predictions, states and errors
x_pred = torch.zeros_like(x) x_pred = torch.zeros_like(x)
h_prior = torch.zeros(seq_len, batch_size, self.states_dim) h_prior = torch.zeros(seq_len, batch_size, self.states_dim)
h_post = torch.zeros(seq_len, batch_size, self.states_dim) h_post = torch.zeros(seq_len, batch_size, self.states_dim)
c = torch.zeros(seq_len+1, batch_size, self.causes_dim) c = torch.zeros(seq_len+1, batch_size, self.causes_dim)
error_h = torch.zeros(seq_len, batch_size, self.states_dim) error_h = torch.zeros(seq_len, batch_size, self.states_dim)
error = torch.zeros_like(x) error = torch.zeros_like(x)
# Initial hidden state and hidden causes # Initial hidden state and hidden causes
c[0] = c_init c[0] = c_init
old_h_post = h_init old_h_post = h_init
for t in range(seq_len): for t in range(seq_len):
# Top-down pass # Top-down pass
# Compute h_prior according to past h_post and c # Compute h_prior according to past h_post and c
h_prior[t] = (1-1/self.tau) * old_h_post + (1/self.tau) * torch.mm( h_prior[t] = (1-1/self.tau) * old_h_post + (1/self.tau) * torch.mm(
torch.mm( torch.mm(
torch.tanh(old_h_post), torch.tanh(old_h_post),
self.w_pd self.w_pd
) * torch.mm( ) * torch.mm(
c[t], c[t],
self.w_cd self.w_cd
), ),
self.w_fd.T self.w_fd.T
) )
# Compute x_pred according to h_prior # Compute x_pred according to h_prior
x_pred[t] = torch.mm(torch.tanh(h_prior[t]), self.w_o) x_pred[t] = torch.mm(torch.tanh(h_prior[t]), self.w_o)
# Bottom-up pass # Bottom-up pass
# Compute the error on the sensory level # Compute the error on the sensory level
error[t] = x_pred[t] - x[t] error[t] = x_pred[t] - x[t]
# Infer h_post according to h_prior and the error on the sensory level # Infer h_post according to h_prior and the error on the sensory level
h_post[t] = h_prior[t] - (1-torch.tanh(h_prior[t])**2)*lr_h*torch.mm(error[t], self.w_o.T) h_post[t] = h_prior[t] - (1-torch.tanh(h_prior[t])**2)*lr_h*torch.mm(error[t], self.w_o.T)
# Compute the error on the hidden state level # Compute the error on the hidden state level
error_h[t] = h_prior[t] - h_post[t] error_h[t] = h_prior[t] - h_post[t]
# Infer c according to its past value and the error on the hidden state level # Infer c according to its past value and the error on the hidden state level
c[t+1] = c[t] - lr_c*torch.mm( c[t+1] = c[t] - lr_c*torch.mm(
torch.mm( torch.mm(
torch.tanh(old_h_post), torch.tanh(old_h_post),
self.w_pd self.w_pd
)* torch.mm( )* torch.mm(
error_h[t], error_h[t],
self.w_fd self.w_fd
), ),
self.w_cd.T self.w_cd.T
) )
old_h_post = h_post[t] old_h_post = h_post[t]
self.x_pred = x_pred self.x_pred = x_pred
self.error = error self.error = error
self.error_h = error_h self.error_h = error_h
self.h_prior = h_prior self.h_prior = h_prior
self.h_post = h_post self.h_post = h_post
self.c = c self.c = c
def learn(self, lr_o, lr_r): def learn(self, lr_o, lr_r):
""" """
Performs learning of the RNN weights. For computational efficieny, sequence length and Performs learning of the RNN weights. For computational efficieny, sequence length and
batch size are merged into a single batch dimension in the following computations batch size are merged into a single batch dimension in the following computations
Parameters : Parameters :
- lr_o : Learning rate for the output weights - lr_o : Learning rate for the output weights
- lr_r : Learning rate for the recurrent weights - lr_r : Learning rate for the recurrent weights
""" """
seq_len, batch_size, _ = self.x_pred.shape seq_len, batch_size, _ = self.x_pred.shape
# Output weights # Output weights
grad_o = lr_o * torch.mean( grad_o = lr_o * torch.mean(
torch.bmm( torch.bmm(
torch.tanh(self.h_prior.reshape(seq_len * batch_size, self.states_dim, 1)), torch.tanh(self.h_prior.reshape(seq_len * batch_size, self.states_dim, 1)),
self.error.reshape(seq_len * batch_size, 1, self.output_dim) self.error.reshape(seq_len * batch_size, 1, self.output_dim)
), ),
axis=0 axis=0
) )
self.w_o -= grad_o self.w_o -= grad_o
nbatch = (seq_len-1)*batch_size nbatch = (seq_len-1)*batch_size
# Recurrent weights # Recurrent weights
grad_pd = lr_r * torch.mean( grad_pd = lr_r * torch.mean(
torch.bmm( torch.bmm(
torch.tanh(self.h_post[:-1]).reshape(nbatch, self.states_dim, 1), torch.tanh(self.h_post[:-1]).reshape(nbatch, self.states_dim, 1),
( (
torch.mm( torch.mm(
self.error_h[1:].reshape(nbatch, self.states_dim), self.error_h[1:].reshape(nbatch, self.states_dim),
self.w_fd self.w_fd
) * \ ) * \
torch.mm( torch.mm(
self.c[1:-1].reshape(nbatch, self.causes_dim), self.c[1:-1].reshape(nbatch, self.causes_dim),
self.w_cd self.w_cd
) )
).reshape(nbatch, 1, self.factor_dim) ).reshape(nbatch, 1, self.factor_dim)
), ),
axis=0 axis=0
) )
self.w_pd -= grad_pd self.w_pd -= grad_pd
grad_cd = 10 * lr_r * torch.mean( grad_cd = 10 * lr_r * torch.mean(
torch.bmm( torch.bmm(
self.c[1:-1].reshape(nbatch, self.causes_dim, 1), self.c[1:-1].reshape(nbatch, self.causes_dim, 1),
( (
torch.mm( torch.mm(
self.error_h[1:].reshape(nbatch, self.states_dim), self.error_h[1:].reshape(nbatch, self.states_dim),
self.w_fd self.w_fd
) * \ ) * \
torch.mm( torch.mm(
torch.tanh(self.h_post[:-1]).reshape(nbatch, self.states_dim), torch.tanh(self.h_post[:-1]).reshape(nbatch, self.states_dim),
self.w_pd self.w_pd
) )
).reshape(nbatch, 1, self.factor_dim) ).reshape(nbatch, 1, self.factor_dim)
), ),
axis=0 axis=0
) )
self.w_cd -= grad_cd self.w_cd -= grad_cd
grad_fd = lr_r * torch.mean( grad_fd = lr_r * torch.mean(
torch.bmm( torch.bmm(
torch.tanh(self.error_h[1:]).reshape(nbatch, self.states_dim, 1), torch.tanh(self.error_h[1:]).reshape(nbatch, self.states_dim, 1),
( (
torch.mm( torch.mm(
torch.tanh(self.h_post[:-1]).reshape(nbatch, self.states_dim), torch.tanh(self.h_post[:-1]).reshape(nbatch, self.states_dim),
self.w_pd self.w_pd
) * \ ) * \
torch.mm( torch.mm(
self.c[1:-1].reshape(nbatch, self.causes_dim), self.c[1:-1].reshape(nbatch, self.causes_dim),
self.w_cd self.w_cd
) )
).reshape(nbatch, 1, self.factor_dim) ).reshape(nbatch, 1, self.factor_dim)
), ),
axis=0 axis=0
) )
self.w_fd -= grad_fd self.w_fd -= grad_fd
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## 2. Load the dataset of handwritten trajectories ## 2. Load the dataset of handwritten trajectories
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
import scipy.io as sio import scipy.io as sio
# The dataset can be downloaded here : https://archive.ics.uci.edu/ml/datasets/Character+Trajectories # The dataset can be downloaded here : https://archive.ics.uci.edu/ml/datasets/Character+Trajectories
# Loading and preprocessing of the dataset # Loading and preprocessing of the dataset
trajectories = sio.loadmat('data/mixoutALL_shifted.mat')['mixout'][0] trajectories = sio.loadmat('data/mixoutALL_shifted.mat')['mixout'][0]
trajectories = [trajectory[:, np.sum(np.abs(trajectory), 0) > 1e-3] for trajectory in trajectories] trajectories = [trajectory[:, np.sum(np.abs(trajectory), 0) > 1e-3] for trajectory in trajectories]
trajectories = [np.cumsum(trajectory, axis=-1) for trajectory in trajectories] trajectories = [np.cumsum(trajectory, axis=-1) for trajectory in trajectories]
# Normalize dataset trajectory length # Normalize dataset trajectory length
traj_len = 60 traj_len = 60
normalized_trajectories = np.zeros((len(trajectories), 2, traj_len)) normalized_trajectories = np.zeros((len(trajectories), 2, traj_len))
for i, traj in enumerate(trajectories): for i, traj in enumerate(trajectories):
tlen = traj.shape[1] tlen = traj.shape[1]
for t in range(traj_len): for t in range(traj_len):
normalized_trajectories[i, :, t] = traj[:2, int(t*tlen/traj_len)] normalized_trajectories[i, :, t] = traj[:2, int(t*tlen/traj_len)]
# Rescale the trajectories # Rescale the trajectories
trajectories = normalized_trajectories/10 trajectories = normalized_trajectories/10
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# Index ranges corresponding to the three first classes (a, b, c) # Index ranges corresponding to the three first classes (a, b, c)
labels_range = np.zeros((3, 2)) labels_range = np.zeros((3, 2))
labels_range[0] = np.array([0, 97]) labels_range[0] = np.array([0, 97])
labels_range[1] = np.array([97, 170]) labels_range[1] = np.array([97, 170])
labels_range[2] = np.array([170, 225]) labels_range[2] = np.array([170, 225])
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## 3. Train the visual prediction RNN ## 3. Train the visual prediction RNN
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# Number of training iterations # Number of training iterations
iterations = 2000 iterations = 2000
# Number of trajectory classes # Number of trajectory classes
p = 3 p = 3
# Dimension of the RNN hidden state # Dimension of the RNN hidden state
states_dim = 100 states_dim = 100
batch_size = p * 20 batch_size = p * 20
# Select 20 trajectories per class for training (20 other will be used for testing) # Select 20 trajectories per class for training (20 other will be used for testing)
traj = torch.cat([ traj = torch.cat([
torch.Tensor(trajectories[int(labels_range[k][0]):int(labels_range[k][0])+20]) torch.Tensor(trajectories[int(labels_range[k][0]):int(labels_range[k][0])+20])
for k in range(p) for k in range(p)
]).transpose(1, 2).transpose(0, 1) ]).transpose(1, 2).transpose(0, 1)
# Initialize the RNN # Initialize the RNN
rnn = RNN(states_dim=states_dim, causes_dim=p, output_dim=2, factor_dim=states_dim//2, tau=7) rnn = RNN(states_dim=states_dim, causes_dim=p, output_dim=2, factor_dim=states_dim//2, tau=7)
# Initial hidden causes and hidden state of the RNN # Initial hidden causes and hidden state of the RNN
c_init = torch.eye(p) c_init = torch.eye(p)
h_init = torch.randn(1, rnn.states_dim).repeat(p, 1) h_init = torch.randn(1, rnn.states_dim).repeat(p, 1)
c_init = c_init.unsqueeze(1).repeat(1, 20, 1).reshape(batch_size, p) c_init = c_init.unsqueeze(1).repeat(1, 20, 1).reshape(batch_size, p)
h_init = h_init.unsqueeze(1).repeat(1, 20, 1).reshape(batch_size, rnn.states_dim) h_init = h_init.unsqueeze(1).repeat(1, 20, 1).reshape(batch_size, rnn.states_dim)
# Store the prediction errors throughout training # Store the prediction errors throughout training
errors = np.zeros(iterations) errors = np.zeros(iterations)
# Train the network # Train the network
for i in tqdm(range(iterations)): for i in tqdm(range(iterations)):
# Learning rates # Learning rates
lr_o = 0.1/(2**(i//1000)) lr_o = 0.1/(2**(i//1000))
lr_r = 3 lr_r = 3
# Forward (prediction and inference) pass through the RNN # Forward (prediction and inference) pass through the RNN
c = c_init.clone() c = c_init.clone()
h = h_init.clone() h = h_init.clone()
rnn.forward(traj, c, h, lr_c=0.0, lr_h=0.001) rnn.forward(traj, c, h, lr_c=0.0, lr_h=0.001)
# Learning # Learning
rnn.learn(lr_o, lr_r) rnn.learn(lr_o, lr_r)
# Store the prediction error # Store the prediction error
errors[i] = torch.mean(rnn.error**2).item() errors[i] = torch.mean(rnn.error**2).item()
``` ```
%%%% Output: display_data %%%% Output: display_data
%%%% Output: stream %%%% Output: stream
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
plt.plot(errors) plt.plot(errors)
plt.yscale('log') plt.yscale('log')
plt.show() plt.show()
``` ```
%%%% Output: display_data %%%% Output: display_data
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
rnn.forward(torch.zeros(60, batch_size, 2), c_init.clone(), h_init.clone(), lr_c=0.0, lr_h=0.0) rnn.forward(torch.zeros(60, batch_size, 2), c_init.clone(), h_init.clone(), lr_c=0.0, lr_h=0.0)
for k in range(p): for k in range(p):
plt.figure() plt.figure()
plt.plot(rnn.x_pred[:, k*20, 0], rnn.x_pred[:, k*20, 1]) plt.plot(rnn.x_pred[:, k*20, 0], rnn.x_pred[:, k*20, 1])
plt.show() plt.show()
``` ```
%%%% Output: display_data %%%% Output: display_data
%%%% Output: display_data %%%% Output: display_data
%%%% Output: display_data %%%% Output: display_data
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## 4. AIF control model ## 4. AIF control model
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
def forward_model(x): def forward_model(x):
""" """
Forward model predicting the next observation based on the current observation and action Forward model predicting the next observation based on the current observation and action
Parameters : Parameters :
- x : Tensor of shape (seq_len, batch_size, joints) angle - x : Tensor of shape (seq_len, batch_size, joints) angle
Returns : Tensor of shape (seq_len, batch_size, 2) corresponding to the trajectory in euclidean coordinates Returns : Tensor of shape (seq_len, batch_size, 2) corresponding to the trajectory in euclidean coordinates
""" """
x = x.clone() x = x.clone()
lens = [6, 4, 2] lens = [6, 4, 2]
seq_len = x.shape[0] seq_len = x.shape[0]
batch_size = x.shape[1] batch_size = x.shape[1]
joints = x.shape[2] joints = x.shape[2]
pos = torch.zeros(seq_len, batch_size, 2) - 6 pos = torch.zeros(seq_len, batch_size, 2) - 6
angles = torch.Tensor([0, np.pi/2, 0]).unsqueeze(0).unsqueeze(0).repeat(seq_len, batch_size, 1) + 0.25*np.pi*torch.tanh(x) angles = torch.Tensor([0, np.pi/2, 0]).unsqueeze(0).unsqueeze(0).repeat(seq_len, batch_size, 1) + 0.25*np.pi*torch.tanh(x)
angle = 0 angle = 0
for j in range(joints): for j in range(joints):
pos[:, :, 0] += lens[j] * torch.cos(angle + angles[:, :, j]) pos[:, :, 0] += lens[j] * torch.cos(angle + angles[:, :, j])
pos[:, :, 1] += lens[j] * torch.sin(angle + angles[:, :, j]) pos[:, :, 1] += lens[j] * torch.sin(angle + angles[:, :, j])
angle += angles[:, :, j] angle += angles[:, :, j]
return pos return pos
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
class Controller(object): class Controller(object):
""" """
Controller class connecting the two RNN generative models Controller class connecting the two RNN generative models
""" """
def __init__(self, lr, prnn, mrnn, batch_size, threshold): def __init__(self, lr, prnn, mrnn, batch_size, threshold):
self.batch_size = batch_size self.batch_size = batch_size
# Perceptual network # Perceptual network
self.prnn = prnn self.prnn = prnn
self.c_p = None self.c_p = None
self.h_p = None self.h_p = None
self.c_p_init = None self.c_p_init = None
self.h_p_init = None self.h_p_init = None
self.sensory_dim = prnn.output_dim self.sensory_dim = prnn.output_dim
# Motor network # Motor network
self.mrnn = mrnn self.mrnn = mrnn
self.c_m = None self.c_m = None
self.h_m = None self.h_m = None
self.c_m_init = None self.c_m_init = None
self.h_m_init = None self.h_m_init = None
self.motor_dim = mrnn.output_dim self.motor_dim = mrnn.output_dim
# Sensory prediction # Sensory prediction
self.mu = None self.mu = None
# Learning parameters # Learning parameters
self.lr = lr self.lr = lr
self.optimizer = None self.optimizer = None
# Threshold for intermittent control # Threshold for intermittent control
self.threshold = threshold self.threshold = threshold
def prediction_error(self, m): def prediction_error(self, m):
""" """
Computes the prediction error associated with a target mu and a value m Computes the prediction error associated with a target mu and a value m
Parameters Parameters
- m : Tensor of shape (batch_size, motor_dim) - m : Tensor of shape (batch_size, motor_dim)
Returns : scalar, squared norm of the prediction error Returns : scalar, squared norm of the prediction error
""" """
o_m = forward_model(m.unsqueeze(0))[0] o_m = forward_model(m.unsqueeze(0))[0]
return torch.mean((o_m-self.mu)**2) return torch.mean((o_m-self.mu)**2)
def step(self, lr=0.1): def step(self, lr=0.1):
""" """
Performs one step of control Performs one step of control
Parameters Parameters
- lr : double, learning rate used in the motor hidden state update - lr : double, learning rate used in the motor hidden state update
Returns Returns
- control : boolean, whether the output was controlled at this timestep - control : boolean, whether the output was controlled at this timestep
- loss : Tensor of shape batch_size, the error between the target and predicted outcome - loss : Tensor of shape batch_size, the error between the target and predicted outcome
- m_target : Tensor of shape (batch_size, 3), the motor target obtained through AIF - m_target : Tensor of shape (batch_size, 3), the motor target obtained through AIF
- m_prior : Tensor of shape (batch_size, 3), the motor output predicted at time t - m_prior : Tensor of shape (batch_size, 3), the motor output predicted at time t
- m_post : Tensor of shape (batch_size, 3), the motor output that would be predicted with - m_post : Tensor of shape (batch_size, 3), the motor output that would be predicted with
the posterior hidden state the posterior hidden state
""" """
# MRNN prediction # MRNN prediction
self.mrnn.forward( self.mrnn.forward(
torch.zeros(1, self.batch_size, self.motor_dim), torch.zeros(1, self.batch_size, self.motor_dim),
self.c_m, self.c_m,
self.h_m, self.h_m,
0, 0,
0 0
) )
m_prior = self.mrnn.x_pred[0] m_prior = self.mrnn.x_pred[0]
# Loss prediction # Loss prediction
loss = self.prediction_error(m_prior) loss = self.prediction_error(m_prior)
# Control # Control
if loss.item() > self.threshold: if loss.item() > self.threshold:
control = True control = True
# We compute the gradient on the output level # We compute the gradient on the output level
m_target = torch.nn.Parameter(m_prior.clone(), requires_grad=True) m_target = torch.nn.Parameter(m_prior.clone(), requires_grad=True)
self.optimizer = torch.optim.SGD([m_target], lr=self.lr) self.optimizer = torch.optim.SGD([m_target], lr=self.lr)
self.optimizer.zero_grad() self.optimizer.zero_grad()
loss = self.prediction_error(m_target) loss = self.prediction_error(m_target)
loss.backward() loss.backward()
self.optimizer.step() self.optimizer.step()
else: else:
control = False control = False
m_target = m_prior.clone() m_target = m_prior.clone()
# MRNN state update with the controlled value # MRNN state update with the controlled value
self.mrnn.forward( self.mrnn.forward(
m_target.reshape(1, self.batch_size, self.mrnn.output_dim), m_target.reshape(1, self.batch_size, self.mrnn.output_dim),
self.c_m, self.c_m,
self.h_m, self.h_m,
0, 0,
lr lr
) )
self.h_m = self.mrnn.h_post[-1] self.h_m = self.mrnn.h_post[-1]
# PRNN states and sensory prediction update # PRNN states and sensory prediction update
self.update_perceptual_state(forward_model(m_prior.detach().unsqueeze(0))[0]) self.update_perceptual_state(forward_model(m_prior.detach().unsqueeze(0))[0])
# Posterior motor prediction # Posterior motor prediction
m_post = torch.mm(torch.tanh(self.h_m), self.mrnn.w_o) m_post = torch.mm(torch.tanh(self.h_m), self.mrnn.w_o)
return control, loss, m_target, m_prior, m_post return control, loss, m_target, m_prior, m_post
def reset(self): def reset(self):
""" """
Resets the motor and perceptual states to initiate a new trajectory Resets the motor and perceptual states to initiate a new trajectory
""" """
self.c_p = self.c_p_init self.c_p = self.c_p_init
self.h_p = self.h_p_init self.h_p = self.h_p_init
self.c_m = self.c_m_init self.c_m = self.c_m_init
self.h_m = self.h_m_init self.h_m = self.h_m_init
def update_perceptual_state(self, o, lr_c=0, lr_h=0): def update_perceptual_state(self, o, lr_c=0, lr_h=0):
""" """
Updates the perceptual states based on the observation Updates the perceptual states based on the observation
Parameters : Parameters :
- o : Tensor of shape (batch_size, sensory_dim), the observation resulting from the motor output - o : Tensor of shape (batch_size, sensory_dim), the observation resulting from the motor output
- lr_c : double, learning rate for hidden causes of the PRNN - lr_c : double, learning rate for hidden causes of the PRNN
- lr_h : double, learning rate for hidden states of the PRNN - lr_h : double, learning rate for hidden states of the PRNN
""" """
o = o.unsqueeze(0) o = o.unsqueeze(0)
self.prnn.forward(o, self.c_p, self.h_p, lr_c=lr_c, lr_h=lr_h) self.prnn.forward(o, self.c_p, self.h_p, lr_c=lr_c, lr_h=lr_h)
self.c_p = self.prnn.c[-1] self.c_p = self.prnn.c[-1]
self.h_p = self.prnn.h_post[-1] self.h_p = self.prnn.h_post[-1]
# Update the sensory prediction # Update the sensory prediction
self.update_sensory_prediction() self.update_sensory_prediction()
def update_sensory_prediction(self): def update_sensory_prediction(self):
""" """
Updates the sensory prediction made by the perceptual network Updates the sensory prediction made by the perceptual network
""" """
self.prnn.forward(torch.zeros(1, self.batch_size, self.sensory_dim), self.c_p, self.h_p, lr_c=0, lr_h=0) self.prnn.forward(torch.zeros(1, self.batch_size, self.sensory_dim), self.c_p, self.h_p, lr_c=0, lr_h=0)
self.mu = self.prnn.x_pred[-1] self.mu = self.prnn.x_pred[-1]
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## 5. Training the motor RNN ## 5. Training the motor RNN
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# Parameters # Parameters
iterations=10000 iterations=10000
# The perception RNN trained previously # The perception RNN trained previously
prnn, c_p_init, h_p_init = rnn, c_init, h_init prnn, c_p_init, h_p_init = rnn, c_init, h_init
# Declare motor RNN # Declare motor RNN
mrnn = RNN(states_dim=states_dim, causes_dim=p, output_dim=3, factor_dim=states_dim//2, tau=7) mrnn = RNN(states_dim=states_dim, causes_dim=p, output_dim=3, factor_dim=states_dim//2, tau=7)
# Declare controller # Declare controller
controller = Controller(lr=5, prnn=prnn, mrnn=mrnn, batch_size=batch_size, threshold=0.0) controller = Controller(lr=5, prnn=prnn, mrnn=mrnn, batch_size=batch_size, threshold=0.0)
# Initialize the RNNs hidden states and causes # Initialize the RNNs hidden states and causes
c_m_init = torch.eye(p) c_m_init = torch.eye(p)
h_m_init = torch.randn(1, mrnn.states_dim).repeat(p, 1) h_m_init = torch.randn(1, mrnn.states_dim).repeat(p, 1)
c_m_init = c_m_init.unsqueeze(1).repeat(1, 20, 1).reshape(batch_size, p) c_m_init = c_m_init.unsqueeze(1).repeat(1, 20, 1).reshape(batch_size, p)
h_m_init = h_m_init.unsqueeze(1).repeat(1, 20, 1).reshape(batch_size, mrnn.states_dim) h_m_init = h_m_init.unsqueeze(1).repeat(1, 20, 1).reshape(batch_size, mrnn.states_dim)
controller.c_p_init = c_p_init controller.c_p_init = c_p_init
controller.h_p_init = h_p_init controller.h_p_init = h_p_init
controller.c_m_init = c_m_init controller.c_m_init = c_m_init
controller.h_m_init = h_m_init controller.h_m_init = h_m_init
# Store the motor RNN errors through training # Store the motor RNN errors through training
errors = np.zeros((iterations)) errors = np.zeros((iterations))
for i in tqdm(range(iterations)): for i in tqdm(range(iterations)):
# Reset the motor and perception RNNs # Reset the motor and perception RNNs
controller.reset() controller.reset()
controller.update_sensory_prediction() controller.update_sensory_prediction()
# Save the target trajectory for learning # Save the target trajectory for learning
target_motor_trajectory = torch.Tensor(traj_len, batch_size, 3) target_motor_trajectory = torch.Tensor(traj_len, batch_size, 3)
for t in range(traj_len): for t in range(traj_len):
# Controller step # Controller step
control, loss, m_target, m_prior, m_post = controller.step(lr=0.0001) control, loss, m_target, m_prior, m_post = controller.step(lr=0.0001)
# Save outputs # Save outputs
target_motor_trajectory[t] = m_target.detach() target_motor_trajectory[t] = m_target.detach()
# Learning on the trajectory # Learning on the trajectory
controller.mrnn.forward(target_motor_trajectory, controller.c_m_init, controller.h_m_init, lr_c=0., lr_h=0.0001) controller.mrnn.forward(target_motor_trajectory, controller.c_m_init, controller.h_m_init, lr_c=0., lr_h=0.0001)
errors[i] = torch.mean(torch.mean(controller.mrnn.error**2)) errors[i] = torch.mean(torch.mean(controller.mrnn.error**2))
controller.mrnn.learn(0.3, 100) controller.mrnn.learn(0.3, 100)
``` ```
%%%% Output: display_data %%%% Output: display_data
%%%% Output: stream %%%% Output: stream
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
plt.plot(errors) plt.plot(errors)
plt.yscale('log') plt.yscale('log')
plt.show() plt.show()
``` ```
%%%% Output: display_data %%%% Output: display_data
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
controller.mrnn.forward(target_motor_trajectory, controller.c_m_init, controller.h_m_init, 0., 0.) controller.mrnn.forward(target_motor_trajectory, controller.c_m_init, controller.h_m_init, 0., 0.)
visual_trajectory = forward_model(controller.mrnn.x_pred) visual_trajectory = forward_model(controller.mrnn.x_pred)
for k in range(p): for k in range(p):
plt.figure() plt.figure()
plt.plot(visual_trajectory[:, k*20, 0], visual_trajectory[:, k*20, 1]) plt.plot(visual_trajectory[:, k*20, 0], visual_trajectory[:, k*20, 1])
plt.show() plt.show()
``` ```
%%%% Output: display_data %%%% Output: display_data
%%%% Output: display_data %%%% Output: display_data
%%%% Output: display_data %%%% Output: display_data
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
``` ```
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment