Jhp commited on
Commit
0bc27e7
1 Parent(s): 529aa2f
Files changed (4) hide show
  1. .gitignore +2 -1
  2. app.py +14 -4
  3. gradio.ipynb +87 -37
  4. visualization.py +26 -21
.gitignore CHANGED
@@ -139,4 +139,5 @@ hico_20160224_det
139
  v-coco
140
 
141
  # *.ipynb
142
- vis_res
 
 
139
  v-coco
140
 
141
  # *.ipynb
142
+ vis_res
143
+ flagged
app.py CHANGED
@@ -1,7 +1,17 @@
1
  import gradio as gr
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from visualization import visualization
3
+ # pipeline = pipeline(task="image-classification", model="julien-c/hotdog-not-hotdog")
4
+ # pipeline = pipeline(task="image-classification", model="jhp/hoi")
5
 
6
+ def predict(image,threshold,topk):
7
+ vis_img = visualization(image,threshold,topk)
8
+ return vis_img
9
 
10
+ gr.Interface(
11
+ predict,
12
+ inputs=[gr.Image(type='pil',label="input image"),
13
+ gr.Slider(0, 1, value=0.4, label="Threshold", info="Set detection score threshold between 0~1"),
14
+ gr.Number(value=5,label='Topk',info='Topk prediction')],
15
+ outputs= gr.Image(type="pil", label="hoi detection results"),
16
+ title="HOI detection",
17
+ ).launch()
gradio.ipynb CHANGED
@@ -2,38 +2,63 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 4,
6
  "id": "531487e5-d72d-41be-b4ae-ccd9f8dc844e",
7
  "metadata": {},
