luoamy3 commited on
Commit
b387f18
·
verified ·
1 Parent(s): 2ce2ff4

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +82 -0
README.md ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ lat mean = 39.951614360789364
2
+
3
+ lat std = 0.0007384844437841076
4
+
5
+ lon mean = -75.19140262762761
6
+
7
+ lon std = 0.0007284591160342192
8
+
9
+ **To load model:**
10
+ ```
11
+ from huggingface_hub import hf_hub_download
12
+ import torch
13
+
14
+ repo_id = "thestalkers/ImageToGPSproject_base_resnet18_v2"
15
+ filename = "resnet_gps_regressor_complete.pth"
16
+
17
+ model_path = hf_hub_download(repo_id=repo_id, filename=filename)
18
+
19
+ # Load the model using torch
20
+ model_test = torch.load(model_path)
21
+ model_test.eval() # Set the model to evaluation mode
22
+ ```
23
+
24
+ **Load a hf dataset:**
25
+ ```
26
+ from datasets import load_dataset, Image
27
+
28
+ dataset_test = load_dataset("gydou/released_img", split="train")
29
+
30
+ inference_transform = transforms.Compose([
31
+ transforms.Resize((224, 224)),
32
+ transforms.ToTensor(),
33
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
34
+ std=[0.229, 0.224, 0.225])
35
+ ])
36
+
37
+ test_dataset = GPSImageDataset(
38
+ hf_dataset=dataset_test,
39
+ transform=inference_transform,
40
+ lat_mean=lat_mean,
41
+ lat_std=lat_std,
42
+ lon_mean=lon_mean,
43
+ lon_std=lon_std
44
+ )
45
+ test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)
46
+ ```
47
+
48
+ **Perform inference:**
49
+ ```
50
+ from sklearn.metrics import mean_absolute_error, mean_squared_error
51
+
52
+ # Initialize lists to store predictions and actual values
53
+ all_preds = []
54
+ all_actuals = []
55
+
56
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
+ print(f'Using device: {device}')
58
+
59
+ with torch.no_grad():
60
+ for images, gps_coords in test_dataloader:
61
+ images, gps_coords = images.to(device), gps_coords.to(device)
62
+
63
+ outputs = model_test(images)
64
+
65
+ # Denormalize predictions and actual values
66
+ preds = outputs.cpu() * torch.tensor([lat_std, lon_std]) + torch.tensor([lat_mean, lon_mean])
67
+ actuals = gps_coords.cpu() * torch.tensor([lat_std, lon_std]) + torch.tensor([lat_mean, lon_mean])
68
+
69
+ all_preds.append(preds)
70
+ all_actuals.append(actuals)
71
+
72
+ # Concatenate all batches
73
+ all_preds = torch.cat(all_preds).numpy()
74
+ all_actuals = torch.cat(all_actuals).numpy()
75
+
76
+ # Compute error metrics
77
+ mae = mean_absolute_error(all_actuals, all_preds)
78
+ rmse = mean_squared_error(all_actuals, all_preds, squared=False)
79
+
80
+ print(f'Mean Absolute Error: {mae}')
81
+ print(f'Root Mean Squared Error: {rmse}')
82
+ ```