hysts HF staff commited on
Commit
4569f2f
·
1 Parent(s): 5664eba
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +23 -21
  3. style.css +3 -0
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🦀
4
  colorFrom: blue
5
  colorTo: pink
6
  sdk: gradio
7
- sdk_version: 3.19.1
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: blue
5
  colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 3.34.0
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py CHANGED
@@ -14,12 +14,7 @@ import torch
14
  import torch.nn as nn
15
  import torch.nn.functional as F
16
 
17
- TITLE = 'Age Estimation'
18
- DESCRIPTION = 'This is an unofficial demo for https://github.com/yu4u/age-estimation-pytorch.'
19
-
20
- HF_TOKEN = os.getenv('HF_TOKEN')
21
- MODEL_REPO = 'hysts/yu4u-age-estimation-pytorch'
22
- MODEL_FILENAME = 'pretrained.pth'
23
 
24
 
25
  def get_model(model_name='se_resnext50_32x4d',
@@ -34,9 +29,8 @@ def get_model(model_name='se_resnext50_32x4d',
34
 
35
  def load_model(device):
36
  model = get_model(model_name='se_resnext50_32x4d', pretrained=None)
37
- path = huggingface_hub.hf_hub_download(MODEL_REPO,
38
- MODEL_FILENAME,
39
- use_auth_token=HF_TOKEN)
40
  model.load_state_dict(torch.load(path))
41
  model = model.to(device)
42
  model.eval()
@@ -111,19 +105,27 @@ def predict(image, model, face_detector, device, margin=0.4, input_size=224):
111
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
112
  model = load_model(device)
113
  face_detector = dlib.get_frontal_face_detector()
114
- func = functools.partial(predict,
115
- model=model,
116
- face_detector=face_detector,
117
- device=device)
118
 
119
  image_dir = pathlib.Path('sample_images')
120
  examples = [path.as_posix() for path in sorted(image_dir.glob('*.jpg'))]
121
 
122
- gr.Interface(
123
- fn=func,
124
- inputs=gr.Image(label='Input', type='filepath'),
125
- outputs=gr.Image(label='Output'),
126
- examples=examples,
127
- title=TITLE,
128
- description=DESCRIPTION,
129
- ).launch(show_api=False)
 
 
 
 
 
 
 
 
 
14
  import torch.nn as nn
15
  import torch.nn.functional as F
16
 
17
+ DESCRIPTION = '# [Age Estimation](https://github.com/yu4u/age-estimation-pytorch)'
 
 
 
 
 
18
 
19
 
20
  def get_model(model_name='se_resnext50_32x4d',
 
29
 
30
  def load_model(device):
31
  model = get_model(model_name='se_resnext50_32x4d', pretrained=None)
32
+ path = huggingface_hub.hf_hub_download(
33
+ 'public-data/yu4u-age-estimation-pytorch', 'pretrained.pth')
 
34
  model.load_state_dict(torch.load(path))
35
  model = model.to(device)
36
  model.eval()
 
105
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
106
  model = load_model(device)
107
  face_detector = dlib.get_frontal_face_detector()
108
+ fn = functools.partial(predict,
109
+ model=model,
110
+ face_detector=face_detector,
111
+ device=device)
112
 
113
  image_dir = pathlib.Path('sample_images')
114
  examples = [path.as_posix() for path in sorted(image_dir.glob('*.jpg'))]
115
 
116
+ with gr.Blocks(css='style.css') as demo:
117
+ gr.Markdown(DESCRIPTION)
118
+ with gr.Row():
119
+ with gr.Column():
120
+ image = gr.Image(label='Input', type='filepath')
121
+ run_button = gr.Button('Run')
122
+ with gr.Column():
123
+ result = gr.Image(label='Result')
124
+
125
+ gr.Examples(examples=examples,
126
+ inputs=image,
127
+ outputs=result,
128
+ fn=fn,
129
+ cache_examples=os.getenv('CACHE_EXAMPLES') == '1')
130
+ run_button.click(fn=fn, inputs=image, outputs=result, api_name='predict')
131
+ demo.queue(max_size=15).launch()
style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }