File size: 3,195 Bytes
0af0a55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import numpy as np
import random
from loguru import logger
from typing import Dict, Any, List


class QLearningAgent:
    def __init__(self, learning_rate: float = 0.1,
                 discount_factor: float = 0.9, exploration_rate: float = 0.1):
        """Initialize the Q-Learning Agent."""
        self.learning_rate = learning_rate
        self.discount_factor = discount_factor
        self.exploration_rate = exploration_rate
        self.q_table = {}  # Initialize Q-table as an empty dictionary
        self.setup_logger()

    def setup_logger(self):
        """Configure logging for the agent."""
        logger.add("logs/q_learning_agent.log", rotation="500 MB")

    def get_q_value(self, state: Dict[str, Any], action: str) -> float:
        """Get the Q-value for a given state-action pair."""
        state_key = self.serialize_state(state)
        if state_key not in self.q_table:
            self.q_table[state_key] = {}

        return self.q_table[state_key].get(
            action, 0.0)  # Default Q-value is 0.0

    def set_q_value(self, state: Dict[str, Any], action: str, value: float):
        """Set the Q-value for a given state-action pair."""
        state_key = self.serialize_state(state)
        if state_key not in self.q_table:
            self.q_table[state_key] = {}
        self.q_table[state_key][action] = value

    def choose_action(self, state: Dict[str, Any],
                      available_actions: List[str]) -> str:
        """Choose an action based on the current state and Q-table."""
        if random.random() < self.exploration_rate:
            # Explore: choose a random action
            return random.choice(available_actions)
        else:
            # Exploit: choose the action with the highest Q-value
            q_values = [self.get_q_value(state, action)
                        for action in available_actions]
            max_q = max(q_values)
            # If multiple actions have the same max Q-value, choose randomly
            # among them
            best_actions = [
                action for action,
                q in zip(
                    available_actions,
                    q_values) if q == max_q]
            return random.choice(best_actions)

    def update_q_table(self, state: Dict[str, Any], action: str,
                       reward: float, next_state: Dict[str, Any], next_actions: List[str]):
        """Update the Q-table based on the observed reward and next state."""
        current_q = self.get_q_value(state, action)
        max_next_q = max([self.get_q_value(next_state, next_action)
                         for next_action in next_actions], default=0)
        new_q = current_q + self.learning_rate * \
            (reward + self.discount_factor * max_next_q - current_q)
        self.set_q_value(state, action, new_q)
        logger.info(
            f"Q-table updated for state-action pair: ({self.serialize_state(state)}, {action})")

    def serialize_state(self, state: Dict[str, Any]) -> str:
        """Serialize the state into a string representation for use as a dictionary key."""
        # Convert the state dictionary to a string representation
        return str(state)