8
  "outputs": [
9
  {
10
- "ename": "OSError",
11
- "evalue": "jhp/hoi does not appear to have a file named config.json. Checkout 'https://huggingface.co/jhp/hoi/main' for available files.",
12
- "output_type": "error",
13
- "traceback": [
14
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
15
- "\u001b[0;31mHTTPError\u001b[0m Traceback (most recent call last)",
16
- "File \u001b[0;32m~/.conda/envs/HOTR_CPC/lib/python3.9/site-packages/huggingface_hub/utils/_errors.py:261\u001b[0m, in \u001b[0;36mhf_raise_for_status\u001b[0;34m(response, endpoint_name)\u001b[0m\n\u001b[1;32m 260\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 261\u001b[0m \u001b[43mresponse\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mraise_for_status\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 262\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m HTTPError \u001b[38;5;28;01mas\u001b[39;00m e:\n",
17
- "File \u001b[0;32m~/.conda/envs/HOTR_CPC/lib/python3.9/site-packages/requests/models.py:1021\u001b[0m, in \u001b[0;36mResponse.raise_for_status\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1020\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m http_error_msg:\n\u001b[0;32m-> 1021\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m HTTPError(http_error_msg, response\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m)\n",
18
- "\u001b[0;31mHTTPError\u001b[0m: 404 Client Error: Not Found for url: https://huggingface.co/Jhp/hoi/resolve/main/config.json",
19
- "\nThe above exception was the direct cause of the following exception:\n",
20
- "\u001b[0;31mEntryNotFoundError\u001b[0m Traceback (most recent call last)",
21
- "File \u001b[0;32m~/.conda/envs/HOTR_CPC/lib/python3.9/site-packages/transformers/utils/hub.py:417\u001b[0m, in \u001b[0;36mcached_file\u001b[0;34m(path_or_repo_id, filename, cache_dir, force_download, resume_download, proxies, use_auth_token, revision, local_files_only, subfolder, repo_type, user_agent, _raise_exceptions_for_missing_entries, _raise_exceptions_for_connection_errors, _commit_hash)\u001b[0m\n\u001b[1;32m 415\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 416\u001b[0m \u001b[38;5;66;03m# Load from URL or cache if already cached\u001b[39;00m\n\u001b[0;32m--> 417\u001b[0m resolved_file \u001b[38;5;241m=\u001b[39m \u001b[43mhf_hub_download\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 418\u001b[0m \u001b[43m \u001b[49m\u001b[43mpath_or_repo_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 419\u001b[0m \u001b[43m \u001b[49m\u001b[43mfilename\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 420\u001b[0m \u001b[43m \u001b[49m\u001b[43msubfolder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43msubfolder\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43msubfolder\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 421\u001b[0m \u001b[43m \u001b[49m\u001b[43mrepo_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrepo_type\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 422\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 423\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 424\u001b[0m \u001b[43m \u001b[49m\u001b[43muser_agent\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muser_agent\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 425\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mforce_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 426\u001b[0m \u001b[43m \u001b[49m\u001b[43mproxies\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mproxies\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 427\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 428\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_auth_token\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_auth_token\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 429\u001b[0m \u001b[43m \u001b[49m\u001b[43mlocal_files_only\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlocal_files_only\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 430\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 432\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m RepositoryNotFoundError:\n",
22
- "File \u001b[0;32m~/.conda/envs/HOTR_CPC/lib/python3.9/site-packages/huggingface_hub/utils/_validators.py:118\u001b[0m, in \u001b[0;36mvalidate_hf_hub_args.<locals>._inner_fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 116\u001b[0m kwargs \u001b[38;5;241m=\u001b[39m smoothly_deprecate_use_auth_token(fn_name\u001b[38;5;241m=\u001b[39mfn\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, has_token\u001b[38;5;241m=\u001b[39mhas_token, kwargs\u001b[38;5;241m=\u001b[39mkwargs)\n\u001b[0;32m--> 118\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
23
- "File \u001b[0;32m~/.conda/envs/HOTR_CPC/lib/python3.9/site-packages/huggingface_hub/file_download.py:1195\u001b[0m, in \u001b[0;36mhf_hub_download\u001b[0;34m(repo_id, filename, subfolder, repo_type, revision, library_name, library_version, cache_dir, local_dir, local_dir_use_symlinks, user_agent, force_download, force_filename, proxies, etag_timeout, resume_download, token, local_files_only, legacy_cache_layout)\u001b[0m\n\u001b[1;32m 1194\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1195\u001b[0m metadata \u001b[38;5;241m=\u001b[39m \u001b[43mget_hf_file_metadata\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1196\u001b[0m \u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1197\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1198\u001b[0m \u001b[43m \u001b[49m\u001b[43mproxies\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mproxies\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1199\u001b[0m \u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43metag_timeout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1200\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1201\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m EntryNotFoundError \u001b[38;5;28;01mas\u001b[39;00m http_error:\n\u001b[1;32m 1202\u001b[0m \u001b[38;5;66;03m# Cache the non-existence of the file and raise\u001b[39;00m\n",
24
- "File \u001b[0;32m~/.conda/envs/HOTR_CPC/lib/python3.9/site-packages/huggingface_hub/utils/_validators.py:118\u001b[0m, in \u001b[0;36mvalidate_hf_hub_args.<locals>._inner_fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 116\u001b[0m kwargs \u001b[38;5;241m=\u001b[39m smoothly_deprecate_use_auth_token(fn_name\u001b[38;5;241m=\u001b[39mfn\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, has_token\u001b[38;5;241m=\u001b[39mhas_token, kwargs\u001b[38;5;241m=\u001b[39mkwargs)\n\u001b[0;32m--> 118\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
25
- "File \u001b[0;32m~/.conda/envs/HOTR_CPC/lib/python3.9/site-packages/huggingface_hub/file_download.py:1541\u001b[0m, in \u001b[0;36mget_hf_file_metadata\u001b[0;34m(url, token, proxies, timeout)\u001b[0m\n\u001b[1;32m 1532\u001b[0m r \u001b[38;5;241m=\u001b[39m _request_wrapper(\n\u001b[1;32m 1533\u001b[0m method\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mHEAD\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 1534\u001b[0m url\u001b[38;5;241m=\u001b[39murl,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1539\u001b[0m timeout\u001b[38;5;241m=\u001b[39mtimeout,\n\u001b[1;32m 1540\u001b[0m )\n\u001b[0;32m-> 1541\u001b[0m \u001b[43mhf_raise_for_status\u001b[49m\u001b[43m(\u001b[49m\u001b[43mr\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;66;03m# Return\u001b[39;00m\n",
26
- "File \u001b[0;32m~/.conda/envs/HOTR_CPC/lib/python3.9/site-packages/huggingface_hub/utils/_errors.py:271\u001b[0m, in \u001b[0;36mhf_raise_for_status\u001b[0;34m(response, endpoint_name)\u001b[0m\n\u001b[1;32m 270\u001b[0m message \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresponse\u001b[38;5;241m.\u001b[39mstatus_code\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m Client Error.\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEntry Not Found for url: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresponse\u001b[38;5;241m.\u001b[39murl\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m--> 271\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m EntryNotFoundError(message, response) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01me\u001b[39;00m\n\u001b[1;32m 273\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m error_code \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mGatedRepo\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n",
27
- "\u001b[0;31mEntryNotFoundError\u001b[0m: 404 Client Error. (Request ID: Root=1-64c1e562-51194bc57d2e3df45df60f34;36aabc8d-19da-4540-9897-3081a028b579)\n\nEntry Not Found for url: https://huggingface.co/Jhp/hoi/resolve/main/config.json.",
28
- "\nDuring handling of the above exception, another exception occurred:\n",
29
- "\u001b[0;31mOSError\u001b[0m Traceback (most recent call last)",
30
- "Cell \u001b[0;32mIn[4], line 5\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtransformers\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m pipeline\n\u001b[1;32m 4\u001b[0m \u001b[38;5;66;03m# pipeline = pipeline(task=\"image-classification\", model=\"julien-c/hotdog-not-hotdog\")\u001b[39;00m\n\u001b[0;32m----> 5\u001b[0m pipeline \u001b[38;5;241m=\u001b[39m \u001b[43mpipeline\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mimage-classification\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mjhp/hoi\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mpredict\u001b[39m(image):\n\u001b[1;32m 8\u001b[0m predictions \u001b[38;5;241m=\u001b[39m pipeline(image)\n",
31
- "File \u001b[0;32m~/.conda/envs/HOTR_CPC/lib/python3.9/site-packages/transformers/pipelines/__init__.py:705\u001b[0m, in \u001b[0;36mpipeline\u001b[0;34m(task, model, config, tokenizer, feature_extractor, image_processor, framework, revision, use_fast, use_auth_token, device, device_map, torch_dtype, trust_remote_code, model_kwargs, pipeline_class, **kwargs)\u001b[0m\n\u001b[1;32m 703\u001b[0m hub_kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_commit_hash\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m config\u001b[38;5;241m.\u001b[39m_commit_hash\n\u001b[1;32m 704\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m config \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(model, \u001b[38;5;28mstr\u001b[39m):\n\u001b[0;32m--> 705\u001b[0m config \u001b[38;5;241m=\u001b[39m \u001b[43mAutoConfig\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m_from_pipeline\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mhub_kwargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 706\u001b[0m hub_kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_commit_hash\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m config\u001b[38;5;241m.\u001b[39m_commit_hash\n\u001b[1;32m 708\u001b[0m custom_tasks \u001b[38;5;241m=\u001b[39m {}\n",
32
- "File \u001b[0;32m~/.conda/envs/HOTR_CPC/lib/python3.9/site-packages/transformers/models/auto/configuration_auto.py:983\u001b[0m, in \u001b[0;36mAutoConfig.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, **kwargs)\u001b[0m\n\u001b[1;32m 981\u001b[0m kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mname_or_path\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m pretrained_model_name_or_path\n\u001b[1;32m 982\u001b[0m trust_remote_code \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtrust_remote_code\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[0;32m--> 983\u001b[0m config_dict, unused_kwargs \u001b[38;5;241m=\u001b[39m \u001b[43mPretrainedConfig\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_config_dict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpretrained_model_name_or_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 984\u001b[0m has_remote_code \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto_map\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m config_dict \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAutoConfig\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m config_dict[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto_map\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 985\u001b[0m has_local_code \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel_type\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m config_dict \u001b[38;5;129;01mand\u001b[39;00m config_dict[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel_type\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;129;01min\u001b[39;00m CONFIG_MAPPING\n",
33
- "File \u001b[0;32m~/.conda/envs/HOTR_CPC/lib/python3.9/site-packages/transformers/configuration_utils.py:617\u001b[0m, in \u001b[0;36mPretrainedConfig.get_config_dict\u001b[0;34m(cls, pretrained_model_name_or_path, **kwargs)\u001b[0m\n\u001b[1;32m 615\u001b[0m original_kwargs \u001b[38;5;241m=\u001b[39m copy\u001b[38;5;241m.\u001b[39mdeepcopy(kwargs)\n\u001b[1;32m 616\u001b[0m \u001b[38;5;66;03m# Get config dict associated with the base config file\u001b[39;00m\n\u001b[0;32m--> 617\u001b[0m config_dict, kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_get_config_dict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpretrained_model_name_or_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 618\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_commit_hash\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m config_dict:\n\u001b[1;32m 619\u001b[0m original_kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_commit_hash\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m config_dict[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_commit_hash\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n",
34
- "File \u001b[0;32m~/.conda/envs/HOTR_CPC/lib/python3.9/site-packages/transformers/configuration_utils.py:672\u001b[0m, in \u001b[0;36mPretrainedConfig._get_config_dict\u001b[0;34m(cls, pretrained_model_name_or_path, **kwargs)\u001b[0m\n\u001b[1;32m 668\u001b[0m configuration_file \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_configuration_file\u001b[39m\u001b[38;5;124m\"\u001b[39m, CONFIG_NAME)\n\u001b[1;32m 670\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 671\u001b[0m \u001b[38;5;66;03m# Load from local folder or from cache or download from model Hub and cache\u001b[39;00m\n\u001b[0;32m--> 672\u001b[0m resolved_config_file \u001b[38;5;241m=\u001b[39m \u001b[43mcached_file\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 673\u001b[0m \u001b[43m \u001b[49m\u001b[43mpretrained_model_name_or_path\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 674\u001b[0m \u001b[43m \u001b[49m\u001b[43mconfiguration_file\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 675\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 676\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mforce_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 677\u001b[0m \u001b[43m \u001b[49m\u001b[43mproxies\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mproxies\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 678\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 679\u001b[0m \u001b[43m \u001b[49m\u001b[43mlocal_files_only\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlocal_files_only\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 680\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_auth_token\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_auth_token\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 681\u001b[0m \u001b[43m \u001b[49m\u001b[43muser_agent\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muser_agent\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 682\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 683\u001b[0m \u001b[43m \u001b[49m\u001b[43msubfolder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msubfolder\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 684\u001b[0m \u001b[43m \u001b[49m\u001b[43m_commit_hash\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcommit_hash\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 685\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 686\u001b[0m commit_hash \u001b[38;5;241m=\u001b[39m extract_commit_hash(resolved_config_file, commit_hash)\n\u001b[1;32m 687\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mEnvironmentError\u001b[39;00m:\n\u001b[1;32m 688\u001b[0m \u001b[38;5;66;03m# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to\u001b[39;00m\n\u001b[1;32m 689\u001b[0m \u001b[38;5;66;03m# the original exception.\u001b[39;00m\n",
35
- "File \u001b[0;32m~/.conda/envs/HOTR_CPC/lib/python3.9/site-packages/transformers/utils/hub.py:463\u001b[0m, in \u001b[0;36mcached_file\u001b[0;34m(path_or_repo_id, filename, cache_dir, force_download, resume_download, proxies, use_auth_token, revision, local_files_only, subfolder, repo_type, user_agent, _raise_exceptions_for_missing_entries, _raise_exceptions_for_connection_errors, _commit_hash)\u001b[0m\n\u001b[1;32m 461\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m revision \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 462\u001b[0m revision \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m--> 463\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mEnvironmentError\u001b[39;00m(\n\u001b[1;32m 464\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpath_or_repo_id\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m does not appear to have a file named \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfull_filename\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m. Checkout \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 465\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mhttps://huggingface.co/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpath_or_repo_id\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mrevision\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m for available files.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 466\u001b[0m )\n\u001b[1;32m 467\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m HTTPError \u001b[38;5;28;01mas\u001b[39;00m err:\n\u001b[1;32m 468\u001b[0m \u001b[38;5;66;03m# First we try to see if we have a cached version (not up to date):\u001b[39;00m\n\u001b[1;32m 469\u001b[0m resolved_file \u001b[38;5;241m=\u001b[39m try_to_load_from_cache(path_or_repo_id, full_filename, cache_dir\u001b[38;5;241m=\u001b[39mcache_dir, revision\u001b[38;5;241m=\u001b[39mrevision)\n",
36
- "\u001b[0;31mOSError\u001b[0m: jhp/hoi does not appear to have a file named config.json. Checkout 'https://huggingface.co/jhp/hoi/main' for available files."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  ]
38
  }
39
  ],
