Raaniel commited on
Commit
d09f6cd
1 Parent(s): e86092a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -10
app.py CHANGED
@@ -5,21 +5,37 @@ import torchvision
5
  from PIL import Image
6
  import gradio as gr
7
  from huggingface_hub import hf_hub_download
8
-
9
- # Configure device
10
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
-
12
- transform = torchvision.transforms.Compose([
13
- torchvision.transforms.Resize((224, 224)),
14
- torchvision.transforms.ToTensor(),
15
- torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.225, 0.225, 0.225])
16
- ])
17
 
18
  REPO_ID = "Raaniel/model-smoke"
19
  MODEL_FILE_NAME = "model_smoke.pt"
 
 
 
 
20
  checkpoint_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILE_NAME)
21
 
22
- model = torch.jit.load(checkpoint_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  model = model.to(device)
24
 
25
  classes = ["chmury", 'inne', "dym"]
 
5
  from PIL import Image
6
  import gradio as gr
7
  from huggingface_hub import hf_hub_download
8
+ import torch.nn as nn
9
+ import timm
 
 
 
 
 
 
 
10
 
11
  REPO_ID = "Raaniel/model-smoke"
12
  MODEL_FILE_NAME = "model_smoke.pt"
13
+ USE_CUDA = torch.cuda.is_available()
14
+ num_classes = 3
15
+
16
+ # Download the model
17
  checkpoint_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILE_NAME)
18
 
19
+ # Load the checkpoint
20
+ state = torch.load(checkpoint_path, map_location=torch.device('cuda' if USE_CUDA else 'cpu'))
21
+
22
+ # Create the model and modify it
23
+ model = timm.create_model('mobilenetv3_small_050', pretrained=True)
24
+ num_features = model.classifier.in_features
25
+
26
+ # Additional linear and dropout layers
27
+ model.classifier = nn.Sequential(
28
+ nn.Linear(num_features, 256), # Additional linear layer
29
+ nn.ReLU(inplace=True),
30
+ nn.Dropout(0.5),
31
+ nn.Linear(256, num_classes) # Final classification layer
32
+ )
33
+
34
+ # Load the model weights
35
+ model.load_state_dict(state['weights'])
36
+
37
+ # Move model to the appropriate device
38
+ device = torch.device('cuda' if USE_CUDA else 'cpu')
39
  model = model.to(device)
40
 
41
  classes = ["chmury", 'inne', "dym"]