agenticAi / agents /q_learning_agent.py
Cline
Initial commit
0af0a55
import numpy as np
import random
from loguru import logger
from typing import Dict, Any, List
class QLearningAgent:
def __init__(self, learning_rate: float = 0.1,
discount_factor: float = 0.9, exploration_rate: float = 0.1):
"""Initialize the Q-Learning Agent."""
self.learning_rate = learning_rate
self.discount_factor = discount_factor
self.exploration_rate = exploration_rate
self.q_table = {} # Initialize Q-table as an empty dictionary
self.setup_logger()
def setup_logger(self):
"""Configure logging for the agent."""
logger.add("logs/q_learning_agent.log", rotation="500 MB")
def get_q_value(self, state: Dict[str, Any], action: str) -> float:
"""Get the Q-value for a given state-action pair."""
state_key = self.serialize_state(state)
if state_key not in self.q_table:
self.q_table[state_key] = {}
return self.q_table[state_key].get(
action, 0.0) # Default Q-value is 0.0
def set_q_value(self, state: Dict[str, Any], action: str, value: float):
"""Set the Q-value for a given state-action pair."""
state_key = self.serialize_state(state)
if state_key not in self.q_table:
self.q_table[state_key] = {}
self.q_table[state_key][action] = value
def choose_action(self, state: Dict[str, Any],
available_actions: List[str]) -> str:
"""Choose an action based on the current state and Q-table."""
if random.random() < self.exploration_rate:
# Explore: choose a random action
return random.choice(available_actions)
else:
# Exploit: choose the action with the highest Q-value
q_values = [self.get_q_value(state, action)
for action in available_actions]
max_q = max(q_values)
# If multiple actions have the same max Q-value, choose randomly
# among them
best_actions = [
action for action,
q in zip(
available_actions,
q_values) if q == max_q]
return random.choice(best_actions)
def update_q_table(self, state: Dict[str, Any], action: str,
reward: float, next_state: Dict[str, Any], next_actions: List[str]):
"""Update the Q-table based on the observed reward and next state."""
current_q = self.get_q_value(state, action)
max_next_q = max([self.get_q_value(next_state, next_action)
for next_action in next_actions], default=0)
new_q = current_q + self.learning_rate * \
(reward + self.discount_factor * max_next_q - current_q)
self.set_q_value(state, action, new_q)
logger.info(
f"Q-table updated for state-action pair: ({self.serialize_state(state)}, {action})")
def serialize_state(self, state: Dict[str, Any]) -> str:
"""Serialize the state into a string representation for use as a dictionary key."""
# Convert the state dictionary to a string representation
return str(state)