Dimitre commited on
Commit
90de23d
·
1 Parent(s): 4ff87ca

Extraction files

Browse files
Files changed (3) hide show
  1. app.py +232 -0
  2. common.py +55 -0
  3. hint.py +149 -0
app.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from typing import Any
4
+
5
+ import pandas as pd
6
+ import streamlit as st
7
+ from countryinfo import CountryInfo
8
+ from dotenv import load_dotenv
9
+
10
+ from common import HintType, configs, get_distance
11
+ from hint import AudioHint, ImageHint, TextHint
12
+
13
+
14
+ def setup_models(_cache: Any, configs: dict) -> None:
15
+ """Setups all hint models.
16
+
17
+ Args:
18
+ _cache (st.session_state): Streamlit cache object
19
+ configs (dict): Configurations used by the models
20
+ """
21
+ for model_type in _cache["hint_types"]:
22
+ if _cache["model"][model_type] is None:
23
+ if model_type == HintType.TEXT.value:
24
+ _cache["model"][model_type] = setup_text_hint(configs)
25
+ elif model_type == HintType.IMAGE.value:
26
+ _cache["model"][model_type] = setup_image_hint(configs)
27
+ elif model_type == HintType.AUDIO.value:
28
+ _cache["model"][model_type] = setup_audio_hint(configs)
29
+
30
+
31
+ @st.cache_resource()
32
+ def setup_text_hint(configs: dict) -> TextHint:
33
+ """Setups the text hint model.
34
+
35
+ Args:
36
+ configs (dict): Configurations used by the model
37
+
38
+ Returns:
39
+ TextHint: Hint model
40
+ """
41
+ with st.spinner("Loading text model..."):
42
+ model_configs = configs["local"][HintType.TEXT.value.lower()]
43
+ model_configs["hf_access_token"] = os.environ["HF_ACCESS_TOKEN"]
44
+ textHint = TextHint(configs=model_configs)
45
+ textHint.initialize()
46
+ return textHint
47
+
48
+
49
+ @st.cache_resource()
50
+ def setup_image_hint(configs: dict) -> ImageHint:
51
+ """Setups the image hint model.
52
+
53
+ Args:
54
+ configs (dict): Configurations used by the model
55
+
56
+ Returns:
57
+ ImageHint: Hint model
58
+ """
59
+ with st.spinner("Loading image model..."):
60
+ model_configs = configs["local"][HintType.IMAGE.value.lower()]
61
+ imageHint = ImageHint(configs=model_configs)
62
+ imageHint.initialize()
63
+ return imageHint
64
+
65
+
66
+ @st.cache_resource()
67
+ def setup_audio_hint(configs: dict) -> AudioHint:
68
+ """Setups the audio hint model.
69
+
70
+ Args:
71
+ configs (dict): Configurations used by the model
72
+
73
+ Returns:
74
+ AudioHint: Hint model
75
+ """
76
+ with st.spinner("Loading audio model..."):
77
+ model_configs = configs["local"][HintType.AUDIO.value.lower()]
78
+ audioHint = AudioHint(configs=model_configs)
79
+ audioHint.initialize()
80
+ return audioHint
81
+
82
+
83
+ @st.cache_resource()
84
+ def get_country_list() -> pd.DataFrame:
85
+ """Builds a database of countries and metadata.
86
+
87
+ Returns:
88
+ pd.DataFrame: Country database
89
+ """
90
+ country_list = list(CountryInfo().all().keys())
91
+
92
+ country_df = {}
93
+ for country in country_list:
94
+ try:
95
+ area = CountryInfo(country).area()
96
+ country_df[country] = area
97
+ except:
98
+ pass
99
+
100
+ country_df = pd.DataFrame(country_df.items(), columns=["country", "area"])
101
+ return country_df
102
+
103
+
104
+ def pick_country(country_df: pd.DataFrame) -> str:
105
+ """Selects a country, the probability of each country is related to its area size.
106
+
107
+ Args:
108
+ country_df (pd.DataFrame): Database of country and their metadata
109
+
110
+ Returns:
111
+ str: The selected country
112
+ """
113
+ country = country_df.sample(n=1, weights="area")["country"].iloc[0]
114
+ return country
115
+
116
+
117
+ def reset_cache() -> None:
118
+ """Reset the Streamlit APP cache."""
119
+ country_df = get_country_list()
120
+ st.session_state["country_list"] = country_df["country"].values.tolist()
121
+ st.session_state["country"] = pick_country(country_df)
122
+ st.session_state["hint_types"] = []
123
+ st.session_state["n_hints"] = 1
124
+ st.session_state["game_started"] = False
125
+ st.session_state["model"] = {
126
+ HintType.TEXT.value: None,
127
+ HintType.IMAGE.value: None,
128
+ HintType.AUDIO.value: None,
129
+ }
130
+
131
+
132
+ logging.basicConfig(level=logging.INFO)
133
+ logger = logging.getLogger(__name__)
134
+
135
+ st.set_page_config(
136
+ page_title="Gen AI GeoGuesser",
137
+ page_icon="🌎",
138
+ )
139
+
140
+ if not st.session_state:
141
+ load_dotenv()
142
+ reset_cache()
143
+
144
+ st.title("Generative AI GeoGuesser 🌎")
145
+
146
+ st.markdown("### Guess the country based on hints generated by AI")
147
+
148
+ col1, col2 = st.columns([2, 1])
149
+
150
+ with col1:
151
+ st.session_state["hint_types"] = st.multiselect(
152
+ "Chose which hint types you want",
153
+ [x.value for x in HintType],
154
+ default=st.session_state["hint_types"],
155
+ )
156
+
157
+ with col2:
158
+ st.session_state["n_hints"] = st.slider(
159
+ "Number of hints",
160
+ min_value=1,
161
+ max_value=5,
162
+ value=st.session_state["n_hints"],
163
+ )
164
+
165
+ start_btn = st.button("Start game")
166
+
167
+ if start_btn:
168
+ if not st.session_state["hint_types"]:
169
+ st.error("Pick at least one hint type")
170
+ reset_cache()
171
+ else:
172
+ print(f'Chosen country "{st.session_state["country"]}"')
173
+
174
+ setup_models(st.session_state, configs)
175
+
176
+ for hint_type in st.session_state["hint_types"]:
177
+ with st.spinner(f"Generating {hint_type} hint..."):
178
+ st.session_state["model"][hint_type].generate_hint(
179
+ st.session_state["country"],
180
+ st.session_state["n_hints"],
181
+ )
182
+
183
+ st.session_state["game_started"] = True
184
+
185
+ if st.session_state["game_started"]:
186
+ game_col1, game_col2, game_col3 = st.columns([2, 1, 1])
187
+
188
+ with game_col1:
189
+ guess = st.selectbox("Country guess", ([""] + st.session_state["country_list"]))
190
+ with game_col2:
191
+ guess_btn = st.button("Make a guess")
192
+ with game_col3:
193
+ reset_btn = st.button("Reset game")
194
+
195
+ if guess_btn:
196
+ if st.session_state["country"] == guess:
197
+ st.success("Correct guess you won!")
198
+ st.balloons()
199
+ else:
200
+ if guess:
201
+ country_latlong = CountryInfo(st.session_state["country"]).latlng()
202
+ guess_latlong = CountryInfo(guess).latlng()
203
+ distance = int(get_distance(country_latlong, guess_latlong))
204
+ st.error(
205
+ f"""
206
+ Wrong guess, you missed the correct country by {distance} KM.
207
+ The correct answer was {st.session_state["country"]}.
208
+ """
209
+ )
210
+ else:
211
+ st.error("Pick a country.")
212
+
213
+ if reset_btn:
214
+ reset_cache()
215
+
216
+ if st.session_state["game_started"]:
217
+ tabs = st.tabs([f"{x} hint" for x in st.session_state["hint_types"]])
218
+
219
+ for tab_idx, tab in enumerate(tabs):
220
+ hint_type = st.session_state["hint_types"][tab_idx]
221
+ with tab:
222
+ if st.session_state["model"][hint_type]:
223
+ for hint_idx, hint in enumerate(
224
+ st.session_state["model"][hint_type].hints
225
+ ):
226
+ st.markdown(f"#### Hint #{hint_idx+1}")
227
+ if hint_type == HintType.TEXT.value:
228
+ st.write(hint["text"])
229
+ elif hint_type == HintType.IMAGE.value:
230
+ st.image(hint["image"])
231
+ elif hint_type == HintType.AUDIO.value:
232
+ st.audio(hint["audio"], sample_rate=hint["sample_rate"])
common.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import pprint
3
+ from enum import Enum
4
+ from math import acos, cos, radians, sin
5
+
6
+ import yaml
7
+
8
+
9
+ def parse_configs(configs_path: str) -> dict:
10
+ """Parse configs from the YAML file.
11
+
12
+ Args:
13
+ configs_path (str): Path to the YAML file
14
+
15
+ Returns:
16
+ dict: Parsed configs
17
+ """
18
+ configs = yaml.safe_load(open(configs_path, "r"))
19
+ logger.info(f"Configs: {pprint.pformat(configs)}")
20
+ return configs
21
+
22
+
23
+ def get_distance(source_country: list[float], target_country: list[float]) -> float:
24
+ """Calculate the distance between two countries.
25
+
26
+ Args:
27
+ source_country (list[float]): Source country coordinates
28
+ target_country (list[float]): Target country coordinates
29
+
30
+ Returns:
31
+ float: Distance in KM
32
+ """
33
+ source_lat = radians(source_country[0])
34
+ source_long = radians(source_country[1])
35
+ target_lat = radians(target_country[0])
36
+ target_long = radians(target_country[1])
37
+ dist = 6371.01 * acos(
38
+ sin(source_lat) * sin(target_lat)
39
+ + cos(source_lat) * cos(target_lat) * cos(source_long - target_long)
40
+ )
41
+ return dist
42
+
43
+
44
+ class HintType(Enum):
45
+ AUDIO = "Audio"
46
+ TEXT = "Text"
47
+ IMAGE = "Image"
48
+
49
+
50
+ CONFIGS_PATH = "configs.yaml"
51
+
52
+ logging.basicConfig(level=logging.INFO)
53
+ logger = logging.getLogger(__file__)
54
+
55
+ configs = parse_configs(CONFIGS_PATH)
hint.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import logging
3
+ import re
4
+ from typing import Any
5
+
6
+ import torch
7
+ from diffusers import AudioLDM2Pipeline, AutoPipelineForText2Image
8
+ from pydantic import BaseModel
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
10
+
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ SAMPLE_RATE = 16000
16
+
17
+
18
+ class BaseHint(BaseModel, abc.ABC):
19
+ configs: dict
20
+ hints: list = []
21
+ model: Any = None
22
+
23
+ @abc.abstractmethod
24
+ def initialize(self):
25
+ """Initialize the hint model."""
26
+ pass
27
+
28
+ @abc.abstractmethod
29
+ def generate_hint(self, country: str, n_hints: int):
30
+ """Generate hints.
31
+
32
+ Args:
33
+ country (str): Country name used to base the hint
34
+ n_hints (int): Number of hints that will be generated
35
+ """
36
+ pass
37
+
38
+
39
+ class TextHint(BaseHint):
40
+ tokenizer: Any = None
41
+
42
+ def initialize(self):
43
+ logger.info(
44
+ f"""Initializing text hint with model '{self.configs["model_id"]}'"""
45
+ )
46
+ self.tokenizer = AutoTokenizer.from_pretrained(
47
+ self.configs["model_id"],
48
+ token=self.configs["hf_access_token"],
49
+ )
50
+ self.model = AutoModelForCausalLM.from_pretrained(
51
+ self.configs["model_id"],
52
+ torch_dtype=torch.float16,
53
+ token=self.configs["hf_access_token"],
54
+ ).to(self.configs["device"])
55
+ logger.info("Initialization finisehd")
56
+
57
+ def generate_hint(self, country: str, n_hints: int):
58
+ logger.info(f"Generating '{n_hints}' text hints")
59
+
60
+ generation_config = GenerationConfig(
61
+ do_sample=True,
62
+ max_new_tokens=self.configs["max_output_tokens"],
63
+ top_k=self.configs["top_k"],
64
+ top_p=self.configs["top_p"],
65
+ temperature=self.configs["temperature"],
66
+ )
67
+
68
+ prompt = [
69
+ f'Describe the country "{country}" without mentioning its name\n'
70
+ for _ in range(n_hints)
71
+ ]
72
+ input_ids = self.tokenizer(prompt, return_tensors="pt")
73
+ text_hints = self.model.generate(
74
+ **input_ids.to(self.configs["device"]),
75
+ generation_config=generation_config,
76
+ )
77
+
78
+ for idx, text_hint in enumerate(text_hints):
79
+ text_hint = (
80
+ self.tokenizer.decode(text_hint, skip_special_tokens=True)
81
+ .strip()
82
+ .replace(prompt[idx], "")
83
+ .strip()
84
+ )
85
+ text_hint = re.sub(
86
+ re.escape(country), "***", text_hint, flags=re.IGNORECASE
87
+ )
88
+
89
+ self.hints.append({"text": text_hint})
90
+
91
+ logger.info(f"Text hints '{n_hints}' successfully generated")
92
+
93
+
94
+ class ImageHint(BaseHint):
95
+ def initialize(self):
96
+ logger.info(
97
+ f"""Initializing image hint with model '{self.configs["model_id"]}'"""
98
+ )
99
+ self.model = AutoPipelineForText2Image.from_pretrained(
100
+ self.configs["model_id"],
101
+ # torch_dtype=torch.float16,
102
+ variant="fp16",
103
+ ).to(self.configs["device"])
104
+ logger.info("Initialization finisehd")
105
+
106
+ def generate_hint(self, country: str, n_hints: int):
107
+ logger.info(f"Generating '{n_hints}' image hints")
108
+ prompt = [f"An image related to the country {country}" for _ in range(n_hints)]
109
+ img_hints = self.model(
110
+ prompt=prompt,
111
+ num_inference_steps=self.configs["num_inference_steps"],
112
+ guidance_scale=self.configs["guidance_scale"],
113
+ ).images
114
+ self.hints = [{"image": img_hint} for img_hint in img_hints]
115
+ logger.info(f"Image hints '{n_hints}' successfully generated")
116
+
117
+
118
+ class AudioHint(BaseHint):
119
+ def initialize(self):
120
+ logger.info(
121
+ f"""Initializing audio hint with model '{self.configs["model_id"]}'"""
122
+ )
123
+ self.model = AudioLDM2Pipeline.from_pretrained(
124
+ self.configs["model_id"],
125
+ # torch_dtype=torch.float16, # Not working with MacOS
126
+ ).to(self.configs["device"])
127
+ logger.info("Initialization finisehd")
128
+
129
+ def generate_hint(self, country: str, n_hints: int):
130
+ logger.info(f"Generating '{n_hints}' audio hints")
131
+ prompt = f"A sound that resembles the country of {country}"
132
+ negative_prompt = "Low quality"
133
+
134
+ audio_hints = self.model(
135
+ prompt,
136
+ negative_prompt=negative_prompt,
137
+ num_inference_steps=self.configs["num_inference_steps"],
138
+ audio_length_in_s=self.configs["audio_length_in_s"],
139
+ num_waveforms_per_prompt=n_hints,
140
+ ).audios
141
+
142
+ for audio_hint in audio_hints:
143
+ self.hints.append(
144
+ {
145
+ "audio": audio_hint,
146
+ "sample_rate": SAMPLE_RATE,
147
+ }
148
+ )
149
+ logger.info(f"Audio hints '{n_hints}' successfully generated")