aidevhund commited on
Commit
47f3474
·
verified ·
1 Parent(s): 9016b28

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -72
app.py CHANGED
@@ -1,14 +1,13 @@
1
  import gradio as gr
2
  import cv2
3
  import numpy as np
4
- from PIL import Image, ImageDraw
5
  import warnings
6
- import io
7
  import torch
8
  import base64
 
9
  import os
10
- import matplotlib.pyplot as plt
11
- from scipy import stats
12
  from sklearn.cluster import DBSCAN
13
  from transformers import (
14
  AutoModelForObjectDetection,
@@ -16,6 +15,8 @@ from transformers import (
16
  pipeline
17
  )
18
  from huggingface_hub import InferenceClient
 
 
19
 
20
  warnings.filterwarnings("ignore")
21
 
@@ -34,6 +35,7 @@ torch.set_num_threads(os.cpu_count() or 8)
34
  # Model configurations
35
  DETECTION_MODEL = "facebook/detr-resnet-50"
36
  LLM_MODEL_NAME = "meta-llama/Meta-Llama-3-70B-Instruct"
 
37
 
38
  # Initialize models
39
  detection_processor = DetrImageProcessor.from_pretrained(DETECTION_MODEL)
@@ -47,6 +49,9 @@ You are a senior cryptocurrency trading analyst with 15 years experience. Analyz
47
  Technical Elements Detected:
48
  {technical_analysis}
49
 
 
 
 
50
  User Query: {question}
51
 
52
  Provide detailed professional analysis covering:
@@ -67,6 +72,43 @@ def adaptive_resize(image):
67
  scale = MAX_SIZE / max(height, width)
68
  return image.resize((int(width*scale), int(height*scale)), Image.LANCZOS)
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  def enhance_contrast(img):
71
  lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
72
  l, a, b = cv2.split(lab)
@@ -75,47 +117,13 @@ def enhance_contrast(img):
75
  merged = cv2.merge([limg, a, b])
76
  return cv2.cvtColor(merged, cv2.COLOR_LAB2RGB)
77
 
78
- def detect_lines(image):
79
- gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
80
- edges = cv2.Canny(gray, *CANNY_THRESHOLDS)
81
- lines = cv2.HoughLinesP(edges, 1, np.pi/180, *HOUGH_PARAMS)
82
- return lines
83
-
84
- def cluster_lines(lines):
85
- if lines is None:
86
- return []
87
- points = lines.reshape(-1, 2)
88
- clustering = DBSCAN(eps=DBSCAN_EPS, min_samples=MIN_SAMPLES).fit(points)
89
- return clustering.labels_
90
-
91
- def calculate_slope(line):
92
- x1, y1, x2, y2 = line
93
- return (y2 - y1) / (x2 - x1) if (x2 - x1) != 0 else np.inf
94
-
95
- def detect_key_levels(image):
96
- gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
97
- hist = cv2.calcHist([gray], [0], None, [256], [0,256])
98
- peaks, _ = find_peaks(hist.flatten(), distance=10, prominence=50)
99
- return [p for p in peaks if 10 < p < 240]
100
-
101
- def analyze_volume_profile(image):
102
- gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
103
- return cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)[1]
104
-
105
- def detect_candlesticks(image):
106
- edges = cv2.Canny(image, 50, 150)
107
- contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
108
- candles = []
109
- for cnt in contours:
110
- x,y,w,h = cv2.boundingRect(cnt)
111
- if 5 < w < 50 and 10 < h < 200:
112
- candles.append((x,y,w,h))
113
- return candles
114
-
115
  def detect_chart_elements(image):
116
  image_np = np.array(image)
117
  enhanced = enhance_contrast(image_np)
118
 
 
 
 
119
  # Deep Learning Detection
120
  inputs = detection_processor(images=Image.fromarray(enhanced), return_tensors="pt")
121
  with torch.no_grad():
@@ -137,6 +145,10 @@ def detect_chart_elements(image):
137
 
138
  draw = ImageDraw.Draw(image)
139
 
 
 
 
 
140
  # Process DL detections
141
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
142
  box = [round(i, 2) for i in box.tolist()]
@@ -145,34 +157,32 @@ def detect_chart_elements(image):
145
  draw.rectangle(box, outline="#FF0000", width=3)
146
  draw.text((box[0], box[1]), f"{label_name} ({score:.2f})", fill="#FF0000")
147
 
148
- # Traditional CV Detection
149
- lines = detect_lines(enhanced)
 
 
 
 
150
  if lines is not None:
151
- clusters = cluster_lines(lines)
152
- for i, line in enumerate(lines):
153
  x1, y1, x2, y2 = line[0]
154
- slope = calculate_slope(line[0])
155
- length = np.sqrt((x2-x1)**2 + (y2-y1)**2)
156
 
157
- if abs(slope) < 0.1 and length > 100:
158
- elements['support_resistance'].append(f"Key Level at y={y1}")
 
 
 
 
159
  draw.line((x1, y1, x2, y2), fill="#00FF00", width=3)