@@ -42,25 +67,50 @@
42
  "from transformers import pipeline\n",
43
  "from visualization import visualization\n",
44
  "# pipeline = pipeline(task=\"image-classification\", model=\"julien-c/hotdog-not-hotdog\")\n",
45
- "pipeline = pipeline(task=\"image-classification\", model=\"jhp/hoi\")\n",
46
  "\n",
47
- "def predict(image):\n",
48
- " image = pipeline(image)\n",
49
- " return {p[\"label\"]: p[\"score\"] for p in predictions}\n",
50
  "\n",
51
  "gr.Interface(\n",
52
  " predict,\n",
53
- " inputs=gr.inputs.Image(label=\"Upload hot dog candidate\", type=\"filepath\"),\n",
54
- " outputs=gr.outputs.Label(num_top_classes=2),\n",
55
- " title=\"Hot Dog? Or Not?\",\n",
56
- ").launch(share=True)"
 
 
57
  ]
58
  },
59
  {
60
  "cell_type": "code",
61
- "execution_count": null,
62
  "id": "439a75e9-77e6-4932-9b9b-35e2d0b7a76b",
63
  "metadata": {},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  "outputs": [],
65
  "source": []
66
  }
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": null,
6
  "id": "531487e5-d72d-41be-b4ae-ccd9f8dc844e",
