kiyer commited on
Commit
a53a055
·
verified ·
1 Parent(s): f2f9cd1

update embedding plot

Browse files
Files changed (1) hide show
  1. app_gradio.py +6 -6
app_gradio.py CHANGED
@@ -471,18 +471,18 @@ def make_embedding_plot(papers_df, top_k, consensus_answer, arxiv_corpus=arxiv_c
471
  alphas[outlier_flag] = 0.5
472
 
473
  fig = plt.figure(figsize=(9*1.8,12*1.8))
474
- plt.scatter(xax,yax, s=1, alpha=0.01, c='k')
475
 
476
  clkws = np.load('kw_tags.npz')
477
  all_x, all_y, all_topics, repeat_flag = clkws['all_x'], clkws['all_y'], clkws['all_topics'], clkws['repeat_flag']
478
- for i in range(len(all_topics)):
479
- if repeat_flag[i] == False:
480
- plt.text(all_x[i], all_y[i], all_topics[i],fontsize=9,ha="center", va="center",
481
- bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.3',alpha=0.81))
482
  plt.scatter(xax[plt_indices], yax[plt_indices], s=300*alphas**2, alpha=alphas, c='w',zorder=1000)
483
  plt.scatter(xax[plt_indices], yax[plt_indices], s=100*alphas**2, alpha=alphas, c='dodgerblue',zorder=1001)
484
  # plt.scatter(xax[plt_indices][outlier_flag], yax[plt_indices][outlier_flag], s=100, alpha=1., c='firebrick')
485
- plt.axis([0,20,-4.2,18])
486
  plt.axis('off')
487
  return fig
488
 
 
471
  alphas[outlier_flag] = 0.5
472
 
473
  fig = plt.figure(figsize=(9*1.8,12*1.8))
474
+ plt.scatter(xax,yax, s=1, alpha=0.1, c='k')
475
 
476
  clkws = np.load('kw_tags.npz')
477
  all_x, all_y, all_topics, repeat_flag = clkws['all_x'], clkws['all_y'], clkws['all_topics'], clkws['repeat_flag']
478
+ # for i in range(len(all_topics)):
479
+ # if repeat_flag[i] == False:
480
+ # plt.text(all_x[i], all_y[i], all_topics[i],fontsize=9,ha="center", va="center",
481
+ # bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.3',alpha=0.81))
482
  plt.scatter(xax[plt_indices], yax[plt_indices], s=300*alphas**2, alpha=alphas, c='w',zorder=1000)
483
  plt.scatter(xax[plt_indices], yax[plt_indices], s=100*alphas**2, alpha=alphas, c='dodgerblue',zorder=1001)
484
  # plt.scatter(xax[plt_indices][outlier_flag], yax[plt_indices][outlier_flag], s=100, alpha=1., c='firebrick')
485
+ plt.axis([2,15,2,8])
486
  plt.axis('off')
487
  return fig
488