import sys import numpy as np import torch from torch import nn sys.path.insert(0, sys.path[0]+"/../") from typing import ( Any, Dict, List, Optional, Sequence, Tuple, Type, Union, no_type_check, ) import torch.nn as nn from tianshou.utils.net.discrete import NoisyLinear ModuleType = Type[nn.Module] import random from collections import namedtuple, deque from itertools import count import math import torch import torch.optim as optim from transformers import AutoModel, AutoTokenizer import torch.nn.functional as F from tianshou.utils.net.common import ModuleType, Net, MLP def bert_embedding(x, max_length=512, device='cuda'): from transformers import logging logging.set_verbosity_error() model_name = 'bert-base-uncased' tokenizer = AutoTokenizer.from_pretrained(model_name) bert_model = AutoModel.from_pretrained(model_name) text = x if isinstance(text, np.ndarray): text = list(text) tokens = tokenizer(text, max_length=max_length, padding='max_length', truncation=True, return_tensors='pt') input_ids = tokens['input_ids'] attention_mask = tokens['attention_mask'] with torch.no_grad(): outputs = bert_model(input_ids, attention_mask=attention_mask) embeddings = outputs.last_hidden_state return embeddings class Net_GRU(nn.Module): def __init__(self, input_size, n_actions, hidden_dim, n_layers, dropout, bidirectional): super(Net_GRU, self).__init__() self.input_size = input_size self.hidden_dim = hidden_dim self.num_classes = n_actions self.n_layers = n_layers self.dropout = dropout self.bidirectional = bidirectional # Layers self.gru = nn.GRU(self.input_size, self.hidden_dim, self.n_layers, batch_first=True, dropout=self.dropout, bidirectional=self.bidirectional) self.final_layer = nn.Linear(self.hidden_dim*(1 + int(self.bidirectional)), self.num_classes) def forward(self, x): # Input shape: (batch_size, seq_length) batch_size, seq_length, emb_size = x.size() gru_out, hidden = self.gru(x) # Use the final state # hidden -> (num_direction, batch, hidden_size) if self.bidirectional: hidden = hidden.view(self.n_layers, 2, batch_size, self.hidden_dim) final_hidden = torch.cat((hidden[:, -1, :, :].squeeze(0), hidden[:, 0, :, :].squeeze(0)), 1) else: final_hidden = hidden.squeeze(0) # final_hidden -> (batch_size, num_classes) logits = self.final_layer(final_hidden) return logits class MyGRU(nn.Module): def __init__(self, input_size, hidden_dim, n_layers, dropout, bidirectional, output_dim): super(MyGRU, self).__init__() self.input_size = input_size self.hidden_dim = hidden_dim self.n_layers = n_layers self.dropout = dropout self.bidirectional = bidirectional # Layers self.gru = nn.GRU(self.input_size, self.hidden_dim, self.n_layers, batch_first=True, dropout=self.dropout, bidirectional=self.bidirectional) self.final_layer = nn.Linear(self.hidden_dim*(1 + int(self.bidirectional)), output_dim) def forward(self, x): batch_size, seq_length, emb_size = x.size() gru_out, hidden = self.gru(x) # Use the final state # hidden -> (num_direction, batch, hidden_size) if self.bidirectional: hidden = hidden.view(self.n_layers, 2, batch_size, self.hidden_dim) final_hidden = torch.cat((hidden[:, -1, :, :].squeeze(0), hidden[:, 0, :, :].squeeze(0)), 1) else: final_hidden = hidden.squeeze(0) # final_hidden -> (batch_size, num_classes) logits = self.final_layer(final_hidden) return logits class MyCNN(nn.Module): def __init__(self, input_dim: int, output_dim: int = 0, hidden_sizes: Sequence[int] = (), norm_layer: Optional[Union[ModuleType, Sequence[ModuleType]]] = None, activation: ModuleType = nn.ReLU, device: Optional[Union[str, int, torch.device]] = None, linear_layer: Type[nn.Linear] = nn.Linear, flatten_input: bool = True,) -> None: super().__init__() self.model = [] input_dim_temp = input_dim for h in hidden_sizes: self.model.append(nn.Conv1d(in_channels=input_dim_temp, out_channels=h, kernel_size=3, padding=1)) self.model.append(activation()) self.model.append(nn.MaxPool1d(kernel_size=2)) input_dim_temp = h self.model = nn.Sequential(*self.model) self.fc = nn.Linear(in_features=input_dim_temp, out_features=output_dim) def forward(self, x): x = self.model(x.transpose(1, 2)) x.transpose_(1, 2) x = self.fc(x) return x class Net_GRU_Bert_tianshou(Net): def __init__( self, state_shape: Union[int, Sequence[int]], action_shape: Union[int, Sequence[int]] = 0, hidden_sizes: Sequence[int] = (), norm_layer: Optional[ModuleType] = None, activation: Optional[ModuleType] = nn.ReLU, device: Union[str, int, torch.device] = "cpu", softmax: bool = False, concat: bool = False, num_atoms: int = 1, dueling_param: Optional[Tuple[Dict[str, Any], Dict[str, Any]]] = None, linear_layer: Type[nn.Linear] = nn.Linear, hidden_dim: int = 128, bidirectional: bool = True, dropout: float = 0., n_layers: int = 1, max_length: int = 512, trans_model_name: str = 'bert-base-uncased', ) -> None: nn.Module.__init__(self) self.device = device self.softmax = softmax self.num_atoms = num_atoms self.hidden_dim = hidden_dim self.bidirectional = bidirectional self.dropout = dropout self.n_layers = n_layers self.trans_model_name = trans_model_name self.max_length = max_length input_dim = int(np.prod(state_shape)) action_dim = int(np.prod(action_shape)) * num_atoms if concat: input_dim += action_dim self.use_dueling = dueling_param is not None output_dim = action_dim if not self.use_dueling and not concat else 0 self.output_dim = output_dim or hidden_dim self.model = MyGRU(768, self.hidden_dim, self.n_layers, self.dropout, self.bidirectional, self.output_dim) if self.use_dueling: # dueling DQN q_kwargs, v_kwargs = dueling_param # type: ignore q_output_dim, v_output_dim = 0, 0 if not concat: q_output_dim, v_output_dim = action_dim, num_atoms q_kwargs: Dict[str, Any] = { **q_kwargs, "input_dim": self.output_dim, "output_dim": q_output_dim, "device": self.device } v_kwargs: Dict[str, Any] = { **v_kwargs, "input_dim": self.output_dim, "output_dim": v_output_dim, "device": self.device } self.Q, self.V = MLP(**q_kwargs), MLP(**v_kwargs) self.output_dim = self.Q.output_dim self.bert_model = AutoModel.from_pretrained(self.trans_model_name).to(self.device) self.tokenizer = AutoTokenizer.from_pretrained(trans_model_name) from transformers import logging logging.set_verbosity_error() def bert_embedding(self, x, max_length=512): text = x if isinstance(text, np.ndarray): text = list(text) tokens = self.tokenizer(text, max_length=max_length, padding='max_length', truncation=True, return_tensors='pt') input_ids = tokens['input_ids'].to(self.device) attention_mask = tokens['attention_mask'].to(self.device) with torch.no_grad(): outputs = self.bert_model(input_ids, attention_mask=attention_mask) embeddings = outputs.last_hidden_state return embeddings def forward( self, obs: Union[np.ndarray, torch.Tensor], state: Any = None, info: Dict[str, Any] = {}, ) -> Tuple[torch.Tensor, Any]: """Mapping: obs -> flatten (inside MLP)-> logits.""" embedding = self.bert_embedding(obs, max_length=self.max_length) logits = self.model(embedding) bsz = logits.shape[0] if self.use_dueling: # Dueling DQN q, v = self.Q(logits), self.V(logits) if self.num_atoms > 1: q = q.view(bsz, -1, self.num_atoms) v = v.view(bsz, -1, self.num_atoms) logits = q - q.mean(dim=1, keepdim=True) + v elif self.num_atoms > 1: logits = logits.view(bsz, -1, self.num_atoms) if self.softmax: logits = torch.softmax(logits, dim=-1) return logits, state class Net_Bert_CLS_tianshou(Net): def __init__( self, state_shape: Union[int, Sequence[int]], action_shape: Union[int, Sequence[int]] = 0, hidden_sizes: Sequence[int] = (), norm_layer: Optional[ModuleType] = None, activation: Optional[ModuleType] = nn.ReLU, device: Union[str, int, torch.device] = "cpu", softmax: bool = False, concat: bool = False, num_atoms: int = 1, dueling_param: Optional[Tuple[Dict[str, Any], Dict[str, Any]]] = None, linear_layer: Type[nn.Linear] = nn.Linear, hidden_dim: int = 128, bidirectional: bool = True, dropout: float = 0., n_layers: int = 1, max_length: int = 512, trans_model_name: str = 'bert-base-uncased', ) -> None: nn.Module.__init__(self) self.device = device self.softmax = softmax self.num_atoms = num_atoms self.hidden_dim = hidden_dim self.bidirectional = bidirectional self.dropout = dropout self.n_layers = n_layers self.trans_model_name = trans_model_name self.max_length = max_length input_dim = int(np.prod(state_shape)) action_dim = int(np.prod(action_shape)) * num_atoms if concat: input_dim += action_dim self.use_dueling = dueling_param is not None output_dim = action_dim if not self.use_dueling and not concat else 0 self.output_dim = output_dim or hidden_dim self.model = MLP(768, output_dim, hidden_sizes, norm_layer, activation, device, linear_layer) if self.use_dueling: # dueling DQN q_kwargs, v_kwargs = dueling_param # type: ignore q_output_dim, v_output_dim = 0, 0 if not concat: q_output_dim, v_output_dim = action_dim, num_atoms q_kwargs: Dict[str, Any] = { **q_kwargs, "input_dim": self.output_dim, "output_dim": q_output_dim, "device": self.device } v_kwargs: Dict[str, Any] = { **v_kwargs, "input_dim": self.output_dim, "output_dim": v_output_dim, "device": self.device } self.Q, self.V = MLP(**q_kwargs), MLP(**v_kwargs) self.output_dim = self.Q.output_dim self.bert_model = AutoModel.from_pretrained(self.trans_model_name).to(self.device) self.tokenizer = AutoTokenizer.from_pretrained(trans_model_name) from transformers import logging logging.set_verbosity_error() def bert_CLS_embedding(self, x, max_length=512): text = x if isinstance(text, np.ndarray): text = list(text) tokens = self.tokenizer(text, max_length=max_length, padding='max_length', truncation=True, return_tensors='pt') input_ids = tokens['input_ids'].to(self.device) attention_mask = tokens['attention_mask'].to(self.device) with torch.no_grad(): outputs = self.bert_model(input_ids, attention_mask=attention_mask) embeddings = outputs[0][:, 0, :] return embeddings def forward( self, obs: Union[np.ndarray, torch.Tensor], state: Any = None, info: Dict[str, Any] = {}, ) -> Tuple[torch.Tensor, Any]: """Mapping: obs -> flatten (inside MLP)-> logits.""" embedding = self.bert_CLS_embedding(obs, max_length=self.max_length) logits = self.model(embedding) bsz = logits.shape[0] if self.use_dueling: # Dueling DQN q, v = self.Q(logits), self.V(logits) if self.num_atoms > 1: q = q.view(bsz, -1, self.num_atoms) v = v.view(bsz, -1, self.num_atoms) logits = q - q.mean(dim=1, keepdim=True) + v elif self.num_atoms > 1: logits = logits.view(bsz, -1, self.num_atoms) if self.softmax: logits = torch.softmax(logits, dim=-1) return logits, state class Net_Bert_CNN_tianshou(Net_GRU_Bert_tianshou): def __init__( self, state_shape: Union[int, Sequence[int]], action_shape: Union[int, Sequence[int]] = 0, hidden_sizes: Sequence[int] = (), norm_layer: Optional[ModuleType] = None, activation: Optional[ModuleType] = nn.ReLU, device: Union[str, int, torch.device] = "cpu", softmax: bool = False, concat: bool = False, num_atoms: int = 1, dueling_param: Optional[Tuple[Dict[str, Any], Dict[str, Any]]] = None, linear_layer: Type[nn.Linear] = nn.Linear, hidden_dim: int = 128, bidirectional: bool = True, dropout: float = 0., n_layers: int = 1, max_length: int = 512, trans_model_name: str = 'bert-base-uncased', ) -> None: nn.Module.__init__(self) self.device = device self.softmax = softmax self.num_atoms = num_atoms self.hidden_dim = hidden_dim self.bidirectional = bidirectional self.dropout = dropout self.n_layers = n_layers self.trans_model_name = trans_model_name self.max_length = max_length input_dim = int(np.prod(state_shape)) action_dim = int(np.prod(action_shape)) * num_atoms if concat: input_dim += action_dim self.use_dueling = dueling_param is not None output_dim = action_dim if not self.use_dueling and not concat else 0 self.output_dim = output_dim or hidden_dim self.model = MyCNN(768, output_dim, hidden_sizes, norm_layer, activation, device, linear_layer, flatten_input=False) if self.use_dueling: # dueling DQN q_kwargs, v_kwargs = dueling_param # type: ignore q_output_dim, v_output_dim = 0, 0 if not concat: q_output_dim, v_output_dim = action_dim, num_atoms q_kwargs: Dict[str, Any] = { **q_kwargs, "input_dim": self.output_dim, "output_dim": q_output_dim, "device": self.device } v_kwargs: Dict[str, Any] = { **v_kwargs, "input_dim": self.output_dim, "output_dim": v_output_dim, "device": self.device } self.Q, self.V = MLP(**q_kwargs), MLP(**v_kwargs) self.output_dim = self.Q.output_dim self.bert_model = AutoModel.from_pretrained(self.trans_model_name).to(self.device) self.tokenizer = AutoTokenizer.from_pretrained(trans_model_name) from transformers import logging logging.set_verbosity_error() class DQN_GRU(nn.Module): """Reference: Human-level control through deep reinforcement learning. """ def __init__( self, state_shape: Union[int, Sequence[int]], action_shape: Sequence[int], device: Union[str, int, torch.device] = "cpu", features_only: bool = False, output_dim: Optional[int] = None, hidden_dim: int = 128, n_layers: int = 1, dropout: float = 0., bidirectional: bool = True, trans_model_name: str = 'bert-base-uncased', max_length: int = 512, ) -> None: super().__init__() self.device = device self.max_length = max_length action_dim = int(np.prod(action_shape)) self.net = MyGRU(768, hidden_dim, n_layers, dropout, bidirectional, hidden_dim) if not features_only: self.net = MyGRU(768, hidden_dim, n_layers, dropout, bidirectional, action_dim) self.output_dim = action_dim elif output_dim is not None: self.net = MyGRU(768, hidden_dim, n_layers, dropout, bidirectional, output_dim) self.output_dim = output_dim else: self.net = MyGRU(768, hidden_dim, n_layers, dropout, bidirectional, hidden_dim) self.output_dim = hidden_dim self.trans_model_name = trans_model_name self.bert_model = AutoModel.from_pretrained(self.trans_model_name).to(self.device) self.tokenizer = AutoTokenizer.from_pretrained(trans_model_name) from transformers import logging logging.set_verbosity_error() def bert_embedding(self, x, max_length=512): text = x if isinstance(text, np.ndarray): text = list(text) tokens = self.tokenizer(text, max_length=max_length, padding='max_length', truncation=True, return_tensors='pt') input_ids = tokens['input_ids'].to(self.device) attention_mask = tokens['attention_mask'].to(self.device) with torch.no_grad(): outputs = self.bert_model(input_ids, attention_mask=attention_mask) embeddings = outputs.last_hidden_state return embeddings def forward( self, obs: Union[np.ndarray, torch.Tensor], state: Optional[Any] = None, info: Dict[str, Any] = {}, ) -> Tuple[torch.Tensor, Any]: r"""Mapping: s -> Q(s, \*).""" embedding = self.bert_embedding(obs, max_length=self.max_length) return self.net(embedding), state class Rainbow_GRU(DQN_GRU): """Reference: Rainbow: Combining Improvements in Deep Reinforcement Learning. """ def __init__( self, state_shape: Union[int, Sequence[int]], action_shape: Sequence[int], num_atoms: int = 51, noisy_std: float = 0.5, device: Union[str, int, torch.device] = "cpu", is_dueling: bool = True, is_noisy: bool = True, output_dim: Optional[int] = None, hidden_dim: int = 128, n_layers: int = 1, dropout: float = 0., bidirectional: bool = True, trans_model_name: str = 'bert-base-uncased', max_length: int = 512, ) -> None: super().__init__(state_shape, action_shape, device, features_only=True, output_dim=output_dim, hidden_dim=hidden_dim, n_layers=n_layers, dropout=dropout, bidirectional=bidirectional, trans_model_name=trans_model_name) self.action_num = np.prod(action_shape) self.num_atoms = num_atoms def linear(x, y): if is_noisy: return NoisyLinear(x, y, noisy_std) else: return nn.Linear(x, y) self.Q = nn.Sequential( linear(self.output_dim, 512), nn.ReLU(inplace=True), linear(512, self.action_num * self.num_atoms) ) self._is_dueling = is_dueling if self._is_dueling: self.V = nn.Sequential( linear(self.output_dim, 512), nn.ReLU(inplace=True), linear(512, self.num_atoms) ) self.output_dim = self.action_num * self.num_atoms def forward( self, obs: Union[np.ndarray, torch.Tensor], state: Optional[Any] = None, info: Dict[str, Any] = {}, ) -> Tuple[torch.Tensor, Any]: r"""Mapping: x -> Z(x, \*).""" obs, state = super().forward(obs) q = self.Q(obs) q = q.view(-1, self.action_num, self.num_atoms) if self._is_dueling: v = self.V(obs) v = v.view(-1, 1, self.num_atoms) logits = q - q.mean(dim=1, keepdim=True) + v else: logits = q probs = logits.softmax(dim=2) return probs, state class Net_GRU_nn_emb_tianshou(Net): def __init__( self, action_shape: Union[int, Sequence[int]] = 0, hidden_sizes: Sequence[int] = (), norm_layer: Optional[ModuleType] = None, activation: Optional[ModuleType] = nn.ReLU, device: Union[str, int, torch.device] = "cpu", softmax: bool = False, concat: bool = False, num_atoms: int = 1, dueling_param: Optional[Tuple[Dict[str, Any], Dict[str, Any]]] = None, linear_layer: Type[nn.Linear] = nn.Linear, hidden_dim: int = 128, bidirectional: bool = True, dropout: float = 0., n_layers: int = 1, max_length: int = 512, trans_model_name: str = 'bert-base-uncased', word_emb_dim: int = 128, ) -> None: nn.Module.__init__(self) self.device = device self.softmax = softmax self.num_atoms = num_atoms self.hidden_dim = hidden_dim self.bidirectional = bidirectional self.dropout = dropout self.n_layers = n_layers self.trans_model_name = trans_model_name self.max_length = max_length action_dim = int(np.prod(action_shape)) * num_atoms self.use_dueling = dueling_param is not None output_dim = action_dim if not self.use_dueling and not concat else 0 self.output_dim = output_dim or hidden_dim self.tokenizer = AutoTokenizer.from_pretrained(trans_model_name) from transformers import logging logging.set_verbosity_error() self.vocab_size = self.tokenizer.vocab_size self.embedding = nn.Embedding(self.vocab_size, word_emb_dim) self.model = MyGRU(word_emb_dim, self.hidden_dim, self.n_layers, self.dropout, self.bidirectional, self.output_dim) if self.use_dueling: # dueling DQN q_kwargs, v_kwargs = dueling_param # type: ignore q_output_dim, v_output_dim = 0, 0 if not concat: q_output_dim, v_output_dim = action_dim, num_atoms q_kwargs: Dict[str, Any] = { **q_kwargs, "input_dim": self.output_dim, "output_dim": q_output_dim, "device": self.device } v_kwargs: Dict[str, Any] = { **v_kwargs, "input_dim": self.output_dim, "output_dim": v_output_dim, "device": self.device } self.Q, self.V = MLP(**q_kwargs), MLP(**v_kwargs) self.output_dim = self.Q.output_dim def forward( self, obs: Union[np.ndarray, torch.Tensor], state: Any = None, info: Dict[str, Any] = {}, ) -> Tuple[torch.Tensor, Any]: """Mapping: obs -> flatten (inside MLP)-> logits.""" if isinstance(obs, np.ndarray): text = list(obs) else: text = obs tokens = self.tokenizer(text, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt') input_ids = tokens['input_ids'].to(self.device) attention_mask = tokens['attention_mask'].to(self.device) embedding = self.embedding(input_ids) mask = attention_mask.unsqueeze(-1).expand(embedding.size()).float() embedding = embedding * mask logits = self.model(embedding) bsz = logits.shape[0] if self.use_dueling: # Dueling DQN q, v = self.Q(logits), self.V(logits) if self.num_atoms > 1: q = q.view(bsz, -1, self.num_atoms) v = v.view(bsz, -1, self.num_atoms) logits = q - q.mean(dim=1, keepdim=True) + v elif self.num_atoms > 1: logits = logits.view(bsz, -1, self.num_atoms) if self.softmax: logits = torch.softmax(logits, dim=-1) return logits, state