Jhp commited on
Commit
825f4db
1 Parent(s): 81a1370
Files changed (3) hide show
  1. app.py +1 -1
  2. gradio.ipynb +13 -5
  3. visualization.py +5 -3
app.py CHANGED
@@ -3,7 +3,7 @@ from visualization import visualization
3
  # pipeline = pipeline(task="image-classification", model="julien-c/hotdog-not-hotdog")
4
  # pipeline = pipeline(task="image-classification", model="jhp/hoi")
5
 
6
- def predict(image,threshold,topk):
7
  vis_img = visualization(image,threshold,topk)
8
  return vis_img
9
 
 
3
  # pipeline = pipeline(task="image-classification", model="julien-c/hotdog-not-hotdog")
4
  # pipeline = pipeline(task="image-classification", model="jhp/hoi")
5
 
6
+ def predict(image,threshold,topk,device=''):
7
  vis_img = visualization(image,threshold,topk)
8
  return vis_img
9
 
gradio.ipynb CHANGED
@@ -6,12 +6,20 @@
6
  "id": "531487e5-d72d-41be-b4ae-ccd9f8dc844e",
7
  "metadata": {},
8
  "outputs": [
 
 
 
 
 
 
 
 
9
  {
10
  "name": "stdout",
11
  "output_type": "stream",
12
  "text": [
13
  "Running on local URL: http://127.0.0.1:7860\n",
14
- "Running on public URL: https://fc8effa414b728bb78.gradio.live\n",
15
  "\n",
16
  "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
17
  ]
@@ -19,7 +27,7 @@
19
  {
20
  "data": {
21
  "text/html": [
22
- "<div><iframe src=\"https://fc8effa414b728bb78.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
23
  ],
24
  "text/plain": [
25
  "<IPython.core.display.HTML object>"
@@ -33,7 +41,7 @@
33
  "output_type": "stream",
34
  "text": [
35
  "loading annotations into memory...\n",
36
- "Done (t=1.67s)\n",
37
  "creating index...\n",
38
  "index created!\n",
39
  "\n",
@@ -77,8 +85,8 @@
77
  " predict,\n",
78
  " inputs=[gr.Image(type='pil',label=\"input image\"),\n",
79
  " gr.Slider(0, 1, value=0.4, label=\"Threshold\", info=\"Set detection score threshold between 0~1\"),\n",
80
- " gr.Number(value=5,label='Topk',info='Topk prediction')],\n",
81
- " outputs= gr.Image(type=\"pil\", label=\"hoi detection results\"),\n",
82
  " title=\"HOI detection\",\n",
83
  ").launch(share=True,debug=True)"
84
  ]
 
6
  "id": "531487e5-d72d-41be-b4ae-ccd9f8dc844e",
7
  "metadata": {},
8
  "outputs": [
9
+ {
10
+ "name": "stderr",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "/tmp/ipykernel_4031598/48305459.py:16: GradioDeprecationWarning: Usage of gradio.outputs is deprecated, and will not be supported in the future, please import your components from gradio.components\n",
14
+ " outputs= gr.outputs.Image(type=\"pil\", label=\"hoi detection results\"),\n"
15
+ ]
16
+ },
17
  {
18
  "name": "stdout",
19
  "output_type": "stream",
20
  "text": [
21
  "Running on local URL: http://127.0.0.1:7860\n",
22
+ "Running on public URL: https://fd9d0145926e3bdb6d.gradio.live\n",
23
  "\n",
24
  "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
25
  ]
 
27
  {
28
  "data": {
29
  "text/html": [
30
+ "<div><iframe src=\"https://fd9d0145926e3bdb6d.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
31
  ],
32
  "text/plain": [
33
  "<IPython.core.display.HTML object>"
 
41
  "output_type": "stream",
42
  "text": [
43
  "loading annotations into memory...\n",
44
+ "Done (t=1.56s)\n",
45
  "creating index...\n",
46
  "index created!\n",
47
  "\n",
 
85
  " predict,\n",
86
  " inputs=[gr.Image(type='pil',label=\"input image\"),\n",
87
  " gr.Slider(0, 1, value=0.4, label=\"Threshold\", info=\"Set detection score threshold between 0~1\"),\n",
88
+ " gr.Number(value=5,info='Topk prediction')],\n",
89
+ " outputs= gr.outputs.Image(type=\"pil\", label=\"hoi detection results\"),\n",
90
  " title=\"HOI detection\",\n",
91
  ").launch(share=True,debug=True)"
92
  ]
visualization.py CHANGED
@@ -144,9 +144,11 @@ def vis(args,input_img=None,id=294,return_img=False):
144
 
145
  vis_img=draw_img_vcoco(image,output_i,top_k=args.topk,threshold=args.threshold,color=builtin_meta.COCO_CATEGORIES)
146
  plt.imshow(cv2.cvtColor(vis_img,cv2.COLOR_BGR2RGB))
147
- # import pdb;pdb.set_trace()
148
  if return_img:
149
- return Image.fromarray(vis_img)
 
 
150
  else:
151
  cv2.imwrite('./vis_res/vis1.jpg',vis_img)
152
 
@@ -235,7 +237,7 @@ def visualization(input_img,threshold,topk):
235
  # args.topk = topk
236
  if args.output_dir:
237
  Path(args.output_dir).mkdir(parents=True, exist_ok=True)
238
- vis(args,input_img=input_img,return_img=True)
239
 
240
  if __name__ == '__main__':
241
  parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
 
144
 
145
  vis_img=draw_img_vcoco(image,output_i,top_k=args.topk,threshold=args.threshold,color=builtin_meta.COCO_CATEGORIES)
146
  plt.imshow(cv2.cvtColor(vis_img,cv2.COLOR_BGR2RGB))
147
+
148
  if return_img:
149
+ vis_img = Image.fromarray(vis_img[:,:,::-1])
150
+ # import pdb;pdb.set_trace()
151
+ return vis_img
152
  else:
153
  cv2.imwrite('./vis_res/vis1.jpg',vis_img)
154
 
 
237
  # args.topk = topk
238
  if args.output_dir:
239
  Path(args.output_dir).mkdir(parents=True, exist_ok=True)
240
+ return vis(args,input_img=input_img,return_img=True)
241
 
242
  if __name__ == '__main__':
243
  parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])