面向強化學習的狀態空間建模:RSSM的介紹和PyTorch實現
循環狀態空間模型(Recurrent State Space Models, RSSM)最初由 Danijar Hafer 等人在論文《Learning Latent Dynamics for Planning from Pixels》中提出。該模型在現代基于模型的強化學習(Model-Based Reinforcement Learning, MBRL)中發揮著關鍵作用,其主要目標是構建可靠的環境動態預測模型。通過這些學習得到的模型,智能體能夠模擬未來軌跡并進行前瞻性的行為規劃。

下面我們就來用一個實際案例來介紹RSSM。
環境配置
環境配置是實現過程中的首要步驟。我們這里用易于使用的 Gym API。為了提高實現效率,設計了多個模塊化的包裝器(wrapper),用于初始化參數并將觀察結果調整為指定格式。
InitialWrapper 的設計允許在不執行任何動作的情況下進行特定數量的觀察,同時支持在返回觀察結果之前多次重復同一動作。這種設計對于響應具有顯著延遲特性的環境特別有效。
PreprocessFrame 包裝器負責將觀察結果轉換為正確的數據類型(本文中使用 numpy 數組),并支持灰度轉換功能。
class InitialWrapper(gym.Wrapper):
def __init__(self, env: gym.Env, no_ops: int = 0, repeat: int = 1):
super(InitialWrapper, self).__init__(env)
self.repeat = repeat
self.no_ops = no_ops
self.op_counter = 0
def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
if self.op_counter < self.no_ops:
obs, reward, done, info = self.env.step(0)
self.op_counter += 1
total_reward = 0.0
done = False
for _ in range(self.repeat):
obs, reward, done, info = self.env.step(action)
total_reward += reward
if done:
break
return obs, total_reward, done, info
class PreprocessFrame(gym.ObservationWrapper):
def __init__(self, env: gym.Env, new_shape: Sequence[int] = (128, 128, 3), grayscale: bool = False):
super(PreprocessFrame, self).__init__(env)
self.shape = new_shape
self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=self.shape, dtype=np.float32)
self.grayscale = grayscale
if self.grayscale:
self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(*self.shape[:-1], 1), dtype=np.float32)
def observation(self, obs: torch.Tensor) -> torch.Tensor:
obs = obs.astype(np.uint8)
new_frame = cv.resize(obs, self.shape[:-1], interpolation=cv.INTER_AREA)
if self.grayscale:
new_frame = cv.cvtColor(new_frame, cv.COLOR_RGB2GRAY)
new_frame = np.expand_dims(new_frame, -1)
torch_frame = torch.from_numpy(new_frame).float()
torch_frame = torch_frame / 255.0
return torch_frame
def make_env(env_name: str, new_shape: Sequence[int] = (128, 128, 3), grayscale: bool = True, **kwargs):
env = gym.make(env_name, **kwargs)
env = PreprocessFrame(env, new_shape, grayscale=grayscale)
return envmake_env 函數用于創建一個具有指定配置參數的環境實例。
模型架構
RSSM 的實現依賴于多個關鍵模型組件。具體來說,需要實現以下四個核心模塊:
- 原始觀察編碼器(Encoder)
- 動態模型(Dynamics Model):通過確定性狀態 h 和隨機狀態 s 對編碼觀察的時間依賴性進行建模
- 解碼器(Decoder):將隨機狀態和確定性狀態映射回原始觀察空間
- 獎勵模型(Reward Model):將隨機狀態和確定性狀態映射到獎勵值