160
- elif 0.1 < abs(slope) < 5:
161
- elements['trendlines'].append(f"Trendline ({'Up' if slope < 0 else 'Down'})")
162
  draw.line((x1, y1, x2, y2), fill="#0000FF", width=3)
 
163
 
164
- # Volume Profile Analysis
165
- volume_profile = analyze_volume_profile(enhanced)
166
- contours, _ = cv2.findContours(volume_profile, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
167
- for cnt in contours:
168
- if cv2.contourArea(cnt) > 1000:
169
- x,y,w,h = cv2.boundingRect(cnt)
170
- elements['support_resistance'].append(f"Volume Cluster at {y+h//2}")
171
- draw.rectangle([x,y,x+w,y+h], outline="#FFA500", width=2)
172
-
173
- return image, elements
174
 
175
- def generate_technical_report(elements):
176
  report = []
177
  if elements['support_resistance']:
178
  report.append("**Key Levels**: " + ", ".join(elements['support_resistance'][:5]))
@@ -180,19 +190,35 @@ def generate_technical_report(elements):
180
  report.append("**Trend Analysis**: " + ", ".join(elements['trendlines']))
181
  if elements['patterns']:
182
  report.append("**Chart Patterns**: " + ", ".join(elements['patterns']))
 
 
 
 
 
183
  return "\n".join(report)
184
 
185
  def respond(message, history, image):
186
- if image is None:
187
- return "Please upload a cryptocurrency chart for analysis."
 
188
 
189
  try:
190
- processed_img = adaptive_resize(image)
191
- annotated_img, elements = detect_chart_elements(processed_img)
192
- tech_report = generate_technical_report(elements)
 
 
 
 
 
 
 
 
 
193
 
194
  full_prompt = SYSTEM_PROMPT.format(
195
  technical_analysis=tech_report,
 
196
  question=message
197
  )
198
 
@@ -204,12 +230,14 @@ def respond(message, history, image):
204
  seed=42
205
  )
206
 
207
- img_base64 = base64.b64encode(annotated_img.tobytes()).decode('utf-8')
208
- img_html = f'<div style="border: 2px solid #4CAF50; padding: 10px; margin-bottom: 20px;">' \
209
- f'<img src="data:image/png;base64,{img_base64}" style="max-width: 100%;">' \
210
- f'</div>'
 
 
211
 
212
- return f"{img_html}\n{response.split('<|assistant|>')[-1].strip()}"
213
 
214
  except Exception as e:
215
  return f"⚠️ Advanced Analysis Error: {str(e)}"
 
1
  import gradio as gr
2
  import cv2
3
  import numpy as np
4
+ from PIL import Image, ImageDraw, ImageFont
5
  import warnings
 
6
  import torch
7
  import base64
8
+ import io
9
  import os
10
+ import pytesseract
 
11
  from sklearn.cluster import DBSCAN
12
  from transformers import (
13
  AutoModelForObjectDetection,
 
15
  pipeline
16
  )
17
  from huggingface_hub import InferenceClient
18
+ import matplotlib.pyplot as plt
19
+ from scipy import stats
20
 
21
  warnings.filterwarnings("ignore")
22
 
 
35
  # Model configurations
36
  DETECTION_MODEL = "facebook/detr-resnet-50"
37
  LLM_MODEL_NAME = "meta-llama/Meta-Llama-3-70B-Instruct"
38
+ OCR_CONFIG = r'--oem 3 --psm 6 -c tessedit_char_whitelist=0123456789.$€£¥%'
39
 
40
  # Initialize models
41
  detection_processor = DetrImageProcessor.from_pretrained(DETECTION_MODEL)
 
49
  Technical Elements Detected:
50
  {technical_analysis}
51
 
52
+ Price Axis Information:
53
+ {price_info}
54
+
55
  User Query: {question}
56
 
57
  Provide detailed professional analysis covering:
 
72
  scale = MAX_SIZE / max(height, width)
73
  return image.resize((int(width*scale), int(height*scale)), Image.LANCZOS)
74
 
