earnliners commited on
Commit
eb44f1d
·
verified ·
1 Parent(s): d858238

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +228 -1
app.py CHANGED
@@ -9,7 +9,234 @@ from sklearn.ensemble import RandomForestClassifier
9
  import shap
10
  import matplotlib.pyplot as plt
11
 
12
- [Previous code remains unchanged until the end of run_analysis() function]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  # Run analysis automatically on page load
15
  run_analysis()
 
9
  import shap
10
  import matplotlib.pyplot as plt
11
 
12
+ class DataFetcher:
13
+ """Fetches historical financial data using yfinance."""
14
+ def __init__(self, ticker, nb_days):
15
+ self.ticker = ticker
16
+ self.nb_days = nb_days
17
+ self.data = None
18
+
19
+ def fetch_data(self):
20
+ """Fetches historical data for the specified ticker and number of days."""
21
+ end_date = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
22
+ start_date = end_date - timedelta(days=self.nb_days)
23
+ end_date = end_date + timedelta(days=1)
24
+ self.data = yf.download(self.ticker, start=start_date, end=end_date, interval="1h")
25
+ return self.data
26
+
27
+ class FinancialDataProcessor:
28
+ """Processes financial data to calculate returns, scenarios, and probabilities."""
29
+ def __init__(self, data):
30
+ self.data = data.copy()
31
+
32
+ def _flatten_columns(self):
33
+ """Flattens MultiIndex columns into a single level."""
34
+ if isinstance(self.data.columns, pd.MultiIndex):
35
+ self.data.columns = [f"{col[0]}_{col[1]}" if col[1] else col[0] for col in self.data.columns]
36
+
37
+ def calculate_returns(self):
38
+ """Calculates logarithmic returns, scenarios, and adjusted returns."""
39
+ self._flatten_columns()
40
+
41
+ close_column = [col for col in self.data.columns if 'Close' in col]
42
+ if not close_column:
43
+ raise ValueError("The 'Close' column is missing in the dataset.")
44
+
45
+ self.data.rename(columns={close_column[0]: 'Close'}, inplace=True)
46
+ self.data = self.data[self.data['Close'] > 0].copy()
47
+
48
+ self.data['LogReturn'] = np.log(self.data['Close'] / self.data['Close'].shift(1))
49
+ self.data.replace([np.inf, -np.inf], np.nan, inplace=True)
50
+ self.data.dropna(subset=['LogReturn'], inplace=True)
51
+
52
+ self.data['Scenario'] = np.where(self.data['LogReturn'] > 0, 'Buy', 'Sell')
53
+ self.data['AdjustedLogReturn'] = np.where(
54
+ self.data['Scenario'] == 'Sell', -self.data['LogReturn'], self.data['LogReturn']
55
+ )
56
+ self.data['AnnualizedReturn'] = self.data['AdjustedLogReturn'] * 252
57
+
58
+ return self
59
+
60
+ def calculate_probabilities(self):
61
+ """Calculates Buy% and Sell% using hyperbolic tangent."""
62
+ self.data['Buy%'] = (1 + np.tanh(self.data['LogReturn'])) / 2
63
+ self.data['Sell%'] = (1 - np.tanh(self.data['LogReturn'])) / 2
64
+ return self.data
65
+
66
+ def apply_pca_calculations(self, pca_result):
67
+ """Applies PCA-based calculations to the data."""
68
+ pca_result = pca_result[pca_result['PC1'] > 0].copy()
69
+
70
+ pca_result['PCA_LogReturn'] = np.log(pca_result['PC1'] / pca_result['PC1'].shift(1))
71
+ pca_result.replace([np.inf, -np.inf], np.nan, inplace=True)
72
+ pca_result.dropna(subset=['PCA_LogReturn'], inplace=True)
73
+
74
+ pca_result['PCA_Scenario'] = np.where(pca_result['PCA_LogReturn'] > 0, 'Buy', 'Sell')
75
+ pca_result['PCA_Buy%'] = (1 + np.tanh(pca_result['PCA_LogReturn'])) / 2
76
+ pca_result['PCA_Sell%'] = (1 - np.tanh(pca_result['PCA_LogReturn'])) / 2
77
+
78
+ self.data = self.data.merge(pca_result, left_index=True, right_index=True)
79
+ return self.data
80
+
81
+ class PCATransformer:
82
+ """Applies PCA to reduce dimensionality and extract features."""
83
+ def __init__(self, n_components=1):
84
+ self.n_components = n_components
85
+ self.scaler = StandardScaler()
86
+ self.pca = PCA(n_components=n_components)
87
+
88
+ def fit_transform(self, data):
89
+ numeric_data = data.select_dtypes(include=[np.number])
90
+ scaled_data = self.scaler.fit_transform(numeric_data)
91
+ pca_result = self.pca.fit_transform(scaled_data)
92
+ return pd.DataFrame(pca_result, columns=[f'PC{i+1}' for i in range(self.n_components)], index=data.index)
93
+
94
+ class StrategyBuilder:
95
+ """Builds and refines the trading strategy using machine learning and SHAP."""
96
+ def __init__(self, data):
97
+ self.data = data.copy()
98
+
99
+ def train_model(self, target_column):
100
+ X = self.data.select_dtypes(include=[np.number])
101
+ y = self.data[target_column]
102
+ y_encoded = LabelEncoder().fit_transform(y)
103
+ model = RandomForestClassifier(n_estimators=100, random_state=42)
104
+ model.fit(X, y_encoded)
105
+ return model, X, y_encoded
106
+
107
+ def compute_shapley_values(self, model, X):
108
+ explainer = shap.TreeExplainer(model)
109
+ return explainer.shap_values(X)
110
+
111
+ def analyze_feature_importance(self, shap_values, feature_names):
112
+ """Analyzes feature importance based on SHAP values."""
113
+ if isinstance(shap_values, list):
114
+ shap_values = shap_values[1]
115
+
116
+ if len(shap_values.shape) == 3:
117
+ shap_values = shap_values[:, :, 1]
118
+
119
+ mean_abs_shap = np.mean(np.abs(shap_values), axis=0)
120
+
121
+ if len(mean_abs_shap) != len(feature_names):
122
+ raise ValueError("Mismatch between SHAP values and feature names.")
123
+
124
+ feature_importance = pd.DataFrame({
125
+ 'Feature': feature_names,
126
+ 'Mean_Abs_SHAP': mean_abs_shap
127
+ }).sort_values(by='Mean_Abs_SHAP', ascending=False)
128
+
129
+ return feature_importance
130
+
131
+ def refine_thresholds(self, feature_importance, buy_threshold=0.5, sell_threshold=0.5):
132
+ top_features = feature_importance.head(3)['Feature'].tolist()
133
+ for feature in top_features:
134
+ if 'Buy%' in feature or 'PCA_Buy%' in feature:
135
+ buy_threshold *= 1.1
136
+ elif 'Sell%' in feature or 'PCA_Sell%' in feature:
137
+ sell_threshold *= 1.1
138
+ return buy_threshold, sell_threshold
139
+
140
+ class Backtester:
141
+ """Backtests the trading strategy on historical data."""
142
+ def __init__(self, data):
143
+ self.data = data.copy()
144
+
145
+ def backtest(self, buy_threshold=0.5, sell_threshold=0.5):
146
+ portfolio_value = 10000
147
+ position = None
148
+ entry_price = None
149
+ portfolio_values = []
150
+
151
+ for i in range(1, len(self.data)):
152
+ last_row = self.data.iloc[i]
153
+ if (last_row['PCA_Scenario'] == 'Buy' and last_row['PCA_Buy%'] > buy_threshold) or \
154
+ (last_row['Scenario'] == 'Buy' and last_row['Buy%'] > buy_threshold):
155
+ if position != 'Buy':
156
+ position = 'Buy'
157
+ entry_price = last_row['Close']
158
+ elif (last_row['PCA_Scenario'] == 'Sell' and last_row['PCA_Sell%'] > sell_threshold) or \
159
+ (last_row['Scenario'] == 'Sell' and last_row['Sell%'] > sell_threshold):
160
+ if position != 'Sell':
161
+ position = 'Sell'
162
+ entry_price = last_row['Close']
163
+
164
+ if position == 'Buy':
165
+ portfolio_value *= (last_row['Close'] / entry_price)
166
+ elif position == 'Sell':
167
+ portfolio_value *= (entry_price / last_row['Close'])
168
+
169
+ portfolio_values.append(portfolio_value)
170
+
171
+ return portfolio_values, position, entry_price
172
+
173
+ def run_analysis():
174
+ """Runs the complete trading analysis."""
175
+ try:
176
+ fetcher = DataFetcher(ticker="BTC-USD", nb_days=50)
177
+ data = fetcher.fetch_data()
178
+
179
+ processor = FinancialDataProcessor(data)
180
+ processed_data = processor.calculate_returns().calculate_probabilities()
181
+
182
+ pca_transformer = PCATransformer(n_components=1)
183
+ pca_result = pca_transformer.fit_transform(processed_data)
184
+ processed_data = processor.apply_pca_calculations(pca_result)
185
+
186
+ strategy_builder = StrategyBuilder(processed_data)
187
+ model, X, y_encoded = strategy_builder.train_model(target_column='PCA_Scenario')
188
+ shap_values = strategy_builder.compute_shapley_values(model, X)
189
+
190
+ feature_importance = strategy_builder.analyze_feature_importance(shap_values, X.columns)
191
+ buy_threshold, sell_threshold = strategy_builder.refine_thresholds(feature_importance)
192
+
193
+ backtester = Backtester(processed_data)
194
+ portfolio_values, final_position, entry_price = backtester.backtest(buy_threshold, sell_threshold)
195
+
196
+ last_row = processed_data.iloc[-1]
197
+
198
+ # Display results
199
+ col1, col2 = st.columns(2)
200
+
201
+ with col1:
202
+ st.subheader("Current Position")
203
+ st.metric("Portfolio Value", f"${portfolio_values[-1]:.2f}")
204
+ st.metric("Position", final_position or "No position")
205
+ if final_position:
206
+ st.metric("Entry Price", f"${entry_price:.2f}")
207
+ st.metric("Latest Close", f"${last_row['Close']:.2f}")
208
+
209
+ with col2:
210
+ st.subheader("Decision Metrics")
211
+ st.metric("Buy%", f"{last_row['Buy%']:.4f}")
212
+ st.metric("Sell%", f"{last_row['Sell%']:.4f}")
213
+ st.metric("PCA Buy%", f"{last_row['PCA_Buy%']:.4f}")
214
+ st.metric("PCA Sell%", f"{last_row['PCA_Sell%']:.4f}")
215
+
216
+ # Plot portfolio value
217
+ st.subheader("Portfolio Value Over Time")
218
+ fig, ax = plt.subplots(figsize=(12, 6))
219
+ ax.plot(processed_data.index[1:], portfolio_values, label='Portfolio Value', color='blue')
220
+ ax.set_title('Portfolio Value Over Time (Backtest)')
221
+ ax.set_xlabel('Date')
222
+ ax.set_ylabel('Portfolio Value ($)')
223
+ ax.grid(True)
224
+ ax.legend()
225
+ st.pyplot(fig)
226
+
227
+ # Feature importance
228
+ st.subheader("Feature Importance")
229
+ st.dataframe(feature_importance)
230
+
231
+ return True
232
+
233
+ except Exception as e:
234
+ st.error(f"An error occurred: {str(e)}")
235
+ return False
236
+
237
+ # Page configuration
238
+ st.set_page_config(page_title="Crypto Trading Bot", layout="wide")
239
+ st.title("Crypto Trading Analysis Bot")
240
 
241
  # Run analysis automatically on page load
242
  run_analysis()