amaye15 commited on
Commit
170fa35
1 Parent(s): b09c427

deploy app

Browse files
Files changed (2) hide show
  1. app.py +59 -0
  2. requirements.txt +66 -0
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
3
+ from PIL import Image
4
+ import torch
5
+
6
+ # Load the feature extractor and model from Hugging Face
7
+ feature_extractor = AutoImageProcessor.from_pretrained(
8
+ "microsoft/swinv2-base-patch4-window16-256"
9
+ )
10
+ model = AutoModelForImageClassification.from_pretrained(
11
+ "amaye15/SwinV2-Base-Image-Orientation-Fixer"
12
+ )
13
+
14
+
15
+ def predict_image(image):
16
+ # Convert the image to the required format and preprocess it
17
+ inputs = feature_extractor(images=image, return_tensors="pt")
18
+ # Perform the prediction
19
+ outputs = model(**inputs)
20
+ # Get the predicted class probabilities
21
+ logits = outputs.logits
22
+ # Calculate softmax to get probabilities
23
+ probabilities = torch.softmax(logits, dim=-1).squeeze()
24
+ # Create a dictionary of all class labels and their probabilities
25
+ result = {
26
+ model.config.id2label[idx]: prob.item()
27
+ for idx, prob in enumerate(probabilities)
28
+ }
29
+ # Sort the results by probability in descending order
30
+ sorted_result = dict(sorted(result.items(), key=lambda item: item[1], reverse=True))
31
+ return sorted_result
32
+
33
+
34
+ # Enhanced description with a detailed overview of the app
35
+ description = """
36
+ ### Overview
37
+ This application is a web-based interface built using Gradio that allows users to upload images and receive class predictions with probabilities.
38
+ It utilizes a pre-trained SwinV2 model from Hugging Face.
39
+
40
+ ### How It Works
41
+ 1. **Image Upload**: Users upload an image which is then processed and classified by the model.
42
+ 2. **Feature Extraction**: The image is preprocessed using a feature extractor that converts it into a format suitable for the model.
43
+ 3. **Prediction**: The model predicts the class probabilities using a softmax function on the output logits.
44
+ 4. **Results**: The results are displayed as a sorted list of classes with their corresponding probabilities, showing the most likely class first.
45
+
46
+ Enjoy exploring the capabilities of this advanced image classification model!
47
+ """
48
+
49
+ # Create the Gradio interface using the updated components and enhanced description
50
+ iface = gr.Interface(
51
+ fn=predict_image, # The prediction function
52
+ inputs=gr.Image(type="pil"), # Accepts images in PIL format
53
+ outputs=gr.Label(num_top_classes=None), # Outputs all predicted classes
54
+ title="Image Orientation", # Optional title
55
+ description=description, # Enhanced description with detailed app overview
56
+ )
57
+
58
+ # Launch the Gradio app
59
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ annotated-types==0.7.0
3
+ anyio==4.4.0
4
+ certifi==2024.8.30
5
+ charset-normalizer==3.3.2
6
+ click==8.1.7
7
+ contourpy==1.3.0
8
+ cycler==0.12.1
9
+ fastapi==0.112.2
10
+ ffmpy==0.4.0
11
+ filelock==3.15.4
12
+ fonttools==4.53.1
13
+ fsspec==2024.6.1
14
+ gradio==4.42.0
15
+ gradio_client==1.3.0
16
+ h11==0.14.0
17
+ httpcore==1.0.5
18
+ httpx==0.27.2
19
+ huggingface-hub==0.24.6
20
+ idna==3.8
21
+ importlib_resources==6.4.4
22
+ Jinja2==3.1.4
23
+ kiwisolver==1.4.5
24
+ markdown-it-py==3.0.0
25
+ MarkupSafe==2.1.5
26
+ matplotlib==3.9.2
27
+ mdurl==0.1.2
28
+ mpmath==1.3.0
29
+ networkx==3.3
30
+ numpy==2.1.0
31
+ orjson==3.10.7
32
+ packaging==24.1
33
+ pandas==2.2.2
34
+ pdf2image==1.17.0
35
+ pillow==10.4.0
36
+ pydantic==2.8.2
37
+ pydantic_core==2.20.1
38
+ pydub==0.25.1
39
+ Pygments==2.18.0
40
+ pyparsing==3.1.4
41
+ python-dateutil==2.9.0.post0
42
+ python-multipart==0.0.9
43
+ pytz==2024.1
44
+ PyYAML==6.0.2
45
+ regex==2024.7.24
46
+ requests==2.32.3
47
+ rich==13.8.0
48
+ ruff==0.6.3
49
+ safetensors==0.4.4
50
+ semantic-version==2.10.0
51
+ shellingham==1.5.4
52
+ six==1.16.0
53
+ sniffio==1.3.1
54
+ starlette==0.38.4
55
+ sympy==1.13.2
56
+ tokenizers==0.19.1
57
+ tomlkit==0.12.0
58
+ torch==2.4.0
59
+ tqdm==4.66.5
60
+ transformers==4.44.2
61
+ typer==0.12.5
62
+ typing_extensions==4.12.2
63
+ tzdata==2024.1
64
+ urllib3==2.2.2
65
+ uvicorn==0.30.6
66
+ websockets==12.0