File size: 3,254 Bytes
6831f1f
3d7f69e
6831f1f
 
 
3d7f69e
 
457d4b2
3d7f69e
 
 
 
 
 
 
 
 
6831f1f
 
 
3d7f69e
 
 
 
 
 
6831f1f
 
 
457d4b2
6831f1f
 
457d4b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6831f1f
 
457d4b2
 
 
6831f1f
 
 
 
 
 
 
 
 
 
 
 
5813146
 
 
457d4b2
5813146
6831f1f
 
 
 
457d4b2
6831f1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d7f69e
6831f1f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from threading import Thread
from multiprocessing import Queue
from typing import Dict, Any
import json
import re
import logging
import sys
from mistralai.client import MistralClient

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    handlers=[logging.StreamHandler(sys.stdout)],
)

logger = logging.getLogger(__name__)


class ActionProcessor(Thread):
    def __init__(
        self,
        text_queue: "Queue[str]",
        action_queue: "Queue[str]",
        mistral_api_key: str,
    ):
        super().__init__()
        self.text_queue = text_queue
        self.action_queue = action_queue
        self.mistral_client = MistralClient(api_key=mistral_api_key)
        self.daemon = True  # Thread will exit when main program exits

    def get_sentiment(self, input_text: str) -> str:
        """Get sentiment analysis for input text."""
        stream_response = self.mistral_client.chat_stream(
            model="mistral-large-latest",
            messages=[
                {
                    "role": "user",
                    "content": f"""You are a sentiment classifier of positive or negative parenting.
                    Classify the following sentence, output "negative" or "positive", do not justify:
                    "{input_text}"
                    """,
                },
            ],
        )

        response = ""
        for chunk in stream_response:
            response += chunk.choices[0].delta.content

        return response.strip()

    def process_text(self, text: str) -> Dict[str, Any] | None:
        """Convert text into an action if a complete command is detected."""
        # Get sentiment first
        sentiment = self.get_sentiment(text)

        # Define command patterns
        command_patterns = {
            r"(?i)\b(stop|now)\b": "stop",
            r"(?i)\b(come back|get back)\b": "return",
            r"(?i)\b(easy)\b": "slow",
            r"(?i)\b(stop drinking)\b": "pause_liquid",
            r"(?i)\b(stop eating)\b": "pause_solid",
            r"(?i)\b(look at me)\b": "look_at_me",
            r"(?i)\b(look away)\b": "look_away",
            r"(?i)\b(don't do that)\b": "stop",
        }

        # TODO: Remove this test thing
        if len(text) <= 3:
            return None
        return {"type": text, "sentiment": sentiment}

        # Check each pattern
        for pattern, action_type in command_patterns.items():
            match = re.search(pattern, text.lower())
            if match:
                return {"type": action_type, "sentiment": sentiment}

        return None

    def run(self) -> None:
        """Main processing loop."""
        while True:
            try:
                # Get text from queue, blocks until text is available
                text = self.text_queue.get()

                # Process the text into an action
                action = self.process_text(text)

                # If we got a valid action, add it to the action queue
                if action:
                    self.action_queue.put(json.dumps(action))

            except Exception as e:
                logger.error(f"Error processing text: {str(e)}")
                continue