File size: 2,564 Bytes
6221b96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# inference.py
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
import streamlit as st
import torch
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
from models.srcnn import SRCNN
from models.vdsr import VDSR
from models.edsr import EDSR

def load_model(model_name):
    if model_name == 'SRCNN':
        model = SRCNN()
    elif model_name == 'VDSR':
        model = VDSR()
    else:
        model = EDSR()
    
    model.load_state_dict(torch.load(f'checkpoints/{model_name.lower()}_best.pth', map_location=torch.device('cpu')))
    model.eval()
    return model

def process_image(image, model):
    # Convert to YCbCr and extract Y channel
    ycbcr = image.convert('YCbCr')
    y, cb, cr = ycbcr.split()
    
    # Transform Y channel
    transform = transforms.Compose([
        transforms.ToTensor()
    ])
    
    input_tensor = transform(y).unsqueeze(0)
    
    # Process through model
    with torch.no_grad():
        output = model(input_tensor)
    
    # Post-process output
    output = output.squeeze().clamp(0, 1).numpy()
    output_y = Image.fromarray((output * 255).astype(np.uint8))
    
    # Merge channels back
    output_ycbcr = Image.merge('YCbCr', [output_y, cb, cr])
    output_rgb = output_ycbcr.convert('RGB')
    
    return output_rgb

def main():
    st.title("Super Resolution Model Comparison")
    st.write("Upload a low-resolution image to compare SRCNN, VDSR, and EDSR models")
    
    # File uploader
    uploaded_file = st.file_uploader("Choose an image", type=['png', 'jpg', 'jpeg'])
    
    if uploaded_file is not None:
        # Load and display input image
        input_image = Image.open(uploaded_file)
        st.subheader("Input Image")
        st.image(input_image, caption="Original Image")
        
        # Process with each model
        col1, col2, col3 = st.columns(3)
        
        with col1:
            st.subheader("SRCNN")
            model = load_model('SRCNN')
            srcnn_output = process_image(input_image, model)
            st.image(srcnn_output, caption="SRCNN Output")
        
        with col2:
            st.subheader("VDSR")
            model = load_model('VDSR')
            vdsr_output = process_image(input_image, model)
            st.image(vdsr_output, caption="VDSR Output")
        
        with col3:
            st.subheader("EDSR")
            model = load_model('EDSR')
            edsr_output = process_image(input_image, model)
            st.image(edsr_output, caption="EDSR Output")

if __name__ == "__main__":
    main()