dhairyashah's picture
Update app.py
096a01b verified
import spaces
from flask import Flask, request, jsonify
import os
from werkzeug.utils import secure_filename
import cv2
import torch
import torch.nn.functional as F
from facenet_pytorch import MTCNN, InceptionResnetV1
import numpy as np
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
import base64
app = Flask(__name__)
# Configuration
UPLOAD_FOLDER = 'uploads'
ALLOWED_EXTENSIONS = {'mp4', 'avi', 'mov', 'webm'}
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
# Device configuration
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
# Configure MTCNN with adjusted thresholds
mtcnn = MTCNN(select_largest=False, post_process=False, device=DEVICE,
thresholds=[0.7, 0.8, 0.8], # Adjust these thresholds for P-Net, R-Net, O-Net
margin=20, min_face_size=50).to(DEVICE).eval()
model = InceptionResnetV1(pretrained="vggface2", classify=True, num_classes=1, device=DEVICE)
checkpoint = torch.load("resnetinceptionv1_epoch_32.pth", map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])
model.to(DEVICE)
model.eval()
# GradCAM setup
target_layers = [model.block8.branch1[-1]]
cam = GradCAM(model=model, target_layers=target_layers)
targets = [ClassifierOutputTarget(0)]
def allowed_file(filename):
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
def filter_low_quality_detections(detection, min_size=(50, 50)):
if detection is None or detection[0] is None:
return None
for i, (box, prob) in enumerate(zip(detection[0], detection[1])):
if prob < 0.9: # Filter out detections with low confidence
continue
if (box[2] - box[0] < min_size[0]) or (box[3] - box[1] < min_size[1]): # Check size
continue
return box # Return the first valid detection
return None
@spaces.GPU
def process_frame(frame):
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
detection = mtcnn.detect(rgb_frame)
face_box = filter_low_quality_detections(detection)
if face_box is None:
return None, None, None
x1, y1, x2, y2 = map(int, face_box)
h, w, _ = rgb_frame.shape
if x1 < 0 or y1 < 0 or x2 > w or y2 > h:
return None, None, None
face = rgb_frame[y1:y2, x1:x2]
if face.size == 0:
return None, None, None
face = cv2.resize(face, (256, 256))
face = torch.from_numpy(face).permute(2, 0, 1).unsqueeze(0).to(DEVICE)
face = face.to(torch.float32) / 255.0
with torch.no_grad():
output = torch.sigmoid(model(face).squeeze(0))
prediction = "fake" if output.item() >= 0.5 else "real"
# Generate GradCAM
grayscale_cam = cam(input_tensor=face, targets=targets, eigen_smooth=True)
grayscale_cam = grayscale_cam[0, :]
face_image_to_plot = face.squeeze(0).permute(1, 2, 0).cpu().detach().numpy()
visualization = show_cam_on_image(face_image_to_plot, grayscale_cam, use_rgb=True)
return prediction, output.item(), visualization
@spaces.GPU
def analyze_video(video_path, sample_rate=30, top_n=5, detection_threshold=0.5):
cap = cv2.VideoCapture(video_path)
frame_count = 0
fake_count = 0
total_processed = 0
frames_info = []
confidence_scores = []
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
if frame_count % sample_rate == 0:
prediction, confidence, visualization = process_frame(frame)
if prediction is not None:
total_processed += 1
confidence_scores.append(confidence)
if prediction == "fake":
fake_count += 1
frames_info.append({
'frame_number': frame_count,
'prediction': prediction,
'confidence': confidence,
'visualization': visualization
})
frame_count += 1
cap.release()
if total_processed > 0:
fake_percentage = (fake_count / total_processed) * 100
average_confidence = sum(confidence_scores) / len(confidence_scores)
model_confidence = 1 - (sum((score - average_confidence) ** 2 for score in confidence_scores) / len(confidence_scores))
frames_info.sort(key=lambda x: x['confidence'], reverse=True)
top_frames = frames_info[:top_n]
return {
'fake_percentage': fake_percentage,
'is_likely_deepfake': fake_percentage >= 60,
'top_frames': top_frames,
'model_confidence': model_confidence,
'total_frames_analyzed': total_processed,
'average_confidence_score': average_confidence,
'detection_threshold': detection_threshold
}
else:
return None
@app.route('/analyze', methods=['POST'])
def analyze_video_api():
if 'video' not in request.files:
return jsonify({'error': 'No video file provided'}), 400
file = request.files['video']
if file.filename == '':
return jsonify({'error': 'No selected file'}), 400
if file and allowed_file(file.filename):
filename = secure_filename(file.filename)
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(filepath)
try:
result = analyze_video(filepath)
os.remove(filepath) # Remove the file after analysis
if result:
# Convert numpy arrays to base64 encoded strings
for frame in result['top_frames']:
frame['visualization'] = base64.b64encode(cv2.imencode('.png', frame['visualization'])[1]).decode('utf-8')
return jsonify(result), 200
else:
return jsonify({'error': 'No frames could be processed'}), 400
except Exception as e:
os.remove(filepath) # Remove the file if an error occurs
return jsonify({'error': str(e)}), 500
else:
return jsonify({'error': f'Invalid file type: {file.filename}'}), 400
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)