trading-ai / app.py
rahilv's picture
Update app.py
b01dea4 verified
import streamlit as st
import requests
import yfinance as yf
import pandas as pd
import datetime
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from plotly import graph_objs as go
# Load the Hugging Face model and tokenizer
@st.cache_resource
def load_model():
model_name = "rahilv/financial-sentiment-model-roberta-3"
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
return model, tokenizer
sentiment_model, sentiment_tokenizer = load_model()
LABEL_MAP = {0: "bullish", 1: "bearish", 2: "neutral"}
# Function to fetch stock news
def get_stock_news(ticker):
api_key = "d651109fae5346cbbb6812912c801e73"
url = f'https://newsapi.org/v2/everything?q={ticker}&apiKey={api_key}'
response = requests.get(url)
if response.status_code == 200:
articles = response.json().get('articles', [])
return articles
else:
st.error(f"Error fetching news: {response.status_code}")
return []
# Function to analyze sentiment
def classify_sentiment(news_title):
inputs = sentiment_tokenizer(news_title, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
with torch.no_grad():
outputs = sentiment_model(**inputs)
predictions = torch.argmax(outputs.logits, dim=-1).item()
return LABEL_MAP[predictions]
# Function to fetch stock data from Yahoo Finance
def fetch_stock_data(ticker):
stock = yf.Ticker(ticker)
end_date = datetime.date.today()
start_date = end_date - datetime.timedelta(days=365) # 1 year of data
data = stock.history(start=start_date, end=end_date)
return data
# Streamlit UI
st.title("Stock Analysis and Sentiment App")
# User input for stock ticker symbol
ticker_symbol = st.text_input("Enter a Stock Ticker Symbol (e.g., AAPL, TSLA, GOOGL):")
if ticker_symbol:
st.subheader(f"Analysis for {ticker_symbol.upper()}")
# Fetch news
articles = get_stock_news(ticker_symbol)
if articles:
# Create a DataFrame for chart points
news_points = []
st.write("## Stock Price Chart")
# Fetch and plot stock data
stock_data = fetch_stock_data(ticker_symbol)
if not stock_data.empty:
fig = go.Figure()
fig.add_trace(go.Scatter(x=stock_data.index, y=stock_data['Close'], mode='lines', name='Close Price'))
for article in articles:
date = article['publishedAt'][:10] # Extract date
title = article['title']
sentiment = classify_sentiment(title)
news_points.append({'date': date, 'title': title, 'sentiment': sentiment})
color = 'green' if sentiment == 'bullish' else 'red' if sentiment == 'bearish' else 'gray'
if date in stock_data.index:
fig.add_trace(go.Scatter(x=[date], y=[stock_data['Close'][date]],
mode='markers', marker=dict(color=color, size=10)))
fig.update_layout(title=f"{ticker_symbol.upper()} Stock Price & News Sentiment", xaxis_title="Date", yaxis_title="Price")
st.plotly_chart(fig)
else:
st.write("No stock data available.")
st.write("## News Analysis")
for point in news_points:
color = 'green' if point['sentiment'] == 'bullish' else 'red' if point['sentiment'] == 'bearish' else 'gray'
st.markdown(f"<div style='border-left: 5px solid {color}; padding: 10px; margin: 10px 0;'>"
f"<b>{point['title']}</b><br>{point['date']}</div>", unsafe_allow_html=True)
# Recommendation based on sentiment
sentiments = [p['sentiment'] for p in news_points]
recommendation = "hold"
if sentiments.count('bullish') > sentiments.count('bearish'):
recommendation = "buy"
elif sentiments.count('bearish') > sentiments.count('bullish'):
recommendation = "sell"
color_map = {"buy": "green", "sell": "red", "hold": "gray"}
st.markdown(f"### Recommendation: <span style='color: {color_map[recommendation]}; font-size: 1.5em;'>{recommendation.upper()}</span>", unsafe_allow_html=True)
else:
st.write("No news articles found.")