aidevhund commited on
Commit
24c2743
·
verified ·
1 Parent(s): 51ec336

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -66
app.py CHANGED
@@ -2,39 +2,59 @@ import gradio as gr
2
  import cv2
3
  import numpy as np
4
  import torch
5
- from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForObjectDetection, AutoProcessor
 
 
 
 
 
 
6
  from PIL import Image, ImageDraw
7
  import matplotlib.pyplot as plt
8
  import pandas as pd
9
  import warnings
10
- from ultralytics import YOLO # Import YOLO from the ultralytics package
 
11
 
12
  warnings.filterwarnings("ignore")
13
 
14
  # Constants
15
- MODEL_NAME = "google/flan-t5-large"
16
- YOLO_MODEL_PATH = "yolov8n.pt" # Load the pre-trained YOLOv8 model (you can change this to your desired model)
17
  MAX_WIDTH = 800
18
  MAX_HEIGHT = 600
19
 
20
- # Load models
21
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
22
- model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
23
- yolo_model = YOLO(YOLO_MODEL_PATH) # Load YOLOv8 model
 
 
24
 
25
- # System Prompt Template
26
- SYSTEM_PROMPT = """
 
 
 
 
 
 
 
 
 
 
 
 
27
  You are a professional financial analyst specializing in cryptocurrency technical analysis.
28
- Analyze the following chart elements and provide a detailed report based on the user's question:
 
 
 
 
 
 
29
 
30
- ### Chart Elements:
31
- Support/Resistance: {support_resistance}
32
- Trendlines: {trendlines}
33
- Patterns: {patterns}
34
- Candlestick formations: {candlesticks}
35
 
36
- ### User Question:
37
- {question}
38
  1. Trend analysis (primary and secondary trends)
39
  2. Key support/resistance levels
40
  3. Detected chart patterns
@@ -43,16 +63,21 @@ Candlestick formations: {candlesticks}
43
  6. Trading signals with confidence levels
44
  7. Risk management suggestions
45
 
46
- Format the response in markdown with clear sections, using professional trading terminology.
47
  """
48
 
49
- # Chart Analysis Functions
 
 
 
 
 
50
  def preprocess_image(image):
51
  img = np.array(image)
52
  img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
53
- height, width = img.shape[:2]
54
 
55
  # Resize if necessary
 
56
  if width > MAX_WIDTH or height > MAX_HEIGHT:
57
  img = cv2.resize(img, (MAX_WIDTH, MAX_HEIGHT), interpolation=cv2.INTER_AREA)
58
 
@@ -60,14 +85,21 @@ def preprocess_image(image):
60
  lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
61
  l, a, b = cv2.split(lab)
62
  clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
63
- limg = cv2.merge([clahe.apply(l),a,b])
64
  enhanced = cv2.cvtColor(limg, cv2.COLOR_LAB2BGR)
65
 
66
- return enhanced
67
 
68
  def detect_chart_elements(image):
69
- # Detect technical patterns using YOLO model
70
- detections = chart_detector(image)
 
 
 
 
 
 
 
71
 
72
  elements = {
73
  'support_resistance': [],
@@ -76,65 +108,89 @@ def detect_chart_elements(image):
76
  'candlesticks': [],
77
  }
78
 
79
- # Process model detections
80
- for det in detections:
81
- label = det['label']
82
- if 'support' in label.lower() or 'resistance' in label.lower():
83
- elements['support_resistance'].append(label)
84
- elif 'trendline' in label.lower():
85
- elements['trendlines'].append(label)
86
- elif 'pattern' in label.lower():
87
- elements['patterns'].append(label.split('_')[0]) # Assuming the pattern name is before '_'
88
- elif 'candlestick' in label.lower():
89
- elements['candlesticks'].append(label)
 
 
 
 
 
 
 
 
90
 
91
- return elements
92
 
93
  def generate_llm_response(elements, question):
94
  prompt = SYSTEM_PROMPT.format(
95
- support_resistance="; ".join(elements['support_resistance']),
96
- trendlines="; ".join(elements['trendlines']),
97
- patterns="; ".join(elements['patterns']),
98
- candlesticks="; ".join(elements['candlesticks']),
99
  question=question
100
  )
101
 
102
- inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
103
- outputs = model.generate(
104
- input_ids=inputs.input_ids,
105
- attention_mask=inputs.attention_mask,
106
- max_length=1500,
107
- temperature=0.3,
108
- num_beams=5,
109
- early_stopping=True
110
  )