7
  "metadata": {},
8
  "outputs": [
9
  {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "Running on local URL: http://127.0.0.1:7860\n",
14
+ "Running on public URL: https://fc8effa414b728bb78.gradio.live\n",
15
+ "\n",
16
+ "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
17
+ ]
18
+ },
19
+ {
20
+ "data": {
21
+ "text/html": [
22
+ "<div><iframe src=\"https://fc8effa414b728bb78.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
23
+ ],
24
+ "text/plain": [
25
+ "<IPython.core.display.HTML object>"
26
+ ]
27
+ },
28
+ "metadata": {},
29
+ "output_type": "display_data"
30
+ },
31
+ {
32
+ "name": "stdout",
33
+ "output_type": "stream",
34
+ "text": [
35
+ "loading annotations into memory...\n",
36
+ "Done (t=1.67s)\n",
37
+ "creating index...\n",
38
+ "index created!\n",
39
+ "\n",
40
+ "[Logger] DETR Arguments:\n",
41
+ "\tlr: 0.0001\n",
42
+ "\tlr_backbone: 1e-05\n",
43
+ "\tlr_drop: 80\n",
44
+ "\tfrozen_weights: None\n",
45
+ "\tbackbone: resnet50\n",
46
+ "\tdilation: False\n",
47
+ "\tposition_embedding: sine\n",
48
+ "\tenc_layers: 6\n",
49
+ "\tdec_layers: 6\n",
50
+ "\tnum_queries: 100\n",
51
+ "\tdataset_file: vcoco\n",
52
+ "\n",
53
+ "[Logger] Number of params: 52413912\n"
54
+ ]
55
+ },
56
+ {
57
+ "name": "stderr",
58
+ "output_type": "stream",
59
+ "text": [
60
+ "/home/jihwan/CPC_HOTR/hotr/models/position_encoding.py:41: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n",
61
+ " dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)\n"
62
  ]
63
  }