RSSM 模型組件結構圖。模型包含隨機狀態 s 和確定性狀態 h。
編碼器實現
編碼器采用簡單的卷積神經網絡(CNN)結構,將輸入圖像降維到一維嵌入表示。實現中使用了 BatchNorm 來提升訓練穩定性。
class EncoderCNN(nn.Module):
def __init__(self, in_channels: int, embedding_dim: int = 2048, input_shape: Tuple[int, int] = (128, 128)):
super(EncoderCNN, self).__init__()
# 定義卷積層結構
self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=3, stride=2, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
self.conv4 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
self.fc1 = nn.Linear(self._compute_conv_output((in_channels, input_shape[0], input_shape[1])), embedding_dim)
# 批標準化層
self.bn1 = nn.BatchNorm2d(32)
self.bn2 = nn.BatchNorm2d(64)
self.bn3 = nn.BatchNorm2d(128)
self.bn4 = nn.BatchNorm2d(256)
def _compute_conv_output(self, shape: Tuple[int, int, int]):
with torch.no_grad():
x = torch.randn(1, shape[0], shape[1], shape[2])
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
return x.shape[1] * x.shape[2] * x.shape[3]
def forward(self, x):
x = torch.relu(self.conv1(x))
x = self.bn1(x)
x = torch.relu(self.conv2(x))
x = self.bn2(x)
x = torch.relu(self.conv3(x))
x = self.bn3(x)
x = self.conv4(x)
x = self.bn4(x)
x = x.view(x.size(0), -1)
x = self.fc1(x)
return x解碼器實現
解碼器遵循傳統自編碼器架構設計,其功能是將編碼后的觀察結果重建回原始觀察空間。
class DecoderCNN(nn.Module):
def __init__(self, hidden_size: int, state_size: int, embedding_size: int,
use_bn: bool = True, output_shape: Tuple[int, int] = (3, 128, 128)):
super(DecoderCNN, self).__init__()
self.output_shape = output_shape
self.embedding_size = embedding_size
# 全連接層進行特征變換
self.fc1 = nn.Linear(hidden_size + state_size, embedding_size)
self.fc2 = nn.Linear(embedding_size, 256 * (output_shape[1] // 16) * (output_shape[2] // 16))
# 反卷積層進行上采樣
self.conv1 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1) # ×2
self.conv2 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1) # ×2
self.conv3 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1) # ×2
self.conv4 = nn.ConvTranspose2d(32, output_shape[0], kernel_size=3, stride=2, padding=1, output_padding=1)
# 批標準化層
self.bn1 = nn.BatchNorm2d(128)
self.bn2 = nn.BatchNorm2d(64)
self.bn3 = nn.BatchNorm2d(32)
self.use_bn = use_bn
def forward(self, h: torch.Tensor, s: torch.Tensor):
x = torch.cat([h, s], dim=-1)
x = self.fc1(x)
x = torch.relu(x)
x = self.fc2(x)
x = x.view(-1, 256, self.output_shape[1] // 16, self.output_shape[2] // 16)
if self.use_bn:
x = torch.relu(self.bn1(self.conv1(x)))
x = torch.relu(self.bn2(self.conv2(x)))
x = torch.relu(self.bn3(self.conv3(x)))
else:
x = torch.relu(self.conv1(x))
x = torch.relu(self.conv2(x))
x = torch.relu(self.conv3(x))
x = self.conv4(x)
return x獎勵模型實現
獎勵模型采用了一個三層前饋神經網絡結構,用于將隨機狀態 s 和確定性狀態 h 映射到正態分布參數,進而通過采樣獲得獎勵預測。
class RewardModel(nn.Module):
def __init__(self, hidden_dim: int, state_dim: int):
super(RewardModel, self).__init__()
self.fc1 = nn.Linear(hidden_dim + state_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, 2)
def forward(self, h: torch.Tensor, s: torch.Tensor):
x = torch.cat([h, s], dim=-1)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x動態模型的實現
動態模型是 RSSM 架構中最復雜的組件,需要同時處理先驗和后驗狀態轉移模型:
- 后驗轉移模型:在能夠訪問真實觀察的情況下使用(主要在訓練階段),用于在給定觀察和歷史狀態的條件下近似隨機狀態的后驗分布。
- 先驗轉移模型:用于近似先驗狀態分布,僅依賴于前一時刻狀態,不依賴于觀察。這在無法獲取后驗觀察的推理階段使用。
這兩個模型均通過單層前饋網絡進行參數化,輸出各自正態分布的均值和對數方差,用于狀態 s 的采樣。該實現采用了簡單的網絡結構,但可以根據需要擴展為更復雜的架構。
確定性狀態采用門控循環單元(GRU)實現。其輸入包括:
- 前一時刻的隱藏狀態
- 獨熱編碼動作
- 前一時刻隨機狀態 s(根據是否可以獲取觀察來選擇使用后驗或先驗狀態)
這些輸入信息足以讓模型了解動作歷史和系統狀態。以下是具體實現代碼:
class DynamicsModel(nn.Module):
def __init__(self, hidden_dim: int, action_dim: int, state_dim: int, embedding_dim: int, rnn_layer: int = 1):
super(DynamicsModel, self).__init__()
self.hidden_dim = hidden_dim
# 遞歸層實現,支持多層 GRU
self.rnn = nn.ModuleList([nn.GRUCell(hidden_dim, hidden_dim) for _ in range(rnn_layer)])
# 狀態動作投影層
self.project_state_action = nn.Linear(action_dim + state_dim, hidden_dim)
# 先驗網絡:輸出正態分布參數
self.prior = nn.Linear(hidden_dim, state_dim * 2)
self.project_hidden_action = nn.Linear(hidden_dim + action_dim, hidden_dim)
# 后驗網絡:輸出正態分布參數
self.posterior = nn.Linear(hidden_dim, state_dim * 2)
self.project_hidden_obs = nn.Linear(hidden_dim + embedding_dim, hidden_dim)
self.state_dim = state_dim
self.act_fn = nn.ReLU()
def forward(self, prev_hidden: torch.Tensor, prev_state: torch.Tensor, actions: torch.Tensor,
obs: torch.Tensor = None, dones: torch.Tensor = None):
"""
動態模型的前向傳播
參數:
prev_hidden: RNN的前一隱藏狀態,形狀 (batch_size, hidden_dim)
prev_state: 前一隨機狀態,形狀 (batch_size, state_dim)
actions: 獨熱編碼動作序列,形狀 (sequence_length, batch_size, action_dim)
obs: 編碼器輸出的觀察嵌入,形狀 (sequence_length, batch_size, embedding_dim)
dones: 終止狀態標志
"""
B, T, _ = actions.size() # 用于無觀察訪問時的推理
# 初始化存儲列表
hiddens_list = []
posterior_means_list = []
posterior_logvars_list = []
prior_means_list = []
prior_logvars_list = []
prior_states_list = []
posterior_states_list = []
# 存儲初始狀態
hiddens_list.append(prev_hidden.unsqueeze(1))
prior_states_list.append(prev_state.unsqueeze(1))
posterior_states_list.append(prev_state.unsqueeze(1))
# 時序展開
for t in range(T - 1):
# 提取當前時刻狀態和動作
action_t = actions[:, t, :]
obs_t = obs[:, t, :] if obs is not None else torch.zeros(B, self.embedding_dim, device=actions.device)
state_t = posterior_states_list[-1][:, 0, :] if obs is not None else prior_states_list[-1][:, 0, :]
state_t = state_t if dones is None else state_t * (1 - dones[:, t, :])
hidden_t = hiddens_list[-1][:, 0, :]
# 狀態動作組合
state_action = torch.cat([state_t, action_t], dim=-1)
state_action = self.act_fn(self.project_state_action(state_action))
# RNN 狀態更新
for i in range(len(self.rnn)):
hidden_t = self.rnn[i](state_action, hidden_t)
# 先驗分布計算
hidden_action = torch.cat([hidden_t, action_t], dim=-1)
hidden_action = self.act_fn(self.project_hidden_action(hidden_action))
prior_params = self.prior(hidden_action)
prior_mean, prior_logvar = torch.chunk(prior_params, 2, dim=-1)
# 從先驗分布采樣
prior_dist = torch.distributions.Normal(prior_mean, torch.exp(F.softplus(prior_logvar)))
prior_state_t = prior_dist.rsample()
# 后驗分布計算
if obs is None:
posterior_mean = prior_mean
posterior_logvar = prior_logvar
else:
hidden_obs = torch.cat([hidden_t, obs_t], dim=-1)
hidden_obs = self.act_fn(self.project_hidden_obs(hidden_obs))
posterior_params = self.posterior(hidden_obs)
posterior_mean, posterior_logvar = torch.chunk(posterior_params, 2, dim=-1)
# 從后驗分布采樣
posterior_dist = torch.distributions.Normal(posterior_mean, torch.exp(F.softplus(posterior_logvar)))
posterior_state_t = posterior_dist.rsample()
# 保存狀態
posterior_means_list.append(posterior_mean.unsqueeze(1))
posterior_logvars_list.append(posterior_logvar.unsqueeze(1))
prior_means_list.append(prior_mean.unsqueeze(1))
prior_logvars_list.append(prior_logvar.unsqueeze(1))
prior_states_list.append(prior_state_t.unsqueeze(1))
posterior_states_list.append(posterior_state_t.unsqueeze(1))
hiddens_list.append(hidden_t.unsqueeze(1))
# 合并時序數據
hiddens = torch.cat(hiddens_list, dim=1)
prior_states = torch.cat(prior_states_list, dim=1)
posterior_states = torch.cat(posterior_states_list, dim=1)
prior_means = torch.cat(prior_means_list, dim=1)
prior_logvars = torch.cat(prior_logvars_list, dim=1)
posterior_means = torch.cat(posterior_means_list, dim=1)
posterior_logvars = torch.cat(posterior_logvars_list, dim=1)
return hiddens, prior_states, posterior_states, prior_means, prior_logvars, posterior_means, posterior_logvars需要特別注意的是,這里的觀察輸入并非原始觀察數據,而是經過編碼器處理后的嵌入表示。這種設計能夠有效降低計算復雜度并提升模型的泛化能力。
RSSM 整體架構
將前述組件整合為完整的 RSSM 模型。其核心是 generate_rollout 方法,負責調用動態模型并生成環境動態的潛在表示序列。對于沒有歷史潛在狀態的情況(通常發生在軌跡開始時),該方法會進行必要的初始化。下面是完整的實現代碼:
class RSSM:
def __init__(self,
encoder: EncoderCNN,
decoder: DecoderCNN,
reward_model: RewardModel,
dynamics_model: nn.Module,
hidden_dim: int,
state_dim: int,
action_dim: int,
embedding_dim: int,
device: str = "mps"):
"""
循環狀態空間模型(RSSM)實現
參數:
encoder: 確定性狀態編碼器
decoder: 觀察重構解碼器
reward_model: 獎勵預測模型
dynamics_model: 狀態動態模型
hidden_dim: RNN 隱藏層維度
state_dim: 隨機狀態維度
action_dim: 動作空間維度
embedding_dim: 觀察嵌入維度
device: 計算設備
"""
super(RSSM, self).__init__()
# 模型組件初始化
self.dynamics = dynamics_model
self.encoder = encoder
self.decoder = decoder
self.reward_model = reward_model
# 維度參數存儲
self.hidden_dim = hidden_dim
self.state_dim = state_dim
self.action_dim = action_dim
self.embedding_dim = embedding_dim
# 模型遷移至指定設備
self.dynamics.to(device)
self.encoder.to(device)
self.decoder.to(device)
self.reward_model.to(device)
def generate_rollout(self, actions: torch.Tensor, hiddens: torch.Tensor = None, states: torch.Tensor = None,
obs: torch.Tensor = None, dones: torch.Tensor = None):
"""
生成狀態序列展開
參數:
actions: 動作序列
hiddens: 初始隱藏狀態(可選)
states: 初始隨機狀態(可選)
obs: 觀察序列(可選)
dones: 終止標志序列
返回:
完整的狀態展開序列
"""
# 狀態初始化
if hiddens is None:
hiddens = torch.zeros(actions.size(0), self.hidden_dim).to(actions.device)
if states is None:
states = torch.zeros(actions.size(0), self.state_dim).to(actions.device)
# 執行動態模型展開
dynamics_result = self.dynamics(hiddens, states, actions, obs, dones)
hiddens, prior_states, posterior_states, prior_means, prior_logvars, posterior_means, posterior_logvars = dynamics_result
return hiddens, prior_states, posterior_states, prior_means, prior_logvars, posterior_means, posterior_logvars
def train(self):
"""啟用訓練模式"""
self.dynamics.train()
self.encoder.train()
self.decoder.train()
self.reward_model.train()
def eval(self):
"""啟用評估模式"""
self.dynamics.eval()
self.encoder.eval()
self.decoder.eval()
self.reward_model.eval()
def encode(self, obs: torch.Tensor):
"""觀察編碼"""
return self.encoder(obs)
def decode(self, state: torch.Tensor):
"""狀態解碼為觀察"""
return self.decoder(state)
def predict_reward(self, h: torch.Tensor, s: torch.Tensor):
"""獎勵預測"""
return self.reward_model(h, s)
def parameters(self):
"""返回所有可訓練參數"""
return list(self.dynamics.parameters()) + list(self.encoder.parameters()) + \
list(self.decoder.parameters()) + list(self.reward_model.parameters())
def save(self, path: str):
"""模型狀態保存"""
torch.save({
"dynamics": self.dynamics.state_dict(),
"encoder": self.encoder.state_dict(),
"decoder": self.decoder.state_dict(),
"reward_model": self.reward_model.state_dict()
}, path)
def load(self, path: str):
"""模型狀態加載"""
checkpoint = torch.load(path)
self.dynamics.load_state_dict(checkpoint["dynamics"])
self.encoder.load_state_dict(checkpoint["encoder"])
self.decoder.load_state_dict(checkpoint["decoder"])
self.reward_model.load_state_dict(checkpoint["reward_model"])這個實現提供了一個完整的 RSSM 框架,包含了模型的訓練、評估、狀態保存和加載等基本功能。該框架可以作為基礎結構,根據具體應用場景進行擴展和優化。
訓練系統設計
RSSM 的訓練系統主要包含兩個核心組件:經驗回放緩沖區(Experience Replay Buffer)和智能體(Agent)。其中,緩沖區負責存儲歷史經驗數據用于訓練,而智能體則作為環境與 RSSM 之間的接口,實現數據收集策略。
經驗回放緩沖區實現
緩沖區采用循環隊列結構,用于存儲和管理觀察、動作、獎勵和終止狀態等數據。通過 sample 方法可以隨機采樣訓練序列。
class Buffer:
def __init__(self, buffer_size: int, obs_shape: tuple, action_shape: tuple, device: torch.device):
"""
經驗回放緩沖區初始化
參數:
buffer_size: 緩沖區容量
obs_shape: 觀察數據維度
action_shape: 動作數據維度
device: 計算設備
"""
self.buffer_size = buffer_size
self.obs_buffer = np.zeros((buffer_size, *obs_shape), dtype=np.float32)
self.action_buffer = np.zeros((buffer_size, *action_shape), dtype=np.int32)
self.reward_buffer = np.zeros((buffer_size, 1), dtype=np.float32)
self.done_buffer = np.zeros((buffer_size, 1), dtype=np.bool_)
self.device = device
self.idx = 0
def add(self, obs: torch.Tensor, action: int, reward: float, done: bool):
"""
添加單步經驗數據
"""
self.obs_buffer[self.idx] = obs
self.action_buffer[self.idx] = action
self.reward_buffer[self.idx] = reward
self.done_buffer[self.idx] = done
self.idx = (self.idx + 1) % self.buffer_size
def sample(self, batch_size: int, sequence_length: int):
"""
隨機采樣經驗序列
參數:
batch_size: 批量大小
sequence_length: 序列長度
返回:
經驗數據元組 (observations, actions, rewards, dones)
"""
# 隨機選擇序列起始位置
starting_idxs = np.random.randint(0, (self.idx % self.buffer_size) - sequence_length, (batch_size,))
# 構建完整序列索引
index_tensor = np.stack([np.arange(start, start + sequence_length) for start in starting_idxs])
# 提取數據序列
obs_sequence = self.obs_buffer[index_tensor]
action_sequence = self.action_buffer[index_tensor]
reward_sequence = self.reward_buffer[index_tensor]
done_sequence = self.done_buffer[index_tensor]
return obs_sequence, action_sequence, reward_sequence, done_sequence
def save(self, path: str):
"""保存緩沖區數據"""
np.savez(path, obs_buffer=self.obs_buffer, action_buffer=self.action_buffer,
reward_buffer=self.reward_buffer, done_buffer=self.done_buffer, idx=self.idx)
def load(self, path: str):
"""加載緩沖區數據"""
data = np.load(path)
self.obs_buffer = data["obs_buffer"]
self.action_buffer = data["action_buffer"]
self.reward_buffer = data["reward_buffer"]
self.done_buffer = data["done_buffer"]
self.idx = data["idx"]智能體設計
智能體實現了數據收集和規劃功能。當前實現采用了簡單的隨機策略進行數據收集,但該框架支持擴展更復雜的策略。
class Policy(ABC):
"""策略基類"""
@abstractmethod
def __call__(self, obs):
pass
class RandomPolicy(Policy):
"""隨機采樣策略"""
def __init__(self, env: Env):
self.env = env
def __call__(self, obs):
return self.env.action_space.sample()
class Agent:
def __init__(self, env: Env, rssm: RSSM, buffer_size: int = 100000,
collection_policy: str = "random", device="mps"):
"""
智能體初始化
參數:
env: 環境實例
rssm: RSSM模型實例
buffer_size: 經驗緩沖區大小
collection_policy: 數據收集策略類型
device: 計算設備
"""
self.env = env
# 策略選擇
match collection_policy:
case "random":
self.rollout_policy = RandomPolicy(env)
case _:
raise ValueError("Invalid rollout policy")
self.buffer = Buffer(buffer_size, env.observation_space.shape,
env.action_space.shape, device=device)
self.rssm = rssm
def data_collection_action(self, obs):
"""執行數據收集動作"""
return self.rollout_policy(obs)
def collect_data(self, num_steps: int):
"""
收集訓練數據
參數:
num_steps: 收集步數
"""
obs = self.env.reset()
done = False
iterator = tqdm(range(num_steps), desc="Data Collection")
for _ in iterator:
action = self.data_collection_action(obs)
next_obs, reward, done, _, _ = self.env.step(action)
self.buffer.add(next_obs, action, reward, done)
obs = next_obs
if done:
obs = self.env.reset()
def imagine_rollout(self, prev_hidden: torch.Tensor, prev_state: torch.Tensor,
actions: torch.Tensor):
"""
執行想象展開
參數:
prev_hidden: 前一隱藏狀態
prev_state: 前一隨機狀態
actions: 動作序列
返回:
完整的模型輸出,包括隱藏狀態、先驗狀態、后驗狀態等
"""
hiddens, prior_states, posterior_states, prior_means, prior_logvars, \
posterior_means, posterior_logvars = self.rssm.generate_rollout(
actions, prev_hidden, prev_state)
# 在想象階段使用先驗狀態預測獎勵
rewards = self.rssm.predict_reward(hiddens, prior_states)
return hiddens, prior_states, posterior_states, prior_means, \
prior_logvars, posterior_means, posterior_logvars, rewards
def plan(self, num_steps: int, prev_hidden: torch.Tensor,
prev_state: torch.Tensor, actions: torch.Tensor):
"""
執行規劃
參數:
num_steps: 規劃步數
prev_hidden: 初始隱藏狀態
prev_state: 初始隨機狀態
actions: 動作序列
返回:
規劃得到的隱藏狀態和先驗狀態序列
"""
hidden_states = []
prior_states = []
hiddens = prev_hidden
states = prev_state
for _ in range(num_steps):
hiddens, states, _, _, _, _, _, _ = self.imagine_rollout(
hiddens, states, actions)
hidden_states.append(hiddens)
prior_states.append(states)
hidden_states = torch.stack(hidden_states)
prior_states = torch.stack(prior_states)
return hidden_states, prior_states這部分實現提供了完整的數據管理和智能體交互框架。通過經驗回放緩沖區,可以高效地存儲和重用歷史數據;通過智能體的抽象策略接口,可以方便地擴展不同的數據收集策略。同時智能體還實現了基于模型的想象展開和規劃功能,為后續的決策制定提供了基礎。
訓練器實現與實驗
訓練器設計
訓練器是 RSSM 實現中的最后一個關鍵組件,負責協調模型訓練過程。訓練器接收 RSSM 模型、智能體、優化器等組件,并實現具體的訓練邏輯。
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[
logging.StreamHandler(), # 控制臺輸出
logging.FileHandler("training.log", mode="w") # 文件輸出
]
)
logger = logging.getLogger(__name__)
class Trainer:
def __init__(self, rssm: RSSM, agent: Agent, optimizer: torch.optim.Optimizer,
device: torch.device):
"""
訓練器初始化
參數:
rssm: RSSM 模型實例
agent: 智能體實例
optimizer: 優化器實例
device: 計算設備
"""
self.rssm = rssm
self.optimizer = optimizer
self.device = device
self.agent = agent
self.writer = SummaryWriter() # tensorboard 日志記錄器
def train_batch(self, batch_size: int, seq_len: int, iteration: int,
save_images: bool = False):
"""
單批次訓練
參數:
batch_size: 批量大小
seq_len: 序列長度
iteration: 當前迭代次數
save_images: 是否保存重建圖像
"""
# 采樣訓練數據
obs, actions, rewards, dones = self.agent.buffer.sample(batch_size, seq_len)
# 數據預處理
actions = torch.tensor(actions).long().to(self.device)
actions = F.one_hot(actions, self.rssm.action_dim).float()
obs = torch.tensor(obs, requires_grad=True).float().to(self.device)
rewards = torch.tensor(rewards, requires_grad=True).float().to(self.device)
dones = torch.tensor(dones).float().to(self.device)
# 觀察編碼
encoded_obs = self.rssm.encoder(obs.reshape(-1, *obs.shape[2:]).permute(0, 3, 1, 2))
encoded_obs = encoded_obs.reshape(batch_size, seq_len, -1)
# 執行 RSSM 展開
rollout = self.rssm.generate_rollout(actions, obs=encoded_obs, dones=dones)
hiddens, prior_states, posterior_states, prior_means, prior_logvars, \
posterior_means, posterior_logvars = rollout
# 重構觀察
hiddens_reshaped = hiddens.reshape(batch_size * seq_len, -1)
posterior_states_reshaped = posterior_states.reshape(batch_size * seq_len, -1)
decoded_obs = self.rssm.decoder(hiddens_reshaped, posterior_states_reshaped)
decoded_obs = decoded_obs.reshape(batch_size, seq_len, *obs.shape[-3:])
# 獎勵預測
reward_params = self.rssm.reward_model(hiddens, posterior_states)
mean, logvar = torch.chunk(reward_params, 2, dim=-1)
logvar = F.softplus(logvar)
reward_dist = Normal(mean, torch.exp(logvar))
predicted_rewards = reward_dist.rsample()
# 可視化
if save_images:
batch_idx = np.random.randint(0, batch_size)
seq_idx = np.random.randint(0, seq_len - 3)
fig = self._visualize(obs, decoded_obs, rewards, predicted_rewards,
batch_idx, seq_idx, iteration, grayscale=True)
if not os.path.exists("reconstructions"):
os.makedirs("reconstructions")
fig.savefig(f"reconstructions/iteration_{iteration}.png")
self.writer.add_figure("Reconstructions", fig, iteration)
plt.close(fig)
# 計算損失
reconstruction_loss = self._reconstruction_loss(decoded_obs, obs)
kl_loss = self._kl_loss(prior_means, F.softplus(prior_logvars),
posterior_means, F.softplus(posterior_logvars))
reward_loss = self._reward_loss(rewards, predicted_rewards)
loss = reconstruction_loss + kl_loss + reward_loss
# 反向傳播和優化
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.rssm.parameters(), 1, norm_type=2)
self.optimizer.step()
return loss.item(), reconstruction_loss.item(), kl_loss.item(), reward_loss.item()
def train(self, iterations: int, batch_size: int, seq_len: int):
"""
執行完整訓練過程
參數:
iterations: 迭代總次數
batch_size: 批量大小
seq_len: 序列長度
"""
self.rssm.train()
iterator = tqdm(range(iterations), desc="Training", total=iterations)
losses = []
infos = []
last_loss = float("inf")
for i in iterator:
# 執行單批次訓練
loss, reconstruction_loss, kl_loss, reward_loss = self.train_batch(
batch_size, seq_len, i, save_images=i % 100 == 0)
# 記錄訓練指標
self.writer.add_scalar("Loss", loss, i)
self.writer.add_scalar("Reconstruction Loss", reconstruction_loss, i)
self.writer.add_scalar("KL Loss", kl_loss, i)
self.writer.add_scalar("Reward Loss", reward_loss, i)
# 保存最佳模型
if loss < last_loss:
self.rssm.save("rssm.pth")
last_loss = loss
# 記錄詳細信息
info = {
"Loss": loss,
"Reconstruction Loss": reconstruction_loss,
"KL Loss": kl_loss,
"Reward Loss": reward_loss
}
losses.append(loss)
infos.append(info)
# 定期輸出訓練狀態
if i % 10 == 0:
logger.info("\n----------------------------")
logger.info(f"Iteration: {i}")
logger.info(f"Loss: {loss:.4f}")
logger.info(f"Running average last 20 losses: {sum(losses[-20:]) / 20: .4f}")
logger.info(f"Reconstruction Loss: {reconstruction_loss:.4f}")
logger.info(f"KL Loss: {kl_loss:.4f}")
logger.info(f"Reward Loss: {reward_loss:.4f}")
### 實驗示例
以下是一個在 CarRacing 環境中訓練 RSSM 的完整示例:
```python
# 環境初始化
env = make_env("CarRacing-v2", render_mode="rgb_array", continuous=False, grayscale=True)
# 模型參數設置
hidden_size = 1024
embedding_dim = 1024
state_dim = 512
# 模型組件實例化
encoder = EncoderCNN(in_channels=1, embedding_dim=embedding_dim)
decoder = DecoderCNN(hidden_size=hidden_size, state_size=state_dim,
embedding_size=embedding_dim, output_shape=(1,128,128))
reward_model = RewardModel(hidden_dim=hidden_size, state_dim=state_dim)
dynamics_model = DynamicsModel(hidden_dim=hidden_size, state_dim=state_dim,
action_dim=5, embedding_dim=embedding_dim)
# RSSM 模型構建
rssm = RSSM(dynamics_model=dynamics_model,
encoder=encoder,
decoder=decoder,
reward_model=reward_model,
hidden_dim=hidden_size,
state_dim=state_dim,
action_dim=5,
embedding_dim=embedding_dim)
# 訓練設置
optimizer = torch.optim.Adam(rssm.parameters(), lr=1e-3)
agent = Agent(env, rssm)
trainer = Trainer(rssm, agent, optimizer=optimizer, device="cuda")
# 數據收集和訓練
trainer.collect_data(20000) # 收集 20000 步經驗數據
trainer.save_buffer("buffer.npz") # 保存經驗緩沖區
trainer.train(10000, 32, 20) # 執行 10000 次迭代訓練總結
本文詳細介紹了基于 PyTorch 實現 RSSM 的完整過程。RSSM 的架構相比傳統的 VAE 或 RNN 更為復雜,這主要源于其混合了隨機和確定性狀態的特性。通過手動實現這一架構,我們可以深入理解其背后的理論基礎及其強大之處。RSSM 能夠遞歸地生成未來潛在狀態軌跡,這為智能體的行為規劃提供了基礎。
實現的優點在于其計算負載適中,可以在單個消費級 GPU 上進行訓練,在有充足時間的情況下甚至可以在 CPU 上運行。這一工作基于論文《Learning Latent Dynamics for Planning from Pixels》,該論文為 RSSM 類動態模型奠定了基礎。后續的研究工作如《Dream to Control: Learning Behaviors by Latent Imagination》進一步發展了這一架構。這些改進的架構將在未來的研究中深入探討,因為它們對理解 MBRL 方法提供了重要的見解。





































