Update app.py
Browse files
app.py
CHANGED
@@ -2,39 +2,59 @@ import gradio as gr
|
|
2 |
import cv2
|
3 |
import numpy as np
|
4 |
import torch
|
5 |
-
from transformers import
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from PIL import Image, ImageDraw
|
7 |
import matplotlib.pyplot as plt
|
8 |
import pandas as pd
|
9 |
import warnings
|
10 |
-
|
|
|
11 |
|
12 |
warnings.filterwarnings("ignore")
|
13 |
|
14 |
# Constants
|
15 |
-
MODEL_NAME = "google/flan-t5-large"
|
16 |
-
YOLO_MODEL_PATH = "yolov8n.pt" # Load the pre-trained YOLOv8 model (you can change this to your desired model)
|
17 |
MAX_WIDTH = 800
|
18 |
MAX_HEIGHT = 600
|
19 |
|
20 |
-
# Load models
|
21 |
-
|
22 |
-
|
23 |
-
|
|
|
|
|
24 |
|
25 |
-
#
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
You are a professional financial analyst specializing in cryptocurrency technical analysis.
|
28 |
-
Analyze the following chart elements and provide a detailed report
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
-
|
31 |
-
Support/Resistance: {support_resistance}
|
32 |
-
Trendlines: {trendlines}
|
33 |
-
Patterns: {patterns}
|
34 |
-
Candlestick formations: {candlesticks}
|
35 |
|
36 |
-
|
37 |
-
{question}
|
38 |
1. Trend analysis (primary and secondary trends)
|
39 |
2. Key support/resistance levels
|
40 |
3. Detected chart patterns
|
@@ -43,16 +63,21 @@ Candlestick formations: {candlesticks}
|
|
43 |
6. Trading signals with confidence levels
|
44 |
7. Risk management suggestions
|
45 |
|
46 |
-
Format
|
47 |
"""
|
48 |
|
49 |
-
#
|
|
|
|
|
|
|
|
|
|
|
50 |
def preprocess_image(image):
|
51 |
img = np.array(image)
|
52 |
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
53 |
-
height, width = img.shape[:2]
|
54 |
|
55 |
# Resize if necessary
|
|
|
56 |
if width > MAX_WIDTH or height > MAX_HEIGHT:
|
57 |
img = cv2.resize(img, (MAX_WIDTH, MAX_HEIGHT), interpolation=cv2.INTER_AREA)
|
58 |
|
@@ -60,14 +85,21 @@ def preprocess_image(image):
|
|
60 |
lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
|
61 |
l, a, b = cv2.split(lab)
|
62 |
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
|
63 |
-
limg = cv2.merge([clahe.apply(l),a,b])
|
64 |
enhanced = cv2.cvtColor(limg, cv2.COLOR_LAB2BGR)
|
65 |
|
66 |
-
return enhanced
|
67 |
|
68 |
def detect_chart_elements(image):
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
elements = {
|
73 |
'support_resistance': [],
|
@@ -76,65 +108,89 @@ def detect_chart_elements(image):
|
|
76 |
'candlesticks': [],
|
77 |
}
|
78 |
|
79 |
-
#
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
-
return elements
|
92 |
|
93 |
def generate_llm_response(elements, question):
|
94 |
prompt = SYSTEM_PROMPT.format(
|
95 |
-
support_resistance="
|
96 |
-
trendlines="
|
97 |
-
patterns="
|
98 |
-
candlesticks="
|
99 |
question=question
|
100 |
)
|
101 |
|
102 |
-
inputs =
|
103 |
-
outputs =
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
)
|
111 |
|
112 |
-
response =
|
113 |
-
return response
|
114 |
|
115 |
-
# Gradio Interface
|
116 |
-
def respond(
|
117 |
-
|
118 |
-
|
119 |
|
120 |
-
#
|
121 |
-
|
|
|
122 |
|
123 |
-
# Generate analysis
|
124 |
-
analysis = generate_llm_response(elements,
|
125 |
|
126 |
-
#
|
127 |
-
|
|
|
128 |
|
129 |
-
return
|
130 |
|
131 |
-
#
|
132 |
demo = gr.ChatInterface(
|
133 |
fn=respond,
|
134 |
-
|
135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
theme="Nymbo/Nymbo_Theme",
|
137 |
-
textbox=gr.Textbox(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
)
|
139 |
|
140 |
if __name__ == "__main__":
|
|
|
2 |
import cv2
|
3 |
import numpy as np
|
4 |
import torch
|
5 |
+
from transformers import (
|
6 |
+
AutoTokenizer,
|
7 |
+
AutoModelForCausalLM,
|
8 |
+
AutoModelForObjectDetection,
|
9 |
+
AutoProcessor,
|
10 |
+
pipeline
|
11 |
+
)
|
12 |
from PIL import Image, ImageDraw
|
13 |
import matplotlib.pyplot as plt
|
14 |
import pandas as pd
|
15 |
import warnings
|
16 |
+
import io
|
17 |
+
import base64
|
18 |
|
19 |
warnings.filterwarnings("ignore")
|
20 |
|
21 |
# Constants
|
|
|
|
|
22 |
MAX_WIDTH = 800
|
23 |
MAX_HEIGHT = 600
|
24 |
|
25 |
+
# Load models (Update these with your actual model paths/names)
|
26 |
+
LLM_MODEL_NAME = "meta-llama/Llama-3-70B-Instruct"
|
27 |
+
DETECTION_MODEL = "facebook/detr-resnet-50"
|
28 |
+
|
29 |
+
# Initialize device
|
30 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
31 |
|
32 |
+
# Load detection model and processor
|
33 |
+
detection_processor = AutoProcessor.from_pretrained(DETECTION_MODEL)
|
34 |
+
detection_model = AutoModelForObjectDetection.from_pretrained(DETECTION_MODEL).to(device)
|
35 |
+
|
36 |
+
# Load LLM components
|
37 |
+
llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
|
38 |
+
llm_model = AutoModelForCausalLM.from_pretrained(
|
39 |
+
LLM_MODEL_NAME,
|
40 |
+
torch_dtype=torch.bfloat16,
|
41 |
+
device_map="auto"
|
42 |
+
)
|
43 |
+
|
44 |
+
# System Prompt Template for LLAMA
|
45 |
+
SYSTEM_PROMPT = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
46 |
You are a professional financial analyst specializing in cryptocurrency technical analysis.
|
47 |
+
Analyze the following chart elements and provide a detailed report:
|
48 |
+
<|eot_id|><|start_header_id|>user<|end_header_id|>
|
49 |
+
Chart Elements Detected:
|
50 |
+
- Support/Resistance: {support_resistance}
|
51 |
+
- Trendlines: {trendlines}
|
52 |
+
- Patterns: {patterns}
|
53 |
+
- Candlestick formations: {candlesticks}
|
54 |
|
55 |
+
User Question: {question}
|
|
|
|
|
|
|
|
|
56 |
|
57 |
+
Provide analysis covering:
|
|
|
58 |
1. Trend analysis (primary and secondary trends)
|
59 |
2. Key support/resistance levels
|
60 |
3. Detected chart patterns
|
|
|
63 |
6. Trading signals with confidence levels
|
64 |
7. Risk management suggestions
|
65 |
|
66 |
+
Format response in markdown with clear sections using professional trading terminology.<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
67 |
"""
|
68 |
|
69 |
+
# Helper functions
|
70 |
+
def image_to_base64(img):
|
71 |
+
buffered = io.BytesIO()
|
72 |
+
img.save(buffered, format="PNG")
|
73 |
+
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
74 |
+
|
75 |
def preprocess_image(image):
|
76 |
img = np.array(image)
|
77 |
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
|
|
78 |
|
79 |
# Resize if necessary
|
80 |
+
height, width = img.shape[:2]
|
81 |
if width > MAX_WIDTH or height > MAX_HEIGHT:
|
82 |
img = cv2.resize(img, (MAX_WIDTH, MAX_HEIGHT), interpolation=cv2.INTER_AREA)
|
83 |
|
|
|
85 |
lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
|
86 |
l, a, b = cv2.split(lab)
|
87 |
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
|
88 |
+
limg = cv2.merge([clahe.apply(l), a, b])
|
89 |
enhanced = cv2.cvtColor(limg, cv2.COLOR_LAB2BGR)
|
90 |
|
91 |
+
return Image.fromarray(cv2.cvtColor(enhanced, cv2.COLOR_BGR2RGB))
|
92 |
|
93 |
def detect_chart_elements(image):
|
94 |
+
inputs = detection_processor(images=image, return_tensors="pt").to(device)
|
95 |
+
outputs = detection_model(**inputs)
|
96 |
+
|
97 |
+
target_sizes = torch.tensor([image.size[::-1]]).to(device)
|
98 |
+
results = detection_processor.post_process_object_detection(
|
99 |
+
outputs,
|
100 |
+
target_sizes=target_sizes,
|
101 |
+
threshold=0.8
|
102 |
+
)[0]
|
103 |
|
104 |
elements = {
|
105 |
'support_resistance': [],
|
|
|
108 |
'candlesticks': [],
|
109 |
}
|
110 |
|
111 |
+
# Draw annotations and categorize elements
|
112 |
+
draw = ImageDraw.Draw(image)
|
113 |
+
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
|
114 |
+
box = [round(i, 2) for i in box.tolist()]
|
115 |
+
label_name = detection_model.config.id2label[label.item()]
|
116 |
+
|
117 |
+
# Draw bounding box
|
118 |
+
draw.rectangle(box, outline="red", width=2)
|
119 |
+
draw.text((box[0], box[1]), f"{label_name} ({round(score.item(), 2)})", fill="red")
|
120 |
+
|
121 |
+
# Categorize elements (customize these mappings based on your detection model's labels)
|
122 |
+
if "support" in label_name.lower() or "resistance" in label_name.lower():
|
123 |
+
elements['support_resistance'].append(label_name)
|
124 |
+
elif "trendline" in label_name.lower():
|
125 |
+
elements['trendlines'].append(label_name)
|
126 |
+
elif "pattern" in label_name.lower():
|
127 |
+
elements['patterns'].append(label_name)
|
128 |
+
elif "candlestick" in label_name.lower():
|
129 |
+
elements['candlesticks'].append(label_name)
|
130 |
|
131 |
+
return image, elements
|
132 |
|
133 |
def generate_llm_response(elements, question):
|
134 |
prompt = SYSTEM_PROMPT.format(
|
135 |
+
support_resistance=", ".join(elements['support_resistance']),
|
136 |
+
trendlines=", ".join(elements['trendlines']),
|
137 |
+
patterns=", ".join(elements['patterns']),
|
138 |
+
candlesticks=", ".join(elements['candlesticks']),
|
139 |
question=question
|
140 |
)
|
141 |
|
142 |
+
inputs = llm_tokenizer(prompt, return_tensors="pt").to(device)
|
143 |
+
outputs = llm_model.generate(
|
144 |
+
**inputs,
|
145 |
+
max_new_tokens=1500,
|
146 |
+
temperature=0.7,
|
147 |
+
top_p=0.9,
|
148 |
+
do_sample=True,
|
149 |
+
pad_token_id=llm_tokenizer.eos_token_id
|
150 |
)
|
151 |
|
152 |
+
response = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
153 |
+
return response.split("<|start_header_id|>assistant<|end_header_id|>")[-1].strip()
|
154 |
|
155 |
+
# Gradio Interface
|
156 |
+
def respond(message, history, image):
|
157 |
+
if image is None:
|
158 |
+
return "Please upload a chart image first."
|
159 |
|
160 |
+
# Preprocess and analyze image
|
161 |
+
processed_img = preprocess_image(image)
|
162 |
+
annotated_img, elements = detect_chart_elements(processed_img)
|
163 |
|
164 |
+
# Generate analysis
|
165 |
+
analysis = generate_llm_response(elements, message)
|
166 |
|
167 |
+
# Convert annotated image to base64
|
168 |
+
img_base64 = image_to_base64(annotated_img)
|
169 |
+
img_html = f'<img src="data:image/png;base64,{img_base64}" style="max-width: 800px; margin-bottom: 20px;">'
|
170 |
|
171 |
+
return f"{img_html}\n{analysis}"
|
172 |
|
173 |
+
# Create interface
|
174 |
demo = gr.ChatInterface(
|
175 |
fn=respond,
|
176 |
+
additional_inputs=[gr.Image(label="Upload Chart", type="pil")],
|
177 |
+
chatbot=gr.Chatbot(
|
178 |
+
show_copy_button=True,
|
179 |
+
layout="panel",
|
180 |
+
bubble_full_width=False,
|
181 |
+
sanitize_html=False
|
182 |
+
),
|
183 |
+
title="Crypto Trading Assistant Pro",
|
184 |
theme="Nymbo/Nymbo_Theme",
|
185 |
+
textbox=gr.Textbox(
|
186 |
+
label="Ask Technical Questions",
|
187 |
+
placeholder="Upload chart image and ask analysis questions...",
|
188 |
+
container=False
|
189 |
+
),
|
190 |
+
examples=[
|
191 |
+
["Is this a bullish reversal pattern?", "chart1.png"],
|
192 |
+
["What are the key support levels?", "chart2.jpg"]
|
193 |
+
]
|
194 |
)
|
195 |
|
196 |
if __name__ == "__main__":
|