111
 
112
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
113
- return response
114
 
115
- # Gradio Interface Functions
116
- def respond(image, question, chat_history):
117
- # Preprocess image
118
- processed_img = preprocess_image(image)
119
 
120
- # Detect elements in the chart
121
- elements = detect_chart_elements(processed_img)
 
122
 
123
- # Generate analysis response
124
- analysis = generate_llm_response(elements, question)
125
 
126
- # Update chat history with the user's question and the model's analysis response
127
- chat_history.append(("User: " + question, "Assistant: " + analysis))
 
128
 
129
- return chat_history
130
 
131
- # Gradio Interface
132
  demo = gr.ChatInterface(
133
  fn=respond,
134
- chatbot=gr.Chatbot(show_copy_button=True, layout="panel"),
135
- title="Crypto Trading Assistant",
 
 
 
 
 
 
136
  theme="Nymbo/Nymbo_Theme",
137
- textbox=gr.Textbox(label="Ask Technical Questions", placeholder="Enter your analysis question...")
 
 
 
 
 
 
 
 
138
  )
139
 
140
  if __name__ == "__main__":
 
2
  import cv2
3
  import numpy as np
4
  import torch
5
+ from transformers import (
6
+ AutoTokenizer,
7
+ AutoModelForCausalLM,
8
+ AutoModelForObjectDetection,
9
+ AutoProcessor,
10
+ pipeline
11
+ )
12
  from PIL import Image, ImageDraw
13
  import matplotlib.pyplot as plt
14
  import pandas as pd
15
  import warnings
16
+ import io
17
+ import base64
18
 
19
  warnings.filterwarnings("ignore")
20
 
21
  # Constants
 
 
22
  MAX_WIDTH = 800
23
  MAX_HEIGHT = 600
24
 
25
+ # Load models (Update these with your actual model paths/names)
26
+ LLM_MODEL_NAME = "meta-llama/Llama-3-70B-Instruct"
27
+ DETECTION_MODEL = "facebook/detr-resnet-50"
28
+
29
+ # Initialize device
30
+ device = "cuda" if torch.cuda.is_available() else "cpu"
31
 
32
+ # Load detection model and processor
33
+ detection_processor = AutoProcessor.from_pretrained(DETECTION_MODEL)
34
+ detection_model = AutoModelForObjectDetection.from_pretrained(DETECTION_MODEL).to(device)
35
+
36
+ # Load LLM components
37
+ llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
38
+ llm_model = AutoModelForCausalLM.from_pretrained(
39
+ LLM_MODEL_NAME,
40
+ torch_dtype=torch.bfloat16,
41
+ device_map="auto"
42
+ )
43
+
44
+ # System Prompt Template for LLAMA
45
+ SYSTEM_PROMPT = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
46
  You are a professional financial analyst specializing in cryptocurrency technical analysis.
47
+ Analyze the following chart elements and provide a detailed report:
48
+ <|eot_id|><|start_header_id|>user<|end_header_id|>
49
+ Chart Elements Detected:
50
+ - Support/Resistance: {support_resistance}
51
+ - Trendlines: {trendlines}
52
+ - Patterns: {patterns}
53
+ - Candlestick formations: {candlesticks}
54
 
55
+ User Question: {question}
 
 
 
 
56
 
57
+ Provide analysis covering:
 
58
  1. Trend analysis (primary and secondary trends)
59
  2. Key support/resistance levels
60
  3. Detected chart patterns
 
63
  6. Trading signals with confidence levels
64
  7. Risk management suggestions
65
 
