Spaces:
Running
on
A10G
Running
on
A10G
from threading import Thread | |
from multiprocessing import Queue | |
from typing import Dict, Any | |
import json | |
import re | |
import logging | |
import sys | |
from mistralai.client import MistralClient | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
handlers=[logging.StreamHandler(sys.stdout)], | |
) | |
logger = logging.getLogger(__name__) | |
class ActionProcessor(Thread): | |
def __init__( | |
self, | |
text_queue: "Queue[str]", | |
action_queue: "Queue[str]", | |
mistral_api_key: str, | |
): | |
super().__init__() | |
self.text_queue = text_queue | |
self.action_queue = action_queue | |
self.mistral_client = MistralClient(api_key=mistral_api_key) | |
self.daemon = True # Thread will exit when main program exits | |
def get_sentiment(self, input_text: str) -> str: | |
"""Get sentiment analysis for input text.""" | |
stream_response = self.mistral_client.chat_stream( | |
model="mistral-large-latest", | |
messages=[ | |
{ | |
"role": "user", | |
"content": f"""You are a sentiment classifier of positive or negative parenting. | |
Classify the following sentence, output "negative" or "positive", do not justify: | |
"{input_text}" | |
""", | |
}, | |
], | |
) | |
response = "" | |
for chunk in stream_response: | |
response += chunk.choices[0].delta.content | |
return response.strip() | |
def process_text(self, text: str) -> Dict[str, Any] | None: | |
"""Convert text into an action if a complete command is detected.""" | |
# Get sentiment first | |
sentiment = self.get_sentiment(text) | |
# Define command patterns | |
command_patterns = { | |
r"(?i)\b(stop|now)\b": "stop", | |
r"(?i)\b(come back|get back)\b": "return", | |
r"(?i)\b(easy)\b": "slow", | |
r"(?i)\b(stop drinking)\b": "pause_liquid", | |
r"(?i)\b(stop eating)\b": "pause_solid", | |
r"(?i)\b(look at me)\b": "look_at_me", | |
r"(?i)\b(look away)\b": "look_away", | |
r"(?i)\b(don't do that)\b": "stop", | |
} | |
# TODO: Remove this test thing | |
if len(text) <= 3: | |
return None | |
return {"type": text, "sentiment": sentiment} | |
# Check each pattern | |
for pattern, action_type in command_patterns.items(): | |
match = re.search(pattern, text.lower()) | |
if match: | |
return {"type": action_type, "sentiment": sentiment} | |
return None | |
def run(self) -> None: | |
"""Main processing loop.""" | |
while True: | |
try: | |
# Get text from queue, blocks until text is available | |
text = self.text_queue.get() | |
# Process the text into an action | |
action = self.process_text(text) | |
# If we got a valid action, add it to the action queue | |
if action: | |
self.action_queue.put(json.dumps(action)) | |
except Exception as e: | |
logger.error(f"Error processing text: {str(e)}") | |
continue | |