aidevhund commited on
Commit
51ec336
·
verified ·
1 Parent(s): dec5444

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -72
app.py CHANGED
@@ -23,10 +23,19 @@ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
23
  yolo_model = YOLO(YOLO_MODEL_PATH) # Load YOLOv8 model
24
 
25
  # System Prompt Template
26
- SYSTEM_PROMPT = """You are a professional financial analyst specializing in cryptocurrency technical analysis.
27
- Analyze the following chart elements and provide detailed report with:
 
28
 
29
- 1. Trend analysis (primary, secondary trends)
 
 
 
 
 
 
 
 
30
  2. Key support/resistance levels
31
  3. Detected chart patterns
32
  4. Candlestick formations
@@ -34,9 +43,8 @@ Analyze the following chart elements and provide detailed report with:
34
  6. Trading signals with confidence levels
35
  7. Risk management suggestions
36
 
37
- Format response in markdown with clear sections. Use professional trading terminology.
38
- Current Elements Detected: {elements}
39
- User Question: {question}"""
40
 
41
  # Chart Analysis Functions
42
  def preprocess_image(image):
@@ -58,36 +66,38 @@ def preprocess_image(image):
58
  return enhanced
59
 
60
  def detect_chart_elements(image):
61
- # Perform object detection on the chart image
62
- results = yolo_model(image) # Object detection using YOLOv8
63
 
64
  elements = {
65
  'support_resistance': [],
66
  'trendlines': [],
67
  'patterns': [],
68
  'candlesticks': [],
69
- 'key_levels': []
70
  }
71
 
72
- for result in results:
73
- for detection in result.boxes:
74
- label = detection.cls # Class label for the detected object
75
- confidence = detection.conf # Confidence score
76
-
77
- if confidence > 0.5: # Filter by confidence level (adjust threshold as needed)
78
- if 'support' in label:
79
- elements['support_resistance'].append(label)
80
- elif 'trendline' in label:
81
- elements['trendlines'].append(label)
82
- elif 'pattern' in label:
83
- elements['patterns'].append(label)
84
- elif 'candlestick' in label:
85
- elements['candlesticks'].append(label)
86
-
87
  return elements
88
 
89
  def generate_llm_response(elements, question):
90
- prompt = SYSTEM_PROMPT.format(elements=elements, question=question)
 
 
 
 
 
 
91
 
92
  inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
