Spaces:
Runtime error
Runtime error
Upload 23 files
Browse files- .gitattributes +35 -35
- .gitignore +0 -0
- Dockerfile +26 -0
- README.md +11 -11
- __init__.py +667 -0
- app.log +0 -0
- app.py +544 -0
- app2.py +560 -0
- bm25retriever.pkl +3 -0
- chain.py +28 -0
- chat.py +667 -0
- chatflask.py +646 -0
- config.py +18 -0
- embeddings.py +62 -0
- flasktest.py +49 -0
- index.html +70 -0
- llm.py +45 -0
- logging_config.py +38 -0
- main.py +100 -0
- rag.py +114 -0
- requirements.txt +30 -0
- retriever.py +53 -0
- tools.py +188 -0
.gitattributes
CHANGED
@@ -1,35 +1,35 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
Binary file (38 Bytes). View file
|
|
Dockerfile
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Use an official Python runtime as a parent image
|
2 |
+
FROM python:3.11-slim
|
3 |
+
|
4 |
+
# Set environment variables to avoid interactive prompts
|
5 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
6 |
+
|
7 |
+
# Set the working directory in the container
|
8 |
+
WORKDIR /app
|
9 |
+
|
10 |
+
# Copy the current directory contents into the container at /app
|
11 |
+
COPY . /app
|
12 |
+
|
13 |
+
# Install any needed packages specified in requirements.txt
|
14 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
15 |
+
|
16 |
+
# Copy and set environment variables from .env file
|
17 |
+
COPY .env .env
|
18 |
+
|
19 |
+
# Expose the port the Flask app runs on
|
20 |
+
EXPOSE 5000
|
21 |
+
|
22 |
+
# Expose the port the Streamlit app runs on
|
23 |
+
EXPOSE 8501
|
24 |
+
|
25 |
+
# Run the Flask app and Streamlit app using a single CMD
|
26 |
+
CMD ["streamlit", "run", "app.py"]
|
README.md
CHANGED
@@ -1,11 +1,11 @@
|
|
1 |
-
---
|
2 |
-
title: Financial Chatbot
|
3 |
-
emoji: 🚀
|
4 |
-
colorFrom: indigo
|
5 |
-
colorTo: blue
|
6 |
-
sdk: docker
|
7 |
-
pinned: false
|
8 |
-
license: mit
|
9 |
-
---
|
10 |
-
|
11 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
+
---
|
2 |
+
title: Financial Chatbot
|
3 |
+
emoji: 🚀
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: blue
|
6 |
+
sdk: docker
|
7 |
+
pinned: false
|
8 |
+
license: mit
|
9 |
+
---
|
10 |
+
|
11 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
__init__.py
ADDED
@@ -0,0 +1,667 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from bs4 import BeautifulSoup
|
3 |
+
import yfinance as yf
|
4 |
+
import pandas as pd
|
5 |
+
from datetime import datetime, timedelta
|
6 |
+
import logging
|
7 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
8 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
9 |
+
from config import Config
|
10 |
+
import numpy as np
|
11 |
+
from typing import Optional, Tuple, List, Dict
|
12 |
+
from rag import get_answer
|
13 |
+
import time
|
14 |
+
from tenacity import retry, stop_after_attempt, wait_exponential
|
15 |
+
|
16 |
+
# Set up logging
|
17 |
+
logging.basicConfig(level=logging.DEBUG,
|
18 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
19 |
+
handlers=[logging.FileHandler("app.log"),
|
20 |
+
logging.StreamHandler()])
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
# Initialize the Gemini model
|
25 |
+
llm = ChatGoogleGenerativeAI(api_key=Config.GEMINI_API_KEY, model="gemini-1.5-flash-latest", temperature=0.5)
|
26 |
+
|
27 |
+
# Configuration for Google Custom Search API
|
28 |
+
GOOGLE_API_KEY = Config.GOOGLE_API_KEY
|
29 |
+
SEARCH_ENGINE_ID = Config.SEARCH_ENGINE_ID
|
30 |
+
|
31 |
+
|
32 |
+
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=2, max=8), reraise=True)
|
33 |
+
def invoke_llm(prompt):
|
34 |
+
return llm.invoke(prompt)
|
35 |
+
|
36 |
+
|
37 |
+
class DataSummarizer:
|
38 |
+
def __init__(self):
|
39 |
+
pass
|
40 |
+
|
41 |
+
def google_search(self, query: str) -> Optional[str]:
|
42 |
+
start_time = time.time()
|
43 |
+
try:
|
44 |
+
url = "https://www.googleapis.com/customsearch/v1"
|
45 |
+
params = {
|
46 |
+
'key': GOOGLE_API_KEY,
|
47 |
+
'cx': SEARCH_ENGINE_ID,
|
48 |
+
'q': query
|
49 |
+
}
|
50 |
+
response = requests.get(url, params=params)
|
51 |
+
response.raise_for_status()
|
52 |
+
search_results = response.json()
|
53 |
+
logger.info("google_search took %.2f seconds", time.time() - start_time)
|
54 |
+
|
55 |
+
# Summarize the search results using Gemini
|
56 |
+
items = search_results.get('items', [])
|
57 |
+
content = "\n\n".join([f"{item.get('title', '')}\n{item.get('snippet', '')}" for item in items])
|
58 |
+
prompt = f"Summarize the following search results:\n\n{content}"
|
59 |
+
summary_response = invoke_llm(prompt)
|
60 |
+
return summary_response.content.strip()
|
61 |
+
except Exception as e:
|
62 |
+
logger.error(f"Error during Google Search API request: {e}")
|
63 |
+
return None
|
64 |
+
|
65 |
+
def extract_content_from_item(self, item: Dict) -> Optional[str]:
|
66 |
+
try:
|
67 |
+
snippet = item.get('snippet', '')
|
68 |
+
title = item.get('title', '')
|
69 |
+
return f"{title}\n{snippet}"
|
70 |
+
except Exception as e:
|
71 |
+
logger.error(f"Error extracting content from item: {e}")
|
72 |
+
return None
|
73 |
+
|
74 |
+
def calculate_moving_average(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.Series]:
|
75 |
+
start_time = time.time()
|
76 |
+
try:
|
77 |
+
result = df['close'].rolling(window=window).mean()
|
78 |
+
logger.info("calculate_moving_average took %.2f seconds", time.time() - start_time)
|
79 |
+
return result
|
80 |
+
except Exception as e:
|
81 |
+
logger.error(f"Error calculating moving average: {e}")
|
82 |
+
return None
|
83 |
+
|
84 |
+
def calculate_rsi(self, df: pd.DataFrame, window: int = 14) -> Optional[pd.Series]:
|
85 |
+
start_time = time.time()
|
86 |
+
try:
|
87 |
+
delta = df['close'].diff()
|
88 |
+
gain = delta.where(delta > 0, 0).rolling(window=window).mean()
|
89 |
+
loss = -delta.where(delta < 0, 0).rolling(window=window).mean()
|
90 |
+
rs = gain / loss
|
91 |
+
result = 100 - (100 / (1 + rs))
|
92 |
+
logger.info("calculate_rsi took %.2f seconds", time.time() - start_time)
|
93 |
+
return result
|
94 |
+
except Exception as e:
|
95 |
+
logger.error(f"Error calculating RSI: {e}")
|
96 |
+
return None
|
97 |
+
|
98 |
+
def calculate_ema(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.Series]:
|
99 |
+
start_time = time.time()
|
100 |
+
try:
|
101 |
+
result = df['close'].ewm(span=window, adjust=False).mean()
|
102 |
+
logger.info("calculate_ema took %.2f seconds", time.time() - start_time)
|
103 |
+
return result
|
104 |
+
except Exception as e:
|
105 |
+
logger.error(f"Error calculating EMA: {e}")
|
106 |
+
return None
|
107 |
+
|
108 |
+
def calculate_bollinger_bands(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.DataFrame]:
|
109 |
+
start_time = time.time()
|
110 |
+
try:
|
111 |
+
ma = df['close'].rolling(window=window).mean()
|
112 |
+
std = df['close'].rolling(window=window).std()
|
113 |
+
upper_band = ma + (std * 2)
|
114 |
+
lower_band = ma - (std * 2)
|
115 |
+
result = pd.DataFrame({'MA': ma, 'Upper Band': upper_band, 'Lower Band': lower_band})
|
116 |
+
logger.info("calculate_bollinger_bands took %.2f seconds", time.time() - start_time)
|
117 |
+
return result
|
118 |
+
except Exception as e:
|
119 |
+
logger.error(f"Error calculating Bollinger Bands: {e}")
|
120 |
+
return None
|
121 |
+
|
122 |
+
def calculate_macd(self, df: pd.DataFrame, short_window: int = 12, long_window: int = 26, signal_window: int = 9) -> \
|
123 |
+
Optional[pd.DataFrame]:
|
124 |
+
start_time = time.time()
|
125 |
+
try:
|
126 |
+
short_ema = df['close'].ewm(span=short_window, adjust=False).mean()
|
127 |
+
long_ema = df['close'].ewm(span=long_window, adjust=False).mean()
|
128 |
+
macd = short_ema - long_ema
|
129 |
+
signal = macd.ewm(span=signal_window, adjust=False).mean()
|
130 |
+
result = pd.DataFrame({'MACD': macd, 'Signal Line': signal})
|
131 |
+
logger.info("calculate_macd took %.2f seconds", time.time() - start_time)
|
132 |
+
return result
|
133 |
+
except Exception as e:
|
134 |
+
logger.error(f"Error calculating MACD: {e}")
|
135 |
+
return None
|
136 |
+
|
137 |
+
def calculate_volatility(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.Series]:
|
138 |
+
start_time = time.time()
|
139 |
+
try:
|
140 |
+
log_returns = np.log(df['close'] / df['close'].shift(1))
|
141 |
+
result = log_returns.rolling(window=window).std() * np.sqrt(window)
|
142 |
+
logger.info("calculate_volatility took %.2f seconds", time.time() - start_time)
|
143 |
+
return result
|
144 |
+
except Exception as e:
|
145 |
+
logger.error(f"Error calculating volatility: {e}")
|
146 |
+
return None
|
147 |
+
|
148 |
+
def calculate_atr(self, df: pd.DataFrame, window: int = 14) -> Optional[pd.Series]:
|
149 |
+
start_time = time.time()
|
150 |
+
try:
|
151 |
+
high_low = df['high'] - df['low']
|
152 |
+
high_close = np.abs(df['high'] - df['close'].shift())
|
153 |
+
low_close = np.abs(df['low'] - df['close'].shift())
|
154 |
+
true_range = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
|
155 |
+
result = true_range.rolling(window=window).mean()
|
156 |
+
logger.info("calculate_atr took %.2f seconds", time.time() - start_time)
|
157 |
+
return result
|
158 |
+
except Exception as e:
|
159 |
+
logger.error(f"Error calculating ATR: {e}")
|
160 |
+
return None
|
161 |
+
|
162 |
+
def calculate_obv(self, df: pd.DataFrame) -> Optional[pd.Series]:
|
163 |
+
start_time = time.time()
|
164 |
+
try:
|
165 |
+
result = (np.sign(df['close'].diff()) * df['volume']).fillna(0).cumsum()
|
166 |
+
logger.info("calculate_obv took %.2f seconds", time.time() - start_time)
|
167 |
+
return result
|
168 |
+
except Exception as e:
|
169 |
+
logger.error(f"Error calculating OBV: {e}")
|
170 |
+
return None
|
171 |
+
|
172 |
+
def calculate_yearly_summary(self, df: pd.DataFrame) -> Optional[pd.DataFrame]:
|
173 |
+
start_time = time.time()
|
174 |
+
try:
|
175 |
+
df['year'] = pd.to_datetime(df['date']).dt.year
|
176 |
+
yearly_summary = df.groupby('year').agg({
|
177 |
+
'close': ['mean', 'max', 'min'],
|
178 |
+
'volume': 'sum'
|
179 |
+
})
|
180 |
+
yearly_summary.columns = ['_'.join(col) for col in yearly_summary.columns]
|
181 |
+
logger.info("calculate_yearly_summary took %.2f seconds", time.time() - start_time)
|
182 |
+
return yearly_summary
|
183 |
+
except Exception as e:
|
184 |
+
logger.error(f"Error calculating yearly summary: {e}")
|
185 |
+
return None
|
186 |
+
|
187 |
+
def get_full_last_year(self, df: pd.DataFrame) -> Optional[pd.DataFrame]:
|
188 |
+
start_time = time.time()
|
189 |
+
try:
|
190 |
+
today = datetime.today().date()
|
191 |
+
last_year_start = datetime(today.year - 1, 1, 1).date()
|
192 |
+
last_year_end = datetime(today.year - 1, 12, 31).date()
|
193 |
+
mask = (df['date'] >= last_year_start) & (df['date'] <= last_year_end)
|
194 |
+
result = df.loc[mask]
|
195 |
+
logger.info("get_full_last_year took %.2f seconds", time.time() - start_time)
|
196 |
+
return result
|
197 |
+
except Exception as e:
|
198 |
+
logger.error(f"Error filtering data for the last year: {e}")
|
199 |
+
return None
|
200 |
+
|
201 |
+
def calculate_ytd_performance(self, df: pd.DataFrame) -> Optional[float]:
|
202 |
+
start_time = time.time()
|
203 |
+
try:
|
204 |
+
today = datetime.today().date()
|
205 |
+
year_start = datetime(today.year, 1, 1).date()
|
206 |
+
mask = (df['date'] >= year_start) & (df['date'] <= today)
|
207 |
+
ytd_data = df.loc[mask]
|
208 |
+
opening_price = ytd_data.iloc[0]['open']
|
209 |
+
closing_price = ytd_data.iloc[-1]['close']
|
210 |
+
result = ((closing_price - opening_price) / opening_price) * 100
|
211 |
+
logger.info("calculate_ytd_performance took %.2f seconds", time.time() - start_time)
|
212 |
+
return result
|
213 |
+
except Exception as e:
|
214 |
+
logger.error(f"Error calculating YTD performance: {e}")
|
215 |
+
return None
|
216 |
+
|
217 |
+
def calculate_pe_ratio(self, current_price: float, eps: float) -> Optional[float]:
|
218 |
+
start_time = time.time()
|
219 |
+
try:
|
220 |
+
if eps == 0:
|
221 |
+
raise ValueError("EPS cannot be zero for P/E ratio calculation.")
|
222 |
+
result = current_price / eps
|
223 |
+
logger.info("calculate_pe_ratio took %.2f seconds", time.time() - start_time)
|
224 |
+
return result
|
225 |
+
except Exception as e:
|
226 |
+
logger.error(f"Error calculating P/E ratio: {e}")
|
227 |
+
return None
|
228 |
+
|
229 |
+
def fetch_google_snippet(self, query: str) -> Optional[str]:
|
230 |
+
try:
|
231 |
+
search_url = f"https://www.google.com/search?q={query}"
|
232 |
+
headers = {
|
233 |
+
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.36"
|
234 |
+
}
|
235 |
+
response = requests.get(search_url, headers=headers)
|
236 |
+
soup = BeautifulSoup(response.text, 'html.parser')
|
237 |
+
snippet_classes = [
|
238 |
+
'BNeawe iBp4i AP7Wnd',
|
239 |
+
'BNeawe s3v9rd AP7Wnd',
|
240 |
+
'BVG0Nb',
|
241 |
+
'kno-rdesc'
|
242 |
+
]
|
243 |
+
snippet = None
|
244 |
+
for cls in snippet_classes:
|
245 |
+
snippet = soup.find('div', class_=cls)
|
246 |
+
if snippet:
|
247 |
+
break
|
248 |
+
return snippet.get_text() if snippet else "Snippet not found."
|
249 |
+
except Exception as e:
|
250 |
+
logger.error(f"Error fetching Google snippet: {e}")
|
251 |
+
return None
|
252 |
+
|
253 |
+
|
254 |
+
def extract_ticker_from_response(response: str) -> Optional[str]:
|
255 |
+
start_time = time.time()
|
256 |
+
try:
|
257 |
+
if "is **" in response and "**." in response:
|
258 |
+
result = response.split("is **")[1].split("**.")[0].strip()
|
259 |
+
logger.info("extract_ticker_from_response took %.2f seconds", time.time() - start_time)
|
260 |
+
return result
|
261 |
+
result = response.strip()
|
262 |
+
logger.info("extract_ticker_from_response took %.2f seconds", time.time() - start_time)
|
263 |
+
return result
|
264 |
+
except Exception as e:
|
265 |
+
logger.error(f"Error extracting ticker from response: {e}")
|
266 |
+
return None
|
267 |
+
|
268 |
+
|
269 |
+
def detect_translate_entity_and_ticker(query: str) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
|
270 |
+
try:
|
271 |
+
start_time = time.time()
|
272 |
+
|
273 |
+
# Step 1: Detect Language
|
274 |
+
prompt = f"Detect the language for the following text: {query}"
|
275 |
+
response = invoke_llm(prompt)
|
276 |
+
detected_language = response.content.strip()
|
277 |
+
logger.info(f"Language detected: {detected_language}")
|
278 |
+
|
279 |
+
# Step 2: Translate to English (if necessary)
|
280 |
+
translated_query = query
|
281 |
+
if detected_language != "English":
|
282 |
+
prompt = f"Translate the following text to English: {query}"
|
283 |
+
response = invoke_llm(prompt)
|
284 |
+
translated_query = response.content.strip()
|
285 |
+
logger.info(f"Translation completed: {translated_query}")
|
286 |
+
print(f"Translation: {translated_query}")
|
287 |
+
|
288 |
+
# Step 3: Detect Entity
|
289 |
+
prompt = f"Detect the entity in the following text that is a company name: {translated_query}"
|
290 |
+
response = invoke_llm(prompt)
|
291 |
+
detected_entity = response.content.strip()
|
292 |
+
logger.info(f"Entity detected: {detected_entity}")
|
293 |
+
print(f"Entity: {detected_entity}")
|
294 |
+
|
295 |
+
if not detected_entity:
|
296 |
+
logger.error("No entity detected")
|
297 |
+
return detected_language, None, translated_query, None
|
298 |
+
|
299 |
+
# Step 4: Get Stock Ticker
|
300 |
+
prompt = f"What is the stock ticker symbol for the company {detected_entity}?"
|
301 |
+
response = invoke_llm(prompt)
|
302 |
+
stock_ticker = extract_ticker_from_response(response.content.strip())
|
303 |
+
|
304 |
+
if not stock_ticker:
|
305 |
+
logger.error("No stock ticker detected")
|
306 |
+
return detected_language, detected_entity, translated_query, None
|
307 |
+
|
308 |
+
logger.info("detect_translate_entity_and_ticker took %.2f seconds", time.time() - start_time)
|
309 |
+
return detected_language, detected_entity, translated_query, stock_ticker
|
310 |
+
except Exception as e:
|
311 |
+
logger.error(f"Error in detecting, translating, or extracting entity and ticker: {e}")
|
312 |
+
return None, None, None, None
|
313 |
+
|
314 |
+
|
315 |
+
def fetch_stock_data_yahoo(symbol: str) -> pd.DataFrame:
|
316 |
+
start_time = time.time()
|
317 |
+
try:
|
318 |
+
stock = yf.Ticker(symbol)
|
319 |
+
logger.info(f"Fetching data for symbol: {symbol}")
|
320 |
+
|
321 |
+
end_date = datetime.now()
|
322 |
+
start_date = end_date - timedelta(days=3 * 365)
|
323 |
+
|
324 |
+
historical_data = stock.history(start=start_date, end=end_date)
|
325 |
+
if historical_data.empty:
|
326 |
+
raise ValueError(f"No historical data found for symbol: {symbol}")
|
327 |
+
|
328 |
+
historical_data = historical_data.rename(
|
329 |
+
columns={"Open": "open", "High": "high", "Low": "low", "Close": "close", "Volume": "volume"}
|
330 |
+
)
|
331 |
+
|
332 |
+
historical_data.reset_index(inplace=True)
|
333 |
+
historical_data['date'] = historical_data['Date'].dt.date
|
334 |
+
historical_data = historical_data.drop(columns=['Date'])
|
335 |
+
historical_data = historical_data[['date', 'open', 'high', 'low', 'close', 'volume']]
|
336 |
+
|
337 |
+
if 'close' not in historical_data.columns:
|
338 |
+
raise KeyError("The historical data must contain a 'close' column.")
|
339 |
+
|
340 |
+
logger.info("fetch_stock_data_yahoo took %.2f seconds", time.time() - start_time)
|
341 |
+
return historical_data
|
342 |
+
except Exception as e:
|
343 |
+
logger.error(f"Failed to fetch stock data for {symbol} from Yahoo Finance: {e}")
|
344 |
+
return pd.DataFrame()
|
345 |
+
|
346 |
+
|
347 |
+
def fetch_current_stock_price(symbol: str) -> Optional[float]:
|
348 |
+
start_time = time.time()
|
349 |
+
try:
|
350 |
+
stock = yf.Ticker(symbol)
|
351 |
+
result = stock.info['currentPrice']
|
352 |
+
logger.info("fetch_current_stock_price took %.2f seconds", time.time() - start_time)
|
353 |
+
return result
|
354 |
+
except Exception as e:
|
355 |
+
logger.error(f"Failed to fetch current stock price for {symbol}: {e}")
|
356 |
+
return None
|
357 |
+
|
358 |
+
|
359 |
+
def format_stock_data_for_gemini(stock_data: pd.DataFrame) -> str:
|
360 |
+
start_time = time.time()
|
361 |
+
try:
|
362 |
+
if stock_data.empty:
|
363 |
+
return "No historical data available."
|
364 |
+
|
365 |
+
formatted_data = "Historical stock data for the last three years:\n\n"
|
366 |
+
formatted_data += "Date | Open | High | Low | Close | Volume\n"
|
367 |
+
formatted_data += "------------------------------------------------------\n"
|
368 |
+
|
369 |
+
for index, row in stock_data.iterrows():
|
370 |
+
formatted_data += f"{row['date']} | {row['open']:.2f} | {row['high']:.2f} | {row['low']:.2f} | {row['close']:.2f} | {int(row['volume'])}\n"
|
371 |
+
|
372 |
+
logger.info("format_stock_data_for_gemini took %.2f seconds", time.time() - start_time)
|
373 |
+
return formatted_data
|
374 |
+
except Exception as e:
|
375 |
+
logger.error(f"Error formatting stock data for Gemini: {e}")
|
376 |
+
return "Error formatting stock data."
|
377 |
+
|
378 |
+
|
379 |
+
def fetch_company_info_yahoo(symbol: str) -> Dict:
|
380 |
+
start_time = time.time()
|
381 |
+
try:
|
382 |
+
if not symbol:
|
383 |
+
return {"error": "Invalid symbol"}
|
384 |
+
|
385 |
+
stock = yf.Ticker(symbol)
|
386 |
+
company_info = stock.info
|
387 |
+
logger.info("fetch_company_info_yahoo took %.2f seconds", time.time() - start_time)
|
388 |
+
return {
|
389 |
+
"name": company_info.get("longName", "N/A"),
|
390 |
+
"sector": company_info.get("sector", "N/A"),
|
391 |
+
"industry": company_info.get("industry", "N/A"),
|
392 |
+
"marketCap": company_info.get("marketCap", "N/A"),
|
393 |
+
"summary": company_info.get("longBusinessSummary", "N/A"),
|
394 |
+
"website": company_info.get("website", "N/A"),
|
395 |
+
"address": company_info.get("address1", "N/A"),
|
396 |
+
"city": company_info.get("city", "N/A"),
|
397 |
+
"state": company_info.get("state", "N/A"),
|
398 |
+
"country": company_info.get("country", "N/A"),
|
399 |
+
"phone": company_info.get("phone", "N/A")
|
400 |
+
}
|
401 |
+
except Exception as e:
|
402 |
+
logger.error(f"Error fetching company info for {symbol}: {e}")
|
403 |
+
return {"error": str(e)}
|
404 |
+
|
405 |
+
|
406 |
+
def format_company_info_for_gemini(company_info: Dict) -> str:
|
407 |
+
start_time = time.time()
|
408 |
+
try:
|
409 |
+
if "error" in company_info:
|
410 |
+
return f"Error fetching company info: {company_info['error']}"
|
411 |
+
|
412 |
+
formatted_info = (f"\nCompany Information:\n"
|
413 |
+
f"Name: {company_info['name']}\n"
|
414 |
+
f"Sector: {company_info['sector']}\n"
|
415 |
+
f"Industry: {company_info['industry']}\n"
|
416 |
+
f"Market Cap: {company_info['marketCap']}\n"
|
417 |
+
f"Summary: {company_info['summary']}\n"
|
418 |
+
f"Website: {company_info['website']}\n"
|
419 |
+
f"Address: {company_info['address']}, {company_info['city']}, {company_info['state']}, {company_info['country']}\n"
|
420 |
+
f"Phone: {company_info['phone']}\n")
|
421 |
+
|
422 |
+
logger.info("format_company_info_for_gemini took %.2f seconds", time.time() - start_time)
|
423 |
+
return formatted_info
|
424 |
+
except Exception as e:
|
425 |
+
logger.error(f"Error formatting company info for Gemini: {e}")
|
426 |
+
return "Error formatting company info."
|
427 |
+
|
428 |
+
|
429 |
+
def fetch_company_news_yahoo(symbol: str) -> List[Dict]:
|
430 |
+
start_time = time.time()
|
431 |
+
try:
|
432 |
+
stock = yf.Ticker(symbol)
|
433 |
+
news = stock.news
|
434 |
+
if not news:
|
435 |
+
raise ValueError(f"No news found for symbol: {symbol}")
|
436 |
+
logger.info("fetch_company_news_yahoo took %.2f seconds", time.time() - start_time)
|
437 |
+
return news
|
438 |
+
except Exception as e:
|
439 |
+
logger.error(f"Failed to fetch news for {symbol} from Yahoo Finance: {e}")
|
440 |
+
return []
|
441 |
+
|
442 |
+
|
443 |
+
def format_company_news_for_gemini(news: List[Dict]) -> str:
|
444 |
+
start_time = time.time()
|
445 |
+
try:
|
446 |
+
if not news:
|
447 |
+
return "No news available."
|
448 |
+
|
449 |
+
formatted_news = "Latest company news:\n\n"
|
450 |
+
for article in news:
|
451 |
+
formatted_news += (f"Title: {article['title']}\n"
|
452 |
+
f"Publisher: {article['publisher']}\n"
|
453 |
+
f"Link: {article['link']}\n"
|
454 |
+
f"Published: {article['providerPublishTime']}\n\n")
|
455 |
+
|
456 |
+
logger.info("format_company_news_for_gemini took %.2f seconds", time.time() - start_time)
|
457 |
+
return formatted_news
|
458 |
+
except Exception as e:
|
459 |
+
logger.error(f"Error formatting company news for Gemini: {e}")
|
460 |
+
return "Error formatting company news."
|
461 |
+
|
462 |
+
|
463 |
+
def send_to_gemini_for_summarization(content: str) -> str:
|
464 |
+
start_time = time.time()
|
465 |
+
try:
|
466 |
+
unified_content = " ".join(content)
|
467 |
+
prompt = f"Summarize the main points of this article.\n\n{unified_content}"
|
468 |
+
response = invoke_llm(prompt)
|
469 |
+
logger.info("send_to_gemini_for_summarization took %.2f seconds", time.time() - start_time)
|
470 |
+
return response.content.strip()
|
471 |
+
except Exception as e:
|
472 |
+
logger.error(f"Error sending content to Gemini for summarization: {e}")
|
473 |
+
return "Error summarizing content."
|
474 |
+
|
475 |
+
|
476 |
+
def answer_question_with_data(question: str, data: Dict) -> str:
|
477 |
+
start_time = time.time()
|
478 |
+
try:
|
479 |
+
data_str = ""
|
480 |
+
for key, value in data.items():
|
481 |
+
data_str += f"{key}:\n{value}\n\n"
|
482 |
+
|
483 |
+
prompt = (f"You are a financial advisor. Begin your answer by stating that and only give the answer after.\n"
|
484 |
+
f"Using the following data, answer this question: {question}\n\nData:\n{data_str}\n"
|
485 |
+
f"Make your answer in the best form and professional.\n"
|
486 |
+
f"Don't say anything about the source of the data.\n"
|
487 |
+
f"If you don't have the data to answer, say this data is not available yet. If the data is not available in the stock history data, say this was a weekend and there is no data for it.")
|
488 |
+
response = invoke_llm(prompt)
|
489 |
+
logger.info("answer_question_with_data took %.2f seconds", time.time() - start_time)
|
490 |
+
return response.content.strip()
|
491 |
+
except Exception as e:
|
492 |
+
logger.error(f"Error answering question with data: {e}")
|
493 |
+
return "Error answering question."
|
494 |
+
|
495 |
+
|
496 |
+
def calculate_metrics(stock_data: pd.DataFrame, summarizer: DataSummarizer, company_info: Dict) -> Dict[str, str]:
|
497 |
+
start_time = time.time()
|
498 |
+
try:
|
499 |
+
moving_average = summarizer.calculate_moving_average(stock_data)
|
500 |
+
rsi = summarizer.calculate_rsi(stock_data)
|
501 |
+
ema = summarizer.calculate_ema(stock_data)
|
502 |
+
bollinger_bands = summarizer.calculate_bollinger_bands(stock_data)
|
503 |
+
macd = summarizer.calculate_macd(stock_data)
|
504 |
+
volatility = summarizer.calculate_volatility(stock_data)
|
505 |
+
atr = summarizer.calculate_atr(stock_data)
|
506 |
+
obv = summarizer.calculate_obv(stock_data)
|
507 |
+
yearly_summary = summarizer.calculate_yearly_summary(stock_data)
|
508 |
+
ytd_performance = summarizer.calculate_ytd_performance(stock_data)
|
509 |
+
|
510 |
+
eps = company_info.get('trailingEps', None)
|
511 |
+
if eps:
|
512 |
+
current_price = stock_data.iloc[-1]['close']
|
513 |
+
pe_ratio = summarizer.calculate_pe_ratio(current_price, eps)
|
514 |
+
formatted_metrics = {
|
515 |
+
"Moving Average": moving_average.to_string(),
|
516 |
+
"RSI": rsi.to_string(),
|
517 |
+
"EMA": ema.to_string(),
|
518 |
+
"Bollinger Bands": bollinger_bands.to_string(),
|
519 |
+
"MACD": macd.to_string(),
|
520 |
+
"Volatility": volatility.to_string(),
|
521 |
+
"ATR": atr.to_string(),
|
522 |
+
"OBV": obv.to_string(),
|
523 |
+
"Yearly Summary": yearly_summary.to_string(),
|
524 |
+
"YTD Performance": f"{ytd_performance:.2f}%",
|
525 |
+
"P/E Ratio": f"{pe_ratio:.2f}"
|
526 |
+
}
|
527 |
+
else:
|
528 |
+
formatted_metrics = {
|
529 |
+
"Moving Average": moving_average.to_string(),
|
530 |
+
"RSI": rsi.to_string(),
|
531 |
+
"EMA": ema.to_string(),
|
532 |
+
"Bollinger Bands": bollinger_bands.to_string(),
|
533 |
+
"MACD": macd.to_string(),
|
534 |
+
"Volatility": volatility.to_string(),
|
535 |
+
"ATR": atr.to_string(),
|
536 |
+
"OBV": obv.to_string(),
|
537 |
+
"Yearly Summary": yearly_summary.to_string(),
|
538 |
+
"YTD Performance": f"{ytd_performance:.2f}%"
|
539 |
+
}
|
540 |
+
|
541 |
+
logger.info("calculate_metrics took %.2f seconds", time.time() - start_time)
|
542 |
+
return formatted_metrics
|
543 |
+
except Exception as e:
|
544 |
+
logger.error(f"Error calculating metrics: {e}")
|
545 |
+
return {"Error": "Error calculating metrics"}
|
546 |
+
|
547 |
+
|
548 |
+
def prepare_data(formatted_stock_data: str, formatted_company_info: str, formatted_company_news: str,
|
549 |
+
google_results: str, formatted_metrics: Dict[str, str], google_snippet: str, rag_response: str) -> \
|
550 |
+
Dict[str, str]:
|
551 |
+
start_time = time.time()
|
552 |
+
collected_data = {
|
553 |
+
"Formatted Stock Data": formatted_stock_data,
|
554 |
+
"Formatted Company Info": formatted_company_info,
|
555 |
+
"Formatted Company News": formatted_company_news,
|
556 |
+
"Google Search Results": google_results,
|
557 |
+
"Google Snippet": google_snippet,
|
558 |
+
"RAG Response": rag_response,
|
559 |
+
"Calculations": formatted_metrics
|
560 |
+
}
|
561 |
+
collected_data.update(formatted_metrics)
|
562 |
+
logger.info("prepare_data took %.2f seconds", time.time() - start_time)
|
563 |
+
return collected_data
|
564 |
+
|
565 |
+
|
566 |
+
def main():
|
567 |
+
print("Welcome to the Financial Data Chatbot. How can I assist you today?")
|
568 |
+
|
569 |
+
summarizer = DataSummarizer()
|
570 |
+
conversation_history = []
|
571 |
+
|
572 |
+
while True:
|
573 |
+
user_input = input("You: ")
|
574 |
+
|
575 |
+
if user_input.lower() in ['exit', 'quit', 'bye']:
|
576 |
+
print("Goodbye! Have a great day!")
|
577 |
+
break
|
578 |
+
|
579 |
+
conversation_history.append(f"You: {user_input}")
|
580 |
+
|
581 |
+
try:
|
582 |
+
# Detect language, entity, translation, and stock ticker
|
583 |
+
language, entity, translation, stock_ticker = detect_translate_entity_and_ticker(user_input)
|
584 |
+
|
585 |
+
logger.info(
|
586 |
+
f"Detected Language: {language}, Entity: {entity}, Translation: {translation}, Stock Ticker: {stock_ticker}")
|
587 |
+
|
588 |
+
if entity and stock_ticker:
|
589 |
+
with ThreadPoolExecutor() as executor:
|
590 |
+
futures = {
|
591 |
+
executor.submit(fetch_stock_data_yahoo, stock_ticker): "stock_data",
|
592 |
+
executor.submit(fetch_company_info_yahoo, stock_ticker): "company_info",
|
593 |
+
executor.submit(fetch_company_news_yahoo, stock_ticker): "company_news",
|
594 |
+
executor.submit(fetch_current_stock_price, stock_ticker): "current_stock_price",
|
595 |
+
executor.submit(get_answer, user_input): "rag_response",
|
596 |
+
executor.submit(summarizer.google_search, user_input): "google_results",
|
597 |
+
executor.submit(summarizer.fetch_google_snippet, user_input): "google_snippet"
|
598 |
+
}
|
599 |
+
results = {futures[future]: future.result() for future in as_completed(futures)}
|
600 |
+
|
601 |
+
stock_data = results.get("stock_data", pd.DataFrame())
|
602 |
+
formatted_stock_data = format_stock_data_for_gemini(
|
603 |
+
stock_data) if not stock_data.empty else "No historical data available."
|
604 |
+
|
605 |
+
company_info = results.get("company_info", {})
|
606 |
+
formatted_company_info = format_company_info_for_gemini(
|
607 |
+
company_info) if company_info else "No company info available."
|
608 |
+
|
609 |
+
company_news = results.get("company_news", [])
|
610 |
+
formatted_company_news = format_company_news_for_gemini(
|
611 |
+
company_news) if company_news else "No news available."
|
612 |
+
|
613 |
+
current_stock_price = results.get("current_stock_price", None)
|
614 |
+
|
615 |
+
formatted_metrics = calculate_metrics(stock_data, summarizer,
|
616 |
+
company_info) if not stock_data.empty else {
|
617 |
+
"Error": "No stock data for metrics"}
|
618 |
+
|
619 |
+
google_results = results.get("google_results", "No additional news found through Google Search.")
|
620 |
+
google_snippet = results.get("google_snippet", "Snippet not found.")
|
621 |
+
|
622 |
+
rag_response = results.get("rag_response", "No response from RAG.")
|
623 |
+
|
624 |
+
collected_data = prepare_data(formatted_stock_data, formatted_company_info, formatted_company_news,
|
625 |
+
google_results, formatted_metrics, google_snippet, rag_response)
|
626 |
+
collected_data[
|
627 |
+
"Current Stock Price"] = f"${current_stock_price:.2f}" if current_stock_price is not None else "N/A"
|
628 |
+
|
629 |
+
conversation_history.append(f"RAG Response: {rag_response}")
|
630 |
+
history_context = "\n".join(conversation_history)
|
631 |
+
|
632 |
+
answer = answer_question_with_data(f"{history_context}\n\nUser's query: {translation}", collected_data)
|
633 |
+
|
634 |
+
print(f"\nBot: {answer}")
|
635 |
+
conversation_history.append(f"Bot: {answer}")
|
636 |
+
|
637 |
+
else:
|
638 |
+
with ThreadPoolExecutor() as executor:
|
639 |
+
futures = {
|
640 |
+
executor.submit(get_answer, user_input): "rag_response",
|
641 |
+
executor.submit(summarizer.google_search, user_input): "google_results",
|
642 |
+
executor.submit(summarizer.fetch_google_snippet, user_input): "google_snippet"
|
643 |
+
}
|
644 |
+
results = {futures[future]: future.result() for future in as_completed(futures)}
|
645 |
+
|
646 |
+
google_results = results.get("google_results", "No additional news found through Google Search.")
|
647 |
+
google_snippet = results.get("google_snippet", "Snippet not found.")
|
648 |
+
rag_response = results.get("rag_response", "No response from RAG.")
|
649 |
+
|
650 |
+
collected_data = prepare_data("", "", "", google_results, {}, google_snippet, rag_response)
|
651 |
+
|
652 |
+
conversation_history.append(f"RAG Response: {rag_response}")
|
653 |
+
history_context = "\n".join(conversation_history)
|
654 |
+
|
655 |
+
answer = answer_question_with_data(f"{history_context}\n\nUser's query: {user_input}", collected_data)
|
656 |
+
|
657 |
+
print(f"\nBot: {answer}")
|
658 |
+
conversation_history.append(f"Bot: {answer}")
|
659 |
+
|
660 |
+
except Exception as e:
|
661 |
+
logger.error(f"An error occurred: {e}")
|
662 |
+
response = "An error occurred while processing your request. Please try again later."
|
663 |
+
print(f"Bot: {response}")
|
664 |
+
conversation_history.append(f"Bot: {response}")
|
665 |
+
|
666 |
+
if __name__ == "__main__":
|
667 |
+
main()
|
app.log
ADDED
The diff for this file is too large to render.
See raw diff
|
|
app.py
ADDED
@@ -0,0 +1,544 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flask import Flask, request, jsonify
|
2 |
+
import requests
|
3 |
+
from bs4 import BeautifulSoup
|
4 |
+
import yfinance as yf
|
5 |
+
import pandas as pd
|
6 |
+
from datetime import datetime, timedelta
|
7 |
+
import logging
|
8 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
9 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
10 |
+
from config import Config
|
11 |
+
import numpy as np
|
12 |
+
from typing import Optional, Tuple, List, Dict
|
13 |
+
from rag import get_answer
|
14 |
+
import time
|
15 |
+
from tenacity import retry, stop_after_attempt, wait_exponential
|
16 |
+
import threading
|
17 |
+
import streamlit as st
|
18 |
+
import json
|
19 |
+
|
20 |
+
# Initialize Flask app
|
21 |
+
app = Flask(__name__)
|
22 |
+
|
23 |
+
# Set up logging
|
24 |
+
logging.basicConfig(level=logging.DEBUG,
|
25 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
26 |
+
handlers=[logging.FileHandler("app.log"),
|
27 |
+
logging.StreamHandler()])
|
28 |
+
|
29 |
+
logger = logging.getLogger(__name__)
|
30 |
+
|
31 |
+
# Initialize the Gemini model
|
32 |
+
llm = ChatGoogleGenerativeAI(api_key=Config.GEMINI_API_KEY, model="gemini-1.5-flash-latest", temperature=0.5)
|
33 |
+
|
34 |
+
# Configuration for Google Custom Search API
|
35 |
+
GOOGLE_API_KEY = Config.GOOGLE_API_KEY
|
36 |
+
SEARCH_ENGINE_ID = Config.SEARCH_ENGINE_ID
|
37 |
+
|
38 |
+
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=2, max=8), reraise=True)
|
39 |
+
def invoke_llm(prompt):
|
40 |
+
return llm.invoke(prompt)
|
41 |
+
|
42 |
+
class DataSummarizer:
|
43 |
+
def google_search(self, query: str) -> Optional[str]:
|
44 |
+
try:
|
45 |
+
url = "https://www.googleapis.com/customsearch/v1"
|
46 |
+
params = {
|
47 |
+
'key': GOOGLE_API_KEY,
|
48 |
+
'cx': SEARCH_ENGINE_ID,
|
49 |
+
'q': query
|
50 |
+
}
|
51 |
+
response = requests.get(url, params=params)
|
52 |
+
response.raise_for_status()
|
53 |
+
search_results = response.json()
|
54 |
+
items = search_results.get('items', [])
|
55 |
+
content = "\n\n".join([f"{item.get('title', '')}\n{item.get('snippet', '')}" for item in items])
|
56 |
+
prompt = f"Summarize the following search results:\n\n{content}"
|
57 |
+
summary_response = invoke_llm(prompt)
|
58 |
+
return summary_response.content.strip()
|
59 |
+
except Exception as e:
|
60 |
+
logger.error(f"Error during Google Search API request: {e}")
|
61 |
+
return None
|
62 |
+
|
63 |
+
def extract_content_from_item(self, item: Dict) -> Optional[str]:
|
64 |
+
try:
|
65 |
+
snippet = item.get('snippet', '')
|
66 |
+
title = item.get('title', '')
|
67 |
+
return f"{title}\n{snippet}"
|
68 |
+
except Exception as e:
|
69 |
+
logger.error(f"Error extracting content from item: {e}")
|
70 |
+
return None
|
71 |
+
|
72 |
+
def calculate_moving_average(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.Series]:
|
73 |
+
try:
|
74 |
+
result = df['close'].rolling(window=window).mean()
|
75 |
+
return result
|
76 |
+
except Exception as e:
|
77 |
+
logger.error(f"Error calculating moving average: {e}")
|
78 |
+
return None
|
79 |
+
|
80 |
+
def calculate_rsi(self, df: pd.DataFrame, window: int = 14) -> Optional[pd.Series]:
|
81 |
+
try:
|
82 |
+
delta = df['close'].diff()
|
83 |
+
gain = delta.where(delta > 0, 0).rolling(window=window).mean()
|
84 |
+
loss = -delta.where(delta < 0, 0).rolling(window=window).mean()
|
85 |
+
rs = gain / loss
|
86 |
+
result = 100 - (100 / (1 + rs))
|
87 |
+
return result
|
88 |
+
except Exception as e:
|
89 |
+
logger.error(f"Error calculating RSI: {e}")
|
90 |
+
return None
|
91 |
+
|
92 |
+
def calculate_ema(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.Series]:
|
93 |
+
try:
|
94 |
+
result = df['close'].ewm(span=window, adjust=False).mean()
|
95 |
+
return result
|
96 |
+
except Exception as e:
|
97 |
+
logger.error(f"Error calculating EMA: {e}")
|
98 |
+
return None
|
99 |
+
|
100 |
+
def calculate_bollinger_bands(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.DataFrame]:
|
101 |
+
try:
|
102 |
+
ma = df['close'].rolling(window=window).mean()
|
103 |
+
std = df['close'].rolling(window=window).std()
|
104 |
+
upper_band = ma + (std * 2)
|
105 |
+
lower_band = ma - (std * 2)
|
106 |
+
result = pd.DataFrame({'MA': ma, 'Upper Band': upper_band, 'Lower Band': lower_band})
|
107 |
+
return result
|
108 |
+
except Exception as e:
|
109 |
+
logger.error(f"Error calculating Bollinger Bands: {e}")
|
110 |
+
return None
|
111 |
+
|
112 |
+
def calculate_macd(self, df: pd.DataFrame, short_window: int = 12, long_window: int = 26, signal_window: int = 9) -> Optional[pd.DataFrame]:
|
113 |
+
try:
|
114 |
+
short_ema = df['close'].ewm(span=short_window, adjust=False).mean()
|
115 |
+
long_ema = df['close'].ewm(span=long_window, adjust=False).mean()
|
116 |
+
macd = short_ema - long_ema
|
117 |
+
signal = macd.ewm(span=signal_window, adjust=False).mean()
|
118 |
+
result = pd.DataFrame({'MACD': macd, 'Signal Line': signal})
|
119 |
+
return result
|
120 |
+
except Exception as e:
|
121 |
+
logger.error(f"Error calculating MACD: {e}")
|
122 |
+
return None
|
123 |
+
|
124 |
+
def calculate_volatility(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.Series]:
|
125 |
+
try:
|
126 |
+
log_returns = np.log(df['close'] / df['close'].shift(1))
|
127 |
+
result = log_returns.rolling(window=window).std() * np.sqrt(window)
|
128 |
+
return result
|
129 |
+
except Exception as e:
|
130 |
+
logger.error(f"Error calculating volatility: {e}")
|
131 |
+
return None
|
132 |
+
|
133 |
+
def calculate_atr(self, df: pd.DataFrame, window: int = 14) -> Optional[pd.Series]:
|
134 |
+
try:
|
135 |
+
high_low = df['high'] - df['low']
|
136 |
+
high_close = np.abs(df['high'] - df['close'].shift())
|
137 |
+
low_close = np.abs(df['low'] - df['close'].shift())
|
138 |
+
true_range = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
|
139 |
+
result = true_range.rolling(window=window).mean()
|
140 |
+
return result
|
141 |
+
except Exception as e:
|
142 |
+
logger.error(f"Error calculating ATR: {e}")
|
143 |
+
return None
|
144 |
+
|
145 |
+
def calculate_obv(self, df: pd.DataFrame) -> Optional[pd.Series]:
|
146 |
+
try:
|
147 |
+
result = (np.sign(df['close'].diff()) * df['volume']).fillna(0).cumsum()
|
148 |
+
return result
|
149 |
+
except Exception as e:
|
150 |
+
logger.error(f"Error calculating OBV: {e}")
|
151 |
+
return None
|
152 |
+
|
153 |
+
def calculate_yearly_summary(self, df: pd.DataFrame) -> Optional[pd.DataFrame]:
|
154 |
+
try:
|
155 |
+
df['year'] = pd.to_datetime(df['date']).dt.year
|
156 |
+
yearly_summary = df.groupby('year').agg({
|
157 |
+
'close': ['mean', 'max', 'min'],
|
158 |
+
'volume': 'sum'
|
159 |
+
})
|
160 |
+
yearly_summary.columns = ['_'.join(col) for col in yearly_summary.columns]
|
161 |
+
return yearly_summary
|
162 |
+
except Exception as e:
|
163 |
+
logger.error(f"Error calculating yearly summary: {e}")
|
164 |
+
return None
|
165 |
+
|
166 |
+
def get_full_last_year(self, df: pd.DataFrame) -> Optional[pd.DataFrame]:
|
167 |
+
try:
|
168 |
+
today = datetime.today().date()
|
169 |
+
last_year_start = datetime(today.year - 1, 1, 1).date()
|
170 |
+
last_year_end = datetime(today.year - 1, 12, 31).date()
|
171 |
+
mask = (df['date'] >= last_year_start) & (df['date'] <= last_year_end)
|
172 |
+
result = df.loc[mask]
|
173 |
+
return result
|
174 |
+
except Exception as e:
|
175 |
+
logger.error(f"Error filtering data for the last year: {e}")
|
176 |
+
return None
|
177 |
+
|
178 |
+
def calculate_ytd_performance(self, df: pd.DataFrame) -> Optional[float]:
|
179 |
+
try:
|
180 |
+
today = datetime.today().date()
|
181 |
+
year_start = datetime(today.year, 1, 1).date()
|
182 |
+
mask = (df['date'] >= year_start) & (df['date'] <= today)
|
183 |
+
ytd_data = df.loc[mask]
|
184 |
+
opening_price = ytd_data.iloc[0]['open']
|
185 |
+
closing_price = ytd_data.iloc[-1]['close']
|
186 |
+
result = ((closing_price - opening_price) / opening_price) * 100
|
187 |
+
return result
|
188 |
+
except Exception as e:
|
189 |
+
logger.error(f"Error calculating YTD performance: {e}")
|
190 |
+
return None
|
191 |
+
|
192 |
+
def calculate_pe_ratio(self, current_price: float, eps: float) -> Optional[float]:
|
193 |
+
try:
|
194 |
+
if eps == 0:
|
195 |
+
raise ValueError("EPS cannot be zero for P/E ratio calculation.")
|
196 |
+
result = current_price / eps
|
197 |
+
return result
|
198 |
+
except Exception as e:
|
199 |
+
logger.error(f"Error calculating P/E ratio: {e}")
|
200 |
+
return None
|
201 |
+
|
202 |
+
def fetch_google_snippet(self, query: str) -> Optional[str]:
|
203 |
+
try:
|
204 |
+
search_url = f"https://www.google.com/search?q={query}"
|
205 |
+
headers = {
|
206 |
+
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.36"
|
207 |
+
}
|
208 |
+
response = requests.get(search_url, headers=headers)
|
209 |
+
soup = BeautifulSoup(response.text, 'html.parser')
|
210 |
+
snippet_classes = [
|
211 |
+
'BNeawe iBp4i AP7Wnd',
|
212 |
+
'BNeawe s3v9rd AP7Wnd',
|
213 |
+
'BVG0Nb',
|
214 |
+
'kno-rdesc'
|
215 |
+
]
|
216 |
+
snippet = None
|
217 |
+
for cls in snippet_classes:
|
218 |
+
snippet = soup.find('div', class_=cls)
|
219 |
+
if snippet:
|
220 |
+
break
|
221 |
+
return snippet.get_text() if snippet else "Snippet not found."
|
222 |
+
except Exception as e:
|
223 |
+
logger.error(f"Error fetching Google snippet: {e}")
|
224 |
+
return None
|
225 |
+
|
226 |
+
def extract_ticker_from_response(response: str) -> Optional[str]:
|
227 |
+
try:
|
228 |
+
if "is **" in response and "**." in response:
|
229 |
+
return response.split("is **")[1].split("**.")[0].strip()
|
230 |
+
return response.strip()
|
231 |
+
except Exception as e:
|
232 |
+
logger.error(f"Error extracting ticker from response: {e}")
|
233 |
+
return None
|
234 |
+
|
235 |
+
def detect_translate_entity_and_ticker(query: str) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
|
236 |
+
try:
|
237 |
+
# Step 1: Detect Language
|
238 |
+
prompt = f"Detect the language for the following text: {query}"
|
239 |
+
response = invoke_llm(prompt)
|
240 |
+
detected_language = response.content.strip()
|
241 |
+
|
242 |
+
# Step 2: Translate to English (if necessary)
|
243 |
+
translated_query = query
|
244 |
+
if detected_language != "English":
|
245 |
+
prompt = f"Translate the following text to English: {query}"
|
246 |
+
response = invoke_llm(prompt)
|
247 |
+
translated_query = response.content.strip()
|
248 |
+
|
249 |
+
# Step 3: Detect Entity
|
250 |
+
prompt = f"Detect the entity in the following text that is a company name: {translated_query}"
|
251 |
+
response = invoke_llm(prompt)
|
252 |
+
detected_entity = response.content.strip()
|
253 |
+
|
254 |
+
if not detected_entity:
|
255 |
+
return detected_language, None, translated_query, None
|
256 |
+
|
257 |
+
# Step 4: Get Stock Ticker
|
258 |
+
prompt = f"What is the stock ticker symbol for the company {detected_entity}?"
|
259 |
+
response = invoke_llm(prompt)
|
260 |
+
stock_ticker = extract_ticker_from_response(response.content.strip())
|
261 |
+
|
262 |
+
if not stock_ticker:
|
263 |
+
return detected_language, detected_entity, translated_query, None
|
264 |
+
|
265 |
+
return detected_language, detected_entity, translated_query, stock_ticker
|
266 |
+
except Exception as e:
|
267 |
+
logger.error(f"Error in detecting, translating, or extracting entity and ticker: {e}")
|
268 |
+
return None, None, None, None
|
269 |
+
|
270 |
+
def fetch_stock_data_yahoo(symbol: str) -> pd.DataFrame:
|
271 |
+
try:
|
272 |
+
stock = yf.Ticker(symbol)
|
273 |
+
end_date = datetime.now()
|
274 |
+
start_date = end_date - timedelta(days=3 * 365)
|
275 |
+
historical_data = stock.history(start=start_date, end=end_date)
|
276 |
+
historical_data = historical_data.rename(columns={"Open": "open", "High": "high", "Low": "low", "Close": "close", "Volume": "volume"})
|
277 |
+
historical_data.reset_index(inplace=True)
|
278 |
+
historical_data['date'] = historical_data['Date'].dt.date
|
279 |
+
historical_data = historical_data.drop(columns=['Date'])
|
280 |
+
historical_data = historical_data[['date', 'open', 'high', 'low', 'close', 'volume']]
|
281 |
+
return historical_data
|
282 |
+
except Exception as e:
|
283 |
+
logger.error(f"Failed to fetch stock data for {symbol} from Yahoo Finance: {e}")
|
284 |
+
return pd.DataFrame()
|
285 |
+
|
286 |
+
def fetch_current_stock_price(symbol: str) -> Optional[float]:
|
287 |
+
try:
|
288 |
+
stock = yf.Ticker(symbol)
|
289 |
+
result = stock.info['currentPrice']
|
290 |
+
return result
|
291 |
+
except Exception as e:
|
292 |
+
logger.error(f"Failed to fetch current stock price for {symbol}: {e}")
|
293 |
+
return None
|
294 |
+
|
295 |
+
def format_stock_data_for_gemini(stock_data: pd.DataFrame) -> str:
|
296 |
+
try:
|
297 |
+
if stock_data.empty:
|
298 |
+
return "No historical data available."
|
299 |
+
formatted_data = "Historical stock data for the last three years:\n\n"
|
300 |
+
formatted_data += "Date | Open | High | Low | Close | Volume\n"
|
301 |
+
formatted_data += "------------------------------------------------------\n"
|
302 |
+
for index, row in stock_data.iterrows():
|
303 |
+
formatted_data += f"{row['date']} | {row['open']:.2f} | {row['high']:.2f} | {row['low']:.2f} | {row['close']:.2f} | {int(row['volume'])}\n"
|
304 |
+
return formatted_data
|
305 |
+
except Exception as e:
|
306 |
+
logger.error(f"Error formatting stock data for Gemini: {e}")
|
307 |
+
return "Error formatting stock data."
|
308 |
+
|
309 |
+
def fetch_company_info_yahoo(symbol: str) -> Dict:
|
310 |
+
try:
|
311 |
+
stock = yf.Ticker(symbol)
|
312 |
+
company_info = stock.info
|
313 |
+
return {
|
314 |
+
"name": company_info.get("longName", "N/A"),
|
315 |
+
"sector": company_info.get("sector", "N/A"),
|
316 |
+
"industry": company_info.get("industry", "N/A"),
|
317 |
+
"marketCap": company_info.get("marketCap", "N/A"),
|
318 |
+
"summary": company_info.get("longBusinessSummary", "N/A"),
|
319 |
+
"website": company_info.get("website", "N/A"),
|
320 |
+
"address": company_info.get("address1", "N/A"),
|
321 |
+
"city": company_info.get("city", "N/A"),
|
322 |
+
"state": company_info.get("state", "N/A"),
|
323 |
+
"country": company_info.get("country", "N/A"),
|
324 |
+
"phone": company_info.get("phone", "N/A")
|
325 |
+
}
|
326 |
+
except Exception as e:
|
327 |
+
logger.error(f"Error fetching company info for {symbol}: {e}")
|
328 |
+
return {"error": str(e)}
|
329 |
+
|
330 |
+
def format_company_info_for_gemini(company_info: Dict) -> str:
|
331 |
+
try:
|
332 |
+
if "error" in company_info:
|
333 |
+
return f"Error fetching company info: {company_info['error']}"
|
334 |
+
formatted_info = (f"\nCompany Information:\n"
|
335 |
+
f"Name: {company_info['name']}\n"
|
336 |
+
f"Sector: {company_info['sector']}\n"
|
337 |
+
f"Industry: {company_info['industry']}\n"
|
338 |
+
f"Market Cap: {company_info['marketCap']}\n"
|
339 |
+
f"Summary: {company_info['summary']}\n"
|
340 |
+
f"Website: {company_info['website']}\n"
|
341 |
+
f"Address: {company_info['address']}, {company_info['city']}, {company_info['state']}, {company_info['country']}\n"
|
342 |
+
f"Phone: {company_info['phone']}\n")
|
343 |
+
return formatted_info
|
344 |
+
except Exception as e:
|
345 |
+
logger.error(f"Error formatting company info for Gemini: {e}")
|
346 |
+
return "Error formatting company info."
|
347 |
+
|
348 |
+
def fetch_company_news_yahoo(symbol: str) -> List[Dict]:
|
349 |
+
try:
|
350 |
+
stock = yf.Ticker(symbol)
|
351 |
+
news = stock.news
|
352 |
+
return news if news else []
|
353 |
+
except Exception as e:
|
354 |
+
logger.error(f"Failed to fetch news for {symbol} from Yahoo Finance: {e}")
|
355 |
+
return []
|
356 |
+
|
357 |
+
def format_company_news_for_gemini(news: List[Dict]) -> str:
|
358 |
+
try:
|
359 |
+
if not news:
|
360 |
+
return "No news available."
|
361 |
+
formatted_news = "Latest company news:\n\n"
|
362 |
+
for article in news:
|
363 |
+
formatted_news += (f"Title: {article['title']}\n"
|
364 |
+
f"Publisher: {article['publisher']}\n"
|
365 |
+
f"Link: {article['link']}\n"
|
366 |
+
f"Published: {article['providerPublishTime']}\n\n")
|
367 |
+
return formatted_news
|
368 |
+
except Exception as e:
|
369 |
+
logger.error(f"Error formatting company news for Gemini: {e}")
|
370 |
+
return "Error formatting company news."
|
371 |
+
|
372 |
+
def send_to_gemini_for_summarization(content: str) -> str:
|
373 |
+
try:
|
374 |
+
unified_content = " ".join(content)
|
375 |
+
prompt = f"Summarize the main points of this article.\n\n{unified_content}"
|
376 |
+
response = invoke_llm(prompt)
|
377 |
+
return response.content.strip()
|
378 |
+
except Exception as e:
|
379 |
+
logger.error(f"Error sending content to Gemini for summarization: {e}")
|
380 |
+
return "Error summarizing content."
|
381 |
+
|
382 |
+
def answer_question_with_data(question: str, data: Dict) -> str:
|
383 |
+
try:
|
384 |
+
data_str = ""
|
385 |
+
for key, value in data.items():
|
386 |
+
data_str += f"{key}:\n{value}\n\n"
|
387 |
+
prompt = (f"You are a financial advisor. Begin your answer and only give the answer after.\n"
|
388 |
+
f"Using the following data, answer this question: {question}\n\nData:\n{data_str}\n"
|
389 |
+
f"Make your answer in the best form and professional.\n"
|
390 |
+
f"Don't say anything about the source of the data.\n"
|
391 |
+
f"If you don't have the data to answer, say this data is not available yet. If the data is not available in the stock history data, say this was a weekend and there is no data for it.")
|
392 |
+
response = invoke_llm(prompt)
|
393 |
+
return response.content.strip()
|
394 |
+
except Exception as e:
|
395 |
+
logger.error(f"Error answering question with data: {e}")
|
396 |
+
return "Error answering question."
|
397 |
+
|
398 |
+
def calculate_metrics(stock_data: pd.DataFrame, summarizer: DataSummarizer, company_info: Dict) -> Dict[str, str]:
|
399 |
+
try:
|
400 |
+
moving_average = summarizer.calculate_moving_average(stock_data)
|
401 |
+
rsi = summarizer.calculate_rsi(stock_data)
|
402 |
+
ema = summarizer.calculate_ema(stock_data)
|
403 |
+
bollinger_bands = summarizer.calculate_bollinger_bands(stock_data)
|
404 |
+
macd = summarizer.calculate_macd(stock_data)
|
405 |
+
volatility = summarizer.calculate_volatility(stock_data)
|
406 |
+
atr = summarizer.calculate_atr(stock_data)
|
407 |
+
obv = summarizer.calculate_obv(stock_data)
|
408 |
+
yearly_summary = summarizer.calculate_yearly_summary(stock_data)
|
409 |
+
ytd_performance = summarizer.calculate_ytd_performance(stock_data)
|
410 |
+
eps = company_info.get('trailingEps', None)
|
411 |
+
if eps:
|
412 |
+
current_price = stock_data.iloc[-1]['close']
|
413 |
+
pe_ratio = summarizer.calculate_pe_ratio(current_price, eps)
|
414 |
+
formatted_metrics = {
|
415 |
+
"Moving Average": moving_average.to_string(),
|
416 |
+
"RSI": rsi.to_string(),
|
417 |
+
"EMA": ema.to_string(),
|
418 |
+
"Bollinger Bands": bollinger_bands.to_string(),
|
419 |
+
"MACD": macd.to_string(),
|
420 |
+
"Volatility": volatility.to_string(),
|
421 |
+
"ATR": atr.to_string(),
|
422 |
+
"OBV": obv.to_string(),
|
423 |
+
"Yearly Summary": yearly_summary.to_string(),
|
424 |
+
"YTD Performance": f"{ytd_performance:.2f}%",
|
425 |
+
"P/E Ratio": f"{pe_ratio:.2f}"
|
426 |
+
}
|
427 |
+
else:
|
428 |
+
formatted_metrics = {
|
429 |
+
"Moving Average": moving_average.to_string(),
|
430 |
+
"RSI": rsi.to_string(),
|
431 |
+
"EMA": ema.to_string(),
|
432 |
+
"Bollinger Bands": bollinger_bands.to_string(),
|
433 |
+
"MACD": macd.to_string(),
|
434 |
+
"Volatility": volatility.to_string(),
|
435 |
+
"ATR": atr.to_string(),
|
436 |
+
"OBV": obv.to_string(),
|
437 |
+
"Yearly Summary": yearly_summary.to_string(),
|
438 |
+
"YTD Performance": f"{ytd_performance:.2f}%"
|
439 |
+
}
|
440 |
+
return formatted_metrics
|
441 |
+
except Exception as e:
|
442 |
+
logger.error(f"Error calculating metrics: {e}")
|
443 |
+
return {"Error": "Error calculating metrics"}
|
444 |
+
|
445 |
+
def prepare_data(formatted_stock_data: str, formatted_company_info: str, formatted_company_news: str,
|
446 |
+
google_results: str, formatted_metrics: Dict[str, str], google_snippet: str, rag_response: str) -> Dict[str, str]:
|
447 |
+
collected_data = {
|
448 |
+
"Formatted Stock Data": formatted_stock_data,
|
449 |
+
"Formatted Company Info": formatted_company_info,
|
450 |
+
"Formatted Company News": formatted_company_news,
|
451 |
+
"Google Search Results": google_results,
|
452 |
+
"Google Snippet": google_snippet,
|
453 |
+
"RAG Response": rag_response,
|
454 |
+
"Calculations": formatted_metrics
|
455 |
+
}
|
456 |
+
collected_data.update(formatted_metrics)
|
457 |
+
return collected_data
|
458 |
+
|
459 |
+
@app.route('/ask', methods=['POST'])
|
460 |
+
def ask():
|
461 |
+
try:
|
462 |
+
user_input = request.json.get('question')
|
463 |
+
summarizer = DataSummarizer()
|
464 |
+
language, entity, translation, stock_ticker = detect_translate_entity_and_ticker(user_input)
|
465 |
+
if entity and stock_ticker:
|
466 |
+
with ThreadPoolExecutor() as executor:
|
467 |
+
futures = {
|
468 |
+
executor.submit(fetch_stock_data_yahoo, stock_ticker): "stock_data",
|
469 |
+
executor.submit(fetch_company_info_yahoo, stock_ticker): "company_info",
|
470 |
+
executor.submit(fetch_company_news_yahoo, stock_ticker): "company_news",
|
471 |
+
executor.submit(fetch_current_stock_price, stock_ticker): "current_stock_price",
|
472 |
+
executor.submit(get_answer, user_input): "rag_response",
|
473 |
+
executor.submit(summarizer.google_search, user_input): "google_results",
|
474 |
+
executor.submit(summarizer.fetch_google_snippet, user_input): "google_snippet"
|
475 |
+
}
|
476 |
+
results = {futures[future]: future.result() for future in as_completed(futures)}
|
477 |
+
stock_data = results.get("stock_data", pd.DataFrame())
|
478 |
+
formatted_stock_data = format_stock_data_for_gemini(stock_data) if not stock_data.empty else "No historical data available."
|
479 |
+
company_info = results.get("company_info", {})
|
480 |
+
formatted_company_info = format_company_info_for_gemini(company_info) if company_info else "No company info available."
|
481 |
+
company_news = results.get("company_news", [])
|
482 |
+
formatted_company_news = format_company_news_for_gemini(company_news) if company_news else "No news available."
|
483 |
+
current_stock_price = results.get("current_stock_price", None)
|
484 |
+
formatted_metrics = calculate_metrics(stock_data, summarizer, company_info) if not stock_data.empty else {"Error": "No stock data for metrics"}
|
485 |
+
google_results = results.get("google_results", "No additional news found through Google Search.")
|
486 |
+
google_snippet = results.get("google_snippet", "Snippet not found.")
|
487 |
+
rag_response = results.get("rag_response", "No response from RAG.")
|
488 |
+
collected_data = prepare_data(formatted_stock_data, formatted_company_info, formatted_company_news, google_results, formatted_metrics, google_snippet, rag_response)
|
489 |
+
collected_data["Current Stock Price"] = f"${current_stock_price:.2f}" if current_stock_price is not None else "N/A"
|
490 |
+
answer = answer_question_with_data(f"{translation}", collected_data)
|
491 |
+
return jsonify({"answer": answer})
|
492 |
+
else:
|
493 |
+
with ThreadPoolExecutor() as executor:
|
494 |
+
futures = {
|
495 |
+
executor.submit(get_answer, user_input): "rag_response",
|
496 |
+
executor.submit(summarizer.google_search, user_input): "google_results",
|
497 |
+
executor.submit(summarizer.fetch_google_snippet, user_input): "google_snippet"
|
498 |
+
}
|
499 |
+
results = {futures[future]: future.result() for future in as_completed(futures)}
|
500 |
+
google_results = results.get("google_results", "No additional news found through Google Search.")
|
501 |
+
google_snippet = results.get("google_snippet", "Snippet not found.")
|
502 |
+
rag_response = results.get("rag_response", "No response from RAG.")
|
503 |
+
collected_data = prepare_data("", "", "", google_results, {}, google_snippet, rag_response)
|
504 |
+
answer = answer_question_with_data(f"{user_input}", collected_data)
|
505 |
+
return jsonify({"answer": answer})
|
506 |
+
except Exception as e:
|
507 |
+
logger.error(f"An error occurred: {e}")
|
508 |
+
return jsonify({"error": "An error occurred while processing your request. Please try again later."}), 500
|
509 |
+
|
510 |
+
# Streamlit App
|
511 |
+
def send_question_to_api(question):
|
512 |
+
url = 'http://localhost:5000/ask'
|
513 |
+
headers = {'Content-Type': 'application/json'}
|
514 |
+
data = {'question': question}
|
515 |
+
response = requests.post(url, headers=headers, data=json.dumps(data))
|
516 |
+
if response.status_code == 200:
|
517 |
+
return response.json().get('answer')
|
518 |
+
else:
|
519 |
+
return f"Error: {response.status_code} - {response.text}"
|
520 |
+
|
521 |
+
def run_streamlit():
|
522 |
+
st.title("Financial Data Chatbot Tester")
|
523 |
+
st.write("Enter your question below and get a response from the chatbot.")
|
524 |
+
if 'history' not in st.session_state:
|
525 |
+
st.session_state.history = []
|
526 |
+
user_input = st.text_input("Your question:", "")
|
527 |
+
if st.button("Submit"):
|
528 |
+
if user_input:
|
529 |
+
with st.spinner('Getting the answer...'):
|
530 |
+
answer = send_question_to_api(user_input)
|
531 |
+
st.session_state.history.append((user_input, answer))
|
532 |
+
st.success(answer)
|
533 |
+
else:
|
534 |
+
st.warning("Please enter a question before submitting.")
|
535 |
+
if st.session_state.history:
|
536 |
+
st.write("### History")
|
537 |
+
for idx, (question, answer) in enumerate(st.session_state.history, 1):
|
538 |
+
st.write(f"**Q{idx}:** {question}")
|
539 |
+
st.write(f"**A{idx}:** {answer}")
|
540 |
+
st.write("---")
|
541 |
+
|
542 |
+
if __name__ == '__main__':
|
543 |
+
threading.Thread(target=lambda: app.run(host='0.0.0.0', port=5000)).start()
|
544 |
+
run_streamlit()
|
app2.py
ADDED
@@ -0,0 +1,560 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from bs4 import BeautifulSoup
|
3 |
+
import yfinance as yf
|
4 |
+
import pandas as pd
|
5 |
+
from datetime import datetime, timedelta
|
6 |
+
import logging
|
7 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
8 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
9 |
+
from config import Config
|
10 |
+
import numpy as np
|
11 |
+
from typing import Optional, Tuple, List, Dict
|
12 |
+
from rag import get_answer
|
13 |
+
|
14 |
+
# Set up logging
|
15 |
+
logging.basicConfig(level=logging.DEBUG,
|
16 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
17 |
+
handlers=[logging.FileHandler("app.log"),
|
18 |
+
logging.StreamHandler()])
|
19 |
+
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
|
22 |
+
# Initialize the Gemini model
|
23 |
+
llm = ChatGoogleGenerativeAI(api_key=Config.GEMINI_API_KEY, model="gemini-1.5-flash-latest", temperature=0.5)
|
24 |
+
|
25 |
+
# Configuration for Google Custom Search API
|
26 |
+
GOOGLE_API_KEY = Config.GOOGLE_API_KEY
|
27 |
+
SEARCH_ENGINE_ID = Config.SEARCH_ENGINE_ID
|
28 |
+
|
29 |
+
def fetch_google_snippet(query: str) -> Optional[str]:
|
30 |
+
try:
|
31 |
+
search_url = f"https://www.google.com/search?q={query}"
|
32 |
+
headers = {
|
33 |
+
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.36"
|
34 |
+
}
|
35 |
+
response = requests.get(search_url, headers=headers)
|
36 |
+
soup = BeautifulSoup(response.text, 'html.parser')
|
37 |
+
snippet_classes = [
|
38 |
+
'BNeawe iBp4i AP7Wnd',
|
39 |
+
'BNeawe s3v9rd AP7Wnd',
|
40 |
+
'BVG0Nb',
|
41 |
+
'kno-rdesc'
|
42 |
+
]
|
43 |
+
for cls in snippet_classes:
|
44 |
+
snippet = soup.find('div', class_=cls)
|
45 |
+
if snippet:
|
46 |
+
return snippet.get_text()
|
47 |
+
return "Snippet not found."
|
48 |
+
except Exception as e:
|
49 |
+
logger.error(f"Error fetching Google snippet: {e}")
|
50 |
+
return None
|
51 |
+
|
52 |
+
class DataSummarizer:
|
53 |
+
def __init__(self):
|
54 |
+
pass
|
55 |
+
|
56 |
+
def google_search(self, query: str) -> Optional[Dict]:
|
57 |
+
try:
|
58 |
+
url = "https://www.googleapis.com/customsearch/v1"
|
59 |
+
params = {
|
60 |
+
'key': GOOGLE_API_KEY,
|
61 |
+
'cx': SEARCH_ENGINE_ID,
|
62 |
+
'q': query
|
63 |
+
}
|
64 |
+
response = requests.get(url, params=params)
|
65 |
+
response.raise_for_status()
|
66 |
+
return response.json()
|
67 |
+
except Exception as e:
|
68 |
+
logger.error(f"Error during Google Search API request: {e}")
|
69 |
+
return None
|
70 |
+
|
71 |
+
def extract_content_from_item(self, item: Dict) -> Optional[str]:
|
72 |
+
try:
|
73 |
+
snippet = item.get('snippet', '')
|
74 |
+
title = item.get('title', '')
|
75 |
+
return f"{title}\n{snippet}"
|
76 |
+
except Exception as e:
|
77 |
+
logger.error(f"Error extracting content from item: {e}")
|
78 |
+
return None
|
79 |
+
|
80 |
+
def calculate_moving_average(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.Series]:
|
81 |
+
try:
|
82 |
+
return df['close'].rolling(window=window).mean()
|
83 |
+
except Exception as e:
|
84 |
+
logger.error(f"Error calculating moving average: {e}")
|
85 |
+
return None
|
86 |
+
|
87 |
+
def calculate_rsi(self, df: pd.DataFrame, window: int = 14) -> Optional[pd.Series]:
|
88 |
+
try:
|
89 |
+
delta = df['close'].diff()
|
90 |
+
gain = delta.where(delta > 0, 0).rolling(window=window).mean()
|
91 |
+
loss = -delta.where(delta < 0, 0).rolling(window=window).mean()
|
92 |
+
rs = gain / loss
|
93 |
+
return 100 - (100 / (1 + rs))
|
94 |
+
except Exception as e:
|
95 |
+
logger.error(f"Error calculating RSI: {e}")
|
96 |
+
return None
|
97 |
+
|
98 |
+
def calculate_ema(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.Series]:
|
99 |
+
try:
|
100 |
+
return df['close'].ewm(span=window, adjust=False).mean()
|
101 |
+
except Exception as e:
|
102 |
+
logger.error(f"Error calculating EMA: {e}")
|
103 |
+
return None
|
104 |
+
|
105 |
+
def calculate_bollinger_bands(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.DataFrame]:
|
106 |
+
try:
|
107 |
+
ma = df['close'].rolling(window=window).mean()
|
108 |
+
std = df['close'].rolling(window=window).std()
|
109 |
+
upper_band = ma + (std * 2)
|
110 |
+
lower_band = ma - (std * 2)
|
111 |
+
return pd.DataFrame({'MA': ma, 'Upper Band': upper_band, 'Lower Band': lower_band})
|
112 |
+
except Exception as e:
|
113 |
+
logger.error(f"Error calculating Bollinger Bands: {e}")
|
114 |
+
return None
|
115 |
+
|
116 |
+
def calculate_macd(self, df: pd.DataFrame, short_window: int = 12, long_window: int = 26, signal_window: int = 9) -> \
|
117 |
+
Optional[pd.DataFrame]:
|
118 |
+
try:
|
119 |
+
short_ema = df['close'].ewm(span=short_window, adjust=False).mean()
|
120 |
+
long_ema = df['close'].ewm(span=long_window, adjust=False).mean()
|
121 |
+
macd = short_ema - long_ema
|
122 |
+
signal = macd.ewm(span=signal_window, adjust=False).mean()
|
123 |
+
return pd.DataFrame({'MACD': macd, 'Signal Line': signal})
|
124 |
+
except Exception as e:
|
125 |
+
logger.error(f"Error calculating MACD: {e}")
|
126 |
+
return None
|
127 |
+
|
128 |
+
def calculate_volatility(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.Series]:
|
129 |
+
try:
|
130 |
+
log_returns = np.log(df['close'] / df['close'].shift(1))
|
131 |
+
return log_returns.rolling(window=window).std() * np.sqrt(window)
|
132 |
+
except Exception as e:
|
133 |
+
logger.error(f"Error calculating volatility: {e}")
|
134 |
+
return None
|
135 |
+
|
136 |
+
def calculate_atr(self, df: pd.DataFrame, window: int = 14) -> Optional[pd.Series]:
|
137 |
+
try:
|
138 |
+
high_low = df['high'] - df['low']
|
139 |
+
high_close = np.abs(df['high'] - df['close'].shift())
|
140 |
+
low_close = np.abs(df['low'] - df['close'].shift())
|
141 |
+
true_range = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
|
142 |
+
return true_range.rolling(window=window).mean()
|
143 |
+
except Exception as e:
|
144 |
+
logger.error(f"Error calculating ATR: {e}")
|
145 |
+
return None
|
146 |
+
|
147 |
+
def calculate_obv(self, df: pd.DataFrame) -> Optional[pd.Series]:
|
148 |
+
try:
|
149 |
+
return (np.sign(df['close'].diff()) * df['volume']).fillna(0).cumsum()
|
150 |
+
except Exception as e:
|
151 |
+
logger.error(f"Error calculating OBV: {e}")
|
152 |
+
return None
|
153 |
+
|
154 |
+
def calculate_yearly_summary(self, df: pd.DataFrame) -> Optional[pd.DataFrame]:
|
155 |
+
try:
|
156 |
+
df['year'] = pd.to_datetime(df['date']).dt.year
|
157 |
+
yearly_summary = df.groupby('year').agg({
|
158 |
+
'close': ['mean', 'max', 'min'],
|
159 |
+
'volume': 'sum'
|
160 |
+
})
|
161 |
+
yearly_summary.columns = ['_'.join(col) for col in yearly_summary.columns]
|
162 |
+
return yearly_summary
|
163 |
+
except Exception as e:
|
164 |
+
logger.error(f"Error calculating yearly summary: {e}")
|
165 |
+
return None
|
166 |
+
|
167 |
+
def get_full_last_year(self, df: pd.DataFrame) -> Optional[pd.DataFrame]:
|
168 |
+
try:
|
169 |
+
today = datetime.today().date()
|
170 |
+
last_year_start = datetime(today.year - 1, 1, 1).date()
|
171 |
+
last_year_end = datetime(today.year - 1, 12, 31).date()
|
172 |
+
mask = (df['date'] >= last_year_start) & (df['date'] <= last_year_end)
|
173 |
+
return df.loc[mask]
|
174 |
+
except Exception as e:
|
175 |
+
logger.error(f"Error filtering data for the last year: {e}")
|
176 |
+
return None
|
177 |
+
|
178 |
+
def calculate_ytd_performance(self, df: pd.DataFrame) -> Optional[float]:
|
179 |
+
try:
|
180 |
+
today = datetime.today().date()
|
181 |
+
year_start = datetime(today.year, 1, 1).date()
|
182 |
+
mask = (df['date'] >= year_start) & (df['date'] <= today)
|
183 |
+
ytd_data = df.loc[mask]
|
184 |
+
opening_price = ytd_data.iloc[0]['open']
|
185 |
+
closing_price = ytd_data.iloc[-1]['close']
|
186 |
+
return ((closing_price - opening_price) / opening_price) * 100
|
187 |
+
except Exception as e:
|
188 |
+
logger.error(f"Error calculating YTD performance: {e}")
|
189 |
+
return None
|
190 |
+
|
191 |
+
def calculate_pe_ratio(self, current_price: float, eps: float) -> Optional[float]:
|
192 |
+
try:
|
193 |
+
if eps == 0:
|
194 |
+
raise ValueError("EPS cannot be zero for P/E ratio calculation.")
|
195 |
+
return current_price / eps
|
196 |
+
except Exception as e:
|
197 |
+
logger.error(f"Error calculating P/E ratio: {e}")
|
198 |
+
return None
|
199 |
+
|
200 |
+
def fetch_google_snippet(self, query: str) -> Optional[str]:
|
201 |
+
try:
|
202 |
+
return fetch_google_snippet(query)
|
203 |
+
except Exception as e:
|
204 |
+
logger.error(f"Error fetching Google snippet: {e}")
|
205 |
+
return None
|
206 |
+
|
207 |
+
def extract_ticker_from_response(response: str) -> Optional[str]:
|
208 |
+
try:
|
209 |
+
if "is **" in response and "**." in response:
|
210 |
+
return response.split("is **")[1].split("**.")[0].strip()
|
211 |
+
return response.strip()
|
212 |
+
except Exception as e:
|
213 |
+
logger.error(f"Error extracting ticker from response: {e}")
|
214 |
+
return None
|
215 |
+
|
216 |
+
def detect_translate_entity_and_ticker(query: str) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
|
217 |
+
try:
|
218 |
+
prompt = f"Detect the language for the following text: {query}"
|
219 |
+
response = llm.invoke(prompt)
|
220 |
+
detected_language = response.content.strip()
|
221 |
+
|
222 |
+
translated_query = query
|
223 |
+
if detected_language != "English":
|
224 |
+
prompt = f"Translate the following text to English: {query}"
|
225 |
+
response = llm.invoke(prompt)
|
226 |
+
translated_query = response.content.strip()
|
227 |
+
|
228 |
+
prompt = f"Detect the entity in the following text that is a company name: {translated_query}"
|
229 |
+
response = llm.invoke(prompt)
|
230 |
+
detected_entity = response.content.strip()
|
231 |
+
|
232 |
+
prompt = f"What is the stock ticker symbol for the company {detected_entity}?"
|
233 |
+
response = llm.invoke(prompt)
|
234 |
+
stock_ticker = extract_ticker_from_response(response.content.strip())
|
235 |
+
|
236 |
+
return detected_language, detected_entity, translated_query, stock_ticker
|
237 |
+
except Exception as e:
|
238 |
+
logger.error(f"Error in detecting, translating, or extracting entity and ticker: {e}")
|
239 |
+
return None, None, None, None
|
240 |
+
|
241 |
+
def fetch_stock_data_yahoo(symbol: str) -> pd.DataFrame:
|
242 |
+
try:
|
243 |
+
stock = yf.Ticker(symbol)
|
244 |
+
logger.info(f"Fetching data for symbol: {symbol}")
|
245 |
+
|
246 |
+
end_date = datetime.now()
|
247 |
+
start_date = end_date - timedelta(days=3 * 365)
|
248 |
+
|
249 |
+
historical_data = stock.history(start=start_date, end=end_date)
|
250 |
+
if historical_data.empty:
|
251 |
+
raise ValueError(f"No historical data found for symbol: {symbol}")
|
252 |
+
|
253 |
+
historical_data = historical_data.rename(
|
254 |
+
columns={"Open": "open", "High": "high", "Low": "low", "Close": "close", "Volume": "volume"}
|
255 |
+
)
|
256 |
+
|
257 |
+
historical_data.reset_index(inplace=True)
|
258 |
+
historical_data['date'] = historical_data['Date'].dt.date
|
259 |
+
historical_data = historical_data.drop(columns=['Date'])
|
260 |
+
historical_data = historical_data[['date', 'open', 'high', 'low', 'close', 'volume']]
|
261 |
+
|
262 |
+
if 'close' not in historical_data.columns:
|
263 |
+
raise KeyError("The historical data must contain a 'close' column.")
|
264 |
+
|
265 |
+
return historical_data
|
266 |
+
except Exception as e:
|
267 |
+
logger.error(f"Failed to fetch stock data for {symbol} from Yahoo Finance: {e}")
|
268 |
+
return pd.DataFrame()
|
269 |
+
|
270 |
+
def fetch_current_stock_price(symbol: str) -> Optional[float]:
|
271 |
+
try:
|
272 |
+
stock = yf.Ticker(symbol)
|
273 |
+
return stock.info['currentPrice']
|
274 |
+
except Exception as e:
|
275 |
+
logger.error(f"Failed to fetch current stock price for {symbol}: {e}")
|
276 |
+
return None
|
277 |
+
|
278 |
+
def format_stock_data_for_gemini(stock_data: pd.DataFrame) -> str:
|
279 |
+
try:
|
280 |
+
if stock_data.empty:
|
281 |
+
return "No historical data available."
|
282 |
+
|
283 |
+
formatted_data = "Historical stock data for the last three years:\n\n"
|
284 |
+
formatted_data += "Date | Open | High | Low | Close | Volume\n"
|
285 |
+
formatted_data += "------------------------------------------------------\n"
|
286 |
+
|
287 |
+
for index, row in stock_data.iterrows():
|
288 |
+
formatted_data += f"{row['date']} | {row['open']:.2f} | {row['high']:.2f} | {row['low']:.2f} | {row['close']:.2f} | {int(row['volume'])}\n"
|
289 |
+
|
290 |
+
return formatted_data
|
291 |
+
except Exception as e:
|
292 |
+
logger.error(f"Error formatting stock data for Gemini: {e}")
|
293 |
+
return "Error formatting stock data."
|
294 |
+
|
295 |
+
def fetch_company_info_yahoo(symbol: str) -> Dict:
|
296 |
+
try:
|
297 |
+
if not symbol:
|
298 |
+
return {"error": "Invalid symbol"}
|
299 |
+
|
300 |
+
stock = yf.Ticker(symbol)
|
301 |
+
company_info = stock.info
|
302 |
+
return {
|
303 |
+
"name": company_info.get("longName", "N/A"),
|
304 |
+
"sector": company_info.get("sector", "N/A"),
|
305 |
+
"industry": company_info.get("industry", "N/A"),
|
306 |
+
"marketCap": company_info.get("marketCap", "N/A"),
|
307 |
+
"summary": company_info.get("longBusinessSummary", "N/A"),
|
308 |
+
"website": company_info.get("website", "N/A"),
|
309 |
+
"address": company_info.get("address1", "N/A"),
|
310 |
+
"city": company_info.get("city", "N/A"),
|
311 |
+
"state": company_info.get("state", "N/A"),
|
312 |
+
"country": company_info.get("country", "N/A"),
|
313 |
+
"phone": company_info.get("phone", "N/A")
|
314 |
+
}
|
315 |
+
except Exception as e:
|
316 |
+
logger.error(f"Error fetching company info for {symbol}: {e}")
|
317 |
+
return {"error": str(e)}
|
318 |
+
|
319 |
+
def format_company_info_for_gemini(company_info: Dict) -> str:
|
320 |
+
try:
|
321 |
+
if "error" in company_info:
|
322 |
+
return f"Error fetching company info: {company_info['error']}"
|
323 |
+
|
324 |
+
formatted_info = (f"\nCompany Information:\n"
|
325 |
+
f"Name: {company_info['name']}\n"
|
326 |
+
f"Sector: {company_info['sector']}\n"
|
327 |
+
f"Industry: {company_info['industry']}\n"
|
328 |
+
f"Market Cap: {company_info['marketCap']}\n"
|
329 |
+
f"Summary: {company_info['summary']}\n"
|
330 |
+
f"Website: {company_info['website']}\n"
|
331 |
+
f"Address: {company_info['address']}, {company_info['city']}, {company_info['state']}, {company_info['country']}\n"
|
332 |
+
f"Phone: {company_info['phone']}\n")
|
333 |
+
|
334 |
+
return formatted_info
|
335 |
+
except Exception as e:
|
336 |
+
logger.error(f"Error formatting company info for Gemini: {e}")
|
337 |
+
return "Error formatting company info."
|
338 |
+
|
339 |
+
def fetch_company_news_yahoo(symbol: str) -> List[Dict]:
|
340 |
+
try:
|
341 |
+
stock = yf.Ticker(symbol)
|
342 |
+
news = stock.news
|
343 |
+
if not news:
|
344 |
+
raise ValueError(f"No news found for symbol: {symbol}")
|
345 |
+
return news
|
346 |
+
except Exception as e:
|
347 |
+
logger.error(f"Failed to fetch news for {symbol} from Yahoo Finance: {e}")
|
348 |
+
return []
|
349 |
+
|
350 |
+
def format_company_news_for_gemini(news: List[Dict]) -> str:
|
351 |
+
try:
|
352 |
+
if not news:
|
353 |
+
return "No news available."
|
354 |
+
|
355 |
+
formatted_news = "Latest company news:\n\n"
|
356 |
+
for article in news:
|
357 |
+
formatted_news += (f"Title: {article['title']}\n"
|
358 |
+
f"Publisher: {article['publisher']}\n"
|
359 |
+
f"Link: {article['link']}\n"
|
360 |
+
f"Published: {article['providerPublishTime']}\n\n")
|
361 |
+
|
362 |
+
return formatted_news
|
363 |
+
except Exception as e:
|
364 |
+
logger.error(f"Error formatting company news for Gemini: {e}")
|
365 |
+
return "Error formatting company news."
|
366 |
+
|
367 |
+
def send_to_gemini_for_summarization(content: str) -> str:
|
368 |
+
try:
|
369 |
+
unified_content = " ".join(content)
|
370 |
+
prompt = f"Summarize the main points of this article.\n\n{unified_content}"
|
371 |
+
response = llm.invoke(prompt)
|
372 |
+
return response.content.strip()
|
373 |
+
except Exception as e:
|
374 |
+
logger.error(f"Error sending content to Gemini for summarization: {e}")
|
375 |
+
return "Error summarizing content."
|
376 |
+
|
377 |
+
def answer_question_with_data(question: str, data: Dict) -> str:
|
378 |
+
try:
|
379 |
+
data_str = ""
|
380 |
+
for key, value in data.items():
|
381 |
+
data_str += f"{key}:\n{value}\n\n"
|
382 |
+
|
383 |
+
prompt = (f"You are a financial advisor. Begin your answer by stating that and only give the answer after.\n"
|
384 |
+
f"Using the following data, answer this question: {question}\n\nData:\n{data_str}\n"
|
385 |
+
f"Make your answer in the best form and professional.\n"
|
386 |
+
f"Don't say anything about the source of the data.\n"
|
387 |
+
f"If you don't have the data to answer, say this data is not available yet. If the data is not available in the stock history data, say this was a weekend and there is no data for it.")
|
388 |
+
response = llm.invoke(prompt)
|
389 |
+
return response.content.strip()
|
390 |
+
except Exception as e:
|
391 |
+
logger.error(f"Error answering question with data: {e}")
|
392 |
+
return "Error answering question."
|
393 |
+
|
394 |
+
def format_google_results(google_results: Optional[Dict], summarizer: DataSummarizer, query: str) -> str:
|
395 |
+
try:
|
396 |
+
if google_results:
|
397 |
+
google_content = [summarizer.extract_content_from_item(item) for item in google_results.get('items', [])]
|
398 |
+
formatted_google_content = "\n\n".join(google_content)
|
399 |
+
else:
|
400 |
+
formatted_google_content = "No additional news found through Google Search."
|
401 |
+
|
402 |
+
snippet_query1 = f"{query} I want the answer only"
|
403 |
+
snippet_query2 = f"{query}"
|
404 |
+
|
405 |
+
google_snippet1 = summarizer.fetch_google_snippet(snippet_query1)
|
406 |
+
google_snippet2 = summarizer.fetch_google_snippet(snippet_query2)
|
407 |
+
|
408 |
+
google_snippet = google_snippet1 if google_snippet1 and google_snippet1 != "Snippet not found." else google_snippet2
|
409 |
+
formatted_google_content += f"\n\nGoogle Snippet: {google_snippet}"
|
410 |
+
|
411 |
+
return formatted_google_content
|
412 |
+
except Exception as e:
|
413 |
+
logger.error(f"Error formatting Google results: {e}")
|
414 |
+
return "Error formatting Google results."
|
415 |
+
|
416 |
+
def calculate_metrics(stock_data: pd.DataFrame, summarizer: DataSummarizer, company_info: Dict) -> Dict[str, str]:
|
417 |
+
try:
|
418 |
+
moving_average = summarizer.calculate_moving_average(stock_data)
|
419 |
+
rsi = summarizer.calculate_rsi(stock_data)
|
420 |
+
ema = summarizer.calculate_ema(stock_data)
|
421 |
+
bollinger_bands = summarizer.calculate_bollinger_bands(stock_data)
|
422 |
+
macd = summarizer.calculate_macd(stock_data)
|
423 |
+
volatility = summarizer.calculate_volatility(stock_data)
|
424 |
+
atr = summarizer.calculate_atr(stock_data)
|
425 |
+
obv = summarizer.calculate_obv(stock_data)
|
426 |
+
yearly_summary = summarizer.calculate_yearly_summary(stock_data)
|
427 |
+
ytd_performance = summarizer.calculate_ytd_performance(stock_data)
|
428 |
+
|
429 |
+
eps = company_info.get('trailingEps', None)
|
430 |
+
if eps:
|
431 |
+
current_price = stock_data.iloc[-1]['close']
|
432 |
+
pe_ratio = summarizer.calculate_pe_ratio(current_price, eps)
|
433 |
+
formatted_metrics = {
|
434 |
+
"Moving Average": moving_average.to_string(),
|
435 |
+
"RSI": rsi.to_string(),
|
436 |
+
"EMA": ema.to_string(),
|
437 |
+
"Bollinger Bands": bollinger_bands.to_string(),
|
438 |
+
"MACD": macd.to_string(),
|
439 |
+
"Volatility": volatility.to_string(),
|
440 |
+
"ATR": atr.to_string(),
|
441 |
+
"OBV": obv.to_string(),
|
442 |
+
"Yearly Summary": yearly_summary.to_string(),
|
443 |
+
"YTD Performance": f"{ytd_performance:.2f}%",
|
444 |
+
"P/E Ratio": f"{pe_ratio:.2f}"
|
445 |
+
}
|
446 |
+
else:
|
447 |
+
formatted_metrics = {
|
448 |
+
"Moving Average": moving_average.to_string(),
|
449 |
+
"RSI": rsi.to_string(),
|
450 |
+
"EMA": ema.to_string(),
|
451 |
+
"Bollinger Bands": bollinger_bands.to_string(),
|
452 |
+
"MACD": macd.to_string(),
|
453 |
+
"Volatility": volatility.to_string(),
|
454 |
+
"ATR": atr.to_string(),
|
455 |
+
"OBV": obv.to_string(),
|
456 |
+
"Yearly Summary": yearly_summary.to_string(),
|
457 |
+
"YTD Performance": f"{ytd_performance:.2f}%"
|
458 |
+
}
|
459 |
+
|
460 |
+
return formatted_metrics
|
461 |
+
except Exception as e:
|
462 |
+
logger.error(f"Error calculating metrics: {e}")
|
463 |
+
return {"Error": "Error calculating metrics"}
|
464 |
+
|
465 |
+
def prepare_data(formatted_stock_data: str, formatted_company_info: str, formatted_company_news: str,
|
466 |
+
summarized_google_content: str, formatted_metrics: Dict[str, str]) -> Dict[str, str]:
|
467 |
+
collected_data = {
|
468 |
+
"Formatted Stock Data": formatted_stock_data,
|
469 |
+
"Formatted Company Info": formatted_company_info,
|
470 |
+
"Formatted Company News": formatted_company_news,
|
471 |
+
"Google Search Results": summarized_google_content,
|
472 |
+
"Calculations": formatted_metrics
|
473 |
+
}
|
474 |
+
collected_data.update(formatted_metrics)
|
475 |
+
return collected_data
|
476 |
+
|
477 |
+
def translate_response(response: str, target_language: str) -> str:
|
478 |
+
try:
|
479 |
+
prompt = f"Translate the following text to {target_language}: {response}"
|
480 |
+
translation = llm.invoke(prompt)
|
481 |
+
return translation.content.strip()
|
482 |
+
except Exception as e:
|
483 |
+
logger.error(f"Error translating response: {e}")
|
484 |
+
return response # Return the original response if translation fails
|
485 |
+
|
486 |
+
def main():
|
487 |
+
print("Welcome to the Financial Data Chatbot. How can I assist you today?")
|
488 |
+
|
489 |
+
summarizer = DataSummarizer()
|
490 |
+
conversation_history = []
|
491 |
+
|
492 |
+
while True:
|
493 |
+
user_input = input("You: ")
|
494 |
+
|
495 |
+
if user_input.lower() in ['exit', 'quit', 'bye']:
|
496 |
+
print("Goodbye! Have a great day!")
|
497 |
+
break
|
498 |
+
|
499 |
+
conversation_history.append(f"You: {user_input}")
|
500 |
+
|
501 |
+
try:
|
502 |
+
# Detect language, entity, translation, and stock ticker
|
503 |
+
language, entity, translation, stock_ticker = detect_translate_entity_and_ticker(user_input)
|
504 |
+
|
505 |
+
if language and entity and translation and stock_ticker:
|
506 |
+
with ThreadPoolExecutor() as executor:
|
507 |
+
futures = {
|
508 |
+
executor.submit(fetch_stock_data_yahoo, stock_ticker): "stock_data",
|
509 |
+
executor.submit(fetch_company_info_yahoo, stock_ticker): "company_info",
|
510 |
+
executor.submit(fetch_company_news_yahoo, stock_ticker): "company_news",
|
511 |
+
executor.submit(fetch_current_stock_price, stock_ticker): "current_stock_price",
|
512 |
+
executor.submit(summarizer.google_search, f"{user_input} latest financial news"): "google_results"
|
513 |
+
}
|
514 |
+
results = {futures[future]: future.result() for future in as_completed(futures)}
|
515 |
+
|
516 |
+
stock_data = results["stock_data"]
|
517 |
+
formatted_stock_data = format_stock_data_for_gemini(stock_data)
|
518 |
+
company_info = results["company_info"]
|
519 |
+
formatted_company_info = format_company_info_for_gemini(company_info)
|
520 |
+
company_news = results["company_news"]
|
521 |
+
formatted_company_news = format_company_news_for_gemini(company_news)
|
522 |
+
current_stock_price = results["current_stock_price"]
|
523 |
+
|
524 |
+
google_results = results["google_results"]
|
525 |
+
formatted_google_content = format_google_results(google_results, summarizer, user_input)
|
526 |
+
summarized_google_content = send_to_gemini_for_summarization(formatted_google_content)
|
527 |
+
|
528 |
+
formatted_metrics = calculate_metrics(stock_data, summarizer, company_info)
|
529 |
+
|
530 |
+
collected_data = prepare_data(formatted_stock_data, formatted_company_info, formatted_company_news,
|
531 |
+
summarized_google_content, formatted_metrics)
|
532 |
+
collected_data["Current Stock Price"] = f"${current_stock_price:.2f}" if current_stock_price else "N/A"
|
533 |
+
|
534 |
+
rag_response = get_answer(user_input)
|
535 |
+
collected_data["RAG Response"] = rag_response
|
536 |
+
|
537 |
+
conversation_history.append(f"RAG Response: {rag_response}")
|
538 |
+
history_context = "\n".join(conversation_history)
|
539 |
+
|
540 |
+
answer = answer_question_with_data(f"{history_context}\n\nUser's query: {user_input}", collected_data)
|
541 |
+
|
542 |
+
if language != "English":
|
543 |
+
answer = translate_response(answer, language)
|
544 |
+
|
545 |
+
print(f"\nBot: {answer}")
|
546 |
+
conversation_history.append(f"Bot: {answer}")
|
547 |
+
|
548 |
+
else:
|
549 |
+
response = "I'm sorry, I couldn't process your request. Could you please rephrase?"
|
550 |
+
print(f"Bot: {response}")
|
551 |
+
conversation_history.append(f"Bot: {response}")
|
552 |
+
|
553 |
+
except Exception as e:
|
554 |
+
logger.error(f"An error occurred: {e}")
|
555 |
+
response = "An error occurred while processing your request. Please try again later."
|
556 |
+
print(f"Bot: {response}")
|
557 |
+
conversation_history.append(f"Bot: {response}")
|
558 |
+
|
559 |
+
if __name__ == "__main__":
|
560 |
+
main()
|
bm25retriever.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:df285be2ae20135ec5219dd34edf52abe2c630b6372f33f1502e48fd52042526
|
3 |
+
size 4215997
|
chain.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.prompts import (
|
2 |
+
SystemMessagePromptTemplate,
|
3 |
+
HumanMessagePromptTemplate,
|
4 |
+
ChatPromptTemplate,
|
5 |
+
MessagesPlaceholder
|
6 |
+
)
|
7 |
+
from langchain.chains import ConversationChain
|
8 |
+
|
9 |
+
class Chain:
|
10 |
+
def __init__(self, llm, history=None):
|
11 |
+
self.llm = llm
|
12 |
+
# self.chain = self.get_conversational_chain()
|
13 |
+
if history is not None:
|
14 |
+
self.history = history
|
15 |
+
|
16 |
+
def run_conversational_chain(self, prompt_template):
|
17 |
+
|
18 |
+
ans = self.llm.invoke(prompt_template).content
|
19 |
+
|
20 |
+
return ans
|
21 |
+
|
22 |
+
def get_chain_with_history(self):
|
23 |
+
system_msg_template = SystemMessagePromptTemplate.from_template(template="""Answer the question as truthfully as possible using the provided context,
|
24 |
+
and if the answer is not contained within the text below, say 'I don't know'""")
|
25 |
+
human_msg_template = HumanMessagePromptTemplate.from_template(template="{input}")
|
26 |
+
prompt_template = ChatPromptTemplate.from_messages([system_msg_template, MessagesPlaceholder(variable_name="history"), human_msg_template])
|
27 |
+
conversation = ConversationChain(memory=self.history, prompt=prompt_template, llm=self.llm, verbose=True)
|
28 |
+
return conversation
|
chat.py
ADDED
@@ -0,0 +1,667 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from bs4 import BeautifulSoup
|
3 |
+
import yfinance as yf
|
4 |
+
import pandas as pd
|
5 |
+
from datetime import datetime, timedelta
|
6 |
+
import logging
|
7 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
8 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
9 |
+
from config import Config
|
10 |
+
import numpy as np
|
11 |
+
from typing import Optional, Tuple, List, Dict
|
12 |
+
from rag import get_answer
|
13 |
+
import time
|
14 |
+
from tenacity import retry, stop_after_attempt, wait_exponential
|
15 |
+
|
16 |
+
# Set up logging
|
17 |
+
logging.basicConfig(level=logging.DEBUG,
|
18 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
19 |
+
handlers=[logging.FileHandler("app.log"),
|
20 |
+
logging.StreamHandler()])
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
# Initialize the Gemini model
|
25 |
+
llm = ChatGoogleGenerativeAI(api_key=Config.GEMINI_API_KEY, model="gemini-1.5-flash-latest", temperature=0.5)
|
26 |
+
|
27 |
+
# Configuration for Google Custom Search API
|
28 |
+
GOOGLE_API_KEY = Config.GOOGLE_API_KEY
|
29 |
+
SEARCH_ENGINE_ID = Config.SEARCH_ENGINE_ID
|
30 |
+
|
31 |
+
|
32 |
+
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=2, max=8), reraise=True)
|
33 |
+
def invoke_llm(prompt):
|
34 |
+
return llm.invoke(prompt)
|
35 |
+
|
36 |
+
|
37 |
+
class DataSummarizer:
|
38 |
+
def __init__(self):
|
39 |
+
pass
|
40 |
+
|
41 |
+
def google_search(self, query: str) -> Optional[str]:
|
42 |
+
start_time = time.time()
|
43 |
+
try:
|
44 |
+
url = "https://www.googleapis.com/customsearch/v1"
|
45 |
+
params = {
|
46 |
+
'key': GOOGLE_API_KEY,
|
47 |
+
'cx': SEARCH_ENGINE_ID,
|
48 |
+
'q': query
|
49 |
+
}
|
50 |
+
response = requests.get(url, params=params)
|
51 |
+
response.raise_for_status()
|
52 |
+
search_results = response.json()
|
53 |
+
logger.info("google_search took %.2f seconds", time.time() - start_time)
|
54 |
+
|
55 |
+
# Summarize the search results using Gemini
|
56 |
+
items = search_results.get('items', [])
|
57 |
+
content = "\n\n".join([f"{item.get('title', '')}\n{item.get('snippet', '')}" for item in items])
|
58 |
+
prompt = f"Summarize the following search results:\n\n{content}"
|
59 |
+
summary_response = invoke_llm(prompt)
|
60 |
+
return summary_response.content.strip()
|
61 |
+
except Exception as e:
|
62 |
+
logger.error(f"Error during Google Search API request: {e}")
|
63 |
+
return None
|
64 |
+
|
65 |
+
def extract_content_from_item(self, item: Dict) -> Optional[str]:
|
66 |
+
try:
|
67 |
+
snippet = item.get('snippet', '')
|
68 |
+
title = item.get('title', '')
|
69 |
+
return f"{title}\n{snippet}"
|
70 |
+
except Exception as e:
|
71 |
+
logger.error(f"Error extracting content from item: {e}")
|
72 |
+
return None
|
73 |
+
|
74 |
+
def calculate_moving_average(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.Series]:
|
75 |
+
start_time = time.time()
|
76 |
+
try:
|
77 |
+
result = df['close'].rolling(window=window).mean()
|
78 |
+
logger.info("calculate_moving_average took %.2f seconds", time.time() - start_time)
|
79 |
+
return result
|
80 |
+
except Exception as e:
|
81 |
+
logger.error(f"Error calculating moving average: {e}")
|
82 |
+
return None
|
83 |
+
|
84 |
+
def calculate_rsi(self, df: pd.DataFrame, window: int = 14) -> Optional[pd.Series]:
|
85 |
+
start_time = time.time()
|
86 |
+
try:
|
87 |
+
delta = df['close'].diff()
|
88 |
+
gain = delta.where(delta > 0, 0).rolling(window=window).mean()
|
89 |
+
loss = -delta.where(delta < 0, 0).rolling(window=window).mean()
|
90 |
+
rs = gain / loss
|
91 |
+
result = 100 - (100 / (1 + rs))
|
92 |
+
logger.info("calculate_rsi took %.2f seconds", time.time() - start_time)
|
93 |
+
return result
|
94 |
+
except Exception as e:
|
95 |
+
logger.error(f"Error calculating RSI: {e}")
|
96 |
+
return None
|
97 |
+
|
98 |
+
def calculate_ema(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.Series]:
|
99 |
+
start_time = time.time()
|
100 |
+
try:
|
101 |
+
result = df['close'].ewm(span=window, adjust=False).mean()
|
102 |
+
logger.info("calculate_ema took %.2f seconds", time.time() - start_time)
|
103 |
+
return result
|
104 |
+
except Exception as e:
|
105 |
+
logger.error(f"Error calculating EMA: {e}")
|
106 |
+
return None
|
107 |
+
|
108 |
+
def calculate_bollinger_bands(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.DataFrame]:
|
109 |
+
start_time = time.time()
|
110 |
+
try:
|
111 |
+
ma = df['close'].rolling(window=window).mean()
|
112 |
+
std = df['close'].rolling(window=window).std()
|
113 |
+
upper_band = ma + (std * 2)
|
114 |
+
lower_band = ma - (std * 2)
|
115 |
+
result = pd.DataFrame({'MA': ma, 'Upper Band': upper_band, 'Lower Band': lower_band})
|
116 |
+
logger.info("calculate_bollinger_bands took %.2f seconds", time.time() - start_time)
|
117 |
+
return result
|
118 |
+
except Exception as e:
|
119 |
+
logger.error(f"Error calculating Bollinger Bands: {e}")
|
120 |
+
return None
|
121 |
+
|
122 |
+
def calculate_macd(self, df: pd.DataFrame, short_window: int = 12, long_window: int = 26, signal_window: int = 9) -> \
|
123 |
+
Optional[pd.DataFrame]:
|
124 |
+
start_time = time.time()
|
125 |
+
try:
|
126 |
+
short_ema = df['close'].ewm(span=short_window, adjust=False).mean()
|
127 |
+
long_ema = df['close'].ewm(span=long_window, adjust=False).mean()
|
128 |
+
macd = short_ema - long_ema
|
129 |
+
signal = macd.ewm(span=signal_window, adjust=False).mean()
|
130 |
+
result = pd.DataFrame({'MACD': macd, 'Signal Line': signal})
|
131 |
+
logger.info("calculate_macd took %.2f seconds", time.time() - start_time)
|
132 |
+
return result
|
133 |
+
except Exception as e:
|
134 |
+
logger.error(f"Error calculating MACD: {e}")
|
135 |
+
return None
|
136 |
+
|
137 |
+
def calculate_volatility(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.Series]:
|
138 |
+
start_time = time.time()
|
139 |
+
try:
|
140 |
+
log_returns = np.log(df['close'] / df['close'].shift(1))
|
141 |
+
result = log_returns.rolling(window=window).std() * np.sqrt(window)
|
142 |
+
logger.info("calculate_volatility took %.2f seconds", time.time() - start_time)
|
143 |
+
return result
|
144 |
+
except Exception as e:
|
145 |
+
logger.error(f"Error calculating volatility: {e}")
|
146 |
+
return None
|
147 |
+
|
148 |
+
def calculate_atr(self, df: pd.DataFrame, window: int = 14) -> Optional[pd.Series]:
|
149 |
+
start_time = time.time()
|
150 |
+
try:
|
151 |
+
high_low = df['high'] - df['low']
|
152 |
+
high_close = np.abs(df['high'] - df['close'].shift())
|
153 |
+
low_close = np.abs(df['low'] - df['close'].shift())
|
154 |
+
true_range = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
|
155 |
+
result = true_range.rolling(window=window).mean()
|
156 |
+
logger.info("calculate_atr took %.2f seconds", time.time() - start_time)
|
157 |
+
return result
|
158 |
+
except Exception as e:
|
159 |
+
logger.error(f"Error calculating ATR: {e}")
|
160 |
+
return None
|
161 |
+
|
162 |
+
def calculate_obv(self, df: pd.DataFrame) -> Optional[pd.Series]:
|
163 |
+
start_time = time.time()
|
164 |
+
try:
|
165 |
+
result = (np.sign(df['close'].diff()) * df['volume']).fillna(0).cumsum()
|
166 |
+
logger.info("calculate_obv took %.2f seconds", time.time() - start_time)
|
167 |
+
return result
|
168 |
+
except Exception as e:
|
169 |
+
logger.error(f"Error calculating OBV: {e}")
|
170 |
+
return None
|
171 |
+
|
172 |
+
def calculate_yearly_summary(self, df: pd.DataFrame) -> Optional[pd.DataFrame]:
|
173 |
+
start_time = time.time()
|
174 |
+
try:
|
175 |
+
df['year'] = pd.to_datetime(df['date']).dt.year
|
176 |
+
yearly_summary = df.groupby('year').agg({
|
177 |
+
'close': ['mean', 'max', 'min'],
|
178 |
+
'volume': 'sum'
|
179 |
+
})
|
180 |
+
yearly_summary.columns = ['_'.join(col) for col in yearly_summary.columns]
|
181 |
+
logger.info("calculate_yearly_summary took %.2f seconds", time.time() - start_time)
|
182 |
+
return yearly_summary
|
183 |
+
except Exception as e:
|
184 |
+
logger.error(f"Error calculating yearly summary: {e}")
|
185 |
+
return None
|
186 |
+
|
187 |
+
def get_full_last_year(self, df: pd.DataFrame) -> Optional[pd.DataFrame]:
|
188 |
+
start_time = time.time()
|
189 |
+
try:
|
190 |
+
today = datetime.today().date()
|
191 |
+
last_year_start = datetime(today.year - 1, 1, 1).date()
|
192 |
+
last_year_end = datetime(today.year - 1, 12, 31).date()
|
193 |
+
mask = (df['date'] >= last_year_start) & (df['date'] <= last_year_end)
|
194 |
+
result = df.loc[mask]
|
195 |
+
logger.info("get_full_last_year took %.2f seconds", time.time() - start_time)
|
196 |
+
return result
|
197 |
+
except Exception as e:
|
198 |
+
logger.error(f"Error filtering data for the last year: {e}")
|
199 |
+
return None
|
200 |
+
|
201 |
+
def calculate_ytd_performance(self, df: pd.DataFrame) -> Optional[float]:
|
202 |
+
start_time = time.time()
|
203 |
+
try:
|
204 |
+
today = datetime.today().date()
|
205 |
+
year_start = datetime(today.year, 1, 1).date()
|
206 |
+
mask = (df['date'] >= year_start) & (df['date'] <= today)
|
207 |
+
ytd_data = df.loc[mask]
|
208 |
+
opening_price = ytd_data.iloc[0]['open']
|
209 |
+
closing_price = ytd_data.iloc[-1]['close']
|
210 |
+
result = ((closing_price - opening_price) / opening_price) * 100
|
211 |
+
logger.info("calculate_ytd_performance took %.2f seconds", time.time() - start_time)
|
212 |
+
return result
|
213 |
+
except Exception as e:
|
214 |
+
logger.error(f"Error calculating YTD performance: {e}")
|
215 |
+
return None
|
216 |
+
|
217 |
+
def calculate_pe_ratio(self, current_price: float, eps: float) -> Optional[float]:
|
218 |
+
start_time = time.time()
|
219 |
+
try:
|
220 |
+
if eps == 0:
|
221 |
+
raise ValueError("EPS cannot be zero for P/E ratio calculation.")
|
222 |
+
result = current_price / eps
|
223 |
+
logger.info("calculate_pe_ratio took %.2f seconds", time.time() - start_time)
|
224 |
+
return result
|
225 |
+
except Exception as e:
|
226 |
+
logger.error(f"Error calculating P/E ratio: {e}")
|
227 |
+
return None
|
228 |
+
|
229 |
+
def fetch_google_snippet(self, query: str) -> Optional[str]:
|
230 |
+
try:
|
231 |
+
search_url = f"https://www.google.com/search?q={query}"
|
232 |
+
headers = {
|
233 |
+
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.36"
|
234 |
+
}
|
235 |
+
response = requests.get(search_url, headers=headers)
|
236 |
+
soup = BeautifulSoup(response.text, 'html.parser')
|
237 |
+
snippet_classes = [
|
238 |
+
'BNeawe iBp4i AP7Wnd',
|
239 |
+
'BNeawe s3v9rd AP7Wnd',
|
240 |
+
'BVG0Nb',
|
241 |
+
'kno-rdesc'
|
242 |
+
]
|
243 |
+
snippet = None
|
244 |
+
for cls in snippet_classes:
|
245 |
+
snippet = soup.find('div', class_=cls)
|
246 |
+
if snippet:
|
247 |
+
break
|
248 |
+
return snippet.get_text() if snippet else "Snippet not found."
|
249 |
+
except Exception as e:
|
250 |
+
logger.error(f"Error fetching Google snippet: {e}")
|
251 |
+
return None
|
252 |
+
|
253 |
+
|
254 |
+
def extract_ticker_from_response(response: str) -> Optional[str]:
|
255 |
+
start_time = time.time()
|
256 |
+
try:
|
257 |
+
if "is **" in response and "**." in response:
|
258 |
+
result = response.split("is **")[1].split("**.")[0].strip()
|
259 |
+
logger.info("extract_ticker_from_response took %.2f seconds", time.time() - start_time)
|
260 |
+
return result
|
261 |
+
result = response.strip()
|
262 |
+
logger.info("extract_ticker_from_response took %.2f seconds", time.time() - start_time)
|
263 |
+
return result
|
264 |
+
except Exception as e:
|
265 |
+
logger.error(f"Error extracting ticker from response: {e}")
|
266 |
+
return None
|
267 |
+
|
268 |
+
|
269 |
+
def detect_translate_entity_and_ticker(query: str) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
|
270 |
+
try:
|
271 |
+
start_time = time.time()
|
272 |
+
|
273 |
+
# Step 1: Detect Language
|
274 |
+
prompt = f"Detect the language for the following text: {query}"
|
275 |
+
response = invoke_llm(prompt)
|
276 |
+
detected_language = response.content.strip()
|
277 |
+
logger.info(f"Language detected: {detected_language}")
|
278 |
+
|
279 |
+
# Step 2: Translate to English (if necessary)
|
280 |
+
translated_query = query
|
281 |
+
if detected_language != "English":
|
282 |
+
prompt = f"Translate the following text to English: {query}"
|
283 |
+
response = invoke_llm(prompt)
|
284 |
+
translated_query = response.content.strip()
|
285 |
+
logger.info(f"Translation completed: {translated_query}")
|
286 |
+
print(f"Translation: {translated_query}")
|
287 |
+
|
288 |
+
# Step 3: Detect Entity
|
289 |
+
prompt = f"Detect the entity in the following text that is a company name: {translated_query}"
|
290 |
+
response = invoke_llm(prompt)
|
291 |
+
detected_entity = response.content.strip()
|
292 |
+
logger.info(f"Entity detected: {detected_entity}")
|
293 |
+
print(f"Entity: {detected_entity}")
|
294 |
+
|
295 |
+
if not detected_entity:
|
296 |
+
logger.error("No entity detected")
|
297 |
+
return detected_language, None, translated_query, None
|
298 |
+
|
299 |
+
# Step 4: Get Stock Ticker
|
300 |
+
prompt = f"What is the stock ticker symbol for the company {detected_entity}?"
|
301 |
+
response = invoke_llm(prompt)
|
302 |
+
stock_ticker = extract_ticker_from_response(response.content.strip())
|
303 |
+
|
304 |
+
if not stock_ticker:
|
305 |
+
logger.error("No stock ticker detected")
|
306 |
+
return detected_language, detected_entity, translated_query, None
|
307 |
+
|
308 |
+
logger.info("detect_translate_entity_and_ticker took %.2f seconds", time.time() - start_time)
|
309 |
+
return detected_language, detected_entity, translated_query, stock_ticker
|
310 |
+
except Exception as e:
|
311 |
+
logger.error(f"Error in detecting, translating, or extracting entity and ticker: {e}")
|
312 |
+
return None, None, None, None
|
313 |
+
|
314 |
+
|
315 |
+
def fetch_stock_data_yahoo(symbol: str) -> pd.DataFrame:
|
316 |
+
start_time = time.time()
|
317 |
+
try:
|
318 |
+
stock = yf.Ticker(symbol)
|
319 |
+
logger.info(f"Fetching data for symbol: {symbol}")
|
320 |
+
|
321 |
+
end_date = datetime.now()
|
322 |
+
start_date = end_date - timedelta(days=3 * 365)
|
323 |
+
|
324 |
+
historical_data = stock.history(start=start_date, end=end_date)
|
325 |
+
if historical_data.empty:
|
326 |
+
raise ValueError(f"No historical data found for symbol: {symbol}")
|
327 |
+
|
328 |
+
historical_data = historical_data.rename(
|
329 |
+
columns={"Open": "open", "High": "high", "Low": "low", "Close": "close", "Volume": "volume"}
|
330 |
+
)
|
331 |
+
|
332 |
+
historical_data.reset_index(inplace=True)
|
333 |
+
historical_data['date'] = historical_data['Date'].dt.date
|
334 |
+
historical_data = historical_data.drop(columns=['Date'])
|
335 |
+
historical_data = historical_data[['date', 'open', 'high', 'low', 'close', 'volume']]
|
336 |
+
|
337 |
+
if 'close' not in historical_data.columns:
|
338 |
+
raise KeyError("The historical data must contain a 'close' column.")
|
339 |
+
|
340 |
+
logger.info("fetch_stock_data_yahoo took %.2f seconds", time.time() - start_time)
|
341 |
+
return historical_data
|
342 |
+
except Exception as e:
|
343 |
+
logger.error(f"Failed to fetch stock data for {symbol} from Yahoo Finance: {e}")
|
344 |
+
return pd.DataFrame()
|
345 |
+
|
346 |
+
|
347 |
+
def fetch_current_stock_price(symbol: str) -> Optional[float]:
|
348 |
+
start_time = time.time()
|
349 |
+
try:
|
350 |
+
stock = yf.Ticker(symbol)
|
351 |
+
result = stock.info['currentPrice']
|
352 |
+
logger.info("fetch_current_stock_price took %.2f seconds", time.time() - start_time)
|
353 |
+
return result
|
354 |
+
except Exception as e:
|
355 |
+
logger.error(f"Failed to fetch current stock price for {symbol}: {e}")
|
356 |
+
return None
|
357 |
+
|
358 |
+
|
359 |
+
def format_stock_data_for_gemini(stock_data: pd.DataFrame) -> str:
|
360 |
+
start_time = time.time()
|
361 |
+
try:
|
362 |
+
if stock_data.empty:
|
363 |
+
return "No historical data available."
|
364 |
+
|
365 |
+
formatted_data = "Historical stock data for the last three years:\n\n"
|
366 |
+
formatted_data += "Date | Open | High | Low | Close | Volume\n"
|
367 |
+
formatted_data += "------------------------------------------------------\n"
|
368 |
+
|
369 |
+
for index, row in stock_data.iterrows():
|
370 |
+
formatted_data += f"{row['date']} | {row['open']:.2f} | {row['high']:.2f} | {row['low']:.2f} | {row['close']:.2f} | {int(row['volume'])}\n"
|
371 |
+
|
372 |
+
logger.info("format_stock_data_for_gemini took %.2f seconds", time.time() - start_time)
|
373 |
+
return formatted_data
|
374 |
+
except Exception as e:
|
375 |
+
logger.error(f"Error formatting stock data for Gemini: {e}")
|
376 |
+
return "Error formatting stock data."
|
377 |
+
|
378 |
+
|
379 |
+
def fetch_company_info_yahoo(symbol: str) -> Dict:
|
380 |
+
start_time = time.time()
|
381 |
+
try:
|
382 |
+
if not symbol:
|
383 |
+
return {"error": "Invalid symbol"}
|
384 |
+
|
385 |
+
stock = yf.Ticker(symbol)
|
386 |
+
company_info = stock.info
|
387 |
+
logger.info("fetch_company_info_yahoo took %.2f seconds", time.time() - start_time)
|
388 |
+
return {
|
389 |
+
"name": company_info.get("longName", "N/A"),
|
390 |
+
"sector": company_info.get("sector", "N/A"),
|
391 |
+
"industry": company_info.get("industry", "N/A"),
|
392 |
+
"marketCap": company_info.get("marketCap", "N/A"),
|
393 |
+
"summary": company_info.get("longBusinessSummary", "N/A"),
|
394 |
+
"website": company_info.get("website", "N/A"),
|
395 |
+
"address": company_info.get("address1", "N/A"),
|
396 |
+
"city": company_info.get("city", "N/A"),
|
397 |
+
"state": company_info.get("state", "N/A"),
|
398 |
+
"country": company_info.get("country", "N/A"),
|
399 |
+
"phone": company_info.get("phone", "N/A")
|
400 |
+
}
|
401 |
+
except Exception as e:
|
402 |
+
logger.error(f"Error fetching company info for {symbol}: {e}")
|
403 |
+
return {"error": str(e)}
|
404 |
+
|
405 |
+
|
406 |
+
def format_company_info_for_gemini(company_info: Dict) -> str:
|
407 |
+
start_time = time.time()
|
408 |
+
try:
|
409 |
+
if "error" in company_info:
|
410 |
+
return f"Error fetching company info: {company_info['error']}"
|
411 |
+
|
412 |
+
formatted_info = (f"\nCompany Information:\n"
|
413 |
+
f"Name: {company_info['name']}\n"
|
414 |
+
f"Sector: {company_info['sector']}\n"
|
415 |
+
f"Industry: {company_info['industry']}\n"
|
416 |
+
f"Market Cap: {company_info['marketCap']}\n"
|
417 |
+
f"Summary: {company_info['summary']}\n"
|
418 |
+
f"Website: {company_info['website']}\n"
|
419 |
+
f"Address: {company_info['address']}, {company_info['city']}, {company_info['state']}, {company_info['country']}\n"
|
420 |
+
f"Phone: {company_info['phone']}\n")
|
421 |
+
|
422 |
+
logger.info("format_company_info_for_gemini took %.2f seconds", time.time() - start_time)
|
423 |
+
return formatted_info
|
424 |
+
except Exception as e:
|
425 |
+
logger.error(f"Error formatting company info for Gemini: {e}")
|
426 |
+
return "Error formatting company info."
|
427 |
+
|
428 |
+
|
429 |
+
def fetch_company_news_yahoo(symbol: str) -> List[Dict]:
|
430 |
+
start_time = time.time()
|
431 |
+
try:
|
432 |
+
stock = yf.Ticker(symbol)
|
433 |
+
news = stock.news
|
434 |
+
if not news:
|
435 |
+
raise ValueError(f"No news found for symbol: {symbol}")
|
436 |
+
logger.info("fetch_company_news_yahoo took %.2f seconds", time.time() - start_time)
|
437 |
+
return news
|
438 |
+
except Exception as e:
|
439 |
+
logger.error(f"Failed to fetch news for {symbol} from Yahoo Finance: {e}")
|
440 |
+
return []
|
441 |
+
|
442 |
+
|
443 |
+
def format_company_news_for_gemini(news: List[Dict]) -> str:
|
444 |
+
start_time = time.time()
|
445 |
+
try:
|
446 |
+
if not news:
|
447 |
+
return "No news available."
|
448 |
+
|
449 |
+
formatted_news = "Latest company news:\n\n"
|
450 |
+
for article in news:
|
451 |
+
formatted_news += (f"Title: {article['title']}\n"
|
452 |
+
f"Publisher: {article['publisher']}\n"
|
453 |
+
f"Link: {article['link']}\n"
|
454 |
+
f"Published: {article['providerPublishTime']}\n\n")
|
455 |
+
|
456 |
+
logger.info("format_company_news_for_gemini took %.2f seconds", time.time() - start_time)
|
457 |
+
return formatted_news
|
458 |
+
except Exception as e:
|
459 |
+
logger.error(f"Error formatting company news for Gemini: {e}")
|
460 |
+
return "Error formatting company news."
|
461 |
+
|
462 |
+
|
463 |
+
def send_to_gemini_for_summarization(content: str) -> str:
|
464 |
+
start_time = time.time()
|
465 |
+
try:
|
466 |
+
unified_content = " ".join(content)
|
467 |
+
prompt = f"Summarize the main points of this article.\n\n{unified_content}"
|
468 |
+
response = invoke_llm(prompt)
|
469 |
+
logger.info("send_to_gemini_for_summarization took %.2f seconds", time.time() - start_time)
|
470 |
+
return response.content.strip()
|
471 |
+
except Exception as e:
|
472 |
+
logger.error(f"Error sending content to Gemini for summarization: {e}")
|
473 |
+
return "Error summarizing content."
|
474 |
+
|
475 |
+
|
476 |
+
def answer_question_with_data(question: str, data: Dict) -> str:
|
477 |
+
start_time = time.time()
|
478 |
+
try:
|
479 |
+
data_str = ""
|
480 |
+
for key, value in data.items():
|
481 |
+
data_str += f"{key}:\n{value}\n\n"
|
482 |
+
|
483 |
+
prompt = (f"You are a financial advisor. Begin your answer by stating that and only give the answer after.\n"
|
484 |
+
f"Using the following data, answer this question: {question}\n\nData:\n{data_str}\n"
|
485 |
+
f"Make your answer in the best form and professional.\n"
|
486 |
+
f"Don't say anything about the source of the data.\n"
|
487 |
+
f"If you don't have the data to answer, say this data is not available yet. If the data is not available in the stock history data, say this was a weekend and there is no data for it.")
|
488 |
+
response = invoke_llm(prompt)
|
489 |
+
logger.info("answer_question_with_data took %.2f seconds", time.time() - start_time)
|
490 |
+
return response.content.strip()
|
491 |
+
except Exception as e:
|
492 |
+
logger.error(f"Error answering question with data: {e}")
|
493 |
+
return "Error answering question."
|
494 |
+
|
495 |
+
|
496 |
+
def calculate_metrics(stock_data: pd.DataFrame, summarizer: DataSummarizer, company_info: Dict) -> Dict[str, str]:
|
497 |
+
start_time = time.time()
|
498 |
+
try:
|
499 |
+
moving_average = summarizer.calculate_moving_average(stock_data)
|
500 |
+
rsi = summarizer.calculate_rsi(stock_data)
|
501 |
+
ema = summarizer.calculate_ema(stock_data)
|
502 |
+
bollinger_bands = summarizer.calculate_bollinger_bands(stock_data)
|
503 |
+
macd = summarizer.calculate_macd(stock_data)
|
504 |
+
volatility = summarizer.calculate_volatility(stock_data)
|
505 |
+
atr = summarizer.calculate_atr(stock_data)
|
506 |
+
obv = summarizer.calculate_obv(stock_data)
|
507 |
+
yearly_summary = summarizer.calculate_yearly_summary(stock_data)
|
508 |
+
ytd_performance = summarizer.calculate_ytd_performance(stock_data)
|
509 |
+
|
510 |
+
eps = company_info.get('trailingEps', None)
|
511 |
+
if eps:
|
512 |
+
current_price = stock_data.iloc[-1]['close']
|
513 |
+
pe_ratio = summarizer.calculate_pe_ratio(current_price, eps)
|
514 |
+
formatted_metrics = {
|
515 |
+
"Moving Average": moving_average.to_string(),
|
516 |
+
"RSI": rsi.to_string(),
|
517 |
+
"EMA": ema.to_string(),
|
518 |
+
"Bollinger Bands": bollinger_bands.to_string(),
|
519 |
+
"MACD": macd.to_string(),
|
520 |
+
"Volatility": volatility.to_string(),
|
521 |
+
"ATR": atr.to_string(),
|
522 |
+
"OBV": obv.to_string(),
|
523 |
+
"Yearly Summary": yearly_summary.to_string(),
|
524 |
+
"YTD Performance": f"{ytd_performance:.2f}%",
|
525 |
+
"P/E Ratio": f"{pe_ratio:.2f}"
|
526 |
+
}
|
527 |
+
else:
|
528 |
+
formatted_metrics = {
|
529 |
+
"Moving Average": moving_average.to_string(),
|
530 |
+
"RSI": rsi.to_string(),
|
531 |
+
"EMA": ema.to_string(),
|
532 |
+
"Bollinger Bands": bollinger_bands.to_string(),
|
533 |
+
"MACD": macd.to_string(),
|
534 |
+
"Volatility": volatility.to_string(),
|
535 |
+
"ATR": atr.to_string(),
|
536 |
+
"OBV": obv.to_string(),
|
537 |
+
"Yearly Summary": yearly_summary.to_string(),
|
538 |
+
"YTD Performance": f"{ytd_performance:.2f}%"
|
539 |
+
}
|
540 |
+
|
541 |
+
logger.info("calculate_metrics took %.2f seconds", time.time() - start_time)
|
542 |
+
return formatted_metrics
|
543 |
+
except Exception as e:
|
544 |
+
logger.error(f"Error calculating metrics: {e}")
|
545 |
+
return {"Error": "Error calculating metrics"}
|
546 |
+
|
547 |
+
|
548 |
+
def prepare_data(formatted_stock_data: str, formatted_company_info: str, formatted_company_news: str,
|
549 |
+
google_results: str, formatted_metrics: Dict[str, str], google_snippet: str, rag_response: str) -> \
|
550 |
+
Dict[str, str]:
|
551 |
+
start_time = time.time()
|
552 |
+
collected_data = {
|
553 |
+
"Formatted Stock Data": formatted_stock_data,
|
554 |
+
"Formatted Company Info": formatted_company_info,
|
555 |
+
"Formatted Company News": formatted_company_news,
|
556 |
+
"Google Search Results": google_results,
|
557 |
+
"Google Snippet": google_snippet,
|
558 |
+
"RAG Response": rag_response,
|
559 |
+
"Calculations": formatted_metrics
|
560 |
+
}
|
561 |
+
collected_data.update(formatted_metrics)
|
562 |
+
logger.info("prepare_data took %.2f seconds", time.time() - start_time)
|
563 |
+
return collected_data
|
564 |
+
|
565 |
+
|
566 |
+
def main():
|
567 |
+
print("Welcome to the Financial Data Chatbot. How can I assist you today?")
|
568 |
+
|
569 |
+
summarizer = DataSummarizer()
|
570 |
+
conversation_history = []
|
571 |
+
|
572 |
+
while True:
|
573 |
+
user_input = input("You: ")
|
574 |
+
|
575 |
+
if user_input.lower() in ['exit', 'quit', 'bye']:
|
576 |
+
print("Goodbye! Have a great day!")
|
577 |
+
break
|
578 |
+
|
579 |
+
conversation_history.append(f"You: {user_input}")
|
580 |
+
|
581 |
+
try:
|
582 |
+
# Detect language, entity, translation, and stock ticker
|
583 |
+
language, entity, translation, stock_ticker = detect_translate_entity_and_ticker(user_input)
|
584 |
+
|
585 |
+
logger.info(
|
586 |
+
f"Detected Language: {language}, Entity: {entity}, Translation: {translation}, Stock Ticker: {stock_ticker}")
|
587 |
+
|
588 |
+
if entity and stock_ticker:
|
589 |
+
with ThreadPoolExecutor() as executor:
|
590 |
+
futures = {
|
591 |
+
executor.submit(fetch_stock_data_yahoo, stock_ticker): "stock_data",
|
592 |
+
executor.submit(fetch_company_info_yahoo, stock_ticker): "company_info",
|
593 |
+
executor.submit(fetch_company_news_yahoo, stock_ticker): "company_news",
|
594 |
+
executor.submit(fetch_current_stock_price, stock_ticker): "current_stock_price",
|
595 |
+
executor.submit(get_answer, user_input): "rag_response",
|
596 |
+
executor.submit(summarizer.google_search, user_input): "google_results",
|
597 |
+
executor.submit(summarizer.fetch_google_snippet, user_input): "google_snippet"
|
598 |
+
}
|
599 |
+
results = {futures[future]: future.result() for future in as_completed(futures)}
|
600 |
+
|
601 |
+
stock_data = results.get("stock_data", pd.DataFrame())
|
602 |
+
formatted_stock_data = format_stock_data_for_gemini(
|
603 |
+
stock_data) if not stock_data.empty else "No historical data available."
|
604 |
+
|
605 |
+
company_info = results.get("company_info", {})
|
606 |
+
formatted_company_info = format_company_info_for_gemini(
|
607 |
+
company_info) if company_info else "No company info available."
|
608 |
+
|
609 |
+
company_news = results.get("company_news", [])
|
610 |
+
formatted_company_news = format_company_news_for_gemini(
|
611 |
+
company_news) if company_news else "No news available."
|
612 |
+
|
613 |
+
current_stock_price = results.get("current_stock_price", None)
|
614 |
+
|
615 |
+
formatted_metrics = calculate_metrics(stock_data, summarizer,
|
616 |
+
company_info) if not stock_data.empty else {
|
617 |
+
"Error": "No stock data for metrics"}
|
618 |
+
|
619 |
+
google_results = results.get("google_results", "No additional news found through Google Search.")
|
620 |
+
google_snippet = results.get("google_snippet", "Snippet not found.")
|
621 |
+
|
622 |
+
rag_response = results.get("rag_response", "No response from RAG.")
|
623 |
+
|
624 |
+
collected_data = prepare_data(formatted_stock_data, formatted_company_info, formatted_company_news,
|
625 |
+
google_results, formatted_metrics, google_snippet, rag_response)
|
626 |
+
collected_data[
|
627 |
+
"Current Stock Price"] = f"${current_stock_price:.2f}" if current_stock_price is not None else "N/A"
|
628 |
+
|
629 |
+
conversation_history.append(f"RAG Response: {rag_response}")
|
630 |
+
history_context = "\n".join(conversation_history)
|
631 |
+
|
632 |
+
answer = answer_question_with_data(f"{history_context}\n\nUser's query: {translation}", collected_data)
|
633 |
+
|
634 |
+
print(f"\nBot: {answer}")
|
635 |
+
conversation_history.append(f"Bot: {answer}")
|
636 |
+
|
637 |
+
else:
|
638 |
+
with ThreadPoolExecutor() as executor:
|
639 |
+
futures = {
|
640 |
+
executor.submit(get_answer, user_input): "rag_response",
|
641 |
+
executor.submit(summarizer.google_search, user_input): "google_results",
|
642 |
+
executor.submit(summarizer.fetch_google_snippet, user_input): "google_snippet"
|
643 |
+
}
|
644 |
+
results = {futures[future]: future.result() for future in as_completed(futures)}
|
645 |
+
|
646 |
+
google_results = results.get("google_results", "No additional news found through Google Search.")
|
647 |
+
google_snippet = results.get("google_snippet", "Snippet not found.")
|
648 |
+
rag_response = results.get("rag_response", "No response from RAG.")
|
649 |
+
|
650 |
+
collected_data = prepare_data("", "", "", google_results, {}, google_snippet, rag_response)
|
651 |
+
|
652 |
+
conversation_history.append(f"RAG Response: {rag_response}")
|
653 |
+
history_context = "\n".join(conversation_history)
|
654 |
+
|
655 |
+
answer = answer_question_with_data(f"{history_context}\n\nUser's query: {user_input}", collected_data)
|
656 |
+
|
657 |
+
print(f"\nBot: {answer}")
|
658 |
+
conversation_history.append(f"Bot: {answer}")
|
659 |
+
|
660 |
+
except Exception as e:
|
661 |
+
logger.error(f"An error occurred: {e}")
|
662 |
+
response = "An error occurred while processing your request. Please try again later."
|
663 |
+
print(f"Bot: {response}")
|
664 |
+
conversation_history.append(f"Bot: {response}")
|
665 |
+
|
666 |
+
if __name__ == "__main__":
|
667 |
+
main()
|
chatflask.py
ADDED
@@ -0,0 +1,646 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flask import Flask, request, jsonify
|
2 |
+
import requests
|
3 |
+
from bs4 import BeautifulSoup
|
4 |
+
import yfinance as yf
|
5 |
+
import pandas as pd
|
6 |
+
from datetime import datetime, timedelta
|
7 |
+
import logging
|
8 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
9 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
10 |
+
from config import Config
|
11 |
+
import numpy as np
|
12 |
+
from typing import Optional, Tuple, List, Dict
|
13 |
+
from rag import get_answer
|
14 |
+
import time
|
15 |
+
from tenacity import retry, stop_after_attempt, wait_exponential
|
16 |
+
|
17 |
+
# Initialize Flask app
|
18 |
+
app = Flask(__name__)
|
19 |
+
|
20 |
+
# Set up logging
|
21 |
+
logging.basicConfig(level=logging.DEBUG,
|
22 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
23 |
+
handlers=[logging.FileHandler("app.log"),
|
24 |
+
logging.StreamHandler()])
|
25 |
+
|
26 |
+
logger = logging.getLogger(__name__)
|
27 |
+
|
28 |
+
# Initialize the Gemini model
|
29 |
+
llm = ChatGoogleGenerativeAI(api_key=Config.GEMINI_API_KEY, model="gemini-1.5-flash-latest", temperature=0.5)
|
30 |
+
|
31 |
+
# Configuration for Google Custom Search API
|
32 |
+
GOOGLE_API_KEY = Config.GOOGLE_API_KEY
|
33 |
+
SEARCH_ENGINE_ID = Config.SEARCH_ENGINE_ID
|
34 |
+
|
35 |
+
|
36 |
+
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=2, max=8), reraise=True)
|
37 |
+
def invoke_llm(prompt):
|
38 |
+
return llm.invoke(prompt)
|
39 |
+
|
40 |
+
|
41 |
+
class DataSummarizer:
|
42 |
+
def __init__(self):
|
43 |
+
pass
|
44 |
+
|
45 |
+
def google_search(self, query: str) -> Optional[str]:
|
46 |
+
start_time = time.time()
|
47 |
+
try:
|
48 |
+
url = "https://www.googleapis.com/customsearch/v1"
|
49 |
+
params = {
|
50 |
+
'key': GOOGLE_API_KEY,
|
51 |
+
'cx': SEARCH_ENGINE_ID,
|
52 |
+
'q': query
|
53 |
+
}
|
54 |
+
response = requests.get(url, params=params)
|
55 |
+
response.raise_for_status()
|
56 |
+
search_results = response.json()
|
57 |
+
logger.info("google_search took %.2f seconds", time.time() - start_time)
|
58 |
+
|
59 |
+
# Summarize the search results using Gemini
|
60 |
+
items = search_results.get('items', [])
|
61 |
+
content = "\n\n".join([f"{item.get('title', '')}\n{item.get('snippet', '')}" for item in items])
|
62 |
+
prompt = f"Summarize the following search results:\n\n{content}"
|
63 |
+
summary_response = invoke_llm(prompt)
|
64 |
+
return summary_response.content.strip()
|
65 |
+
except Exception as e:
|
66 |
+
logger.error(f"Error during Google Search API request: {e}")
|
67 |
+
return None
|
68 |
+
|
69 |
+
def extract_content_from_item(self, item: Dict) -> Optional[str]:
|
70 |
+
try:
|
71 |
+
snippet = item.get('snippet', '')
|
72 |
+
title = item.get('title', '')
|
73 |
+
return f"{title}\n{snippet}"
|
74 |
+
except Exception as e:
|
75 |
+
logger.error(f"Error extracting content from item: {e}")
|
76 |
+
return None
|
77 |
+
|
78 |
+
def calculate_moving_average(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.Series]:
|
79 |
+
start_time = time.time()
|
80 |
+
try:
|
81 |
+
result = df['close'].rolling(window=window).mean()
|
82 |
+
logger.info("calculate_moving_average took %.2f seconds", time.time() - start_time)
|
83 |
+
return result
|
84 |
+
except Exception as e:
|
85 |
+
logger.error(f"Error calculating moving average: {e}")
|
86 |
+
return None
|
87 |
+
|
88 |
+
def calculate_rsi(self, df: pd.DataFrame, window: int = 14) -> Optional[pd.Series]:
|
89 |
+
start_time = time.time()
|
90 |
+
try:
|
91 |
+
delta = df['close'].diff()
|
92 |
+
gain = delta.where(delta > 0, 0).rolling(window=window).mean()
|
93 |
+
loss = -delta.where(delta < 0, 0).rolling(window=window).mean()
|
94 |
+
rs = gain / loss
|
95 |
+
result = 100 - (100 / (1 + rs))
|
96 |
+
logger.info("calculate_rsi took %.2f seconds", time.time() - start_time)
|
97 |
+
return result
|
98 |
+
except Exception as e:
|
99 |
+
logger.error(f"Error calculating RSI: {e}")
|
100 |
+
return None
|
101 |
+
|
102 |
+
def calculate_ema(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.Series]:
|
103 |
+
start_time = time.time()
|
104 |
+
try:
|
105 |
+
result = df['close'].ewm(span=window, adjust=False).mean()
|
106 |
+
logger.info("calculate_ema took %.2f seconds", time.time() - start_time)
|
107 |
+
return result
|
108 |
+
except Exception as e:
|
109 |
+
logger.error(f"Error calculating EMA: {e}")
|
110 |
+
return None
|
111 |
+
|
112 |
+
def calculate_bollinger_bands(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.DataFrame]:
|
113 |
+
start_time = time.time()
|
114 |
+
try:
|
115 |
+
ma = df['close'].rolling(window=window).mean()
|
116 |
+
std = df['close'].rolling(window=window).std()
|
117 |
+
upper_band = ma + (std * 2)
|
118 |
+
lower_band = ma - (std * 2)
|
119 |
+
result = pd.DataFrame({'MA': ma, 'Upper Band': upper_band, 'Lower Band': lower_band})
|
120 |
+
logger.info("calculate_bollinger_bands took %.2f seconds", time.time() - start_time)
|
121 |
+
return result
|
122 |
+
except Exception as e:
|
123 |
+
logger.error(f"Error calculating Bollinger Bands: {e}")
|
124 |
+
return None
|
125 |
+
|
126 |
+
def calculate_macd(self, df: pd.DataFrame, short_window: int = 12, long_window: int = 26, signal_window: int = 9) -> \
|
127 |
+
Optional[pd.DataFrame]:
|
128 |
+
start_time = time.time()
|
129 |
+
try:
|
130 |
+
short_ema = df['close'].ewm(span=short_window, adjust=False).mean()
|
131 |
+
long_ema = df['close'].ewm(span=long_window, adjust=False).mean()
|
132 |
+
macd = short_ema - long_ema
|
133 |
+
signal = macd.ewm(span=signal_window, adjust=False).mean()
|
134 |
+
result = pd.DataFrame({'MACD': macd, 'Signal Line': signal})
|
135 |
+
logger.info("calculate_macd took %.2f seconds", time.time() - start_time)
|
136 |
+
return result
|
137 |
+
except Exception as e:
|
138 |
+
logger.error(f"Error calculating MACD: {e}")
|
139 |
+
return None
|
140 |
+
|
141 |
+
def calculate_volatility(self, df: pd.DataFrame, window: int = 20) -> Optional[pd.Series]:
|
142 |
+
start_time = time.time()
|
143 |
+
try:
|
144 |
+
log_returns = np.log(df['close'] / df['close'].shift(1))
|
145 |
+
result = log_returns.rolling(window=window).std() * np.sqrt(window)
|
146 |
+
logger.info("calculate_volatility took %.2f seconds", time.time() - start_time)
|
147 |
+
return result
|
148 |
+
except Exception as e:
|
149 |
+
logger.error(f"Error calculating volatility: {e}")
|
150 |
+
return None
|
151 |
+
|
152 |
+
def calculate_atr(self, df: pd.DataFrame, window: int = 14) -> Optional[pd.Series]:
|
153 |
+
start_time = time.time()
|
154 |
+
try:
|
155 |
+
high_low = df['high'] - df['low']
|
156 |
+
high_close = np.abs(df['high'] - df['close'].shift())
|
157 |
+
low_close = np.abs(df['low'] - df['close'].shift())
|
158 |
+
true_range = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
|
159 |
+
result = true_range.rolling(window=window).mean()
|
160 |
+
logger.info("calculate_atr took %.2f seconds", time.time() - start_time)
|
161 |
+
return result
|
162 |
+
except Exception as e:
|
163 |
+
logger.error(f"Error calculating ATR: {e}")
|
164 |
+
return None
|
165 |
+
|
166 |
+
def calculate_obv(self, df: pd.DataFrame) -> Optional[pd.Series]:
|
167 |
+
start_time = time.time()
|
168 |
+
try:
|
169 |
+
result = (np.sign(df['close'].diff()) * df['volume']).fillna(0).cumsum()
|
170 |
+
logger.info("calculate_obv took %.2f seconds", time.time() - start_time)
|
171 |
+
return result
|
172 |
+
except Exception as e:
|
173 |
+
logger.error(f"Error calculating OBV: {e}")
|
174 |
+
return None
|
175 |
+
|
176 |
+
def calculate_yearly_summary(self, df: pd.DataFrame) -> Optional[pd.DataFrame]:
|
177 |
+
start_time = time.time()
|
178 |
+
try:
|
179 |
+
df['year'] = pd.to_datetime(df['date']).dt.year
|
180 |
+
yearly_summary = df.groupby('year').agg({
|
181 |
+
'close': ['mean', 'max', 'min'],
|
182 |
+
'volume': 'sum'
|
183 |
+
})
|
184 |
+
yearly_summary.columns = ['_'.join(col) for col in yearly_summary.columns]
|
185 |
+
logger.info("calculate_yearly_summary took %.2f seconds", time.time() - start_time)
|
186 |
+
return yearly_summary
|
187 |
+
except Exception as e:
|
188 |
+
logger.error(f"Error calculating yearly summary: {e}")
|
189 |
+
return None
|
190 |
+
|
191 |
+
def get_full_last_year(self, df: pd.DataFrame) -> Optional[pd.DataFrame]:
|
192 |
+
start_time = time.time()
|
193 |
+
try:
|
194 |
+
today = datetime.today().date()
|
195 |
+
last_year_start = datetime(today.year - 1, 1, 1).date()
|
196 |
+
last_year_end = datetime(today.year - 1, 12, 31).date()
|
197 |
+
mask = (df['date'] >= last_year_start) & (df['date'] <= last_year_end)
|
198 |
+
result = df.loc[mask]
|
199 |
+
logger.info("get_full_last_year took %.2f seconds", time.time() - start_time)
|
200 |
+
return result
|
201 |
+
except Exception as e:
|
202 |
+
logger.error(f"Error filtering data for the last year: {e}")
|
203 |
+
return None
|
204 |
+
|
205 |
+
def calculate_ytd_performance(self, df: pd.DataFrame) -> Optional[float]:
|
206 |
+
start_time = time.time()
|
207 |
+
try:
|
208 |
+
today = datetime.today().date()
|
209 |
+
year_start = datetime(today.year, 1, 1).date()
|
210 |
+
mask = (df['date'] >= year_start) & (df['date'] <= today)
|
211 |
+
ytd_data = df.loc[mask]
|
212 |
+
opening_price = ytd_data.iloc[0]['open']
|
213 |
+
closing_price = ytd_data.iloc[-1]['close']
|
214 |
+
result = ((closing_price - opening_price) / opening_price) * 100
|
215 |
+
logger.info("calculate_ytd_performance took %.2f seconds", time.time() - start_time)
|
216 |
+
return result
|
217 |
+
except Exception as e:
|
218 |
+
logger.error(f"Error calculating YTD performance: {e}")
|
219 |
+
return None
|
220 |
+
|
221 |
+
def calculate_pe_ratio(self, current_price: float, eps: float) -> Optional[float]:
|
222 |
+
start_time = time.time()
|
223 |
+
try:
|
224 |
+
if eps == 0:
|
225 |
+
raise ValueError("EPS cannot be zero for P/E ratio calculation.")
|
226 |
+
result = current_price / eps
|
227 |
+
logger.info("calculate_pe_ratio took %.2f seconds", time.time() - start_time)
|
228 |
+
return result
|
229 |
+
except Exception as e:
|
230 |
+
logger.error(f"Error calculating P/E ratio: {e}")
|
231 |
+
return None
|
232 |
+
|
233 |
+
def fetch_google_snippet(self, query: str) -> Optional[str]:
|
234 |
+
try:
|
235 |
+
search_url = f"https://www.google.com/search?q={query}"
|
236 |
+
headers = {
|
237 |
+
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.36"
|
238 |
+
}
|
239 |
+
response = requests.get(search_url, headers=headers)
|
240 |
+
soup = BeautifulSoup(response.text, 'html.parser')
|
241 |
+
snippet_classes = [
|
242 |
+
'BNeawe iBp4i AP7Wnd',
|
243 |
+
'BNeawe s3v9rd AP7Wnd',
|
244 |
+
'BVG0Nb',
|
245 |
+
'kno-rdesc'
|
246 |
+
]
|
247 |
+
snippet = None
|
248 |
+
for cls in snippet_classes:
|
249 |
+
snippet = soup.find('div', class_=cls)
|
250 |
+
if snippet:
|
251 |
+
break
|
252 |
+
return snippet.get_text() if snippet else "Snippet not found."
|
253 |
+
except Exception as e:
|
254 |
+
logger.error(f"Error fetching Google snippet: {e}")
|
255 |
+
return None
|
256 |
+
|
257 |
+
|
258 |
+
def extract_ticker_from_response(response: str) -> Optional[str]:
|
259 |
+
start_time = time.time()
|
260 |
+
try:
|
261 |
+
if "is **" in response and "**." in response:
|
262 |
+
result = response.split("is **")[1].split("**.")[0].strip()
|
263 |
+
logger.info("extract_ticker_from_response took %.2f seconds", time.time() - start_time)
|
264 |
+
return result
|
265 |
+
result = response.strip()
|
266 |
+
logger.info("extract_ticker_from_response took %.2f seconds", time.time() - start_time)
|
267 |
+
return result
|
268 |
+
except Exception as e:
|
269 |
+
logger.error(f"Error extracting ticker from response: {e}")
|
270 |
+
return None
|
271 |
+
|
272 |
+
|
273 |
+
def detect_translate_entity_and_ticker(query: str) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
|
274 |
+
try:
|
275 |
+
start_time = time.time()
|
276 |
+
|
277 |
+
# Step 1: Detect Language
|
278 |
+
prompt = f"Detect the language for the following text: {query}"
|
279 |
+
response = invoke_llm(prompt)
|
280 |
+
detected_language = response.content.strip()
|
281 |
+
logger.info(f"Language detected: {detected_language}")
|
282 |
+
|
283 |
+
# Step 2: Translate to English (if necessary)
|
284 |
+
translated_query = query
|
285 |
+
if detected_language != "English":
|
286 |
+
prompt = f"Translate the following text to English: {query}"
|
287 |
+
response = invoke_llm(prompt)
|
288 |
+
translated_query = response.content.strip()
|
289 |
+
logger.info(f"Translation completed: {translated_query}")
|
290 |
+
print(f"Translation: {translated_query}")
|
291 |
+
|
292 |
+
# Step 3: Detect Entity
|
293 |
+
prompt = f"Detect the entity in the following text that is a company name: {translated_query}"
|
294 |
+
response = invoke_llm(prompt)
|
295 |
+
detected_entity = response.content.strip()
|
296 |
+
logger.info(f"Entity detected: {detected_entity}")
|
297 |
+
print(f"Entity: {detected_entity}")
|
298 |
+
|
299 |
+
if not detected_entity:
|
300 |
+
logger.error("No entity detected")
|
301 |
+
return detected_language, None, translated_query, None
|
302 |
+
|
303 |
+
# Step 4: Get Stock Ticker
|
304 |
+
prompt = f"What is the stock ticker symbol for the company {detected_entity}?"
|
305 |
+
response = invoke_llm(prompt)
|
306 |
+
stock_ticker = extract_ticker_from_response(response.content.strip())
|
307 |
+
|
308 |
+
if not stock_ticker:
|
309 |
+
logger.error("No stock ticker detected")
|
310 |
+
return detected_language, detected_entity, translated_query, None
|
311 |
+
|
312 |
+
logger.info("detect_translate_entity_and_ticker took %.2f seconds", time.time() - start_time)
|
313 |
+
return detected_language, detected_entity, translated_query, stock_ticker
|
314 |
+
except Exception as e:
|
315 |
+
logger.error(f"Error in detecting, translating, or extracting entity and ticker: {e}")
|
316 |
+
return None, None, None, None
|
317 |
+
|
318 |
+
|
319 |
+
def fetch_stock_data_yahoo(symbol: str) -> pd.DataFrame:
|
320 |
+
start_time = time.time()
|
321 |
+
try:
|
322 |
+
stock = yf.Ticker(symbol)
|
323 |
+
logger.info(f"Fetching data for symbol: {symbol}")
|
324 |
+
|
325 |
+
end_date = datetime.now()
|
326 |
+
start_date = end_date - timedelta(days=3 * 365)
|
327 |
+
|
328 |
+
historical_data = stock.history(start=start_date, end=end_date)
|
329 |
+
if historical_data.empty:
|
330 |
+
raise ValueError(f"No historical data found for symbol: {symbol}")
|
331 |
+
|
332 |
+
historical_data = historical_data.rename(
|
333 |
+
columns={"Open": "open", "High": "high", "Low": "low", "Close": "close", "Volume": "volume"}
|
334 |
+
)
|
335 |
+
|
336 |
+
historical_data.reset_index(inplace=True)
|
337 |
+
historical_data['date'] = historical_data['Date'].dt.date
|
338 |
+
historical_data = historical_data.drop(columns=['Date'])
|
339 |
+
historical_data = historical_data[['date', 'open', 'high', 'low', 'close', 'volume']]
|
340 |
+
|
341 |
+
if 'close' not in historical_data.columns:
|
342 |
+
raise KeyError("The historical data must contain a 'close' column.")
|
343 |
+
|
344 |
+
logger.info("fetch_stock_data_yahoo took %.2f seconds", time.time() - start_time)
|
345 |
+
return historical_data
|
346 |
+
except Exception as e:
|
347 |
+
logger.error(f"Failed to fetch stock data for {symbol} from Yahoo Finance: {e}")
|
348 |
+
return pd.DataFrame()
|
349 |
+
|
350 |
+
|
351 |
+
def fetch_current_stock_price(symbol: str) -> Optional[float]:
|
352 |
+
start_time = time.time()
|
353 |
+
try:
|
354 |
+
stock = yf.Ticker(symbol)
|
355 |
+
result = stock.info['currentPrice']
|
356 |
+
logger.info("fetch_current_stock_price took %.2f seconds", time.time() - start_time)
|
357 |
+
return result
|
358 |
+
except Exception as e:
|
359 |
+
logger.error(f"Failed to fetch current stock price for {symbol}: {e}")
|
360 |
+
return None
|
361 |
+
|
362 |
+
|
363 |
+
def format_stock_data_for_gemini(stock_data: pd.DataFrame) -> str:
|
364 |
+
start_time = time.time()
|
365 |
+
try:
|
366 |
+
if stock_data.empty:
|
367 |
+
return "No historical data available."
|
368 |
+
|
369 |
+
formatted_data = "Historical stock data for the last three years:\n\n"
|
370 |
+
formatted_data += "Date | Open | High | Low | Close | Volume\n"
|
371 |
+
formatted_data += "------------------------------------------------------\n"
|
372 |
+
|
373 |
+
for index, row in stock_data.iterrows():
|
374 |
+
formatted_data += f"{row['date']} | {row['open']:.2f} | {row['high']:.2f} | {row['low']:.2f} | {row['close']:.2f} | {int(row['volume'])}\n"
|
375 |
+
|
376 |
+
logger.info("format_stock_data_for_gemini took %.2f seconds", time.time() - start_time)
|
377 |
+
return formatted_data
|
378 |
+
except Exception as e:
|
379 |
+
logger.error(f"Error formatting stock data for Gemini: {e}")
|
380 |
+
return "Error formatting stock data."
|
381 |
+
|
382 |
+
|
383 |
+
def fetch_company_info_yahoo(symbol: str) -> Dict:
|
384 |
+
start_time = time.time()
|
385 |
+
try:
|
386 |
+
if not symbol:
|
387 |
+
return {"error": "Invalid symbol"}
|
388 |
+
|
389 |
+
stock = yf.Ticker(symbol)
|
390 |
+
company_info = stock.info
|
391 |
+
logger.info("fetch_company_info_yahoo took %.2f seconds", time.time() - start_time)
|
392 |
+
return {
|
393 |
+
"name": company_info.get("longName", "N/A"),
|
394 |
+
"sector": company_info.get("sector", "N/A"),
|
395 |
+
"industry": company_info.get("industry", "N/A"),
|
396 |
+
"marketCap": company_info.get("marketCap", "N/A"),
|
397 |
+
"summary": company_info.get("longBusinessSummary", "N/A"),
|
398 |
+
"website": company_info.get("website", "N/A"),
|
399 |
+
"address": company_info.get("address1", "N/A"),
|
400 |
+
"city": company_info.get("city", "N/A"),
|
401 |
+
"state": company_info.get("state", "N/A"),
|
402 |
+
"country": company_info.get("country", "N/A"),
|
403 |
+
"phone": company_info.get("phone", "N/A")
|
404 |
+
}
|
405 |
+
except Exception as e:
|
406 |
+
logger.error(f"Error fetching company info for {symbol}: {e}")
|
407 |
+
return {"error": str(e)}
|
408 |
+
|
409 |
+
|
410 |
+
def format_company_info_for_gemini(company_info: Dict) -> str:
|
411 |
+
start_time = time.time()
|
412 |
+
try:
|
413 |
+
if "error" in company_info:
|
414 |
+
return f"Error fetching company info: {company_info['error']}"
|
415 |
+
|
416 |
+
formatted_info = (f"\nCompany Information:\n"
|
417 |
+
f"Name: {company_info['name']}\n"
|
418 |
+
f"Sector: {company_info['sector']}\n"
|
419 |
+
f"Industry: {company_info['industry']}\n"
|
420 |
+
f"Market Cap: {company_info['marketCap']}\n"
|
421 |
+
f"Summary: {company_info['summary']}\n"
|
422 |
+
f"Website: {company_info['website']}\n"
|
423 |
+
f"Address: {company_info['address']}, {company_info['city']}, {company_info['state']}, {company_info['country']}\n"
|
424 |
+
f"Phone: {company_info['phone']}\n")
|
425 |
+
|
426 |
+
logger.info("format_company_info_for_gemini took %.2f seconds", time.time() - start_time)
|
427 |
+
return formatted_info
|
428 |
+
except Exception as e:
|
429 |
+
logger.error(f"Error formatting company info for Gemini: {e}")
|
430 |
+
return "Error formatting company info."
|
431 |
+
|
432 |
+
|
433 |
+
def fetch_company_news_yahoo(symbol: str) -> List[Dict]:
|
434 |
+
start_time = time.time()
|
435 |
+
try:
|
436 |
+
stock = yf.Ticker(symbol)
|
437 |
+
news = stock.news
|
438 |
+
if not news:
|
439 |
+
raise ValueError(f"No news found for symbol: {symbol}")
|
440 |
+
logger.info("fetch_company_news_yahoo took %.2f seconds", time.time() - start_time)
|
441 |
+
return news
|
442 |
+
except Exception as e:
|
443 |
+
logger.error(f"Failed to fetch news for {symbol} from Yahoo Finance: {e}")
|
444 |
+
return []
|
445 |
+
|
446 |
+
|
447 |
+
def format_company_news_for_gemini(news: List[Dict]) -> str:
|
448 |
+
start_time = time.time()
|
449 |
+
try:
|
450 |
+
if not news:
|
451 |
+
return "No news available."
|
452 |
+
|
453 |
+
formatted_news = "Latest company news:\n\n"
|
454 |
+
for article in news:
|
455 |
+
formatted_news += (f"Title: {article['title']}\n"
|
456 |
+
f"Publisher: {article['publisher']}\n"
|
457 |
+
f"Link: {article['link']}\n"
|
458 |
+
f"Published: {article['providerPublishTime']}\n\n")
|
459 |
+
|
460 |
+
logger.info("format_company_news_for_gemini took %.2f seconds", time.time() - start_time)
|
461 |
+
return formatted_news
|
462 |
+
except Exception as e:
|
463 |
+
logger.error(f"Error formatting company news for Gemini: {e}")
|
464 |
+
return "Error formatting company news."
|
465 |
+
|
466 |
+
|
467 |
+
def send_to_gemini_for_summarization(content: str) -> str:
|
468 |
+
start_time = time.time()
|
469 |
+
try:
|
470 |
+
unified_content = " ".join(content)
|
471 |
+
prompt = f"Summarize the main points of this article.\n\n{unified_content}"
|
472 |
+
response = invoke_llm(prompt)
|
473 |
+
logger.info("send_to_gemini_for_summarization took %.2f seconds", time.time() - start_time)
|
474 |
+
return response.content.strip()
|
475 |
+
except Exception as e:
|
476 |
+
logger.error(f"Error sending content to Gemini for summarization: {e}")
|
477 |
+
return "Error summarizing content."
|
478 |
+
|
479 |
+
|
480 |
+
def answer_question_with_data(question: str, data: Dict) -> str:
|
481 |
+
start_time = time.time()
|
482 |
+
try:
|
483 |
+
data_str = ""
|
484 |
+
for key, value in data.items():
|
485 |
+
data_str += f"{key}:\n{value}\n\n"
|
486 |
+
|
487 |
+
prompt = (f"You are a financial advisor. Begin your answer and only give the answer after.\n"
|
488 |
+
f"Using the following data, answer this question: {question}\n\nData:\n{data_str}\n"
|
489 |
+
f"Make your answer in the best form and professional.\n"
|
490 |
+
f"Don't say anything about the source of the data.\n"
|
491 |
+
f"If you don't have the data to answer, say this data is not available yet. If the data is not available in the stock history data, say this was a weekend and there is no data for it.")
|
492 |
+
response = invoke_llm(prompt)
|
493 |
+
logger.info("answer_question_with_data took %.2f seconds", time.time() - start_time)
|
494 |
+
return response.content.strip()
|
495 |
+
except Exception as e:
|
496 |
+
logger.error(f"Error answering question with data: {e}")
|
497 |
+
return "Error answering question."
|
498 |
+
|
499 |
+
|
500 |
+
def calculate_metrics(stock_data: pd.DataFrame, summarizer: DataSummarizer, company_info: Dict) -> Dict[str, str]:
|
501 |
+
start_time = time.time()
|
502 |
+
try:
|
503 |
+
moving_average = summarizer.calculate_moving_average(stock_data)
|
504 |
+
rsi = summarizer.calculate_rsi(stock_data)
|
505 |
+
ema = summarizer.calculate_ema(stock_data)
|
506 |
+
bollinger_bands = summarizer.calculate_bollinger_bands(stock_data)
|
507 |
+
macd = summarizer.calculate_macd(stock_data)
|
508 |
+
volatility = summarizer.calculate_volatility(stock_data)
|
509 |
+
atr = summarizer.calculate_atr(stock_data)
|
510 |
+
obv = summarizer.calculate_obv(stock_data)
|
511 |
+
yearly_summary = summarizer.calculate_yearly_summary(stock_data)
|
512 |
+
ytd_performance = summarizer.calculate_ytd_performance(stock_data)
|
513 |
+
|
514 |
+
eps = company_info.get('trailingEps', None)
|
515 |
+
if eps:
|
516 |
+
current_price = stock_data.iloc[-1]['close']
|
517 |
+
pe_ratio = summarizer.calculate_pe_ratio(current_price, eps)
|
518 |
+
formatted_metrics = {
|
519 |
+
"Moving Average": moving_average.to_string(),
|
520 |
+
"RSI": rsi.to_string(),
|
521 |
+
"EMA": ema.to_string(),
|
522 |
+
"Bollinger Bands": bollinger_bands.to_string(),
|
523 |
+
"MACD": macd.to_string(),
|
524 |
+
"Volatility": volatility.to_string(),
|
525 |
+
"ATR": atr.to_string(),
|
526 |
+
"OBV": obv.to_string(),
|
527 |
+
"Yearly Summary": yearly_summary.to_string(),
|
528 |
+
"YTD Performance": f"{ytd_performance:.2f}%",
|
529 |
+
"P/E Ratio": f"{pe_ratio:.2f}"
|
530 |
+
}
|
531 |
+
else:
|
532 |
+
formatted_metrics = {
|
533 |
+
"Moving Average": moving_average.to_string(),
|
534 |
+
"RSI": rsi.to_string(),
|
535 |
+
"EMA": ema.to_string(),
|
536 |
+
"Bollinger Bands": bollinger_bands.to_string(),
|
537 |
+
"MACD": macd.to_string(),
|
538 |
+
"Volatility": volatility.to_string(),
|
539 |
+
"ATR": atr.to_string(),
|
540 |
+
"OBV": obv.to_string(),
|
541 |
+
"Yearly Summary": yearly_summary.to_string(),
|
542 |
+
"YTD Performance": f"{ytd_performance:.2f}%"
|
543 |
+
}
|
544 |
+
|
545 |
+
logger.info("calculate_metrics took %.2f seconds", time.time() - start_time)
|
546 |
+
return formatted_metrics
|
547 |
+
except Exception as e:
|
548 |
+
logger.error(f"Error calculating metrics: {e}")
|
549 |
+
return {"Error": "Error calculating metrics"}
|
550 |
+
|
551 |
+
|
552 |
+
def prepare_data(formatted_stock_data: str, formatted_company_info: str, formatted_company_news: str,
|
553 |
+
google_results: str, formatted_metrics: Dict[str, str], google_snippet: str, rag_response: str) -> \
|
554 |
+
Dict[str, str]:
|
555 |
+
start_time = time.time()
|
556 |
+
collected_data = {
|
557 |
+
"Formatted Stock Data": formatted_stock_data,
|
558 |
+
"Formatted Company Info": formatted_company_info,
|
559 |
+
"Formatted Company News": formatted_company_news,
|
560 |
+
"Google Search Results": google_results,
|
561 |
+
"Google Snippet": google_snippet,
|
562 |
+
"RAG Response": rag_response,
|
563 |
+
"Calculations": formatted_metrics
|
564 |
+
}
|
565 |
+
collected_data.update(formatted_metrics)
|
566 |
+
logger.info("prepare_data took %.2f seconds", time.time() - start_time)
|
567 |
+
return collected_data
|
568 |
+
|
569 |
+
|
570 |
+
@app.route('/ask', methods=['POST'])
|
571 |
+
def ask():
|
572 |
+
try:
|
573 |
+
user_input = request.json.get('question')
|
574 |
+
logger.info(f"Received question: {user_input}")
|
575 |
+
|
576 |
+
summarizer = DataSummarizer()
|
577 |
+
|
578 |
+
# Detect language, entity, translation, and stock ticker
|
579 |
+
language, entity, translation, stock_ticker = detect_translate_entity_and_ticker(user_input)
|
580 |
+
|
581 |
+
logger.info(f"Detected Language: {language}, Entity: {entity}, Translation: {translation}, Stock Ticker: {stock_ticker}")
|
582 |
+
|
583 |
+
if entity and stock_ticker:
|
584 |
+
with ThreadPoolExecutor() as executor:
|
585 |
+
futures = {
|
586 |
+
executor.submit(fetch_stock_data_yahoo, stock_ticker): "stock_data",
|
587 |
+
executor.submit(fetch_company_info_yahoo, stock_ticker): "company_info",
|
588 |
+
executor.submit(fetch_company_news_yahoo, stock_ticker): "company_news",
|
589 |
+
executor.submit(fetch_current_stock_price, stock_ticker): "current_stock_price",
|
590 |
+
executor.submit(get_answer, user_input): "rag_response",
|
591 |
+
executor.submit(summarizer.google_search, user_input): "google_results",
|
592 |
+
executor.submit(summarizer.fetch_google_snippet, user_input): "google_snippet"
|
593 |
+
}
|
594 |
+
results = {futures[future]: future.result() for future in as_completed(futures)}
|
595 |
+
|
596 |
+
stock_data = results.get("stock_data", pd.DataFrame())
|
597 |
+
formatted_stock_data = format_stock_data_for_gemini(stock_data) if not stock_data.empty else "No historical data available."
|
598 |
+
|
599 |
+
company_info = results.get("company_info", {})
|
600 |
+
formatted_company_info = format_company_info_for_gemini(company_info) if company_info else "No company info available."
|
601 |
+
|
602 |
+
company_news = results.get("company_news", [])
|
603 |
+
formatted_company_news = format_company_news_for_gemini(company_news) if company_news else "No news available."
|
604 |
+
|
605 |
+
current_stock_price = results.get("current_stock_price", None)
|
606 |
+
|
607 |
+
formatted_metrics = calculate_metrics(stock_data, summarizer, company_info) if not stock_data.empty else {"Error": "No stock data for metrics"}
|
608 |
+
|
609 |
+
google_results = results.get("google_results", "No additional news found through Google Search.")
|
610 |
+
google_snippet = results.get("google_snippet", "Snippet not found.")
|
611 |
+
rag_response = results.get("rag_response", "No response from RAG.")
|
612 |
+
|
613 |
+
collected_data = prepare_data(formatted_stock_data, formatted_company_info, formatted_company_news,
|
614 |
+
google_results, formatted_metrics, google_snippet, rag_response)
|
615 |
+
collected_data["Current Stock Price"] = f"${current_stock_price:.2f}" if current_stock_price is not None else "N/A"
|
616 |
+
|
617 |
+
answer = answer_question_with_data(f"{translation}", collected_data)
|
618 |
+
|
619 |
+
return jsonify({"answer": answer})
|
620 |
+
|
621 |
+
else:
|
622 |
+
with ThreadPoolExecutor() as executor:
|
623 |
+
futures = {
|
624 |
+
executor.submit(get_answer, user_input): "rag_response",
|
625 |
+
executor.submit(summarizer.google_search, user_input): "google_results",
|
626 |
+
executor.submit(summarizer.fetch_google_snippet, user_input): "google_snippet"
|
627 |
+
}
|
628 |
+
results = {futures[future]: future.result() for future in as_completed(futures)}
|
629 |
+
|
630 |
+
google_results = results.get("google_results", "No additional news found through Google Search.")
|
631 |
+
google_snippet = results.get("google_snippet", "Snippet not found.")
|
632 |
+
rag_response = results.get("rag_response", "No response from RAG.")
|
633 |
+
|
634 |
+
collected_data = prepare_data("", "", "", google_results, {}, google_snippet, rag_response)
|
635 |
+
|
636 |
+
answer = answer_question_with_data(f"{user_input}", collected_data)
|
637 |
+
|
638 |
+
return jsonify({"answer": answer})
|
639 |
+
|
640 |
+
except Exception as e:
|
641 |
+
logger.error(f"An error occurred: {e}")
|
642 |
+
return jsonify({"error": "An error occurred while processing your request. Please try again later."}), 500
|
643 |
+
|
644 |
+
|
645 |
+
if __name__ == '__main__':
|
646 |
+
app.run(host='0.0.0.0', port=5000)
|
config.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from dotenv import load_dotenv
|
3 |
+
|
4 |
+
# Load environment variables from .env file
|
5 |
+
load_dotenv()
|
6 |
+
|
7 |
+
class Config:
|
8 |
+
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
|
9 |
+
ALPHA_VANTAGE_KEY = os.getenv("ALPHA_VANTAGE_KEY")
|
10 |
+
YAHOO_FINANCE_API_KEY = os.getenv("YAHOO_FINANCE_API_KEY")
|
11 |
+
FINNHUB_API_KEY = os.getenv("FINNHUB_API_KEY")
|
12 |
+
POLYGON_API_KEY = os.getenv("POLYGON_API_KEY")
|
13 |
+
SECRET_KEY = os.getenv("SECRET_KEY")
|
14 |
+
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
|
15 |
+
SEARCH_ENGINE_ID = os.getenv("SEARCH_ENGINE_ID")
|
16 |
+
# Add any additional configuration variables here
|
17 |
+
|
18 |
+
|
embeddings.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import google.generativeai as genai
|
2 |
+
from dotenv import load_dotenv
|
3 |
+
import os
|
4 |
+
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
5 |
+
from langchain_cohere import CohereEmbeddings
|
6 |
+
from langchain_openai import OpenAIEmbeddings
|
7 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
8 |
+
|
9 |
+
load_dotenv()
|
10 |
+
|
11 |
+
class Embeddings:
|
12 |
+
|
13 |
+
'''
|
14 |
+
google, models/embedding-001
|
15 |
+
openai, openai
|
16 |
+
cohere, cohere
|
17 |
+
hf, all-MiniLM-L6-v2
|
18 |
+
hf, BAAI/bge-large-en-v1.5
|
19 |
+
hf, Alibaba-NLP/gte-large-en-v1.5, True
|
20 |
+
...
|
21 |
+
...
|
22 |
+
'''
|
23 |
+
|
24 |
+
def __init__(self, emb, model, trust_remote=False, normalize = False):
|
25 |
+
self.emb=emb
|
26 |
+
self.model = model
|
27 |
+
self.trust_remote = trust_remote
|
28 |
+
self.normalize = normalize
|
29 |
+
self.embedding = self.get_embedding()
|
30 |
+
self.seq_len = self.get_emb_len()
|
31 |
+
|
32 |
+
def get_emb_len(self):
|
33 |
+
return len(self.embedding.embed_query('hi how are you'))
|
34 |
+
|
35 |
+
def google_embedding(self):
|
36 |
+
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
|
37 |
+
embeddings = GoogleGenerativeAIEmbeddings(model = self.model)
|
38 |
+
return embeddings
|
39 |
+
|
40 |
+
def openai_embedding(self):
|
41 |
+
embeddings_model = OpenAIEmbeddings(api_key=os.getenv("OPENAI_API_KEY"))
|
42 |
+
return embeddings_model
|
43 |
+
|
44 |
+
def cohere_embedding(self):
|
45 |
+
embeddings_model = CohereEmbeddings(cohere_api_key=os.getenv("COHERE_API_KEY"))
|
46 |
+
return embeddings_model
|
47 |
+
|
48 |
+
def hf_embedding(self):
|
49 |
+
model_args = {'trust_remote_code': True} if self.trust_remote else {}
|
50 |
+
encode_args = {'normalize_embeddings': True} if self.normalize else {}
|
51 |
+
embedding = HuggingFaceEmbeddings(model_name=self.model, model_kwargs = model_args, encode_kwargs = encode_args)
|
52 |
+
return embedding
|
53 |
+
|
54 |
+
def get_embedding(self):
|
55 |
+
if self.emb == 'google':
|
56 |
+
return self.google_embedding()
|
57 |
+
elif self.emb == 'openai':
|
58 |
+
return self.openai_embedding()
|
59 |
+
elif self.emb == 'cohere':
|
60 |
+
return self.cohere_embedding()
|
61 |
+
elif self.emb == 'hf':
|
62 |
+
return self.hf_embedding()
|
flasktest.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import requests
|
3 |
+
import json
|
4 |
+
|
5 |
+
|
6 |
+
def send_question_to_api(question):
|
7 |
+
url = 'http://localhost:5000/ask'
|
8 |
+
headers = {'Content-Type': 'application/json'}
|
9 |
+
data = {'question': question}
|
10 |
+
|
11 |
+
response = requests.post(url, headers=headers, data=json.dumps(data))
|
12 |
+
|
13 |
+
if response.status_code == 200:
|
14 |
+
return response.json().get('answer')
|
15 |
+
else:
|
16 |
+
return f"Error: {response.status_code} - {response.text}"
|
17 |
+
|
18 |
+
|
19 |
+
def main():
|
20 |
+
st.title("Financial Data Chatbot Tester")
|
21 |
+
|
22 |
+
st.write("Enter your question below and get a response from the chatbot.")
|
23 |
+
|
24 |
+
# Initialize session state to store question history
|
25 |
+
if 'history' not in st.session_state:
|
26 |
+
st.session_state.history = []
|
27 |
+
|
28 |
+
user_input = st.text_input("Your question:", "")
|
29 |
+
|
30 |
+
if st.button("Submit"):
|
31 |
+
if user_input:
|
32 |
+
with st.spinner('Getting the answer...'):
|
33 |
+
answer = send_question_to_api(user_input)
|
34 |
+
st.session_state.history.append((user_input, answer))
|
35 |
+
st.success(answer)
|
36 |
+
else:
|
37 |
+
st.warning("Please enter a question before submitting.")
|
38 |
+
|
39 |
+
# Display the history of questions and answers
|
40 |
+
if st.session_state.history:
|
41 |
+
st.write("### History")
|
42 |
+
for idx, (question, answer) in enumerate(st.session_state.history, 1):
|
43 |
+
st.write(f"**Q{idx}:** {question}")
|
44 |
+
st.write(f"**A{idx}:** {answer}")
|
45 |
+
st.write("---")
|
46 |
+
|
47 |
+
|
48 |
+
if __name__ == '__main__':
|
49 |
+
main()
|
index.html
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8">
|
5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
6 |
+
<title>Chatbot</title>
|
7 |
+
<style>
|
8 |
+
body {
|
9 |
+
font-family: Arial, sans-serif;
|
10 |
+
background-color: #f4f4f9;
|
11 |
+
margin: 40px;
|
12 |
+
text-align: center;
|
13 |
+
}
|
14 |
+
input[type="text"] {
|
15 |
+
width: 300px;
|
16 |
+
padding: 10px;
|
17 |
+
font-size: 16px;
|
18 |
+
margin-top: 20px;
|
19 |
+
border: 2px solid #ccc;
|
20 |
+
border-radius: 5px;
|
21 |
+
}
|
22 |
+
button {
|
23 |
+
background-color: #4CAF50;
|
24 |
+
color: white;
|
25 |
+
padding: 10px 20px;
|
26 |
+
margin-top: 10px;
|
27 |
+
border: none;
|
28 |
+
border-radius: 5px;
|
29 |
+
cursor: pointer;
|
30 |
+
font-size: 16px;
|
31 |
+
}
|
32 |
+
button:hover {
|
33 |
+
background-color: #45a049;
|
34 |
+
}
|
35 |
+
p {
|
36 |
+
margin-top: 20px;
|
37 |
+
font-size: 18px;
|
38 |
+
color: #333;
|
39 |
+
}
|
40 |
+
</style>
|
41 |
+
</head>
|
42 |
+
<body>
|
43 |
+
<h1>Chatbot Interface</h1>
|
44 |
+
<input type="text" id="question" placeholder="Ask a question...">
|
45 |
+
<button onclick="askQuestion()">Ask</button>
|
46 |
+
<p id="answer">Answer will appear here...</p>
|
47 |
+
|
48 |
+
<script>
|
49 |
+
async function askQuestion() {
|
50 |
+
const questionInput = document.getElementById('question');
|
51 |
+
const answerDisplay = document.getElementById('answer');
|
52 |
+
const question = questionInput.value;
|
53 |
+
|
54 |
+
const response = await fetch('/chat/', {
|
55 |
+
method: 'POST',
|
56 |
+
headers: {
|
57 |
+
'Content-Type': 'application/json'
|
58 |
+
},
|
59 |
+
body: JSON.stringify({ question: question })
|
60 |
+
});
|
61 |
+
if (response.ok) {
|
62 |
+
const data = await response.json();
|
63 |
+
answerDisplay.textContent = 'Answer: ' + data.answer;
|
64 |
+
} else {
|
65 |
+
answerDisplay.textContent = 'Error: Unable to fetch answer.';
|
66 |
+
}
|
67 |
+
}
|
68 |
+
</script>
|
69 |
+
</body>
|
70 |
+
</html>
|
llm.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
2 |
+
import google.generativeai as genai
|
3 |
+
from langchain.chat_models import ChatOpenAI
|
4 |
+
from langchain_groq import ChatGroq
|
5 |
+
import vertexai
|
6 |
+
from langchain_google_vertexai import ChatVertexAI
|
7 |
+
|
8 |
+
from dotenv import load_dotenv
|
9 |
+
import os
|
10 |
+
|
11 |
+
load_dotenv()
|
12 |
+
|
13 |
+
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
|
14 |
+
|
15 |
+
class LLM:
|
16 |
+
def __init__(self, llm, model=None):
|
17 |
+
if llm == 'gemini':
|
18 |
+
if model is None:
|
19 |
+
model = "gemini-pro"
|
20 |
+
self.llm = ChatGoogleGenerativeAI(model=model, temperature=0.3)
|
21 |
+
elif llm == 'vertex':
|
22 |
+
vertexai.init(project="website-254017", location="us-central1")
|
23 |
+
if model is None:
|
24 |
+
model = "gemini-1.5-pro-preview-0514"
|
25 |
+
self.llm = ChatVertexAI(model_name=model, temperature=0, max_tokens=8000)
|
26 |
+
elif llm == 'openai':
|
27 |
+
if model is None:
|
28 |
+
model = 'gpt-3.5-turbo-0125'
|
29 |
+
# ChatOpenAI(temperature=0, model="gpt-3.5-turbo-0125")
|
30 |
+
self.llm = ChatOpenAI(api_key=os.getenv("OPENAI_API_KEY"), model=model)
|
31 |
+
|
32 |
+
elif llm == 'mixtral':
|
33 |
+
model = "mixtral-8x7b-32768"
|
34 |
+
self.llm = ChatGroq(temperature=0, groq_api_key=os.getenv("GROK_API_KEY"), model_name=model)
|
35 |
+
|
36 |
+
elif llm == 'llama':
|
37 |
+
if model is None:
|
38 |
+
model = 'llama3-8b-8192'
|
39 |
+
self.llm = ChatGroq(temperature=0, groq_api_key=os.getenv("GROK_API_KEY"), model_name=model)
|
40 |
+
|
41 |
+
def get_llm(self):
|
42 |
+
return self.llm
|
43 |
+
|
44 |
+
|
45 |
+
|
logging_config.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging.config
|
2 |
+
|
3 |
+
def setup_logging():
|
4 |
+
logging_config = {
|
5 |
+
'version': 1,
|
6 |
+
'disable_existing_loggers': False,
|
7 |
+
'formatters': {
|
8 |
+
'standard': {
|
9 |
+
'format': '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
10 |
+
},
|
11 |
+
},
|
12 |
+
'handlers': {
|
13 |
+
'console': {
|
14 |
+
'level': 'DEBUG',
|
15 |
+
'class': 'logging.StreamHandler',
|
16 |
+
'formatter': 'standard',
|
17 |
+
},
|
18 |
+
'file': {
|
19 |
+
'level': 'DEBUG',
|
20 |
+
'class': 'logging.FileHandler',
|
21 |
+
'filename': 'financial_adviser.log',
|
22 |
+
'formatter': 'standard',
|
23 |
+
},
|
24 |
+
},
|
25 |
+
'loggers': {
|
26 |
+
'': {
|
27 |
+
'handlers': ['console', 'file'],
|
28 |
+
'level': 'DEBUG',
|
29 |
+
'propagate': True,
|
30 |
+
},
|
31 |
+
},
|
32 |
+
}
|
33 |
+
|
34 |
+
logging.config.dictConfig(logging_config)
|
35 |
+
|
36 |
+
# Initialize the logger
|
37 |
+
setup_logging()
|
38 |
+
logger = logging.getLogger(__name__)
|
main.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import os
|
3 |
+
from langchain.chains.conversation.memory import ConversationBufferWindowMemory
|
4 |
+
from data_extraction import Extraction
|
5 |
+
import nest_asyncio
|
6 |
+
from chunking import Chunker
|
7 |
+
from embeddings import Embeddings
|
8 |
+
from vectorstore import VectorDB
|
9 |
+
from retriever import Retriever, CreateBM25Retriever
|
10 |
+
from llm import LLM
|
11 |
+
from langchain_core.prompts import PromptTemplate
|
12 |
+
from chain import Chain
|
13 |
+
from streamlit_chat import message
|
14 |
+
|
15 |
+
if 'responses' not in st.session_state:
|
16 |
+
st.session_state['responses'] = ["How can I assist you?"]
|
17 |
+
|
18 |
+
if 'requests' not in st.session_state:
|
19 |
+
st.session_state['requests'] = []
|
20 |
+
|
21 |
+
if 'buffer_memory' not in st.session_state:
|
22 |
+
st.session_state.buffer_memory=ConversationBufferWindowMemory(k=3,return_messages=True)
|
23 |
+
|
24 |
+
nest_asyncio.apply()
|
25 |
+
ext = Extraction('fast')
|
26 |
+
chnk = Chunker(chunk_size=1000, chunk_overlap=200)
|
27 |
+
emb = Embeddings("hf", "all-MiniLM-L6-v2")
|
28 |
+
_llm = LLM('vertex').get_llm()
|
29 |
+
ch = Chain(_llm, st.session_state.buffer_memory)
|
30 |
+
conversation = ch.get_chain_with_history()
|
31 |
+
|
32 |
+
def query_refiner(conversation, query):
|
33 |
+
prompt=f"Given the following user query and historical user queries, rephrase the users current query to form a meaningful and clear question.Previously user has asked the following: \n{conversation}\n\n User's Current Query: {query}. What will be the refined query? Only provide the query without any extra details or explanations.",
|
34 |
+
ans = _llm.invoke(prompt).content
|
35 |
+
return ans
|
36 |
+
|
37 |
+
def get_conversation_string():
|
38 |
+
conversation_string = ""
|
39 |
+
for i in range(len(st.session_state['responses'])-1):
|
40 |
+
conversation_string += "Human: "+st.session_state['requests'][i] + "\n"
|
41 |
+
# conversation_string += "Bot: "+ st.session_state['responses'][i+1] + "\n"
|
42 |
+
return conversation_string
|
43 |
+
|
44 |
+
def main():
|
45 |
+
inp_dir = "./inputs"
|
46 |
+
db = 'pinecone'
|
47 |
+
db_dir = 'pineconedb'
|
48 |
+
st.set_page_config("Chat PDF")
|
49 |
+
st.header("Chat with PDF")
|
50 |
+
|
51 |
+
response_container = st.container()
|
52 |
+
textcontainer = st.container()
|
53 |
+
ret = None
|
54 |
+
with textcontainer:
|
55 |
+
query = st.text_input("Query: ", key="input")
|
56 |
+
if query:
|
57 |
+
if ret is None:
|
58 |
+
ret = Retriever(db, db_dir, emb.embedding, 'ensemble', 5)
|
59 |
+
with st.spinner("typing..."):
|
60 |
+
conversation_string = get_conversation_string()
|
61 |
+
if len(st.session_state['responses']) != 0:
|
62 |
+
refined_query = query_refiner(conversation_string, query)
|
63 |
+
else:
|
64 |
+
refined_query = query
|
65 |
+
st.subheader("Refined Query:")
|
66 |
+
st.write(refined_query)
|
67 |
+
context, context_list = ret.get_context(refined_query)
|
68 |
+
response = conversation.predict(input=f"Context:\n {context} \n\n Query:\n{query}")
|
69 |
+
# response += '\n' + "Source: " + src
|
70 |
+
st.session_state.requests.append(query)
|
71 |
+
st.session_state.responses.append(response)
|
72 |
+
|
73 |
+
with response_container:
|
74 |
+
if st.session_state['responses']:
|
75 |
+
for i in range(len(st.session_state['responses'])):
|
76 |
+
message(st.session_state['responses'][i],key=str(i))
|
77 |
+
if i < len(st.session_state['requests']):
|
78 |
+
message(st.session_state["requests"][i], is_user=True,key=str(i)+ '_user')
|
79 |
+
|
80 |
+
with st.sidebar:
|
81 |
+
st.title("Menu:")
|
82 |
+
pdf_docs = st.file_uploader("Upload your PDF Files and Click on the Submit & Process Button", accept_multiple_files=True)
|
83 |
+
pdfs = []
|
84 |
+
if pdf_docs:
|
85 |
+
for pdf_file in pdf_docs:
|
86 |
+
filename = pdf_file.name
|
87 |
+
path = os.path.join(inp_dir,filename)
|
88 |
+
with open(path, "wb") as f:
|
89 |
+
f.write(pdf_file.getvalue())
|
90 |
+
pdfs.append(path)
|
91 |
+
|
92 |
+
with st.spinner("Processing..."):
|
93 |
+
texts, metas = ext.get_text(pdfs)
|
94 |
+
docs = chnk.get_chunks(texts, metas)
|
95 |
+
vs = VectorDB(db, emb.embedding, db_dir, docs=docs)
|
96 |
+
bm = CreateBM25Retriever(docs)
|
97 |
+
st.success("Done")
|
98 |
+
|
99 |
+
if __name__ == "__main__":
|
100 |
+
main()
|
rag.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException
|
2 |
+
from pydantic import BaseModel
|
3 |
+
from embeddings import Embeddings
|
4 |
+
from chain import Chain
|
5 |
+
from llm import LLM
|
6 |
+
from retriever import Retriever
|
7 |
+
from fastapi import FastAPI, HTTPException
|
8 |
+
from fastapi.responses import HTMLResponse
|
9 |
+
from functools import lru_cache
|
10 |
+
from tools import *
|
11 |
+
import re
|
12 |
+
|
13 |
+
emb = Embeddings("hf", "all-MiniLM-L6-v2")
|
14 |
+
llm = LLM('gemini').get_llm()
|
15 |
+
ch = Chain(llm,None)
|
16 |
+
ret = Retriever('pinecone', 'pinecone', emb.embedding, 'ensemble', 5)
|
17 |
+
|
18 |
+
is_arabic = False
|
19 |
+
|
20 |
+
@lru_cache()
|
21 |
+
def investment_banker(query):
|
22 |
+
global is_arabic
|
23 |
+
context, context_list = ret.get_context(query)
|
24 |
+
if not is_arabic:
|
25 |
+
prompt_template = f"""
|
26 |
+
You are an investment banker and financial advisor.
|
27 |
+
Answer the question as detailed as possible from the provided context and make sure to provide all the details.
|
28 |
+
Answer only from the context. If the answer is not in provided context, say "Answer not in context".\n\n
|
29 |
+
Context:\n {context}\n\n
|
30 |
+
Question: \n{query}\n
|
31 |
+
|
32 |
+
Answer:
|
33 |
+
"""
|
34 |
+
else:
|
35 |
+
prompt_template = f"""
|
36 |
+
You are an investment banker and financial advisor.
|
37 |
+
Answer the question as detailed as possible from the provided context and make sure to provide all the details.
|
38 |
+
Answer only from the context. If the answer is not in provided context, say "Answer not in context".
|
39 |
+
Return the answer in Arabic only.\n\n
|
40 |
+
Context:\n {context}\n\n
|
41 |
+
Question: \n{query}\n
|
42 |
+
|
43 |
+
Answer:
|
44 |
+
"""
|
45 |
+
response = ch.run_conversational_chain(prompt_template)
|
46 |
+
is_arabic = False
|
47 |
+
return response
|
48 |
+
|
49 |
+
def check_arabic(s):
|
50 |
+
arabic_pattern = re.compile(r'[\u0600-\u06FF]')
|
51 |
+
if arabic_pattern.search(s):
|
52 |
+
return True
|
53 |
+
else:
|
54 |
+
return False
|
55 |
+
|
56 |
+
history = ""
|
57 |
+
|
58 |
+
@lru_cache()
|
59 |
+
def refine_query(query, conversation):
|
60 |
+
prompt=f"""Given the following user query and historical user conversation with banker.
|
61 |
+
If the current user query is in arabic, convert it to english and then proceed.
|
62 |
+
If conversation history is empty return the current query as it is.
|
63 |
+
If the query is a continuation of previous conversation then only rephrase the users current query to form a meaningful and clear question.
|
64 |
+
Otherwise return the user query as it is.
|
65 |
+
Previously user and banker had the following conversation: \n{conversation}\n\n User's Current Query: {query}.
|
66 |
+
What will be the refined query? Only provide the query without any extra details or explanations."""
|
67 |
+
ans = llm.invoke(prompt).content
|
68 |
+
return ans
|
69 |
+
|
70 |
+
|
71 |
+
def get_answer(query):
|
72 |
+
global history
|
73 |
+
global is_arabic
|
74 |
+
|
75 |
+
is_arabic = check_arabic(query)
|
76 |
+
ref_query = refine_query(query, history)
|
77 |
+
ans = investment_banker(ref_query)
|
78 |
+
history += "Human: "+ ref_query + "\n"
|
79 |
+
history += "Banker: "+ ans + "\n"
|
80 |
+
|
81 |
+
return ans
|
82 |
+
if __name__ == "__main__":
|
83 |
+
response = get_answer()
|
84 |
+
print(response)
|
85 |
+
# app = FastAPI()
|
86 |
+
|
87 |
+
# class Query(BaseModel):
|
88 |
+
# question: str
|
89 |
+
|
90 |
+
# @app.post("/chat/")
|
91 |
+
# async def chat(query: Query):
|
92 |
+
# global history
|
93 |
+
# global is_arabic
|
94 |
+
|
95 |
+
# try:
|
96 |
+
|
97 |
+
# is_arabic = check_arabic(query.question)
|
98 |
+
# ref_query = refine_query(query.question, history)
|
99 |
+
|
100 |
+
|
101 |
+
# print(query.question, ref_query)
|
102 |
+
# print(is_arabic)
|
103 |
+
# ans = investment_banker(ref_query)
|
104 |
+
# history += "Human: "+ ref_query + "\n"
|
105 |
+
# history += "Banker: "+ ans + "\n"
|
106 |
+
# return {"question": query.question, "answer": ans}
|
107 |
+
# except Exception as e:
|
108 |
+
# raise HTTPException(status_code=500, detail=str(e))
|
109 |
+
|
110 |
+
|
111 |
+
# @app.get("/", response_class=HTMLResponse)
|
112 |
+
# async def read_index():
|
113 |
+
# with open('index.html', 'r') as f:
|
114 |
+
# return HTMLResponse(content=f.read())
|
requirements.txt
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pydantic
|
2 |
+
langchain
|
3 |
+
yfinance
|
4 |
+
langchain_google_genai
|
5 |
+
langchain_openai
|
6 |
+
langchain_cohere
|
7 |
+
google-generativeai
|
8 |
+
langchain_groq
|
9 |
+
python-dotenv
|
10 |
+
vertexai
|
11 |
+
langchain_pinecone
|
12 |
+
qdrant_client
|
13 |
+
uvicorn
|
14 |
+
langchain-community
|
15 |
+
langchain_google_vertexai
|
16 |
+
sentence-transformers
|
17 |
+
rank_bm25
|
18 |
+
matplotlib
|
19 |
+
pandas
|
20 |
+
numpy
|
21 |
+
requests
|
22 |
+
spacy
|
23 |
+
transformers
|
24 |
+
torch
|
25 |
+
sentencepiece
|
26 |
+
streamlit
|
27 |
+
flask
|
28 |
+
bs4
|
29 |
+
tenacity
|
30 |
+
loguru
|
retriever.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.retrievers import BM25Retriever, EnsembleRetriever
|
2 |
+
from langchain.vectorstores import FAISS, Chroma, Qdrant
|
3 |
+
from qdrant_client import QdrantClient
|
4 |
+
from langchain_pinecone import PineconeVectorStore
|
5 |
+
import os
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
import pickle
|
8 |
+
|
9 |
+
load_dotenv()
|
10 |
+
|
11 |
+
class CreateBM25Retriever:
|
12 |
+
def __init__(self, docs):
|
13 |
+
self.bm25_retriever = BM25Retriever.from_documents(docs)
|
14 |
+
with open('bm25retriever.pkl', 'wb') as outp:
|
15 |
+
pickle.dump(self.bm25_retriever, outp, pickle.HIGHEST_PROTOCOL)
|
16 |
+
|
17 |
+
class Retriever:
|
18 |
+
def __init__(self, db,per_dir,embeddings, strategy, k, collection_name="mydocuments"):
|
19 |
+
self.db = db
|
20 |
+
self.strategy = strategy
|
21 |
+
self.per_dir = per_dir
|
22 |
+
if self.db == 'faiss':
|
23 |
+
self.db_ = FAISS.load_local(self.per_dir, embeddings, allow_dangerous_deserialization=True)
|
24 |
+
elif self.db == 'chroma':
|
25 |
+
self.db_ = Chroma(persist_directory=self.per_dir, embedding_function=embeddings)
|
26 |
+
elif self.db == 'qdrant':
|
27 |
+
self.db_ = Qdrant(client=QdrantClient(path=self.per_dir), collection_name=collection_name, embeddings=embeddings)
|
28 |
+
elif self.db == 'pinecone':
|
29 |
+
self.db_ = PineconeVectorStore(pinecone_api_key=os.getenv("PINECONE_API_KEY"),index_name=collection_name, embedding=embeddings)
|
30 |
+
self.retriever = self.db_.as_retriever(search_kwargs={"k": k})
|
31 |
+
|
32 |
+
if strategy == 'ensemble':
|
33 |
+
with open('bm25retriever.pkl', 'rb') as inp:
|
34 |
+
self.bm25_retriever = pickle.load(inp)
|
35 |
+
self.bm25_retriever.k = k
|
36 |
+
self.retriever = EnsembleRetriever(retrievers=[self.bm25_retriever, self.retriever],
|
37 |
+
weights=[0.4, 0.6])
|
38 |
+
|
39 |
+
def get_docs(self, query):
|
40 |
+
return self.retriever.get_relevant_documents(query)
|
41 |
+
|
42 |
+
def get_context(self, query):
|
43 |
+
docs = self.get_docs(query)
|
44 |
+
context = ""
|
45 |
+
context_list = []
|
46 |
+
# src = []
|
47 |
+
for txt in docs:
|
48 |
+
context += '\n\n'+txt.page_content + "\n" + "Source: "+txt.metadata['source']
|
49 |
+
context_list.append(txt.page_content)
|
50 |
+
# src.append(txt.metadata['source'])
|
51 |
+
# src = max(set(src), key=src.count)
|
52 |
+
return context, context_list
|
53 |
+
|
tools.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel
|
2 |
+
from datetime import datetime, timedelta
|
3 |
+
import yfinance as yf
|
4 |
+
from langchain.prompts import MessagesPlaceholder, ChatPromptTemplate, HumanMessagePromptTemplate, PromptTemplate, AIMessagePromptTemplate
|
5 |
+
from pydantic import BaseModel, Field
|
6 |
+
from langchain.tools import BaseTool
|
7 |
+
from typing import Optional, Type
|
8 |
+
from typing import List
|
9 |
+
from functools import lru_cache
|
10 |
+
|
11 |
+
|
12 |
+
@lru_cache()
|
13 |
+
def get_stock_price(symbol):
|
14 |
+
ticker = yf.Ticker(symbol)
|
15 |
+
todays_data = ticker.history(period='1d')
|
16 |
+
price = round(todays_data['Close'][0], 2)
|
17 |
+
currency = ticker.info['currency']
|
18 |
+
return price, currency
|
19 |
+
|
20 |
+
@lru_cache()
|
21 |
+
def get_stock_data_yahoo(ticker):
|
22 |
+
stock = yf.Ticker(ticker)
|
23 |
+
data = stock.history(period="1y")
|
24 |
+
return data
|
25 |
+
|
26 |
+
@lru_cache()
|
27 |
+
def get_company_profile_yahoo(ticker):
|
28 |
+
stock = yf.Ticker(ticker)
|
29 |
+
info = stock.info
|
30 |
+
profile = {
|
31 |
+
"name": info.get("shortName"),
|
32 |
+
"sector": info.get("sector"),
|
33 |
+
"industry": info.get("industry"),
|
34 |
+
"marketCap": info.get("marketCap"),
|
35 |
+
"website": info.get("website"),
|
36 |
+
"description": info.get("longBusinessSummary"),
|
37 |
+
}
|
38 |
+
return profile
|
39 |
+
|
40 |
+
@lru_cache()
|
41 |
+
def get_company_news_yahoo(ticker):
|
42 |
+
stock = yf.Ticker(ticker)
|
43 |
+
news = stock.news
|
44 |
+
return news
|
45 |
+
|
46 |
+
@lru_cache()
|
47 |
+
def get_price_change_percent(symbol, days_ago):
|
48 |
+
ticker = yf.Ticker(symbol)
|
49 |
+
end_date = datetime.now()
|
50 |
+
start_date = end_date - timedelta(days=days_ago)
|
51 |
+
|
52 |
+
# Convert dates to string format that yfinance can accept
|
53 |
+
start_date = start_date.strftime('%Y-%m-%d')
|
54 |
+
end_date = end_date.strftime('%Y-%m-%d')
|
55 |
+
|
56 |
+
historical_data = ticker.history(start=start_date, end=end_date)
|
57 |
+
|
58 |
+
old_price = historical_data['Close'].iloc[0]
|
59 |
+
new_price = historical_data['Close'].iloc[-1]
|
60 |
+
|
61 |
+
percent_change = ((new_price - old_price) / old_price) * 100
|
62 |
+
return round(percent_change, 2)
|
63 |
+
|
64 |
+
@lru_cache()
|
65 |
+
def calculate_performance(symbol, days_ago):
|
66 |
+
ticker = yf.Ticker(symbol)
|
67 |
+
end_date = datetime.now()
|
68 |
+
start_date = end_date - timedelta(days=days_ago)
|
69 |
+
start_date = start_date.strftime('%Y-%m-%d')
|
70 |
+
end_date = end_date.strftime('%Y-%m-%d')
|
71 |
+
historical_data = ticker.history(start=start_date, end=end_date)
|
72 |
+
old_price = historical_data['Close'].iloc[0]
|
73 |
+
new_price = historical_data['Close'].iloc[-1]
|
74 |
+
percent_change = ((new_price - old_price) / old_price) * 100
|
75 |
+
return round(percent_change, 2)
|
76 |
+
|
77 |
+
@lru_cache()
|
78 |
+
def get_best_performing(stocks, days_ago):
|
79 |
+
best_stock = None
|
80 |
+
best_performance = None
|
81 |
+
for stock in stocks:
|
82 |
+
try:
|
83 |
+
performance = calculate_performance(stock, days_ago)
|
84 |
+
if best_performance is None or performance > best_performance:
|
85 |
+
best_stock = stock
|
86 |
+
best_performance = performance
|
87 |
+
except Exception as e:
|
88 |
+
print(f"Could not calculate performance for {stock}: {e}")
|
89 |
+
return best_stock, best_performance
|
90 |
+
|
91 |
+
class StockPriceCheckInput(BaseModel):
|
92 |
+
"""Input for Stock price check."""
|
93 |
+
|
94 |
+
stockticker: str = Field(..., description="Ticker symbol for stock or index")
|
95 |
+
|
96 |
+
class StockPriceTool(BaseTool):
|
97 |
+
name = "get_stock_ticker_price"
|
98 |
+
description = "Useful for when you need to find out the price of the stock today. You should input the stock ticker used on the yfinance API"
|
99 |
+
|
100 |
+
def _run(self, stockticker: str):
|
101 |
+
# print("i'm running")
|
102 |
+
price_response, currency = get_stock_price(stockticker)
|
103 |
+
|
104 |
+
return f"{currency} {price_response}"
|
105 |
+
|
106 |
+
def _arun(self, stockticker: str):
|
107 |
+
raise NotImplementedError("This tool does not support async")
|
108 |
+
|
109 |
+
args_schema: Optional[Type[BaseModel]] = StockPriceCheckInput
|
110 |
+
|
111 |
+
class PrevYearStockTool(BaseTool):
|
112 |
+
name = "get_past_year_stock_data"
|
113 |
+
description = "Useful for when you need to find out the past 1 year performance of a stock. You should input the stock ticker used on the yfinance API"
|
114 |
+
|
115 |
+
def _run(self, stockticker: str):
|
116 |
+
price_response = get_stock_data_yahoo(stockticker)
|
117 |
+
return price_response
|
118 |
+
|
119 |
+
def _arun(self, stockticker: str):
|
120 |
+
raise NotImplementedError("This tool does not support async")
|
121 |
+
|
122 |
+
args_schema: Optional[Type[BaseModel]] = StockPriceCheckInput
|
123 |
+
|
124 |
+
class StockNewsTool(BaseTool):
|
125 |
+
name = "get_news_about_stock"
|
126 |
+
description = "Useful for when you need recent news related to a stock. You should input the stock ticker used on the yfinance API"
|
127 |
+
|
128 |
+
def _run(self, stockticker: str):
|
129 |
+
price_response = get_company_news_yahoo(stockticker)
|
130 |
+
return price_response
|
131 |
+
|
132 |
+
def _arun(self, stockticker: str):
|
133 |
+
raise NotImplementedError("This tool does not support async")
|
134 |
+
|
135 |
+
args_schema: Optional[Type[BaseModel]] = StockPriceCheckInput
|
136 |
+
|
137 |
+
class StockProfileTool(BaseTool):
|
138 |
+
name = "get_profile_of_stock"
|
139 |
+
description = "Useful for when you need details or profile of a stock. You should input the stock ticker used on the yfinance API"
|
140 |
+
|
141 |
+
def _run(self, stockticker: str):
|
142 |
+
price_response = get_company_profile_yahoo(stockticker)
|
143 |
+
return price_response
|
144 |
+
|
145 |
+
def _arun(self, stockticker: str):
|
146 |
+
raise NotImplementedError("This tool does not support async")
|
147 |
+
|
148 |
+
args_schema: Optional[Type[BaseModel]] = StockPriceCheckInput
|
149 |
+
|
150 |
+
class StockChangePercentageCheckInput(BaseModel):
|
151 |
+
"""Input for Stock ticker check. for percentage check"""
|
152 |
+
|
153 |
+
stockticker: str = Field(..., description="Ticker symbol for stock or index")
|
154 |
+
days_ago: int = Field(..., description="Int number of days to look back")
|
155 |
+
|
156 |
+
class StockPercentageChangeTool(BaseTool):
|
157 |
+
name = "get_price_change_percent"
|
158 |
+
description = "Useful for when you need to find out the performance or percentage change in a stock's value. You should input the stock ticker used on the yfinance API and also input the number of days to check the change over"
|
159 |
+
|
160 |
+
def _run(self, stockticker: str, days_ago: int):
|
161 |
+
price_change_response = get_price_change_percent(stockticker, days_ago)
|
162 |
+
|
163 |
+
return price_change_response
|
164 |
+
|
165 |
+
def _arun(self, stockticker: str, days_ago: int):
|
166 |
+
raise NotImplementedError("This tool does not support async")
|
167 |
+
|
168 |
+
args_schema: Optional[Type[BaseModel]] = StockChangePercentageCheckInput
|
169 |
+
|
170 |
+
class StockBestPerformingInput(BaseModel):
|
171 |
+
"""Input for Stock ticker check. for percentage check"""
|
172 |
+
|
173 |
+
stocktickers: List[str] = Field(..., description="Ticker symbols for stocks or indices")
|
174 |
+
days_ago: int = Field(..., description="Int number of days to look back")
|
175 |
+
|
176 |
+
class StockGetBestPerformingTool(BaseTool):
|
177 |
+
name = "get_best_performing"
|
178 |
+
description = "Useful for when you need to the performance of multiple stocks over a period. You should input a list of stock tickers used on the yfinance API and also input the number of days to check the change over"
|
179 |
+
|
180 |
+
def _run(self, stocktickers: List[str], days_ago: int):
|
181 |
+
price_change_response = get_best_performing(stocktickers, days_ago)
|
182 |
+
|
183 |
+
return price_change_response
|
184 |
+
|
185 |
+
def _arun(self, stockticker: List[str], days_ago: int):
|
186 |
+
raise NotImplementedError("This tool does not support async")
|
187 |
+
|
188 |
+
args_schema: Optional[Type[BaseModel]] = StockBestPerformingInput
|