Niansuh commited on
Commit
58893fb
·
verified ·
1 Parent(s): b794454

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +124 -49
main.py CHANGED
@@ -1,10 +1,11 @@
1
  import os
2
  import secrets
 
3
  import requests
4
  from fastapi import FastAPI, Depends, HTTPException, Header, Request
5
- from sqlalchemy import create_engine, Column, Integer, String, Boolean, DateTime, func
6
  from sqlalchemy.ext.declarative import declarative_base
7
- from sqlalchemy.orm import sessionmaker, Session
8
  from pydantic import BaseModel
9
  from dotenv import load_dotenv
10
  from fastapi.templating import Jinja2Templates
@@ -13,7 +14,7 @@ from fastapi.staticfiles import StaticFiles
13
  # Load environment variables from .env file
14
  load_dotenv()
15
 
16
- # Environment variables for MySQL and main API settings
17
  MYSQL_USER = os.getenv("MYSQL_USER")
18
  MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD")
19
  MYSQL_HOST = os.getenv("MYSQL_HOST")
@@ -24,32 +25,44 @@ MODEL_NAME = os.getenv("MODEL_NAME", "Image-Generator")
24
 
25
  DATABASE_URL = f"mysql+pymysql://{MYSQL_USER}:{MYSQL_PASSWORD}@{MYSQL_HOST}/{MYSQL_DB}"
26
 
27
- # SQLAlchemy setup
28
  engine = create_engine(DATABASE_URL)
29
  SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
30
  Base = declarative_base()
31
 
32
- # Updated User model with credits field
 
 
33
  class User(Base):
34
  __tablename__ = "users"
35
  id = Column(Integer, primary_key=True, index=True)
36
  username = Column(String(50), unique=True, index=True, nullable=False)
37
  hashed_password = Column(String(128), nullable=False)
38
- api_key = Column(String(64), unique=True, index=True, nullable=False)
39
  is_admin = Column(Boolean, default=False)
40
  credits = Column(Integer, default=0)
41
  created_at = Column(DateTime(timezone=True), server_default=func.now())
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- # Create tables
44
  Base.metadata.create_all(bind=engine)
45
 
46
- app = FastAPI(title="API Key Platform")
 
 
 
47
 
48
- # Mount static files and templates
49
  app.mount("/static", StaticFiles(directory="static"), name="static")
50
  templates = Jinja2Templates(directory="templates")
51
 
52
- # Dependency: Database session
53
  def get_db():
54
  db = SessionLocal()
55
  try:
@@ -57,11 +70,12 @@ def get_db():
57
  finally:
58
  db.close()
59
 
60
- # Utility: Generate a unique API key
61
  def generate_api_key() -> str:
62
  return secrets.token_hex(16)
63
 
64
- # Pydantic models
 
 
65
  class UserCreate(BaseModel):
66
  username: str
67
  password: str
@@ -69,52 +83,91 @@ class UserCreate(BaseModel):
69
  class UserOut(BaseModel):
70
  id: int
71
  username: str
72
- api_key: str
73
  is_admin: bool
74
  credits: int
 
75
 
76
  class Config:
77
  orm_mode = True
78
 
79
- # Dependency: Validate API key from header and return current user
80
- def get_current_user(x_api_key: str = Header(...), db: Session = Depends(get_db)) -> User:
81
- user = db.query(User).filter(User.api_key == x_api_key).first()
82
- if not user:
83
- raise HTTPException(status_code=401, detail="Invalid API Key")
84
- return user
 
 
 
 
 
 
 
 
 
85
 
86
- # --- Endpoints ---
 
 
 
 
 
87
 
88
- # 1. User Registration (generates a unique API key)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  @app.post("/register", response_model=UserOut)
90
  def register(user: UserCreate, db: Session = Depends(get_db)):
91
  if db.query(User).filter(User.username == user.username).first():
92
  raise HTTPException(status_code=400, detail="Username already exists")
93
- new_api_key = generate_api_key()
94
- # In production, use proper password hashing
95
- new_user = User(username=user.username, hashed_password=user.password, api_key=new_api_key, is_admin=False)
96
  db.add(new_user)
