brichett commited on
Commit
cab69db
·
verified ·
1 Parent(s): 9211fe0

Update src/gradio_server.py

Browse files
Files changed (1) hide show
  1. src/gradio_server.py +24 -17
src/gradio_server.py CHANGED
@@ -1,12 +1,11 @@
1
  import gradio as gr
2
  import os
3
  import sys
4
- from typing import Annotated
5
-
6
- from fastapi import FastAPI, Form, UploadFile
7
  from pydantic import BaseModel
8
  from hamilton import driver
9
  from pandas import DataFrame
 
10
 
11
  # Add the src directory to the Python path
12
  sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
@@ -19,13 +18,23 @@ from decouple import config
19
 
20
  app = FastAPI()
21
 
22
- config = {"loader": "pd",
23
- "embedding_service": "openai",
24
- "api_key": config("OPENAI_API_KEY"),
25
- "model_name": "text-embedding-ada-002",
26
- "mistral_public_url": config("MISTRAL_PUBLIC_URL"),
27
- "ner_public_url": config("NER_PUBLIC_URL")
28
- } # or "pd"
 
 
 
 
 
 
 
 
 
 
29
 
30
  dr = (
31
  driver.Builder()
@@ -49,13 +58,11 @@ class PolicyEnforcementRequest(BaseModel):
49
  violation_context: dict
50
 
51
  class RadicalizationDetectionResponse(BaseModel):
52
- """Response to the /detect endpoint"""
53
  values: dict
54
 
55
  class PolicyEnforcementResponse(BaseModel):
56
- """Response to the /generate_policy_enforcement endpoint"""
57
  values: dict
58
-
59
  @app.post("/detect_radicalization")
60
  def detect_radicalization(
61
  request: RadicalizationDetectionRequest
@@ -65,8 +72,6 @@ def detect_radicalization(
65
  final_vars=["detect_glorification"],
66
  inputs={"project_root": ".", "user_input": request.user_text}
67
  )
68
- print(results)
69
- print(type(results))
70
  if isinstance(results, DataFrame):
71
  results = results.to_dict(orient="dict")
72
  return RadicalizationDetectionResponse(values=results)
@@ -79,8 +84,6 @@ def generate_policy_enforcement(
79
  final_vars=["get_enforcement_decision"],
80
  inputs={"project_root": ".", "user_input": request.user_text, "violation_context": request.violation_context}
81
  )
82
- print(results)
83
- print(type(results))
84
  if isinstance(results, DataFrame):
85
  results = results.to_dict(orient="dict")
86
  return PolicyEnforcementResponse(values=results)
@@ -118,3 +121,7 @@ iface2 = gr.Interface(
118
 
119
  # Combine the interfaces in a Tabbed interface
120
  iface_combined = gr.TabbedInterface([iface, iface2], ["Detect Radicalization", "Policy Enforcement"])
 
 
 
 
 
1
  import gradio as gr
2
  import os
3
  import sys
4
+ from fastapi import FastAPI
 
 
5
  from pydantic import BaseModel
6
  from hamilton import driver
7
  from pandas import DataFrame
8
+ from fastapi.middleware.cors import CORSMiddleware
9
 
10
  # Add the src directory to the Python path
11
  sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
 
18
 
19
  app = FastAPI()
20
 
21
+ # Enable CORS for Gradio to communicate with FastAPI
22
+ app.add_middleware(
23
+ CORSMiddleware,
24
+ allow_origins=["*"],
25
+ allow_credentials=True,
26
+ allow_methods=["*"],
27
+ allow_headers=["*"],
28
+ )
29
+
30
+ config = {
31
+ "loader": "pd",
32
+ "embedding_service": "openai",
33
+ "api_key": config("OPENAI_API_KEY"),
34
+ "model_name": "text-embedding-ada-002",
35
+ "mistral_public_url": config("MISTRAL_PUBLIC_URL"),
36
+ "ner_public_url": config("NER_PUBLIC_URL"),
37
+ }
38
 
39
  dr = (
40
  driver.Builder()
 
58
  violation_context: dict
59
 
60
  class RadicalizationDetectionResponse(BaseModel):
 
61
  values: dict
62
 
63
  class PolicyEnforcementResponse(BaseModel):
 
64
  values: dict
65
+
66
  @app.post("/detect_radicalization")
67
  def detect_radicalization(
68
  request: RadicalizationDetectionRequest
 
72
  final_vars=["detect_glorification"],
73
  inputs={"project_root": ".", "user_input": request.user_text}
74
  )
 
 
75
  if isinstance(results, DataFrame):
76
  results = results.to_dict(orient="dict")
77
  return RadicalizationDetectionResponse(values=results)
 
84
  final_vars=["get_enforcement_decision"],
85
  inputs={"project_root": ".", "user_input": request.user_text, "violation_context": request.violation_context}
86
  )
 
 
87
  if isinstance(results, DataFrame):
88
  results = results.to_dict(orient="dict")
89
  return PolicyEnforcementResponse(values=results)
 
121
 
122
  # Combine the interfaces in a Tabbed interface
123
  iface_combined = gr.TabbedInterface([iface, iface2], ["Detect Radicalization", "Policy Enforcement"])
124
+
125
+ if __name__ == "__main__":
126
+ # Launch Gradio interface (no need to launch Uvicorn separately)
127
+ iface_combined.launch(server_name="0.0.0.0", server_port=7860)