|
import numpy as np |
|
import torch |
|
|
|
|
|
class Agent: |
|
def __init__(self, |
|
net, |
|
action_space=None, |
|
exploration_initial_eps=None, |
|
exploration_decay=None, |
|
exploration_final_eps=None): |
|
|
|
self.net = net |
|
self.action_space = action_space |
|
self.exploration_initial_eps = exploration_initial_eps |
|
self.exploration_decay = exploration_decay |
|
self.exploration_final_eps = exploration_final_eps |
|
self.epsilon = 0. |
|
|
|
def __call__(self, state, device=torch.device('cpu')): |
|
if np.random.random() < self.epsilon: |
|
action = self.get_random_action() |
|
else: |
|
action = self.get_action(state, device) |
|
|
|
return action |
|
|
|
def get_random_action(self): |
|
action = self.action_space.sample() |
|
return action |
|
|
|
def get_action(self, state, device=torch.device('cpu')): |
|
if not isinstance(state, torch.Tensor): |
|
state = torch.tensor([state]) |
|
|
|
if device.type != 'cpu': |
|
state = state.cuda(device) |
|
|
|
q_values = self.net.eval()(state) |
|
_, action = torch.max(q_values, dim=1) |
|
return int(action.item()) |
|
|
|
def update_epsilon(self, step): |
|
self.epsilon = max( |
|
self.exploration_final_eps, self.exploration_final_eps + |
|
(self.exploration_initial_eps - self.exploration_final_eps) * |
|
self.exploration_decay**step) |
|
return self.epsilon |
|
|