64
  ],
 
67
  "from transformers import pipeline\n",
68
  "from visualization import visualization\n",
69
  "# pipeline = pipeline(task=\"image-classification\", model=\"julien-c/hotdog-not-hotdog\")\n",
70
+ "# pipeline = pipeline(task=\"image-classification\", model=\"jhp/hoi\")\n",
71
  "\n",
72
+ "def predict(image,threshold,topk):\n",
73
+ " vis_img = visualization(image,threshold,topk)\n",
74
+ " return vis_img\n",
75
  "\n",
76
  "gr.Interface(\n",
77
  " predict,\n",
78
+ " inputs=[gr.Image(type='pil',label=\"input image\"),\n",
79
+ " gr.Slider(0, 1, value=0.4, label=\"Threshold\", info=\"Set detection score threshold between 0~1\"),\n",
80
+ " gr.Number(value=5,label='Topk',info='Topk prediction')],\n",
81
+ " outputs= gr.Image(type=\"pil\", label=\"hoi detection results\"),\n",
82
+ " title=\"HOI detection\",\n",
83
+ ").launch(share=True,debug=True)"
84
  ]
85
  },
86
  {
87
  "cell_type": "code",
88
+ "execution_count": 1,
89
  "id": "439a75e9-77e6-4932-9b9b-35e2d0b7a76b",
90
  "metadata": {},
91
+ "outputs": [
92
+ {
93
+ "ename": "TypeError",
94
+ "evalue": "string indices must be integers",
95
+ "output_type": "error",
96
+ "traceback": [
97
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
98
+ "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
99
+ "Cell \u001b[0;32mIn[1], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m a\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124msdsd\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[0;32m----> 2\u001b[0m \u001b[43ma\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m\n",
100
+ "\u001b[0;31mTypeError\u001b[0m: string indices must be integers"
101
+ ]
102
+ }
103
+ ],
104
+ "source": [
105
+ "a='sdsd'\n",
106
+ "a[:,:]\n"
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "execution_count": null,
112
+ "id": "96fc750d-1869-4c83-87ad-d4ef909bbddb",
113
+ "metadata": {},
114
  "outputs": [],
115
  "source": []
116
  }
visualization.py CHANGED
@@ -51,7 +51,7 @@ def change_format(results,valid_ids):
51
  output_i['hoi_prediction'].append({'subject_id':hum,'object_id':k,'category_id':i+2,'score':verb[j][k]})
52
 
53
  return output_i
54
- def vis(args,id=294,return_img=False):
55
 
56
  if args.frozen_weights is not None:
57
  print("Freeze weights for detector")
@@ -116,8 +116,13 @@ def vis(args,id=294,return_img=False):
116
  # if not args.video_vis:
117
  # url='http://images.cocodataset.org/val2014/COCO_val2014_{}.jpg'.format(str(id).zfill(12))
118
  # req = requests.get(url, stream=True, timeout=1, verify=False).raw
119
- req = args.image_dir
120
- img = Image.open(req).convert('RGB')
 
 
 
 
 
121
 
122
  w,h=img.size
123
  orig_size = torch.as_tensor([int(h), int(w)]).unsqueeze(0).to(device)
@@ -138,8 +143,9 @@ def vis(args,id=294,return_img=False):
138
 
139
  vis_img=draw_img_vcoco(image,output_i,top_k=args.topk,threshold=args.threshold,color=builtin_meta.COCO_CATEGORIES)
140
  plt.imshow(cv2.cvtColor(vis_img,cv2.COLOR_BGR2RGB))
 
141
  if return_img:
142
- return vis_img
143
  else:
144
  cv2.imwrite('./vis_res/vis1.jpg',vis_img)
145
 
@@ -203,33 +209,32 @@ def vis(args,id=294,return_img=False):
203
  # vis(args,id)
204
 
205
  # 230727 for huggingface
206
- def visualization(return_img=False):
207
 
208
  parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
209
- parser.add_argument('--threshold',help='score threshold for visualization', default=0.4, type=float)
210
- # parser.add_argument('--path_id',help='index of inference path', default=1, type=int)
211
- parser.add_argument('--topk',help='topk prediction', default=5, type=int)
212
- parser.add_argument('--video_vis', action='store_true')
213
- parser.add_argument('--image_dir', default='', type=str)
214
- args = parser.parse_args()
215
  # checkpoint_dir= './checkpoints/vcoco/checkpoint.pth' if dataset_file=='vcoco' else './checkpoints/hico-det/hico_ft_q16.pth'
216
  args.resume= './checkpoints/vcoco/checkpoint.pth'
217
- with open('./v-coco/data/splits/vcoco_test.ids') as file:
218
- test_idxs = [line.rstrip('\n') for line in file]
219
- # if not video_vis:
220
- id = test_idxs[309]
221
  # args = parser.parse_args()
222
- # args.dataset_file = 'vcoco'
223
- # args.data_path = 'v-coco'
224
  # args.resume = checkpoint_dir
225
- # args.num_hoi_queries = 16
226
- # args.temperature = 0.05
227
  args.augpath_name = ['p2','p3','p4']
228
  # args.path_id = 1
229
-
 
230
  if args.output_dir:
231
  Path(args.output_dir).mkdir(parents=True, exist_ok=True)
232
- vis(args,return_img=return_img)
233
 
234
  if __name__ == '__main__':
235
  parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
 
51
  output_i['hoi_prediction'].append({'subject_id':hum,'object_id':k,'category_id':i+2,'score':verb[j][k]})
52
 
53
  return output_i
54
+ def vis(args,input_img=None,id=294,return_img=False):
55
 
56
  if args.frozen_weights is not None:
57
  print("Freeze weights for detector")
 
116
  # if not args.video_vis:
117
  # url='http://images.cocodataset.org/val2014/COCO_val2014_{}.jpg'.format(str(id).zfill(12))
118
  # req = requests.get(url, stream=True, timeout=1, verify=False).raw
119
+
120
+ if input_img is None:
121
+ req = args.image_dir
122
+ img = Image.open(req).convert('RGB')
123
+ else:
124
+ # import pdb;pdb.set_trace()
125
+ img = input_img
126
 
127
  w,h=img.size
128
  orig_size = torch.as_tensor([int(h), int(w)]).unsqueeze(0).to(device)
 
143
 
144
  vis_img=draw_img_vcoco(image,output_i,top_k=args.topk,threshold=args.threshold,color=builtin_meta.COCO_CATEGORIES)
145
  plt.imshow(cv2.cvtColor(vis_img,cv2.COLOR_BGR2RGB))
146
+ # import pdb;pdb.set_trace()
147
  if return_img:
148
+ return Image.fromarray(vis_img)
149
  else:
150
  cv2.imwrite('./vis_res/vis1.jpg',vis_img)
151
 
 
209
  # vis(args,id)
210
 
211
  # 230727 for huggingface
212
+ def visualization(input_img,threshold,topk):
213
 
214
  parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
215
+ args = parser.parse_args(args=[])
216
+ args.threshold = threshold
217
+ args.topk = int(topk)
218
+
 
 
219
  # checkpoint_dir= './checkpoints/vcoco/checkpoint.pth' if dataset_file=='vcoco' else './checkpoints/hico-det/hico_ft_q16.pth'
220
  args.resume= './checkpoints/vcoco/checkpoint.pth'
221
+ # with open('./v-coco/data/splits/vcoco_test.ids') as file:
222
+ # test_idxs = [line.rstrip('\n') for line in file]
223
+ # # if not video_vis:
224
+ # id = test_idxs[309]
225
  # args = parser.parse_args()
226
+ args.dataset_file = 'vcoco'
227
+ args.data_path = 'v-coco'
228
  # args.resume = checkpoint_dir
229
+ args.num_hoi_queries = 16
230
+ args.temperature = 0.05
231
  args.augpath_name = ['p2','p3','p4']
232
  # args.path_id = 1
233
+ # args.threshold = threshold
234
+ # args.topk = topk
235
  if args.output_dir:
236
  Path(args.output_dir).mkdir(parents=True, exist_ok=True)
237
+ vis(args,input_img=input_img,return_img=True)
238
 
239
  if __name__ == '__main__':
240
  parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])