Spaces:
Sleeping
Sleeping
Extraction files
Browse files
app.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from typing import Any
|
4 |
+
|
5 |
+
import pandas as pd
|
6 |
+
import streamlit as st
|
7 |
+
from countryinfo import CountryInfo
|
8 |
+
from dotenv import load_dotenv
|
9 |
+
|
10 |
+
from common import HintType, configs, get_distance
|
11 |
+
from hint import AudioHint, ImageHint, TextHint
|
12 |
+
|
13 |
+
|
14 |
+
def setup_models(_cache: Any, configs: dict) -> None:
|
15 |
+
"""Setups all hint models.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
_cache (st.session_state): Streamlit cache object
|
19 |
+
configs (dict): Configurations used by the models
|
20 |
+
"""
|
21 |
+
for model_type in _cache["hint_types"]:
|
22 |
+
if _cache["model"][model_type] is None:
|
23 |
+
if model_type == HintType.TEXT.value:
|
24 |
+
_cache["model"][model_type] = setup_text_hint(configs)
|
25 |
+
elif model_type == HintType.IMAGE.value:
|
26 |
+
_cache["model"][model_type] = setup_image_hint(configs)
|
27 |
+
elif model_type == HintType.AUDIO.value:
|
28 |
+
_cache["model"][model_type] = setup_audio_hint(configs)
|
29 |
+
|
30 |
+
|
31 |
+
@st.cache_resource()
|
32 |
+
def setup_text_hint(configs: dict) -> TextHint:
|
33 |
+
"""Setups the text hint model.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
configs (dict): Configurations used by the model
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
TextHint: Hint model
|
40 |
+
"""
|
41 |
+
with st.spinner("Loading text model..."):
|
42 |
+
model_configs = configs["local"][HintType.TEXT.value.lower()]
|
43 |
+
model_configs["hf_access_token"] = os.environ["HF_ACCESS_TOKEN"]
|
44 |
+
textHint = TextHint(configs=model_configs)
|
45 |
+
textHint.initialize()
|
46 |
+
return textHint
|
47 |
+
|
48 |
+
|
49 |
+
@st.cache_resource()
|
50 |
+
def setup_image_hint(configs: dict) -> ImageHint:
|
51 |
+
"""Setups the image hint model.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
configs (dict): Configurations used by the model
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
ImageHint: Hint model
|
58 |
+
"""
|
59 |
+
with st.spinner("Loading image model..."):
|
60 |
+
model_configs = configs["local"][HintType.IMAGE.value.lower()]
|
61 |
+
imageHint = ImageHint(configs=model_configs)
|
62 |
+
imageHint.initialize()
|
63 |
+
return imageHint
|
64 |
+
|
65 |
+
|
66 |
+
@st.cache_resource()
|
67 |
+
def setup_audio_hint(configs: dict) -> AudioHint:
|
68 |
+
"""Setups the audio hint model.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
configs (dict): Configurations used by the model
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
AudioHint: Hint model
|
75 |
+
"""
|
76 |
+
with st.spinner("Loading audio model..."):
|
77 |
+
model_configs = configs["local"][HintType.AUDIO.value.lower()]
|
78 |
+
audioHint = AudioHint(configs=model_configs)
|
79 |
+
audioHint.initialize()
|
80 |
+
return audioHint
|
81 |
+
|
82 |
+
|
83 |
+
@st.cache_resource()
|
84 |
+
def get_country_list() -> pd.DataFrame:
|
85 |
+
"""Builds a database of countries and metadata.
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
pd.DataFrame: Country database
|
89 |
+
"""
|
90 |
+
country_list = list(CountryInfo().all().keys())
|
91 |
+
|
92 |
+
country_df = {}
|
93 |
+
for country in country_list:
|
94 |
+
try:
|
95 |
+
area = CountryInfo(country).area()
|
96 |
+
country_df[country] = area
|
97 |
+
except:
|
98 |
+
pass
|
99 |
+
|
100 |
+
country_df = pd.DataFrame(country_df.items(), columns=["country", "area"])
|
101 |
+
return country_df
|
102 |
+
|
103 |
+
|
104 |
+
def pick_country(country_df: pd.DataFrame) -> str:
|
105 |
+
"""Selects a country, the probability of each country is related to its area size.
|
106 |
+
|
107 |
+
Args:
|
108 |
+
country_df (pd.DataFrame): Database of country and their metadata
|
109 |
+
|
110 |
+
Returns:
|
111 |
+
str: The selected country
|
112 |
+
"""
|
113 |
+
country = country_df.sample(n=1, weights="area")["country"].iloc[0]
|
114 |
+
return country
|
115 |
+
|
116 |
+
|
117 |
+
def reset_cache() -> None:
|
118 |
+
"""Reset the Streamlit APP cache."""
|
119 |
+
country_df = get_country_list()
|
120 |
+
st.session_state["country_list"] = country_df["country"].values.tolist()
|
121 |
+
st.session_state["country"] = pick_country(country_df)
|
122 |
+
st.session_state["hint_types"] = []
|
123 |
+
st.session_state["n_hints"] = 1
|
124 |
+
st.session_state["game_started"] = False
|
125 |
+
st.session_state["model"] = {
|
126 |
+
HintType.TEXT.value: None,
|
127 |
+
HintType.IMAGE.value: None,
|
128 |
+
HintType.AUDIO.value: None,
|
129 |
+
}
|
130 |
+
|
131 |
+
|
132 |
+
logging.basicConfig(level=logging.INFO)
|
133 |
+
logger = logging.getLogger(__name__)
|
134 |
+
|
135 |
+
st.set_page_config(
|
136 |
+
page_title="Gen AI GeoGuesser",
|
137 |
+
page_icon="🌎",
|
138 |
+
)
|
139 |
+
|
140 |
+
if not st.session_state:
|
141 |
+
load_dotenv()
|
142 |
+
reset_cache()
|
143 |
+
|
144 |
+
st.title("Generative AI GeoGuesser 🌎")
|
145 |
+
|
146 |
+
st.markdown("### Guess the country based on hints generated by AI")
|
147 |
+
|
148 |
+
col1, col2 = st.columns([2, 1])
|
149 |
+
|
150 |
+
with col1:
|
151 |
+
st.session_state["hint_types"] = st.multiselect(
|
152 |
+
"Chose which hint types you want",
|
153 |
+
[x.value for x in HintType],
|
154 |
+
default=st.session_state["hint_types"],
|
155 |
+
)
|
156 |
+
|
157 |
+
with col2:
|
158 |
+
st.session_state["n_hints"] = st.slider(
|
159 |
+
"Number of hints",
|
160 |
+
min_value=1,
|
161 |
+
max_value=5,
|
162 |
+
value=st.session_state["n_hints"],
|
163 |
+
)
|
164 |
+
|
165 |
+
start_btn = st.button("Start game")
|
166 |
+
|
167 |
+
if start_btn:
|
168 |
+
if not st.session_state["hint_types"]:
|
169 |
+
st.error("Pick at least one hint type")
|
170 |
+
reset_cache()
|
171 |
+
else:
|
172 |
+
print(f'Chosen country "{st.session_state["country"]}"')
|
173 |
+
|
174 |
+
setup_models(st.session_state, configs)
|
175 |
+
|
176 |
+
for hint_type in st.session_state["hint_types"]:
|
177 |
+
with st.spinner(f"Generating {hint_type} hint..."):
|
178 |
+
st.session_state["model"][hint_type].generate_hint(
|
179 |
+
st.session_state["country"],
|
180 |
+
st.session_state["n_hints"],
|
181 |
+
)
|
182 |
+
|
183 |
+
st.session_state["game_started"] = True
|
184 |
+
|
185 |
+
if st.session_state["game_started"]:
|
186 |
+
game_col1, game_col2, game_col3 = st.columns([2, 1, 1])
|
187 |
+
|
188 |
+
with game_col1:
|
189 |
+
guess = st.selectbox("Country guess", ([""] + st.session_state["country_list"]))
|
190 |
+
with game_col2:
|
191 |
+
guess_btn = st.button("Make a guess")
|
192 |
+
with game_col3:
|
193 |
+
reset_btn = st.button("Reset game")
|
194 |
+
|
195 |
+
if guess_btn:
|
196 |
+
if st.session_state["country"] == guess:
|
197 |
+
st.success("Correct guess you won!")
|
198 |
+
st.balloons()
|
199 |
+
else:
|
200 |
+
if guess:
|
201 |
+
country_latlong = CountryInfo(st.session_state["country"]).latlng()
|
202 |
+
guess_latlong = CountryInfo(guess).latlng()
|
203 |
+
distance = int(get_distance(country_latlong, guess_latlong))
|
204 |
+
st.error(
|
205 |
+
f"""
|
206 |
+
Wrong guess, you missed the correct country by {distance} KM.
|
207 |
+
The correct answer was {st.session_state["country"]}.
|
208 |
+
"""
|
209 |
+
)
|
210 |
+
else:
|
211 |
+
st.error("Pick a country.")
|
212 |
+
|
213 |
+
if reset_btn:
|
214 |
+
reset_cache()
|
215 |
+
|
216 |
+
if st.session_state["game_started"]:
|
217 |
+
tabs = st.tabs([f"{x} hint" for x in st.session_state["hint_types"]])
|
218 |
+
|
219 |
+
for tab_idx, tab in enumerate(tabs):
|
220 |
+
hint_type = st.session_state["hint_types"][tab_idx]
|
221 |
+
with tab:
|
222 |
+
if st.session_state["model"][hint_type]:
|
223 |
+
for hint_idx, hint in enumerate(
|
224 |
+
st.session_state["model"][hint_type].hints
|
225 |
+
):
|
226 |
+
st.markdown(f"#### Hint #{hint_idx+1}")
|
227 |
+
if hint_type == HintType.TEXT.value:
|
228 |
+
st.write(hint["text"])
|
229 |
+
elif hint_type == HintType.IMAGE.value:
|
230 |
+
st.image(hint["image"])
|
231 |
+
elif hint_type == HintType.AUDIO.value:
|
232 |
+
st.audio(hint["audio"], sample_rate=hint["sample_rate"])
|
common.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import pprint
|
3 |
+
from enum import Enum
|
4 |
+
from math import acos, cos, radians, sin
|
5 |
+
|
6 |
+
import yaml
|
7 |
+
|
8 |
+
|
9 |
+
def parse_configs(configs_path: str) -> dict:
|
10 |
+
"""Parse configs from the YAML file.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
configs_path (str): Path to the YAML file
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
dict: Parsed configs
|
17 |
+
"""
|
18 |
+
configs = yaml.safe_load(open(configs_path, "r"))
|
19 |
+
logger.info(f"Configs: {pprint.pformat(configs)}")
|
20 |
+
return configs
|
21 |
+
|
22 |
+
|
23 |
+
def get_distance(source_country: list[float], target_country: list[float]) -> float:
|
24 |
+
"""Calculate the distance between two countries.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
source_country (list[float]): Source country coordinates
|
28 |
+
target_country (list[float]): Target country coordinates
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
float: Distance in KM
|
32 |
+
"""
|
33 |
+
source_lat = radians(source_country[0])
|
34 |
+
source_long = radians(source_country[1])
|
35 |
+
target_lat = radians(target_country[0])
|
36 |
+
target_long = radians(target_country[1])
|
37 |
+
dist = 6371.01 * acos(
|
38 |
+
sin(source_lat) * sin(target_lat)
|
39 |
+
+ cos(source_lat) * cos(target_lat) * cos(source_long - target_long)
|
40 |
+
)
|
41 |
+
return dist
|
42 |
+
|
43 |
+
|
44 |
+
class HintType(Enum):
|
45 |
+
AUDIO = "Audio"
|
46 |
+
TEXT = "Text"
|
47 |
+
IMAGE = "Image"
|
48 |
+
|
49 |
+
|
50 |
+
CONFIGS_PATH = "configs.yaml"
|
51 |
+
|
52 |
+
logging.basicConfig(level=logging.INFO)
|
53 |
+
logger = logging.getLogger(__file__)
|
54 |
+
|
55 |
+
configs = parse_configs(CONFIGS_PATH)
|
hint.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
import logging
|
3 |
+
import re
|
4 |
+
from typing import Any
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from diffusers import AudioLDM2Pipeline, AutoPipelineForText2Image
|
8 |
+
from pydantic import BaseModel
|
9 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
10 |
+
|
11 |
+
logging.basicConfig(level=logging.INFO)
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
|
15 |
+
SAMPLE_RATE = 16000
|
16 |
+
|
17 |
+
|
18 |
+
class BaseHint(BaseModel, abc.ABC):
|
19 |
+
configs: dict
|
20 |
+
hints: list = []
|
21 |
+
model: Any = None
|
22 |
+
|
23 |
+
@abc.abstractmethod
|
24 |
+
def initialize(self):
|
25 |
+
"""Initialize the hint model."""
|
26 |
+
pass
|
27 |
+
|
28 |
+
@abc.abstractmethod
|
29 |
+
def generate_hint(self, country: str, n_hints: int):
|
30 |
+
"""Generate hints.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
country (str): Country name used to base the hint
|
34 |
+
n_hints (int): Number of hints that will be generated
|
35 |
+
"""
|
36 |
+
pass
|
37 |
+
|
38 |
+
|
39 |
+
class TextHint(BaseHint):
|
40 |
+
tokenizer: Any = None
|
41 |
+
|
42 |
+
def initialize(self):
|
43 |
+
logger.info(
|
44 |
+
f"""Initializing text hint with model '{self.configs["model_id"]}'"""
|
45 |
+
)
|
46 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
47 |
+
self.configs["model_id"],
|
48 |
+
token=self.configs["hf_access_token"],
|
49 |
+
)
|
50 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
51 |
+
self.configs["model_id"],
|
52 |
+
torch_dtype=torch.float16,
|
53 |
+
token=self.configs["hf_access_token"],
|
54 |
+
).to(self.configs["device"])
|
55 |
+
logger.info("Initialization finisehd")
|
56 |
+
|
57 |
+
def generate_hint(self, country: str, n_hints: int):
|
58 |
+
logger.info(f"Generating '{n_hints}' text hints")
|
59 |
+
|
60 |
+
generation_config = GenerationConfig(
|
61 |
+
do_sample=True,
|
62 |
+
max_new_tokens=self.configs["max_output_tokens"],
|
63 |
+
top_k=self.configs["top_k"],
|
64 |
+
top_p=self.configs["top_p"],
|
65 |
+
temperature=self.configs["temperature"],
|
66 |
+
)
|
67 |
+
|
68 |
+
prompt = [
|
69 |
+
f'Describe the country "{country}" without mentioning its name\n'
|
70 |
+
for _ in range(n_hints)
|
71 |
+
]
|
72 |
+
input_ids = self.tokenizer(prompt, return_tensors="pt")
|
73 |
+
text_hints = self.model.generate(
|
74 |
+
**input_ids.to(self.configs["device"]),
|
75 |
+
generation_config=generation_config,
|
76 |
+
)
|
77 |
+
|
78 |
+
for idx, text_hint in enumerate(text_hints):
|
79 |
+
text_hint = (
|
80 |
+
self.tokenizer.decode(text_hint, skip_special_tokens=True)
|
81 |
+
.strip()
|
82 |
+
.replace(prompt[idx], "")
|
83 |
+
.strip()
|
84 |
+
)
|
85 |
+
text_hint = re.sub(
|
86 |
+
re.escape(country), "***", text_hint, flags=re.IGNORECASE
|
87 |
+
)
|
88 |
+
|
89 |
+
self.hints.append({"text": text_hint})
|
90 |
+
|
91 |
+
logger.info(f"Text hints '{n_hints}' successfully generated")
|
92 |
+
|
93 |
+
|
94 |
+
class ImageHint(BaseHint):
|
95 |
+
def initialize(self):
|
96 |
+
logger.info(
|
97 |
+
f"""Initializing image hint with model '{self.configs["model_id"]}'"""
|
98 |
+
)
|
99 |
+
self.model = AutoPipelineForText2Image.from_pretrained(
|
100 |
+
self.configs["model_id"],
|
101 |
+
# torch_dtype=torch.float16,
|
102 |
+
variant="fp16",
|
103 |
+
).to(self.configs["device"])
|
104 |
+
logger.info("Initialization finisehd")
|
105 |
+
|
106 |
+
def generate_hint(self, country: str, n_hints: int):
|
107 |
+
logger.info(f"Generating '{n_hints}' image hints")
|
108 |
+
prompt = [f"An image related to the country {country}" for _ in range(n_hints)]
|
109 |
+
img_hints = self.model(
|
110 |
+
prompt=prompt,
|
111 |
+
num_inference_steps=self.configs["num_inference_steps"],
|
112 |
+
guidance_scale=self.configs["guidance_scale"],
|
113 |
+
).images
|
114 |
+
self.hints = [{"image": img_hint} for img_hint in img_hints]
|
115 |
+
logger.info(f"Image hints '{n_hints}' successfully generated")
|
116 |
+
|
117 |
+
|
118 |
+
class AudioHint(BaseHint):
|
119 |
+
def initialize(self):
|
120 |
+
logger.info(
|
121 |
+
f"""Initializing audio hint with model '{self.configs["model_id"]}'"""
|
122 |
+
)
|
123 |
+
self.model = AudioLDM2Pipeline.from_pretrained(
|
124 |
+
self.configs["model_id"],
|
125 |
+
# torch_dtype=torch.float16, # Not working with MacOS
|
126 |
+
).to(self.configs["device"])
|
127 |
+
logger.info("Initialization finisehd")
|
128 |
+
|
129 |
+
def generate_hint(self, country: str, n_hints: int):
|
130 |
+
logger.info(f"Generating '{n_hints}' audio hints")
|
131 |
+
prompt = f"A sound that resembles the country of {country}"
|
132 |
+
negative_prompt = "Low quality"
|
133 |
+
|
134 |
+
audio_hints = self.model(
|
135 |
+
prompt,
|
136 |
+
negative_prompt=negative_prompt,
|
137 |
+
num_inference_steps=self.configs["num_inference_steps"],
|
138 |
+
audio_length_in_s=self.configs["audio_length_in_s"],
|
139 |
+
num_waveforms_per_prompt=n_hints,
|
140 |
+
).audios
|
141 |
+
|
142 |
+
for audio_hint in audio_hints:
|
143 |
+
self.hints.append(
|
144 |
+
{
|
145 |
+
"audio": audio_hint,
|
146 |
+
"sample_rate": SAMPLE_RATE,
|
147 |
+
}
|
148 |
+
)
|
149 |
+
logger.info(f"Audio hints '{n_hints}' successfully generated")
|