Eladlev commited on
Commit
e245801
·
verified ·
1 Parent(s): 3fa69a2

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +262 -0
app.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Entrypoint for Gradio, see https://gradio.app/
3
+ """
4
+
5
+ import platform
6
+ import asyncio
7
+ import base64
8
+ import os
9
+ from datetime import datetime
10
+ from enum import StrEnum
11
+ from functools import partial
12
+ from pathlib import Path
13
+ from typing import cast, Dict
14
+
15
+ import gradio as gr
16
+ from anthropic import APIResponse
17
+ from anthropic.types import TextBlock
18
+ from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock
19
+ from anthropic.types.tool_use_block import ToolUseBlock
20
+
21
+ from computer_use_demo.loop import (
22
+ PROVIDER_TO_DEFAULT_MODEL_NAME,
23
+ APIProvider,
24
+ sampling_loop,
25
+ sampling_loop_sync,
26
+ )
27
+
28
+ from computer_use_demo.tools import ToolResult
29
+
30
+
31
+ CONFIG_DIR = Path("~/.anthropic").expanduser()
32
+ API_KEY_FILE = CONFIG_DIR / "api_key"
33
+
34
+ WARNING_TEXT = "⚠️ Security Alert: Never provide access to sensitive accounts or data, as malicious web content can hijack Claude's behavior"
35
+
36
+
37
+ class Sender(StrEnum):
38
+ USER = "user"
39
+ BOT = "assistant"
40
+ TOOL = "tool"
41
+
42
+
43
+ def setup_state(state):
44
+ if "messages" not in state:
45
+ state["messages"] = []
46
+ if "api_key" not in state:
47
+ # Try to load API key from file first, then environment
48
+ state["api_key"] = load_from_storage("api_key") or os.getenv("ANTHROPIC_API_KEY", "")
49
+ if not state["api_key"]:
50
+ print("API key not found. Please set it in the environment or storage.")
51
+ if "provider" not in state:
52
+ state["provider"] = os.getenv("API_PROVIDER", "anthropic") or APIProvider.ANTHROPIC
53
+ if "provider_radio" not in state:
54
+ state["provider_radio"] = state["provider"]
55
+ if "model" not in state:
56
+ _reset_model(state)
57
+ if "auth_validated" not in state:
58
+ state["auth_validated"] = False
59
+ if "responses" not in state:
60
+ state["responses"] = {}
61
+ if "tools" not in state:
62
+ state["tools"] = {}
63
+ if "only_n_most_recent_images" not in state:
64
+ state["only_n_most_recent_images"] = 3 # 10
65
+ if "custom_system_prompt" not in state:
66
+ state["custom_system_prompt"] = load_from_storage("system_prompt") or ""
67
+ # remove if want to use default system prompt
68
+ device_os_name = "Windows" if platform.platform == "Windows" else "Mac" if platform.platform == "Darwin" else "Linux"
69
+ state["custom_system_prompt"] += f"\n\nNOTE: you are operating a {device_os_name} machine"
70
+ if "hide_images" not in state:
71
+ state["hide_images"] = False
72
+
73
+
74
+ def _reset_model(state):
75
+ state["model"] = PROVIDER_TO_DEFAULT_MODEL_NAME[cast(APIProvider, state["provider"])]
76
+
77
+
78
+ async def main(state):
79
+ """Render loop for Gradio"""
80
+ setup_state(state)
81
+ return "Setup completed"
82
+
83
+
84
+ def validate_auth(provider: APIProvider, api_key: str | None):
85
+ if provider == APIProvider.ANTHROPIC:
86
+ if not api_key:
87
+ return "Enter your Anthropic API key to continue."
88
+ if provider == APIProvider.BEDROCK:
89
+ import boto3
90
+
91
+ if not boto3.Session().get_credentials():
92
+ return "You must have AWS credentials set up to use the Bedrock API."
93
+ if provider == APIProvider.VERTEX:
94
+ import google.auth
95
+ from google.auth.exceptions import DefaultCredentialsError
96
+
97
+ if not os.environ.get("CLOUD_ML_REGION"):
98
+ return "Set the CLOUD_ML_REGION environment variable to use the Vertex API."
99
+ try:
100
+ google.auth.default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
101
+ except DefaultCredentialsError:
102
+ return "Your google cloud credentials are not set up correctly."
103
+
104
+
105
+ def load_from_storage(filename: str) -> str | None:
106
+ """Load data from a file in the storage directory."""
107
+ try:
108
+ file_path = CONFIG_DIR / filename
109
+ if file_path.exists():
110
+ data = file_path.read_text().strip()
111
+ if data:
112
+ return data
113
+ except Exception as e:
114
+ print(f"Debug: Error loading {filename}: {e}")
115
+ return None
116
+
117
+
118
+ def save_to_storage(filename: str, data: str) -> None:
119
+ """Save data to a file in the storage directory."""
120
+ try:
121
+ CONFIG_DIR.mkdir(parents=True, exist_ok=True)
122
+ file_path = CONFIG_DIR / filename
123
+ file_path.write_text(data)
124
+ # Ensure only user can read/write the file
125
+ file_path.chmod(0o600)
126
+ except Exception as e:
127
+ print(f"Debug: Error saving {filename}: {e}")
128
+
129
+
130
+ def _api_response_callback(response: APIResponse[BetaMessage], response_state: dict):
131
+ response_id = datetime.now().isoformat()
132
+ response_state[response_id] = response
133
+
134
+
135
+ def _tool_output_callback(tool_output: ToolResult, tool_id: str, tool_state: dict):
136
+ tool_state[tool_id] = tool_output
137
+
138
+
139
+ def _render_message(sender: Sender, message: str | BetaTextBlock | BetaToolUseBlock | ToolResult, state):
140
+ is_tool_result = not isinstance(message, str) and (
141
+ isinstance(message, ToolResult)
142
+ or message.__class__.__name__ == "ToolResult"
143
+ or message.__class__.__name__ == "CLIResult"
144
+ )
145
+ if not message or (
146
+ is_tool_result
147
+ and state["hide_images"]
148
+ and not hasattr(message, "error")
149
+ and not hasattr(message, "output")
150
+ ):
151
+ return
152
+ if is_tool_result:
153
+ message = cast(ToolResult, message)
154
+ if message.output:
155
+ return message.output
156
+ if message.error:
157
+ return f"Error: {message.error}"
158
+ if message.base64_image and not state["hide_images"]:
159
+ return base64.b64decode(message.base64_image)
160
+ elif isinstance(message, BetaTextBlock) or isinstance(message, TextBlock):
161
+ return message.text
162
+ elif isinstance(message, BetaToolUseBlock) or isinstance(message, ToolUseBlock):
163
+ return f"Tool Use: {message.name}\nInput: {message.input}"
164
+ else:
165
+ return message
166
+ # open new tab, open google sheets inside, then create a new blank spreadsheet
167
+
168
+ def process_input(user_input, state):
169
+ # Ensure the state is properly initialized
170
+ setup_state(state)
171
+
172
+ # Append the user input to the messages in the state
173
+ state["messages"].append(
174
+ {
175
+ "role": Sender.USER,
176
+ "content": [TextBlock(type="text", text=user_input)],
177
+ }
178
+ )
179
+
180
+ # Run the sampling loop synchronously and yield messages
181
+ for message in sampling_loop(state):
182
+ yield message
183
+
184
+
185
+ def accumulate_messages(*args, **kwargs):
186
+ """
187
+ Wrapper function to accumulate messages from sampling_loop_sync.
188
+ """
189
+ accumulated_messages = []
190
+
191
+ for message in sampling_loop_sync(*args, **kwargs):
192
+ # Check if the message is already in the accumulated messages
193
+ if message not in accumulated_messages:
194
+ accumulated_messages.append(message)
195
+ # Yield the accumulated messages as a list
196
+ yield accumulated_messages
197
+
198
+
199
+ def sampling_loop(state):
200
+ # Ensure the API key is present
201
+ if not state.get("api_key"):
202
+ raise ValueError("API key is missing. Please set it in the environment or storage.")
203
+
204
+ # Call the sampling loop and yield messages
205
+ for message in accumulate_messages(
206
+ system_prompt_suffix=state["custom_system_prompt"],
207
+ model=state["model"],
208
+ provider=state["provider"],
209
+ messages=state["messages"],
210
+ output_callback=partial(_render_message, Sender.BOT, state=state),
211
+ tool_output_callback=partial(_tool_output_callback, tool_state=state["tools"]),
212
+ api_response_callback=partial(_api_response_callback, response_state=state["responses"]),
213
+ api_key=state["api_key"],
214
+ only_n_most_recent_images=state["only_n_most_recent_images"],
215
+ ):
216
+ yield message
217
+
218
+
219
+ with gr.Blocks() as demo:
220
+ state = gr.State({}) # Use Gradio's state management
221
+
222
+ gr.Markdown("# Claude Computer Use Demo")
223
+
224
+ if not os.getenv("HIDE_WARNING", False):
225
+ gr.Markdown(WARNING_TEXT)
226
+
227
+ with gr.Row():
228
+ provider = gr.Dropdown(
229
+ label="API Provider",
230
+ choices=[option.value for option in APIProvider],
231
+ value="anthropic",
232
+ interactive=True,
233
+ )
234
+ model = gr.Textbox(label="Model", value="claude-3-5-sonnet-20241022")
235
+ api_key = gr.Textbox(
236
+ label="Anthropic API Key",
237
+ type="password",
238
+ value="",
239
+ interactive=True,
240
+ )
241
+ only_n_images = gr.Slider(
242
+ label="Only send N most recent images",
243
+ minimum=0,
244
+ value=3, # 10
245
+ interactive=True,
246
+ )
247
+ custom_prompt = gr.Textbox(
248
+ label="Custom System Prompt Suffix",
249
+ value="",
250
+ interactive=True,
251
+ )
252
+ hide_images = gr.Checkbox(label="Hide screenshots", value=False)
253
+
254
+ api_key.change(fn=lambda key: save_to_storage(API_KEY_FILE, key), inputs=api_key)
255
+ chat_input = gr.Textbox(label="Type a message to send to Claude...")
256
+ # chat_output = gr.Textbox(label="Chat Output", interactive=False)
257
+ chatbot = gr.Chatbot(label="Chatbot History", autoscroll=True)
258
+
259
+ # Pass state as an input to the function
260
+ chat_input.submit(process_input, [chat_input, state], chatbot)
261
+
262
+ demo.launch(share=True)