93
  outputs = model.generate(
@@ -102,60 +112,30 @@ def generate_llm_response(elements, question):
102
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
103
  return response
104
 
105
- # Gradio Interface
106
- def analyze_chart(image, question, chat_history):
107
  # Preprocess image
108
  processed_img = preprocess_image(image)
109
 
110
- # Detect elements
111
  elements = detect_chart_elements(processed_img)
112
 
113
- # Generate analysis
114
- analysis = generate_llm_response(str(elements), question)
115
-
116
- # Update chat
117
- chat_history.append((question, analysis))
118
 
119
- # Create annotated image
120
- draw = ImageDraw.Draw(image)
121
- elements_str = "\n".join([f"{k}: {v}" for k,v in elements.items()])
122
 
123
- return analysis, chat_history, image
124
 
125
- with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {background-color: #f5f5f5}") as demo:
126
- with gr.Row():
127
- gr.Markdown("# Professional Crypto Chart Analyzer 📊", elem_id="title")
128
-
129
- with gr.Row():
130
- with gr.Column(scale=1):
131
- image_input = gr.Image(label="Upload Chart", type="pil", elem_id="upload-box")
132
- submit_btn = gr.Button("Analyze", variant="primary")
133
-
134
- with gr.Column(scale=2):
135
- chatbot = gr.Chatbot(
136
- label="Analysis Chat",
137
- bubble_full_width=False,
138
- avatar_images=(
139
- "user.png",
140
- "assistant.png"
141
- ),
142
- height=600
143
- )
144
- question_input = gr.Textbox(
145
- label="Ask Technical Questions",
146
- placeholder="Enter your analysis question...",
147
- lines=3
148
- )
149
-
150
- with gr.Row():
151
- analysis_output = gr.Markdown(label="Technical Analysis Report")
152
- annotated_output = gr.Image(label="Annotated Chart", width=800)
153
-
154
- submit_btn.click(
155
- analyze_chart,
156
- [image_input, question_input, chatbot],
157
- [analysis_output, chatbot, annotated_output]
158
- )
159
 
160
  if __name__ == "__main__":
161
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
23
  yolo_model = YOLO(YOLO_MODEL_PATH) # Load YOLOv8 model
24
 
25
  # System Prompt Template
26
+ SYSTEM_PROMPT = """
27
+ You are a professional financial analyst specializing in cryptocurrency technical analysis.
28
+ Analyze the following chart elements and provide a detailed report based on the user's question:
29
 
30
+ ### Chart Elements:
31
+ Support/Resistance: {support_resistance}
32
+ Trendlines: {trendlines}
33
+ Patterns: {patterns}
34
+ Candlestick formations: {candlesticks}
35
+
36
+ ### User Question:
37
+ {question}
38
+ 1. Trend analysis (primary and secondary trends)
39
  2. Key support/resistance levels
40
  3. Detected chart patterns
41
  4. Candlestick formations
 
43
  6. Trading signals with confidence levels
44
  7. Risk management suggestions
45
 
46
+ Format the response in markdown with clear sections, using professional trading terminology.
47
+ """
 
48
 
49
  # Chart Analysis Functions
50
  def preprocess_image(image):
 
66
  return enhanced
67
 
68
  def detect_chart_elements(image):
69
+ # Detect technical patterns using YOLO model
70
+ detections = chart_detector(image)
71
 
72
  elements = {
73
  'support_resistance': [],
74
  'trendlines': [],
75
  'patterns': [],
76
  'candlesticks': [],
 
77
  }
78
 
79
+ # Process model detections
80
+ for det in detections:
81
+ label = det['label']
82
+ if 'support' in label.lower() or 'resistance' in label.lower():
83
+ elements['support_resistance'].append(label)
84
+ elif 'trendline' in label.lower():
85
+ elements['trendlines'].append(label)
86
+ elif 'pattern' in label.lower():
87
+ elements['patterns'].append(label.split('_')[0]) # Assuming the pattern name is before '_'
88
+ elif 'candlestick' in label.lower():
89
+ elements['candlesticks'].append(label)
90
+
 
 
 
91
  return elements
92
 
93
  def generate_llm_response(elements, question):
94
+ prompt = SYSTEM_PROMPT.format(
95
+ support_resistance="; ".join(elements['support_resistance']),
96
+ trendlines="; ".join(elements['trendlines']),
97
+ patterns="; ".join(elements['patterns']),
98
+ candlesticks="; ".join(elements['candlesticks']),
99
+ question=question
100
+ )
101
 
102
  inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
103
  outputs = model.generate(
 
112
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
113
  return response
114
 
115
+ # Gradio Interface Functions
116
+ def respond(image, question, chat_history):
117
  # Preprocess image
118
  processed_img = preprocess_image(image)
119
 
120
+ # Detect elements in the chart
121
  elements = detect_chart_elements(processed_img)
122
 
123
+ # Generate analysis response
124
+ analysis = generate_llm_response(elements, question)
 
 
 
125
 
126
+ # Update chat history with the user's question and the model's analysis response
127
+ chat_history.append(("User: " + question, "Assistant: " + analysis))
 
128
 
129
+ return chat_history
130
 
131
+ # Gradio Interface
132
+ demo = gr.ChatInterface(
133
+ fn=respond,
134
+ chatbot=gr.Chatbot(show_copy_button=True, layout="panel"),
135
+ title="Crypto Trading Assistant",
136
+ theme="Nymbo/Nymbo_Theme",
137
+ textbox=gr.Textbox(label="Ask Technical Questions", placeholder="Enter your analysis question...")
138
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  if __name__ == "__main__":
141
+ demo.launch()