tdurbor commited on
Commit
bb26f6a
·
1 Parent(s): dc5f6ab

Fixed voting + changed db for testing

Browse files
Files changed (3) hide show
  1. app.py +16 -26
  2. data/newvotes.db +0 -0
  3. db.py +1 -1
app.py CHANGED
@@ -6,7 +6,7 @@ import gradio as gr
6
  import numpy as np
7
  from PIL import Image
8
  import random
9
- from db import compute_elo_scores, get_all_votes
10
  import json
11
  from pathlib import Path
12
  from uuid import uuid4
@@ -100,8 +100,7 @@ def select_new_image():
100
  model_b_output_image = segmented_images[model_b_index]
101
  model_a_name = segmented_sources[model_a_index]
102
  model_b_name = segmented_sources[model_b_index]
103
- return (sample['original_image'], input_image, model_a_output_image, model_a_output_image,
104
- model_b_output_image, model_b_output_image, model_a_name, model_b_name)
105
  except Exception as e:
106
  logging.error("Error processing images: %s. Resampling another image.", str(e))
107
  last_image_index = random_index
@@ -146,13 +145,10 @@ def gradio_interface():
146
  with gr.Tab("⚔️ Arena (battle)", id=0):
147
  notice_markdown = gr.Markdown(get_notice_markdown(), elem_id="notice_markdown")
148
 
149
- (fpath_input, input_image, fpath_a, segmented_a, fpath_b, segmented_b,
150
- a_name, b_name) = select_new_image()
151
  model_a_name = gr.State(a_name)
152
  model_b_name = gr.State(b_name)
153
- fpath_input = gr.State(fpath_input)
154
- fpath_a = gr.State(fpath_a)
155
- fpath_b = gr.State(fpath_b)
156
 
157
  # Compute the absolute difference between the masks
158
  mask_difference = compute_mask_difference(segmented_a, segmented_b)
@@ -186,60 +182,54 @@ def gradio_interface():
186
 
187
 
188
  vote_a_btn.click(
189
- fn=lambda: vote_for_model("model_a", (fpath_input, fpath_a, fpath_b), model_a_name, model_b_name),
190
  outputs=[
191
- fpath_input, input_image_display, fpath_a, image_a_display, fpath_b, image_b_display, model_a_name, model_b_name, notice_markdown
192
  ]
193
  )
194
  vote_b_btn.click(
195
- fn=lambda: vote_for_model("model_b", (fpath_input, fpath_a, fpath_b), model_a_name, model_b_name),
196
  outputs=[
197
- fpath_input, input_image_display, fpath_a, image_a_display, fpath_b, image_b_display, model_a_name, model_b_name, notice_markdown
198
  ]
199
  )
200
  vote_tie.click(
201
- fn=lambda: vote_for_model("tie", (fpath_input, fpath_a, fpath_b), model_a_name, model_b_name),
202
  outputs=[
203
- fpath_input, input_image_display, fpath_a, image_a_display, fpath_b, image_b_display, model_a_name, model_b_name, notice_markdown
204
  ]
205
  )
206
 
207
- def vote_for_model(choice, fpaths, model_a_name, model_b_name):
208
  """Submit a vote for a model and return updated images and model names."""
209
  logging.info("Voting for model: %s", choice)
210
-
211
  vote_data = {
212
- "image_id": fpaths[0].value,
213
  "model_a": model_a_name.value,
214
  "model_b": model_b_name.value,
215
  "winner": choice,
216
- "fpath_a": fpaths[1].value,
217
- "fpath_b": fpaths[2].value,
218
  }
219
 
220
  try:
221
  logging.debug("Adding vote data to the database: %s", vote_data)
222
- from db import add_vote
223
  result = add_vote(vote_data)
224
  logging.info("Vote successfully recorded in the database with ID: %s", result["id"])
225
  except Exception as e:
226
  logging.error("Error recording vote: %s", str(e))
227
 
228
- (new_fpath_input, new_input_image, new_fpath_a, new_segmented_a,
229
- new_fpath_b, new_segmented_b, new_a_name, new_b_name) = select_new_image()
230
  model_a_name.value = new_a_name
231
  model_b_name.value = new_b_name
232
  fpath_input.value = new_fpath_input
233
- fpath_a.value = new_fpath_a
234
- fpath_b.value = new_fpath_b
235
 
236
  mask_difference = compute_mask_difference(new_segmented_a, new_segmented_b)
237
 
238
  # Update the notice markdown with the new vote count
239
  new_notice_markdown = get_notice_markdown()
240
 
241
- return (fpath_input.value, (new_input_image, [(mask_difference, "Mask")]), fpath_a.value, new_segmented_a,
242
- fpath_b.value, new_segmented_b, model_a_name.value, model_b_name.value, new_notice_markdown)
243
 