75
+ def extract_price_info(image):
76
+ img_np = np.array(image)
77
+ gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
78
+ data = pytesseract.image_to_data(gray, config=OCR_CONFIG, output_type=pytesseract.Output.DICT)
79
+
80
+ price_levels = []
81
+ price_rects = []
82
+
83
+ for i, text in enumerate(data['text']):
84
+ if text.strip() and any(c.isdigit() or c in '$€£¥%' for c in text):
85
+ x = data['left'][i]
86
+ y = data['top'][i]
87
+ w = data['width'][i]
88
+ h = data['height'][i]
89
+ price_rects.append((x, y, w, h))
90
+
91
+ try:
92
+ price = float(text.replace('$','').replace('%','').strip())
93
+ price_levels.append((y + h//2, price))
94
+ except:
95
+ continue
96
+
97
+ return price_levels, price_rects
98
+
99
+ def map_y_to_price(y_pos, price_levels):
100
+ if not price_levels:
101
+ return None
102
+
103
+ y_values = [y for y, _ in price_levels]
104
+ prices = [p for _, p in price_levels]
105
+
106
+ try:
107
+ slope, intercept, _, _, _ = stats.linregress(y_values, prices)
108
+ return round(intercept + slope * y_pos, 2)
109
+ except:
110
+ return None
111
+
112
  def enhance_contrast(img):
113
  lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
114
  l, a, b = cv2.split(lab)
 
117
  merged = cv2.merge([limg, a, b])
118
  return cv2.cvtColor(merged, cv2.COLOR_LAB2RGB)
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  def detect_chart_elements(image):
121
  image_np = np.array(image)
122
  enhanced = enhance_contrast(image_np)
123
 
124
+ # OCR for price information
125
+ price_levels, price_rects = extract_price_info(image)
126
+
127
  # Deep Learning Detection
128
  inputs = detection_processor(images=Image.fromarray(enhanced), return_tensors="pt")
129
  with torch.no_grad():
 
145
 
146
  draw = ImageDraw.Draw(image)
147
 
148
+ # Draw price levels
149
+ for x, y, w, h in price_rects:
150
+ draw.rectangle([x, y, x+w, y+h], outline="#4CAF50", width=1)
151
+
152
  # Process DL detections
153
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
154
  box = [round(i, 2) for i in box.tolist()]
 
157
  draw.rectangle(box, outline="#FF0000", width=3)
158
  draw.text((box[0], box[1]), f"{label_name} ({score:.2f})", fill="#FF0000")
159
 
160
+ # Trendline and support/resistance detection
161
+ lines = cv2.HoughLinesP(
162
+ cv2.Canny(cv2.cvtColor(enhanced, cv2.COLOR_RGB2GRAY), *CANNY_THRESHOLDS),
163
+ 1, np.pi/180, *HOUGH_PARAMS
164
+ )
165
+
166
  if lines is not None:
167
+ for line in lines:
 
168
  x1, y1, x2, y2 = line[0]
169
+ slope = (y2 - y1) / (x2 - x1) if (x2 - x1) != 0 else np.inf
 
170
 
171
+ price1 = map_y_to_price(y1, price_levels)
172
+ price2 = map_y_to_price(y2, price_levels)
173
+
174
+ if abs(slope) < 0.1: # Horizontal line
175
+ label = f"Key Level: {price1:.2f}" if price1 else f"Y={y1}"
176
+ elements['support_resistance'].append(label)
177
  draw.line((x1, y1, x2, y2), fill="#00FF00", width=3)
178
+ draw.text((x1+5, y1+5), label, fill="#00FF00")
179
+ else: # Trendline
180
  draw.line((x1, y1, x2, y2), fill="#0000FF", width=3)
181
+ elements['trendlines'].append(f"Trendline ({'Bullish' if slope < 0 else 'Bearish'})")
182
 
183
+ return image, elements, price_levels
 
 
 
 
 
 
 
 
 
184
 
185
+ def generate_technical_report(elements, price_levels):
186
  report = []
187
  if elements['support_resistance']:
188
  report.append("**Key Levels**: " + ", ".join(elements['support_resistance'][:5]))
 
190
  report.append("**Trend Analysis**: " + ", ".join(elements['trendlines']))
191
  if elements['patterns']:
192
  report.append("**Chart Patterns**: " + ", ".join(elements['patterns']))
193
+
194
+ if price_levels:
195
+ prices = [p for _, p in price_levels]
196
+ report.append(f"**Detected Price Range**: ${min(prices):.2f} - ${max(prices):.2f}")
197
+
198
  return "\n".join(report)
199
 
200
  def respond(message, history, image):
201
+ # Handle initial greeting
202
+ if not history:
203
+ return "Merhaba! Hoş geldiniz. Size nasıl yardımcı olabilirim? Crypto analiz için lütfen grafik yükleyin, genel sorularınızı direkt sorabilirsiniz."
204
 
205
  try:
206
+ tech_report = ""
207
+ annotated_img = None
208
+ price_info = ""
209
+
210
+ if image is not None:
211
+ processed_img = adaptive_resize(image)
212
+ annotated_img, elements, price_levels = detect_chart_elements(processed_img)
213
+ tech_report = generate_technical_report(elements, price_levels)
214
+
215
+ if price_levels:
216
+ prices = [f"${p:.2f}" for _, p in price_levels]
217
+ price_info = f"Detected Price Levels: {', '.join(prices)}"
218
 
219
  full_prompt = SYSTEM_PROMPT.format(
220
  technical_analysis=tech_report,
221
+ price_info=price_info,
222
  question=message
223
  )
224
 
 
230
  seed=42
231
  )
232
 
233
+ if annotated_img:
234
+ img_base64 = base64.b64encode(annotated_img.tobytes()).decode('utf-8')
235
+ img_html = f'<div style="border: 2px solid #4CAF50; padding: 10px; margin-bottom: 20px;">' \
236
+ f'<img src="data:image/png;base64,{img_base64}" style="max-width: 100%;">' \
237
+ f'</div>'
238
+ return f"{img_html}\n{response.split('<|assistant|>')[-1].strip()}"
239
 
240
+ return response.split('<|assistant|>')[-1].strip()
241
 
242
  except Exception as e:
243
  return f"⚠️ Advanced Analysis Error: {str(e)}"