oscared commited on
Commit
be9798a
1 Parent(s): 7404e99

subiendo utils y app

Browse files
Files changed (2) hide show
  1. app.py +45 -0
  2. utils.py +15 -0
app.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ from utils import carga_model, genera
4
+
5
+ ##pagina principla
6
+ st.title('Generador de mariposas')
7
+ st.write('este es un model light gan entrenado y utilizado con platzi')
8
+
9
+ ## barra lateral
10
+ st.sidebar.subheader('!Esta mariposa no existe, puedes creerlo?')
11
+ st.sidebar.image('assets/logo.png', width=200)
12
+ st.sidebar.caption('Demo creado en vivo.')
13
+
14
+
15
+ ## cargamos el model
16
+ repo_id = 'ceyda/butterfly_cropped_uniq1K_512'
17
+ modelo_gan = carga_model(repo_id)
18
+
19
+
20
+ ## genera 4 mariposas
21
+
22
+ n_mariposas = 4
23
+
24
+ def corre():
25
+ with st.spinner('Generando, espera sentado...'):
26
+ ims = genera(modelo_gan, n_mariposas)
27
+ st.session_state['ims'] = ims
28
+
29
+ if 'ims' not in st.session_state:
30
+ st.session_state['ims'] = None
31
+ corre()
32
+
33
+ ims = st.session_stat['ims']
34
+
35
+ corre_boton = st.button(
36
+ 'Genera mariposas',
37
+ on_click= corre,
38
+ help='Estamos en vuelo, abre la imaginacion'
39
+ )
40
+
41
+ if ims is not None:
42
+ cols = st.columns(n_mariposas)
43
+ for j, im in enumerate(ims):
44
+ i = j % n_mariposa
45
+ cols[i].image(im, use_column_width=True)
utils.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from hugga.pytorch.lightweight_gan.lightweight_gan import LightweightGAN
4
+
5
+ def carga_modelo(model_name='ceyda/butterfly_croppe_uniq1K_512',model_version=None):
6
+ gan = LightweightGAN.from_pretrained(model_name, vesion=model_version)
7
+ gan.eval()
8
+ return gan
9
+
10
+ def genera(gan, batch_size=1):
11
+ with torch.no_grad():
12
+ ims = gan.G(torch.randn(batch_size, gan.latent_dim).clamp_(0.0,1.0)*255)
13
+ ims = ims.permute(0,2,3,1).deatch().cpu().numpy().asttype(np.uint8)
14
+ return ims
15
+