97
  db.commit()
98
  db.refresh(new_user)
 
 
 
 
99
  return new_user
100
 
101
- # 2. User Panel: Get current user info
102
  @app.get("/user/me", response_model=UserOut)
103
  def read_user_me(current_user: User = Depends(get_current_user)):
104
  return current_user
105
 
106
- # 3. Admin Panel: List all users (admin-only)
107
- @app.get("/admin/users", response_model=list[UserOut])
108
- def list_users(current_user: User = Depends(get_current_user), db: Session = Depends(get_db)):
109
- if not current_user.is_admin:
110
- raise HTTPException(status_code=403, detail="Not authorized")
111
- users = db.query(User).all()
112
- return users
113
 
114
- # 4. Proxy Endpoint to access the main API
115
- class RequestPayload(BaseModel):
116
- prompt: str
 
 
 
 
 
117
 
 
 
 
 
 
 
118
  @app.post("/generate")
119
  def generate_image(payload: RequestPayload, current_user: User = Depends(get_current_user)):
120
  headers = {
@@ -130,16 +183,19 @@ def generate_image(payload: RequestPayload, current_user: User = Depends(get_cur
130
  raise HTTPException(status_code=response.status_code, detail="Error from main API")
131
  return response.json()
132
 
133
- # 5. New endpoint for users to test their API key
134
- @app.get("/user/test_api")
135
- def test_api(current_user: User = Depends(get_current_user)):
136
- return {"message": "API is working", "username": current_user.username, "credits": current_user.credits}
137
 
138
- # 6. New endpoint for admin to add credits to a user account
139
- class CreditPayload(BaseModel):
140
- username: str
141
- credits: int
 
 
 
142
 
 
143
  @app.post("/admin/add_credit")
144
  def add_credit(payload: CreditPayload, current_user: User = Depends(get_current_user), db: Session = Depends(get_db)):
145
  if not current_user.is_admin:
@@ -150,16 +206,35 @@ def add_credit(payload: CreditPayload, current_user: User = Depends(get_current_
150
  user.credits += payload.credits
151
  db.commit()
152
  db.refresh(user)
153
- return {"message": f"Added {payload.credits} credits to user {user.username}. Total credits: {user.credits}"}
154
 
155
- # 7. Render Admin Panel UI
156
- @app.get("/admin/ui")
157
- def admin_ui(request: Request, current_user: User = Depends(get_current_user)):
158
  if not current_user.is_admin:
159
  raise HTTPException(status_code=403, detail="Not authorized")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  return templates.TemplateResponse("admin.html", {"request": request})
161
 
162
- # 8. Render User Panel UI
163
  @app.get("/user/ui")
164
- def user_ui(request: Request, current_user: User = Depends(get_current_user)):
165
  return templates.TemplateResponse("user.html", {"request": request})
 
1
  import os
2
  import secrets
3
+ import datetime
4
  import requests
5
  from fastapi import FastAPI, Depends, HTTPException, Header, Request
6
+ from sqlalchemy import create_engine, Column, Integer, String, Boolean, DateTime, ForeignKey, func
7
  from sqlalchemy.ext.declarative import declarative_base
8
+ from sqlalchemy.orm import sessionmaker, Session, relationship
9
  from pydantic import BaseModel
10
  from dotenv import load_dotenv
11
  from fastapi.templating import Jinja2Templates
 
14
  # Load environment variables from .env file
15
  load_dotenv()
16
 
17
+ # Environment variables
18
  MYSQL_USER = os.getenv("MYSQL_USER")
19
  MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD")
20
  MYSQL_HOST = os.getenv("MYSQL_HOST")
 
25
 
26
  DATABASE_URL = f"mysql+pymysql://{MYSQL_USER}:{MYSQL_PASSWORD}@{MYSQL_HOST}/{MYSQL_DB}"
27
 
 
28
  engine = create_engine(DATABASE_URL)
29
  SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
30
  Base = declarative_base()
31
 
32
+ # ---------------------------
33
+ # Database Models
34
+ # ---------------------------
35
  class User(Base):
36
  __tablename__ = "users"
37
  id = Column(Integer, primary_key=True, index=True)
38
  username = Column(String(50), unique=True, index=True, nullable=False)
39
  hashed_password = Column(String(128), nullable=False)
 
40
  is_admin = Column(Boolean, default=False)
41
  credits = Column(Integer, default=0)
42
  created_at = Column(DateTime(timezone=True), server_default=func.now())
43
+ # Relationship to API keys
44
+ api_keys = relationship("APIKey", back_populates="owner", cascade="all, delete-orphan")
45
+
46
+ class APIKey(Base):
47
+ __tablename__ = "api_keys"
48
+ id = Column(Integer, primary_key=True, index=True)
49
+ key = Column(String(64), unique=True, index=True, nullable=False)
50
+ user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
51
+ expiry_date = Column(DateTime, nullable=True)
52
+ active = Column(Boolean, default=True)
53
+ created_at = Column(DateTime(timezone=True), server_default=func.now())
54
+ owner = relationship("User", back_populates="api_keys")
55
 
 
56
  Base.metadata.create_all(bind=engine)
57
 
58
+ # ---------------------------
59
+ # FastAPI App & Config
60
+ # ---------------------------
61
+ app = FastAPI(title="Real API Key Platform")
62
 
 
63
  app.mount("/static", StaticFiles(directory="static"), name="static")
64
  templates = Jinja2Templates(directory="templates")
65
 
 
66
  def get_db():
67
  db = SessionLocal()
68
  try:
 
70
  finally:
71
  db.close()
72
 
 
73
  def generate_api_key() -> str:
74
  return secrets.token_hex(16)
75
 
76
+ # ---------------------------
77
+ # Pydantic Models
78
+ # ---------------------------
79
  class UserCreate(BaseModel):
80
  username: str
81
  password: str
 
83
  class UserOut(BaseModel):
84
  id: int
85
  username: str
 
86
  is_admin: bool
87
  credits: int
88
+ created_at: datetime.datetime
89
 
90
  class Config:
91
  orm_mode = True
92
 
93
+ class APIKeyOut(BaseModel):
94
+ id: int
95
+ key: str
96
+ expiry_date: datetime.datetime | None = None
97
+ active: bool
98
+ created_at: datetime.datetime
99
+
100
+ class Config:
101
+ orm_mode = True
102
+
103
+ class GenerateKeyPayload(BaseModel):
104
+ expiry_date: datetime.datetime | None = None
105
+
106
+ class RequestPayload(BaseModel):
107
+ prompt: str
108
 
109
+ class CreditPayload(BaseModel):
110
+ username: str
111
+ credits: int
112
+
113
+ class DeactivateKeyPayload(BaseModel):
114
+ key: str
115
 
116
+ # ---------------------------
117
+ # Authentication Dependency
118
+ # ---------------------------
119
+ def get_current_user(x_api_key: str = Header(...), db: Session = Depends(get_db)) -> User:
120
+ api_key_obj = db.query(APIKey).filter(APIKey.key == x_api_key, APIKey.active == True).first()
121
+ if not api_key_obj:
122
+ raise HTTPException(status_code=401, detail="Invalid or inactive API Key")
123
+ if api_key_obj.expiry_date and datetime.datetime.utcnow() > api_key_obj.expiry_date:
124
+ raise HTTPException(status_code=401, detail="API Key expired")
125
+ return api_key_obj.owner
126
+
127
+ # ---------------------------
128
+ # Endpoints
129
+ # ---------------------------
130
+
131
+ # Registration: Create a new user and generate a primary API key (no expiry)
132
  @app.post("/register", response_model=UserOut)
133
  def register(user: UserCreate, db: Session = Depends(get_db)):
134
  if db.query(User).filter(User.username == user.username).first():
135
  raise HTTPException(status_code=400, detail="Username already exists")
136
+ new_user = User(username=user.username, hashed_password=user.password, is_admin=False)
 
 
137
  db.add(new_user)
138
  db.commit()
139
  db.refresh(new_user)
140
+ primary_key = APIKey(key=generate_api_key(), user_id=new_user.id, expiry_date=None, active=True)
141
+ db.add(primary_key)
142
+ db.commit()
143
+ db.refresh(primary_key)
144
  return new_user
145
 
146
+ # User panel: Get current user details
147
  @app.get("/user/me", response_model=UserOut)
148
  def read_user_me(current_user: User = Depends(get_current_user)):
149
  return current_user
150
 
151
+ # List all API keys belonging to the current user
152
+ @app.get("/user/api_keys", response_model=list[APIKeyOut])
153
+ def get_user_api_keys(current_user: User = Depends(get_current_user)):
154
+ return current_user.api_keys
 
 
 
155
 
156
+ # Allow user to generate a new API key (with optional expiry date)
157
+ @app.post("/user/generate_key", response_model=APIKeyOut)
158
+ def generate_key(payload: GenerateKeyPayload, current_user: User = Depends(get_current_user), db: Session = Depends(get_db)):
159
+ new_key = APIKey(key=generate_api_key(), user_id=current_user.id, expiry_date=payload.expiry_date, active=True)
160
+ db.add(new_key)
161
+ db.commit()
162
+ db.refresh(new_key)
163
+ return new_key
164
 
165
+ # Test endpoint for users
166
+ @app.get("/user/test_api")
167
+ def test_api(current_user: User = Depends(get_current_user)):
168
+ return {"message": "API is working", "username": current_user.username, "credits": current_user.credits}
169
+
170
+ # Proxy endpoint: Forwards request to main API using the secured main API key
171
  @app.post("/generate")
172
  def generate_image(payload: RequestPayload, current_user: User = Depends(get_current_user)):
173
  headers = {
 
183
  raise HTTPException(status_code=response.status_code, detail="Error from main API")
184
  return response.json()
185
 
186
+ # ---------------------------
187
+ # Admin Endpoints
188
+ # ---------------------------
 
189
 
190
+ # List all users (admin-only)
191
+ @app.get("/admin/users", response_model=list[UserOut])
192
+ def list_users(current_user: User = Depends(get_current_user), db: Session = Depends(get_db)):
193
+ if not current_user.is_admin:
194
+ raise HTTPException(status_code=403, detail="Not authorized")
195
+ users = db.query(User).all()
196
+ return users
197
 
198
+ # Add credits to a user's account (admin-only)
199
  @app.post("/admin/add_credit")
200
  def add_credit(payload: CreditPayload, current_user: User = Depends(get_current_user), db: Session = Depends(get_db)):
201
  if not current_user.is_admin:
 
206
  user.credits += payload.credits
207
  db.commit()
208
  db.refresh(user)
209
+ return {"message": f"Added {payload.credits} credits to {user.username}. Total credits: {user.credits}"}
210
 
211
+ # List all API keys in the system (admin-only)
212
+ @app.get("/admin/api_keys", response_model=list[APIKeyOut])
213
+ def list_all_api_keys(current_user: User = Depends(get_current_user), db: Session = Depends(get_db)):
214
  if not current_user.is_admin:
215
  raise HTTPException(status_code=403, detail="Not authorized")
216
+ keys = db.query(APIKey).all()
217
+ return keys
218
+
219
+ # Deactivate an API key (admin-only)
220
+ @app.post("/admin/deactivate_key")
221
+ def deactivate_key(payload: DeactivateKeyPayload, current_user: User = Depends(get_current_user), db: Session = Depends(get_db)):
222
+ if not current_user.is_admin:
223
+ raise HTTPException(status_code=403, detail="Not authorized")
224
+ key_obj = db.query(APIKey).filter(APIKey.key == payload.key).first()
225
+ if not key_obj:
226
+ raise HTTPException(status_code=404, detail="API Key not found")
227
+ key_obj.active = False
228
+ db.commit()
229
+ return {"message": f"API Key {payload.key} deactivated."}
230
+
231
+ # ---------------------------
232
+ # UI Endpoints (Panels)
233
+ # ---------------------------
234
+ @app.get("/admin/ui")
235
+ def admin_ui(request: Request):
236
  return templates.TemplateResponse("admin.html", {"request": request})
237
 
 
238
  @app.get("/user/ui")
239
+ def user_ui(request: Request):
240
  return templates.TemplateResponse("user.html", {"request": request})