yunusserhat commited on
Commit
894bc0c
·
verified ·
1 Parent(s): 26c1e06

Create APP

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +35 -35
  2. README.md +13 -13
  3. app.py +128 -0
  4. configs/computer/a100.yaml +8 -0
  5. configs/computer/cluster-node-a100.yaml +8 -0
  6. configs/computer/cluster-node-v100.yaml +8 -0
  7. configs/computer/cpu.yaml +8 -0
  8. configs/computer/v100.yaml +8 -0
  9. configs/config.yaml +89 -0
  10. configs/dataset/baselines/im2gps.yaml +16 -0
  11. configs/dataset/baselines/im2gps3k.yaml +16 -0
  12. configs/dataset/baselines/yfcc4k.yaml +16 -0
  13. configs/dataset/osv5m.yaml +46 -0
  14. configs/dataset/osv5m_contrastive.yaml +34 -0
  15. configs/dataset/osv5m_contrastive_best.yaml +37 -0
  16. configs/dataset/osv5m_text_contrastive.yaml +34 -0
  17. configs/dataset/test_transform/center_crop.yaml +12 -0
  18. configs/dataset/test_transform/clip.yaml +2 -0
  19. configs/dataset/test_transform/fast_clip.yaml +12 -0
  20. configs/dataset/test_transform/fast_resnet.yaml +12 -0
  21. configs/dataset/test_transform/none.yaml +6 -0
  22. configs/dataset/train_transform/augmentation.yaml +85 -0
  23. configs/dataset/train_transform/center_crop.yaml +14 -0
  24. configs/dataset/train_transform/clip.yaml +2 -0
  25. configs/dataset/train_transform/fast_clip.yaml +12 -0
  26. configs/dataset/train_transform/fast_resnet.yaml +12 -0
  27. configs/dataset/train_transform/none.yaml +7 -0
  28. configs/exp/DinoV2.yaml +18 -0
  29. configs/exp/ResNet.yaml +21 -0
  30. configs/exp/base_model.yaml +19 -0
  31. configs/exp/best_model.yaml +25 -0
  32. configs/exp/classification_area.yaml +19 -0
  33. configs/exp/classification_cell.yaml +19 -0
  34. configs/exp/classification_cell_hier.yaml +20 -0
  35. configs/exp/classification_city.yaml +19 -0
  36. configs/exp/classification_city_hier.yaml +20 -0
  37. configs/exp/classification_country.yaml +19 -0
  38. configs/exp/classification_region copy.yaml +19 -0
  39. configs/exp/classification_region.yaml +19 -0
  40. configs/exp/clip_L_14_DataComp.yaml +18 -0
  41. configs/exp/clip_L_14_Laion.yaml +18 -0
  42. configs/exp/clip_L_14_OpenAI.yaml +18 -0
  43. configs/exp/clip_bigG_14_Laion.yaml +18 -0
  44. configs/exp/contrastive_area.yaml +20 -0
  45. configs/exp/contrastive_cell.yaml +20 -0
  46. configs/exp/contrastive_city.yaml +20 -0
  47. configs/exp/contrastive_country.yaml +20 -0
  48. configs/exp/contrastive_region.yaml +20 -0
  49. configs/exp/contrastive_text.yaml +22 -0
  50. configs/exp/eval_best_model.yaml +29 -0
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,13 +1,13 @@
1
- ---
2
- title: Location Predictor
3
- emoji: 📚
4
- colorFrom: pink
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 4.38.1
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: Location Predictor
3
+ emoji: 🌍
4
+ colorFrom: red
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 4.38.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from geoclip import GeoCLIP
3
+ from PIL import Image
4
+ import tempfile
5
+ from pathlib import Path
6
+ import gradio as gr
7
+ import spaces
8
+ from geopy.geocoders import Nominatim
9
+ from transformers import CLIPProcessor, CLIPModel
10
+ from torchvision import transforms
11
+ import reverse_geocoder as rg
12
+ from models.huggingface import Geolocalizer
13
+ import folium
14
+ import json
15
+ from geopy.exc import GeocoderTimedOut
16
+
17
+ if torch.cuda.is_available():
18
+ geoclip_model = GeoCLIP().to("cuda")
19
+ else:
20
+ geoclip_model = GeoCLIP()
21
+
22
+ geolocator = Nominatim(user_agent="predictGeolocforImage")
23
+
24
+ streetclip_model = CLIPModel.from_pretrained("geolocal/StreetCLIP")
25
+ streetclip_processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP")
26
+ labels = ['Albania', 'Andorra', 'Argentina', 'Australia', 'Austria', 'Bangladesh', 'Belgium', 'Bermuda', 'Bhutan', 'Bolivia', 'Botswana', 'Brazil', 'Bulgaria', 'Cambodia', 'Canada', 'Chile', 'China', 'Colombia', 'Croatia', 'Czech Republic', 'Denmark', 'Dominican Republic', 'Ecuador', 'Estonia', 'Finland', 'France', 'Germany', 'Ghana', 'Greece', 'Greenland', 'Guam', 'Guatemala', 'Hungary', 'Iceland', 'India', 'Indonesia', 'Ireland', 'Israel', 'Italy', 'Japan', 'Jordan', 'Kenya', 'Kyrgyzstan', 'Laos', 'Latvia', 'Lesotho', 'Lithuania', 'Luxembourg', 'Macedonia', 'Madagascar', 'Malaysia', 'Malta', 'Mexico', 'Monaco', 'Mongolia', 'Montenegro', 'Netherlands', 'New Zealand', 'Nigeria', 'Norway', 'Pakistan', 'Palestine', 'Peru', 'Philippines', 'Poland', 'Portugal', 'Puerto Rico', 'Romania', 'Russia', 'Rwanda', 'Senegal', 'Serbia', 'Singapore', 'Slovakia', 'Slovenia', 'South Africa', 'South Korea', 'Spain', 'Sri Lanka', 'Swaziland', 'Sweden', 'Switzerland', 'Taiwan', 'Thailand', 'Tunisia', 'Turkey', 'Uganda', 'Ukraine', 'United Arab Emirates', 'United Kingdom', 'United States', 'Uruguay']
27
+
28
+ IMAGE_SIZE = (224, 224)
29
+ GEOLOC_MODEL_NAME = "osv5m/baseline"
30
+ geoloc_model = Geolocalizer.from_pretrained(GEOLOC_MODEL_NAME)
31
+ geoloc_model.eval()
32
+
33
+ def transform_image(image):
34
+ transform = transforms.Compose([
35
+ transforms.Resize(IMAGE_SIZE),
36
+ transforms.ToTensor(),
37
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
38
+ ])
39
+ return transform(image).unsqueeze(0)
40
+
41
+ def create_map(lat, lon):
42
+ m = folium.Map(location=[lat, lon], zoom_start=4)
43
+ folium.Marker([lat, lon]).add_to(m)
44
+ map_html = m._repr_html_()
45
+ return map_html
46
+
47
+ def get_country_coordinates(country_name):
48
+ try:
49
+ location = geolocator.geocode(country_name, timeout=10)
50
+ if location:
51
+ return location.latitude, location.longitude
52
+ except GeocoderTimedOut:
53
+ return None
54
+ return None
55
+
56
+ @spaces.GPU
57
+ def predict_geoclip(image):
58
+ with tempfile.TemporaryDirectory() as tmp_dir:
59
+ tmppath = Path(tmp_dir) / "tmp.jpg"
60
+ image.save(str(tmppath))
61
+ top_pred_gps, top_pred_prob = geoclip_model.predict(str(tmppath), top_k=50)
62
+
63
+ predictions = []
64
+ for i in range(1):
65
+ lat, lon = top_pred_gps[i]
66
+ probpercent = top_pred_prob[i] * 100
67
+ location = geolocator.reverse((lat, lon), exactly_one=True)
68
+ address = location.raw['address']
69
+ city = address.get('city', '')
70
+ country = address.get('country', '')
71
+ prediction = f"Latitude: {lat:.6f}, Longitude: {lon:.6f} - Country: {country}"
72
+ predictions.append(prediction)
73
+
74
+ map_html = create_map(lat, lon)
75
+ return "\n".join(predictions), map_html
76
+
77
+ @spaces.GPU
78
+ def classify_streetclip(image):
79
+ inputs = streetclip_processor(text=labels, images=image, return_tensors="pt", padding=True)
80
+ with torch.no_grad():
81
+ outputs = streetclip_model(**inputs)
82
+ logits_per_image = outputs.logits_per_image
83
+ prediction = logits_per_image.softmax(dim=1)
84
+ confidences = {labels[i]: float(prediction[0][i].item()) for i in range(len(labels))}
85
+
86
+ sorted_confidences = sorted(confidences.items(), key=lambda item: item[1], reverse=True)
87
+ top_label, top_confidence = sorted_confidences[0]
88
+ coords = get_country_coordinates(top_label)
89
+ map_html = create_map(*coords) if coords else "Map not available"
90
+ return f"Country: {top_label}", map_html
91
+
92
+ def infer(image):
93
+ try:
94
+ img_tensor = transform_image(image)
95
+ gps_radians = geoloc_model(img_tensor)
96
+ gps_degrees = torch.rad2deg(gps_radians).squeeze(0).cpu().tolist()
97
+ lat, lon = gps_degrees[0], gps_degrees[1]
98
+ location_query = rg.search((lat, lon))[0]
99
+ location_name = f"{location_query['name']}, {location_query['admin1']}, {location_query['cc']}"
100
+ map_html = create_map(lat, lon)
101
+ return f"Latitude: {lat:.6f}, Longitude: {lon:.6f} - Country: {location_query['admin1']} - {location_query['cc']}", map_html
102
+ except Exception as e:
103
+ return f"Failed to predict the location: {e}", None
104
+
105
+ geoclip_interface = gr.Interface(
106
+ fn=predict_geoclip,
107
+ inputs=gr.Image(type="pil", label="Upload Image", elem_id="geoclip_image_input"),
108
+ outputs=[gr.Textbox(label="Prediction", elem_id="geoclip_output"), gr.HTML(label="Map", elem_id="geoclip_map_output")],
109
+ title="GeoCLIP"
110
+ )
111
+
112
+ streetclip_interface = gr.Interface(
113
+ fn=classify_streetclip,
114
+ inputs=gr.Image(type="pil", label="Upload Image", elem_id="streetclip_image_input"),
115
+ outputs=[gr.Textbox(label="Prediction", elem_id="streetclip_output"), gr.HTML(label="Map", elem_id="streetclip_map_output")],
116
+ title="StreetCLIP"
117
+ )
118
+
119
+ osv5m_interface = gr.Interface(
120
+ fn=infer,
121
+ inputs=gr.Image(label="Upload Image", type="pil", elem_id="osv5m_image_input"),
122
+ outputs=[gr.Textbox(label="Prediction", elem_id="result_text"), gr.HTML(label="Map", elem_id="map_output")],
123
+ title="OSV-5M Baseline"
124
+ )
125
+
126
+ demo = gr.TabbedInterface([geoclip_interface, streetclip_interface, osv5m_interface], tab_names=["GeoCLIP", "StreetCLIP", "OSV-5M Baseline"])
127
+
128
+ demo.launch()
configs/computer/a100.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ devices: 1
2
+ progress_bar_refresh_rate: 2
3
+ num_workers: 8
4
+ sync_batchnorm: False
5
+ accelerator: gpu
6
+ precision: 32
7
+ strategy: auto
8
+ num_nodes: 1
configs/computer/cluster-node-a100.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ devices: 8
2
+ num_workers: 8
3
+ progress_bar_refresh_rate: 2
4
+ sync_batchnorm: True
5
+ accelerator: gpu
6
+ precision: 32
7
+ strategy: ddp
8
+ num_nodes: 1
configs/computer/cluster-node-v100.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ devices: 4
2
+ num_workers: 10
3
+ progress_bar_refresh_rate: 2
4
+ sync_batchnorm: True
5
+ accelerator: gpu
6
+ precision: 32
7
+ strategy: ddp
8
+ num_nodes: 1
configs/computer/cpu.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ devices: null
2
+ num_workers: 0
3
+ progress_bar_refresh_rate: 2
4
+ sync_batchnorm: False
5
+ accelerator: cpu
6
+ precision: 32
7
+ strategy: auto
8
+ num_nodes: null
configs/computer/v100.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ devices: 1
2
+ num_workers: 10
3
+ progress_bar_refresh_rate: 2
4
+ sync_batchnorm: False
5
+ accelerator: gpu
6
+ precision: 32
7
+ strategy: auto
8
+ num_nodes: 1
configs/config.yaml ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - model: default
3
+ - computer: v100
4
+ - dataset: osv5m
5
+ - _self_
6
+ - exp: ???
7
+
8
+ model:
9
+ val_metrics:
10
+ _target_: metrics.distance_based.HaversineMetrics
11
+ acc_radiuses:
12
+ - 1
13
+ - 25
14
+ - 200
15
+ - 750
16
+ - 2500
17
+ acc_area: []
18
+ aux_data: ${aux_data}
19
+ test_metrics:
20
+ _target_: metrics.distance_based.HaversineMetrics
21
+ acc_radiuses:
22
+ - 1
23
+ - 25
24
+ - 200
25
+ - 750
26
+ - 2500
27
+ acc_area: ${areas}
28
+ aux_data: ${aux_data}
29
+
30
+ datamodule:
31
+ _target_: data.datamodule.ImageDataModule
32
+ train_dataset: ${dataset.train_dataset}
33
+ val_dataset: ${dataset.val_dataset}
34
+ test_dataset: ${dataset.test_dataset}
35
+ global_batch_size: ${dataset.global_batch_size}
36
+ num_workers: ${computer.num_workers}
37
+ num_nodes: ${computer.num_nodes}
38
+ num_devices: ${computer.devices}
39
+ val_proportion: 0.1
40
+
41
+ trainer:
42
+ _target_: pytorch_lightning.Trainer
43
+ devices: ${computer.devices}
44
+ accelerator: ${computer.accelerator}
45
+ strategy: ${computer.strategy}
46
+ num_nodes: ${computer.num_nodes}
47
+ precision: ${computer.precision}
48
+ max_epochs: ${max_epochs}
49
+
50
+ logger:
51
+ _target_: pytorch_lightning.loggers.WandbLogger
52
+ save_dir: ${root_dir}
53
+ name: ${experiment_name}
54
+ project: plonk
55
+ log_model: False
56
+ offline: False
57
+ entity: imaginelab
58
+
59
+ checkpoints:
60
+ _target_: pytorch_lightning.callbacks.ModelCheckpoint
61
+ dirpath: ${root_dir}/checkpoints/${experiment_name}
62
+ filename: 'epoch_{epoch}'
63
+ monitor: val/loss
64
+ save_last: True
65
+ save_top_k: 0
66
+ every_n_epochs: 1
67
+
68
+ progress_bar:
69
+ _target_: pytorch_lightning.callbacks.TQDMProgressBar
70
+ refresh_rate: ${computer.progress_bar_refresh_rate}
71
+
72
+ aux_data: []
73
+ max_epochs: 100
74
+ data_dir: ${root_dir}/datasets
75
+ root_dir: ${hydra:runtime.cwd}
76
+ experiment_name: ${dataset.name}__${model.name}
77
+ mode: train # change that to eval to do the testing
78
+ num_classes: 0
79
+ areas: ['country', 'region', 'sub-region', 'city']
80
+ class_name: null
81
+ streetclip: False
82
+ blur: False
83
+ text_tuning: False
84
+
85
+ hydra:
86
+ run:
87
+ dir: outputs/${hydra.job.name}/${now:%Y-%m-%d_%H-%M-%S}/${experiment_name}
88
+ job:
89
+ chdir: true
configs/dataset/baselines/im2gps.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ name: im2gps
3
+ global_batch_size: 512
4
+ test_dataset:
5
+ _partial_: true
6
+ _target_: data.data.Baseline
7
+ path: ${data_dir}/baselines/im2gps
8
+ which: 'im2gps'
9
+ transforms: ${dataset.test_transform}
10
+ datamodule:
11
+ _target_: data.datamodule.BaselineDataModule
12
+ test_dataset: ${dataset.test_dataset}
13
+ global_batch_size: ${dataset.global_batch_size}
14
+ num_workers: ${computer.num_workers}
15
+ num_nodes: ${computer.num_nodes}
16
+ num_devices: ${computer.devices}
configs/dataset/baselines/im2gps3k.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ name: im2gps3k
3
+ global_batch_size: 512
4
+ test_dataset:
5
+ _partial_: true
6
+ _target_: data.data.Baseline
7
+ path: ${data_dir}/baselines/im2gps3k
8
+ which: 'im2gps3k'
9
+ transforms: ${dataset.test_transform}
10
+ datamodule:
11
+ _target_: data.datamodule.BaselineDataModule
12
+ test_dataset: ${dataset.test_dataset}
13
+ global_batch_size: ${dataset.global_batch_size}
14
+ num_workers: ${computer.num_workers}
15
+ num_nodes: ${computer.num_nodes}
16
+ num_devices: ${computer.devices}
configs/dataset/baselines/yfcc4k.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ name: yfcc4k
3
+ global_batch_size: 512
4
+ test_dataset:
5
+ _partial_: true
6
+ _target_: data.data.Baseline
7
+ path: ${data_dir}/baselines/yfcc4k
8
+ which: 'yfcc4k'
9
+ transforms: ${dataset.test_transform}
10
+ datamodule:
11
+ _target_: data.datamodule.BaselineDataModule
12
+ test_dataset: ${dataset.test_dataset}
13
+ global_batch_size: ${dataset.global_batch_size}
14
+ num_workers: ${computer.num_workers}
15
+ num_nodes: ${computer.num_nodes}
16
+ num_devices: ${computer.devices}
configs/dataset/osv5m.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - train_transform: fast_clip
3
+ - test_transform: fast_clip
4
+ - _self_
5
+
6
+ name: osv5m
7
+ global_batch_size: 256
8
+
9
+ train_dataset:
10
+ _partial_: true
11
+ _target_: data.data.osv5m
12
+ path: ${data_dir}/osv5m/
13
+ split: train
14
+ class_name: ${class_name}
15
+ transforms: ${dataset.train_transform}
16
+ aux_data: ${aux_data}
17
+ is_baseline: ${is_baseline}
18
+ areas: ${areas}
19
+ streetclip: ${streetclip}
20
+ blur: ${blur}
21
+
22
+ val_dataset:
23
+ _partial_: true
24
+ _target_: data.data.osv5m
25
+ path: ${data_dir}/osv5m/
26
+ split: val
27
+ class_name: ${class_name}
28
+ transforms: ${dataset.test_transform}
29
+ aux_data: ${aux_data}
30
+ is_baseline: ${is_baseline}
31
+ areas: ${areas}
32
+ streetclip: ${streetclip}
33
+ blur: ${blur}
34
+
35
+ test_dataset:
36
+ _partial_: true
37
+ _target_: data.data.osv5m
38
+ path: ${data_dir}/osv5m/
39
+ split: test
40
+ class_name: ${class_name}
41
+ transforms: ${dataset.test_transform}
42
+ aux_data: ${aux_data}
43
+ is_baseline: ${is_baseline}
44
+ areas: ${areas}
45
+ streetclip: ${streetclip}
46
+ blur: ${blur}
configs/dataset/osv5m_contrastive.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - train_transform: fast_clip
3
+ - test_transform: fast_clip
4
+ - _self_
5
+
6
+ name: osv5m
7
+ global_batch_size: 256
8
+
9
+ train_dataset:
10
+ _partial_: true
11
+ _target_: data.data.Contrastiveosv5m
12
+ path: ${data_dir}/osv5m/
13
+ split: train
14
+ class_name: ${class_name}
15
+ transforms: ${dataset.train_transform}
16
+ blur: ${blur}
17
+
18
+ val_dataset:
19
+ _partial_: true
20
+ _target_: data.data.Contrastiveosv5m
21
+ path: ${data_dir}/osv5m/
22
+ split: val
23
+ class_name: ${class_name}
24
+ transforms: ${dataset.test_transform}
25
+ blur: ${blur}
26
+
27
+ test_dataset:
28
+ _partial_: true
29
+ _target_: data.data.Contrastiveosv5m
30
+ path: ${data_dir}/osv5m/
31
+ split: test
32
+ class_name: ${class_name}
33
+ transforms: ${dataset.test_transform}
34
+ blur: ${blur}
configs/dataset/osv5m_contrastive_best.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - train_transform: fast_clip
3
+ - test_transform: fast_clip
4
+ - _self_
5
+
6
+ name: osv5m
7
+ global_batch_size: 256
8
+
9
+ train_dataset:
10
+ _partial_: true
11
+ _target_: data.data.Contrastiveosv5m
12
+ path: ${data_dir}/osv5m/
13
+ split: train
14
+ class_name: ${class_name}
15
+ transforms: ${dataset.train_transform}
16
+ class_name2: 'unique_region'
17
+ blur: ${blur}
18
+
19
+ val_dataset:
20
+ _partial_: true
21
+ _target_: data.data.Contrastiveosv5m
22
+ path: ${data_dir}/osv5m/
23
+ split: val
24
+ class_name: ${class_name}
25
+ transforms: ${dataset.test_transform}
26
+ class_name2: 'unique_region'
27
+ blur: ${blur}
28
+
29
+ test_dataset:
30
+ _partial_: true
31
+ _target_: data.data.Contrastiveosv5m
32
+ path: ${data_dir}/osv5m/
33
+ split: test
34
+ class_name: ${class_name}
35
+ transforms: ${dataset.test_transform}
36
+ class_name2: 'unique_region'
37
+ blur: ${blur}
configs/dataset/osv5m_text_contrastive.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - train_transform: fast_clip
3
+ - test_transform: fast_clip
4
+ - _self_
5
+
6
+ name: osv5m
7
+ global_batch_size: 256
8
+
9
+ train_dataset:
10
+ _partial_: true
11
+ _target_: data.data.TextContrastiveosv5m
12
+ path: ${data_dir}/osv5m/
13
+ split: train
14
+ class_name: ${class_name}
15
+ transforms: ${dataset.train_transform}
16
+ blur: ${blur}
17
+
18
+ val_dataset:
19
+ _partial_: true
20
+ _target_: data.data.TextContrastiveosv5m
21
+ path: ${data_dir}/osv5m/
22
+ split: val
23
+ class_name: ${class_name}
24
+ transforms: ${dataset.test_transform}
25
+ blur: ${blur}
26
+
27
+ test_dataset:
28
+ _partial_: true
29
+ _target_: data.data.TextContrastiveosv5m
30
+ path: ${data_dir}/osv5m/
31
+ split: test
32
+ class_name: ${class_name}
33
+ transforms: ${dataset.test_transform}
34
+ blur: ${blur}
configs/dataset/test_transform/center_crop.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: torchvision.transforms.Compose
2
+ transforms:
3
+ - _target_: torchvision.transforms.ToTensor
4
+ - _target_: utils.image_processing.CenterCrop
5
+ ratio: "1:1"
6
+ - _target_: torchvision.transforms.Resize
7
+ size: ${dataset.img_resolution}
8
+ interpolation: 3
9
+ antialias: true
10
+ - _target_: torchvision.transforms.Normalize
11
+ mean: 0.5
12
+ std: 0.5
configs/dataset/test_transform/clip.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ _target_: data.transforms.ClipTransform
2
+ split: val
configs/dataset/test_transform/fast_clip.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: torchvision.transforms.Compose
2
+ transforms:
3
+ - _target_: torchvision.transforms.Resize
4
+ size: 224
5
+ interpolation: 3
6
+ antialias: true
7
+ - _target_: torchvision.transforms.CenterCrop
8
+ size: 224
9
+ - _target_: torchvision.transforms.ToTensor
10
+ - _target_: torchvision.transforms.Normalize
11
+ mean: [0.48145466, 0.4578275, 0.40821073]
12
+ std: [0.26862954, 0.26130258, 0.27577711]
configs/dataset/test_transform/fast_resnet.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: torchvision.transforms.Compose
2
+ transforms:
3
+ - _target_: torchvision.transforms.Resize
4
+ size: 224
5
+ interpolation: 3
6
+ antialias: true
7
+ - _target_: torchvision.transforms.CenterCrop
8
+ size: 224
9
+ - _target_: torchvision.transforms.ToTensor
10
+ - _target_: torchvision.transforms.Normalize
11
+ mean: [0.485 ,0.456 ,0.406]
12
+ std: [0.229, 0.224, 0.225]
configs/dataset/test_transform/none.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ _target_: torchvision.transforms.Compose
2
+ transforms:
3
+ - _target_: torchvision.transforms.ToTensor
4
+ - _target_: torchvision.transforms.Normalize
5
+ mean: 0.5
6
+ std: 0.5
configs/dataset/train_transform/augmentation.yaml ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: data.augmentation.ImageAugmentation
2
+ names: "standard_augmentation,geometric_augmentation,clip_transform"
3
+
4
+ # always apply clip_transform at the end
5
+ clip_transform:
6
+ _target_: torchvision.transforms.Compose
7
+ transforms:
8
+ - _target_: torchvision.transforms.Resize
9
+ size: 224
10
+ interpolation: 3
11
+ antialias: true
12
+ - _target_: torchvision.transforms.CenterCrop
13
+ size: 224
14
+ - _target_: torchvision.transforms.ToTensor
15
+ - _target_: torchvision.transforms.Normalize
16
+ mean: [0.48145466, 0.4578275, 0.40821073]
17
+ std: [0.26862954, 0.26130258, 0.27577711]
18
+
19
+ standard_augmentation:
20
+ _target_: data.augmentation.StandardAugmentation
21
+ # by default, we all augmentation methods
22
+ names: "brightness,contrast,sharpness,color,blur,gaussian_noise"
23
+
24
+ # random PIL brigtness
25
+ brightness:
26
+ _target_: data.augmentation.PillowBrightness
27
+ p: 0.2
28
+ factor_interval: [0.5, 1.5]
29
+
30
+ # random PIL contrast
31
+ contrast:
32
+ _target_: data.augmentation.PillowContrast
33
+ p: 0.2
34
+ factor_interval: [0.3, 3]
35
+
36
+ # random PIL sharpness
37
+ sharpness:
38
+ _target_: data.augmentation.PillowSharpness
39
+ p: 0.2
40
+ factor_interval: [0.5, 30.0]
41
+
42
+ # random PIL color
43
+ color:
44
+ _target_: data.augmentation.PillowColor
45
+ p: 0.2
46
+ factor_interval: [0.0, 2.0]
47
+
48
+ # random PIL blur
49
+ blur:
50
+ _target_: data.augmentation.PillowBlur
51
+ p: 0.2
52
+ factor_interval: [1, 2]
53
+
54
+ # random numpy gaussian noise
55
+ gaussian_noise:
56
+ _target_: data.augmentation.NumpyGaussianNoise
57
+ p: 0.2
58
+ factor_interval: [0.1, 0.04]
59
+
60
+ geometric_augmentation:
61
+ _target_: data.augmentation.GeometricAugmentation
62
+ # by default, we all augmentation methods
63
+ names: "random_rotation,random_resized_crop,random_horizontal_flip"
64
+
65
+ # random rotation
66
+ random_rotation:
67
+ _target_: torchvision.transforms.RandomRotation
68
+ degrees: [-15, 15]
69
+
70
+ # random crop
71
+ random_resized_crop:
72
+ _target_: torchvision.transforms.RandomResizedCrop
73
+ scale: [0.5, 1.0]
74
+ ratio: [0.9, 1.1]
75
+ size: 224
76
+
77
+ # random horizontal flip
78
+ random_horizontal_flip:
79
+ _target_: torchvision.transforms.RandomHorizontalFlip
80
+ p: 0.5
81
+
82
+ # random vertical flip
83
+ random_vertical_flip:
84
+ _target_: torchvision.transforms.RandomVerticalFlip
85
+ p: 0.5
configs/dataset/train_transform/center_crop.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: torchvision.transforms.Compose
2
+ transforms:
3
+ - _target_: torchvision.transforms.ToTensor
4
+ - _target_: utils.image_processing.CenterCrop
5
+ ratio: "1:1"
6
+ - _target_: torchvision.transforms.Resize
7
+ size: ${dataset.img_resolution}
8
+ interpolation: 3
9
+ antialias: true
10
+ - _target_: torchvision.transforms.RandomHorizontalFlip
11
+ p: 0.5
12
+ - _target_: torchvision.transforms.Normalize
13
+ mean: 0.5
14
+ std: 0.5
configs/dataset/train_transform/clip.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ _target_: data.transforms.ClipTransform
2
+ split: val
configs/dataset/train_transform/fast_clip.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: torchvision.transforms.Compose
2
+ transforms:
3
+ - _target_: torchvision.transforms.Resize
4
+ size: 224
5
+ interpolation: 3
6
+ antialias: true
7
+ - _target_: torchvision.transforms.CenterCrop
8
+ size: 224
9
+ - _target_: torchvision.transforms.ToTensor
10
+ - _target_: torchvision.transforms.Normalize
11
+ mean: [0.48145466, 0.4578275, 0.40821073]
12
+ std: [0.26862954, 0.26130258, 0.27577711]
configs/dataset/train_transform/fast_resnet.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: torchvision.transforms.Compose
2
+ transforms:
3
+ - _target_: torchvision.transforms.Resize
4
+ size: 224
5
+ interpolation: 3
6
+ antialias: true
7
+ - _target_: torchvision.transforms.CenterCrop
8
+ size: 224
9
+ - _target_: torchvision.transforms.ToTensor
10
+ - _target_: torchvision.transforms.Normalize
11
+ mean: [0.485 ,0.456 ,0.406]
12
+ std: [0.229, 0.224, 0.225]
configs/dataset/train_transform/none.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ _target_: torchvision.transforms.Compose
2
+ transforms:
3
+ - _target_: torchvision.transforms.Resize
4
+ size: 224
5
+ interpolation: 3
6
+ antialias: true
7
+ - _target_: torchvision.transforms.ToTensor
configs/exp/DinoV2.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /model: regression
5
+ - override /model/network/backbone: dinov2_vitl14
6
+ - _self_
7
+
8
+ model:
9
+ optimizer:
10
+ optim:
11
+ lr: 0.0002
12
+ weight_decay: 0.0001
13
+
14
+ is_baseline: false
15
+ max_epochs: 30
16
+
17
+ dataset:
18
+ global_batch_size: 2048
configs/exp/ResNet.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /model: regression
5
+ - override /dataset/test_transform: fast_resnet
6
+ - override /dataset/train_transform: fast_resnet
7
+ - override /model.network.mid: mlp_resnet
8
+ - override /model/network/backbone: ResNet50
9
+ - _self_
10
+
11
+ model:
12
+ optimizer:
13
+ optim:
14
+ lr: 0.0002
15
+ weight_decay: 0.0001
16
+
17
+ is_baseline: false
18
+ max_epochs: 30
19
+
20
+ dataset:
21
+ global_batch_size: 2048
configs/exp/base_model.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /model: regression
5
+ - override /model/network/backbone: openclip_B_32
6
+ - _self_
7
+
8
+ model:
9
+ name: base_model
10
+ optimizer:
11
+ optim:
12
+ lr: 0.0002
13
+ weight_decay: 0.0001
14
+
15
+ is_baseline: false
16
+ max_epochs: 30
17
+
18
+ dataset:
19
+ global_batch_size: 2048
configs/exp/best_model.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: osv5m_contrastive_best
5
+ - override /model: hybrid
6
+ - override /model/network: best_backbone
7
+ - override /model/network/backbone: clip_L_14_DataComp
8
+ - override /model/network/mid: mlp_hybrid
9
+ - override /model/loss: best_model
10
+ - _self_
11
+
12
+ class_name: 'quadtree_10_1000'
13
+ is_baseline: false
14
+ max_epochs: 30
15
+
16
+ model:
17
+ name: best_model
18
+ optimizer:
19
+ optim:
20
+ lr: 2e-4
21
+ weight_decay: 0.0001
22
+ backbone_lr: 2e-5
23
+
24
+ dataset:
25
+ global_batch_size: 2048
configs/exp/classification_area.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /model: classification
5
+ - override /model/network/backbone: openclip_B_32
6
+ - _self_
7
+
8
+ class_name: 'area'
9
+ model:
10
+ optimizer:
11
+ optim:
12
+ lr: 0.0002
13
+ weight_decay: 0.0001
14
+
15
+ is_baseline: false
16
+ max_epochs: 15
17
+
18
+ dataset:
19
+ global_batch_size: 2048
configs/exp/classification_cell.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /model: classification
5
+ - override /model/network/backbone: openclip_B_32
6
+ - _self_
7
+
8
+ class_name: quadtree_10_1000
9
+ model:
10
+ optimizer:
11
+ optim:
12
+ lr: 0.0002
13
+ weight_decay: 0.0001
14
+
15
+ is_baseline: false
16
+ max_epochs: 15
17
+
18
+ dataset:
19
+ global_batch_size: 2048
configs/exp/classification_cell_hier.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /model: classification
5
+ - override /model/network/backbone: openclip_B_32
6
+ - override /model/loss: cls_hier_quad
7
+ - _self_
8
+
9
+ class_name: quadtree_10_1000
10
+ model:
11
+ optimizer:
12
+ optim:
13
+ lr: 0.0002
14
+ weight_decay: 0.0001
15
+
16
+ is_baseline: false
17
+ max_epochs: 15
18
+
19
+ dataset:
20
+ global_batch_size: 2048
configs/exp/classification_city.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /model: classification
5
+ - override /model/network/backbone: openclip_B_32
6
+ - _self_
7
+
8
+ class_name: 'city'
9
+ model:
10
+ optimizer:
11
+ optim:
12
+ lr: 0.0002
13
+ weight_decay: 0.0001
14
+
15
+ is_baseline: false
16
+ max_epochs: 15
17
+
18
+ dataset:
19
+ global_batch_size: 2048
configs/exp/classification_city_hier.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /model: classification
5
+ - override /model/network/backbone: openclip_B_32
6
+ - override /model/loss: cls_hier
7
+ - _self_
8
+
9
+ class_name: 'city'
10
+ model:
11
+ optimizer:
12
+ optim:
13
+ lr: 0.0002
14
+ weight_decay: 0.0001
15
+
16
+ is_baseline: false
17
+ max_epochs: 15
18
+
19
+ dataset:
20
+ global_batch_size: 2048
configs/exp/classification_country.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /model: classification
5
+ - override /model/network/backbone: openclip_B_32
6
+ - _self_
7
+
8
+ class_name: 'country'
9
+ model:
10
+ optimizer:
11
+ optim:
12
+ lr: 0.0002
13
+ weight_decay: 0.0001
14
+
15
+ is_baseline: false
16
+ max_epochs: 15
17
+
18
+ dataset:
19
+ global_batch_size: 2048
configs/exp/classification_region copy.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /model: classification
5
+ - override /model/network/backbone: openclip_B_32
6
+ - _self_
7
+
8
+ class_name: 'region'
9
+ model:
10
+ optimizer:
11
+ optim:
12
+ lr: 0.0002
13
+ weight_decay: 0.0001
14
+
15
+ is_baseline: false
16
+ max_epochs: 15
17
+
18
+ dataset:
19
+ global_batch_size: 2048
configs/exp/classification_region.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /model: classification
5
+ - override /model/network/backbone: openclip_B_32
6
+ - _self_
7
+
8
+ class_name: 'region'
9
+ model:
10
+ optimizer:
11
+ optim:
12
+ lr: 0.0002
13
+ weight_decay: 0.0001
14
+
15
+ is_baseline: false
16
+ max_epochs: 15
17
+
18
+ dataset:
19
+ global_batch_size: 2048
configs/exp/clip_L_14_DataComp.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /model: regression
5
+ - override /model/network/backbone: clip_L_14_DataComp
6
+ - _self_
7
+
8
+ model:
9
+ optimizer:
10
+ optim:
11
+ lr: 0.0002
12
+ weight_decay: 0.0001
13
+
14
+ is_baseline: false
15
+ max_epochs: 30
16
+
17
+ dataset:
18
+ global_batch_size: 2048
configs/exp/clip_L_14_Laion.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /model: regression
5
+ - override /model/network/backbone: openclip_L_14
6
+ - _self_
7
+
8
+ model:
9
+ optimizer:
10
+ optim:
11
+ lr: 0.0002
12
+ weight_decay: 0.0001
13
+
14
+ is_baseline: false
15
+ max_epochs: 30
16
+
17
+ dataset:
18
+ global_batch_size: 2048
configs/exp/clip_L_14_OpenAI.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /model: regression
5
+ - override /model/network/backbone: clip_L_14
6
+ - _self_
7
+
8
+ model:
9
+ optimizer:
10
+ optim:
11
+ lr: 0.0002
12
+ weight_decay: 0.0001
13
+
14
+ is_baseline: false
15
+ max_epochs: 30
16
+
17
+ dataset:
18
+ global_batch_size: 2048
configs/exp/clip_bigG_14_Laion.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /model: regression
5
+ - override /model/network/backbone: openclip_bigG_14
6
+ - _self_
7
+
8
+ model:
9
+ optimizer:
10
+ optim:
11
+ lr: 0.0002
12
+ weight_decay: 0.0001
13
+
14
+ is_baseline: false
15
+ max_epochs: 30
16
+
17
+ dataset:
18
+ global_batch_size: 2048
configs/exp/contrastive_area.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: osv5m_contrastive
5
+ - override /model: regression
6
+ - override /model/network: contrastive_unfrozen_backbone
7
+ - override /model/network/backbone: openclip_B_32
8
+ - override /model/loss: contrastive
9
+ - _self_
10
+
11
+ model:
12
+ optimizer:
13
+ optim:
14
+ lr: 2e-4
15
+ weight_decay: 0.0001
16
+ backbone_lr: 2e-5
17
+
18
+ class_name: area
19
+ is_baseline: false
20
+ max_epochs: 30
configs/exp/contrastive_cell.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: osv5m_contrastive
5
+ - override /model: regression
6
+ - override /model/network: contrastive_unfrozen_backbone
7
+ - override /model/network/backbone: openclip_B_32
8
+ - override /model/loss: contrastive
9
+ - _self_
10
+
11
+ model:
12
+ optimizer:
13
+ optim:
14
+ lr: 2e-4
15
+ weight_decay: 0.0001
16
+ backbone_lr: 2e-5
17
+
18
+ class_name: quadtree_10_1000
19
+ is_baseline: false
20
+ max_epochs: 30
configs/exp/contrastive_city.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: osv5m_contrastive
5
+ - override /model: regression
6
+ - override /model/network: contrastive_unfrozen_backbone
7
+ - override /model/network/backbone: openclip_B_32
8
+ - override /model/loss: contrastive
9
+ - _self_
10
+
11
+ model:
12
+ optimizer:
13
+ optim:
14
+ lr: 2e-4
15
+ weight_decay: 0.0001
16
+ backbone_lr: 2e-5
17
+
18
+ class_name: city
19
+ is_baseline: false
20
+ max_epochs: 30
configs/exp/contrastive_country.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: osv5m_contrastive
5
+ - override /model: regression
6
+ - override /model/network: contrastive_unfrozen_backbone
7
+ - override /model/network/backbone: openclip_B_32
8
+ - override /model/loss: contrastive
9
+ - _self_
10
+
11
+ model:
12
+ optimizer:
13
+ optim:
14
+ lr: 2e-4
15
+ weight_decay: 0.0001
16
+ backbone_lr: 2e-5
17
+
18
+ class_name: country
19
+ is_baseline: false
20
+ max_epochs: 30
configs/exp/contrastive_region.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: osv5m_contrastive
5
+ - override /model: regression
6
+ - override /model/network: contrastive_unfrozen_backbone
7
+ - override /model/network/backbone: openclip_B_32
8
+ - override /model/loss: contrastive
9
+ - _self_
10
+
11
+ model:
12
+ optimizer:
13
+ optim:
14
+ lr: 2e-4
15
+ weight_decay: 0.0001
16
+ backbone_lr: 2e-5
17
+
18
+ class_name: region
19
+ is_baseline: false
20
+ max_epochs: 30
configs/exp/contrastive_text.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: osv5m_text_contrastive
5
+ - override /model: text_tuning
6
+ - override /model/network/backbone: openclip_B_32
7
+ - _self_
8
+
9
+ model:
10
+ network:
11
+ backbone:
12
+ instance:
13
+ _target_: models.networks.backbones.CLIPText
14
+ optimizer:
15
+ optim:
16
+ lr: 0.0002
17
+ weight_decay: 0.0001
18
+
19
+ is_baseline: false
20
+ class_name: city
21
+ text_tuning: True
22
+ max_epochs: 30
configs/exp/eval_best_model.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ defaults:
4
+ - override /dataset: osv5m_contrastive_best
5
+ - override /model: hybrid
6
+ - override /model/network: best_backbone
7
+ - override /model/network/backbone: clip_L_14_DataComp
8
+ - override /model/network/mid: mlp_hybrid
9
+ - _self_
10
+
11
+ class_name: 'quadtree_10_1000'
12
+ is_baseline: false
13
+ max_epochs: 30
14
+ mode: 'eval'
15
+
16
+ model:
17
+ name: best_model
18
+ optimizer:
19
+ optim:
20
+ lr: 2e-4
21
+ weight_decay: 0.0001
22
+ backbone_lr: 2e-5
23
+ network:
24
+ head:
25
+ instance:
26
+ quadtree_path: ${root_dir}/quadtree_10_1000.csv
27
+
28
+ dataset:
29
+ global_batch_size: 2048