Spaces:
Runtime error
Runtime error
import numpy as np | |
import random | |
from loguru import logger | |
from typing import Dict, Any, List | |
class QLearningAgent: | |
def __init__(self, learning_rate: float = 0.1, | |
discount_factor: float = 0.9, exploration_rate: float = 0.1): | |
"""Initialize the Q-Learning Agent.""" | |
self.learning_rate = learning_rate | |
self.discount_factor = discount_factor | |
self.exploration_rate = exploration_rate | |
self.q_table = {} # Initialize Q-table as an empty dictionary | |
self.setup_logger() | |
def setup_logger(self): | |
"""Configure logging for the agent.""" | |
logger.add("logs/q_learning_agent.log", rotation="500 MB") | |
def get_q_value(self, state: Dict[str, Any], action: str) -> float: | |
"""Get the Q-value for a given state-action pair.""" | |
state_key = self.serialize_state(state) | |
if state_key not in self.q_table: | |
self.q_table[state_key] = {} | |
return self.q_table[state_key].get( | |
action, 0.0) # Default Q-value is 0.0 | |
def set_q_value(self, state: Dict[str, Any], action: str, value: float): | |
"""Set the Q-value for a given state-action pair.""" | |
state_key = self.serialize_state(state) | |
if state_key not in self.q_table: | |
self.q_table[state_key] = {} | |
self.q_table[state_key][action] = value | |
def choose_action(self, state: Dict[str, Any], | |
available_actions: List[str]) -> str: | |
"""Choose an action based on the current state and Q-table.""" | |
if random.random() < self.exploration_rate: | |
# Explore: choose a random action | |
return random.choice(available_actions) | |
else: | |
# Exploit: choose the action with the highest Q-value | |
q_values = [self.get_q_value(state, action) | |
for action in available_actions] | |
max_q = max(q_values) | |
# If multiple actions have the same max Q-value, choose randomly | |
# among them | |
best_actions = [ | |
action for action, | |
q in zip( | |
available_actions, | |
q_values) if q == max_q] | |
return random.choice(best_actions) | |
def update_q_table(self, state: Dict[str, Any], action: str, | |
reward: float, next_state: Dict[str, Any], next_actions: List[str]): | |
"""Update the Q-table based on the observed reward and next state.""" | |
current_q = self.get_q_value(state, action) | |
max_next_q = max([self.get_q_value(next_state, next_action) | |
for next_action in next_actions], default=0) | |
new_q = current_q + self.learning_rate * \ | |
(reward + self.discount_factor * max_next_q - current_q) | |
self.set_q_value(state, action, new_q) | |
logger.info( | |
f"Q-table updated for state-action pair: ({self.serialize_state(state)}, {action})") | |
def serialize_state(self, state: Dict[str, Any]) -> str: | |
"""Serialize the state into a string representation for use as a dictionary key.""" | |
# Convert the state dictionary to a string representation | |
return str(state) | |