244
  with gr.Tab("🏆 Leaderboard", id=1) as leaderboard_tab:
245
  rankings_table = gr.Dataframe(
 
6
  import numpy as np
7
  from PIL import Image
8
  import random
9
+ from db import compute_elo_scores, get_all_votes, add_vote
10
  import json
11
  from pathlib import Path
12
  from uuid import uuid4
 
100
  model_b_output_image = segmented_images[model_b_index]
101
  model_a_name = segmented_sources[model_a_index]
102
  model_b_name = segmented_sources[model_b_index]
103
+ return sample['original_filename'], input_image, model_a_output_image, model_b_output_image, model_a_name, model_b_name
 
104
  except Exception as e:
105
  logging.error("Error processing images: %s. Resampling another image.", str(e))
106
  last_image_index = random_index
 
145
  with gr.Tab("⚔️ Arena (battle)", id=0):
146
  notice_markdown = gr.Markdown(get_notice_markdown(), elem_id="notice_markdown")
147
 
148
+ filname, input_image, segmented_a, segmented_b, a_name, b_name = select_new_image()
 
149
  model_a_name = gr.State(a_name)
150
  model_b_name = gr.State(b_name)
151
+ fpath_input = gr.State(filname)
 
 
152
 
153
  # Compute the absolute difference between the masks
154
  mask_difference = compute_mask_difference(segmented_a, segmented_b)
 
182
 
183
 
184
  vote_a_btn.click(
185
+ fn=lambda: vote_for_model("model_a", fpath_input, model_a_name, model_b_name),
186
  outputs=[
187
+ fpath_input, input_image_display, image_a_display, image_b_display, model_a_name, model_b_name, notice_markdown
188
  ]
189
  )
190
  vote_b_btn.click(
191
+ fn=lambda: vote_for_model("model_b",fpath_input, model_a_name, model_b_name),
192
  outputs=[
193
+ fpath_input, input_image_display, image_a_display, image_b_display, model_a_name, model_b_name, notice_markdown
194
  ]
195
  )
196
  vote_tie.click(
197
+ fn=lambda: vote_for_model("tie", fpath_input, model_a_name, model_b_name),
198
  outputs=[
199
+ fpath_input, input_image_display, image_a_display, image_b_display, model_a_name, model_b_name, notice_markdown
200
  ]
201
  )
202
 
203
+ def vote_for_model(choice, original_filename, model_a_name, model_b_name):
204
  """Submit a vote for a model and return updated images and model names."""
205
  logging.info("Voting for model: %s", choice)
 
206
  vote_data = {
207
+ "image_id": original_filename.value,
208
  "model_a": model_a_name.value,
209
  "model_b": model_b_name.value,
210
  "winner": choice,
 
 
211
  }
212
 
213
  try:
214
  logging.debug("Adding vote data to the database: %s", vote_data)
215
+
216
  result = add_vote(vote_data)
217
  logging.info("Vote successfully recorded in the database with ID: %s", result["id"])
218
  except Exception as e:
219
  logging.error("Error recording vote: %s", str(e))
220
 
221
+ new_fpath_input, new_input_image, new_segmented_a, new_segmented_b, new_a_name, new_b_name = select_new_image()
 
222
  model_a_name.value = new_a_name
223
  model_b_name.value = new_b_name
224
  fpath_input.value = new_fpath_input
 
 
225
 
226
  mask_difference = compute_mask_difference(new_segmented_a, new_segmented_b)
227
 
228
  # Update the notice markdown with the new vote count
229
  new_notice_markdown = get_notice_markdown()
230
 
231
+ return (fpath_input.value, (new_input_image, [(mask_difference, "Mask")]), new_segmented_a,
232
+ new_segmented_b, model_a_name.value, model_b_name.value, new_notice_markdown)
233
 
234
  with gr.Tab("🏆 Leaderboard", id=1) as leaderboard_tab:
235
  rankings_table = gr.Dataframe(
data/newvotes.db ADDED
Binary file (86 kB). View file
 
db.py CHANGED
@@ -7,7 +7,7 @@ import uuid
7
  from rating_systems import compute_elo
8
 
9
 
10
- DATABASE_URL = "sqlite:///./data/votes.db" # Example with SQLite, replace with PostgreSQL for production
11
  engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
12
  SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
13
  Base = declarative_base()
 
7
  from rating_systems import compute_elo
8
 
9
 
10
+ DATABASE_URL = "sqlite:///./data/newvotes.db" # Example with SQLite, replace with PostgreSQL for production
11
  engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
12
  SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
13
  Base = declarative_base()