66
+ Format response in markdown with clear sections using professional trading terminology.<|eot_id|><|start_header_id|>assistant<|end_header_id|>
67
  """
68
 
69
+ # Helper functions
70
+ def image_to_base64(img):
71
+ buffered = io.BytesIO()
72
+ img.save(buffered, format="PNG")
73
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
74
+
75
  def preprocess_image(image):
76
  img = np.array(image)
77
  img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
 
78
 
79
  # Resize if necessary
80
+ height, width = img.shape[:2]
81
  if width > MAX_WIDTH or height > MAX_HEIGHT:
82
  img = cv2.resize(img, (MAX_WIDTH, MAX_HEIGHT), interpolation=cv2.INTER_AREA)
83
 
 
85
  lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
86
  l, a, b = cv2.split(lab)
87
  clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
88
+ limg = cv2.merge([clahe.apply(l), a, b])
89
  enhanced = cv2.cvtColor(limg, cv2.COLOR_LAB2BGR)
90
 
91
+ return Image.fromarray(cv2.cvtColor(enhanced, cv2.COLOR_BGR2RGB))
92
 
93
  def detect_chart_elements(image):
94
+ inputs = detection_processor(images=image, return_tensors="pt").to(device)
95
+ outputs = detection_model(**inputs)
96
+
97
+ target_sizes = torch.tensor([image.size[::-1]]).to(device)
98
+ results = detection_processor.post_process_object_detection(
99
+ outputs,
100
+ target_sizes=target_sizes,
101
+ threshold=0.8
102
+ )[0]
103
 
104
  elements = {
105
  'support_resistance': [],
 
108
  'candlesticks': [],
109
  }
110
 
111
+ # Draw annotations and categorize elements
112
+ draw = ImageDraw.Draw(image)
113
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
114
+ box = [round(i, 2) for i in box.tolist()]
115
+ label_name = detection_model.config.id2label[label.item()]
116
+
117
+ # Draw bounding box
118
+ draw.rectangle(box, outline="red", width=2)
119
+ draw.text((box[0], box[1]), f"{label_name} ({round(score.item(), 2)})", fill="red")
120
+
121
+ # Categorize elements (customize these mappings based on your detection model's labels)
122
+ if "support" in label_name.lower() or "resistance" in label_name.lower():
123
+ elements['support_resistance'].append(label_name)
124
+ elif "trendline" in label_name.lower():
125
+ elements['trendlines'].append(label_name)
126
+ elif "pattern" in label_name.lower():
127
+ elements['patterns'].append(label_name)
128
+ elif "candlestick" in label_name.lower():
129
+ elements['candlesticks'].append(label_name)
130
 
131
+ return image, elements
132
 
133
  def generate_llm_response(elements, question):
134
  prompt = SYSTEM_PROMPT.format(
135
+ support_resistance=", ".join(elements['support_resistance']),
136
+ trendlines=", ".join(elements['trendlines']),
137
+ patterns=", ".join(elements['patterns']),
138
+ candlesticks=", ".join(elements['candlesticks']),
139
  question=question
140
  )
141
 
142
+ inputs = llm_tokenizer(prompt, return_tensors="pt").to(device)
143
+ outputs = llm_model.generate(
144
+ **inputs,
145
+ max_new_tokens=1500,
146
+ temperature=0.7,
147
+ top_p=0.9,
148
+ do_sample=True,
149
+ pad_token_id=llm_tokenizer.eos_token_id
150
  )
151
 
152
+ response = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
153
+ return response.split("<|start_header_id|>assistant<|end_header_id|>")[-1].strip()
154
 
155
+ # Gradio Interface
156
+ def respond(message, history, image):
157
+ if image is None:
158
+ return "Please upload a chart image first."
159
 
160
+ # Preprocess and analyze image
161
+ processed_img = preprocess_image(image)
162
+ annotated_img, elements = detect_chart_elements(processed_img)
163
 
164
+ # Generate analysis
165
+ analysis = generate_llm_response(elements, message)
166
 
167
+ # Convert annotated image to base64
168
+ img_base64 = image_to_base64(annotated_img)
169
+ img_html = f'<img src="data:image/png;base64,{img_base64}" style="max-width: 800px; margin-bottom: 20px;">'
170
 
171
+ return f"{img_html}\n{analysis}"
172
 
173
+ # Create interface
174
  demo = gr.ChatInterface(
175
  fn=respond,
176
+ additional_inputs=[gr.Image(label="Upload Chart", type="pil")],
177
+ chatbot=gr.Chatbot(
178
+ show_copy_button=True,
179
+ layout="panel",
180
+ bubble_full_width=False,
181
+ sanitize_html=False
182
+ ),
183
+ title="Crypto Trading Assistant Pro",
184
  theme="Nymbo/Nymbo_Theme",
185
+ textbox=gr.Textbox(
186
+ label="Ask Technical Questions",
187
+ placeholder="Upload chart image and ask analysis questions...",
188
+ container=False
189
+ ),
190
+ examples=[
191
+ ["Is this a bullish reversal pattern?", "chart1.png"],
192
+ ["What are the key support levels?", "chart2.jpg"]
193
+ ]
194
  )
195
 
196
  if __name__ == "__main__":