wavesoumen commited on
Commit
a80511b
1 Parent(s): 7c7cb02

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -26
app.py CHANGED
@@ -1,17 +1,40 @@
1
  import streamlit as st
2
- from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
 
 
3
  import nltk
4
  from youtube_transcript_api import YouTubeTranscriptApi
5
 
6
  # Download NLTK data
7
  nltk.download('punkt')
8
 
9
- # Initialize the image captioning pipeline
10
- captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
 
11
 
12
- # Load the tokenizer and model for tag generation
13
- tokenizer = AutoTokenizer.from_pretrained("fabiochiu/t5-base-tag-generation")
14
- model = AutoModelForSeq2SeqLM.from_pretrained("fabiochiu/t5-base-tag-generation")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  # Function to fetch YouTube transcript
17
  def fetch_transcript(url):
@@ -34,22 +57,21 @@ with tab1:
34
  st.header("Image Captioning")
35
 
36
  # Input for image URL
37
- image_url = st.text_input("Enter the URL of the image:")
38
 
39
  # If an image URL is provided
40
- if image_url:
41
- try:
42
- # Display the image
43
- st.image(image_url, caption="Provided Image", use_column_width=True)
44
-
45
- # Generate the caption
46
- caption = captioner(image_url)
47
-
48
- # Display the caption
49
- st.write("**Generated Caption:**")
50
- st.write(caption[0]['generated_text'])
51
- except Exception as e:
52
- st.error(f"An error occurred: {e}")
53
 
54
  # Text Tag Generation Tab
55
  with tab2:
@@ -59,17 +81,17 @@ with tab2:
59
  text = st.text_area("Enter the text for tag extraction:", height=200)
60
 
61
  # Button to generate tags
62
- if st.button("Generate Tags"):
63
  if text:
64
  try:
65
  # Tokenize and encode the input text
66
- inputs = tokenizer([text], max_length=512, truncation=True, return_tensors="pt")
67
 
68
  # Generate tags
69
- output = model.generate(**inputs, num_beams=8, do_sample=True, min_length=10, max_length=64)
70
 
71
  # Decode the output
72
- decoded_output = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
73
 
74
  # Extract unique tags
75
  tags = list(set(decoded_output.strip().split(", ")))
@@ -90,7 +112,7 @@ with tab3:
90
  youtube_url = st.text_input("Enter YouTube URL:")
91
 
92
  # Button to get transcript
93
- if st.button("Get Transcript"):
94
  if youtube_url:
95
  transcript = fetch_transcript(youtube_url)
96
  if "error" not in transcript.lower():
@@ -100,4 +122,3 @@ with tab3:
100
  st.error(f"An error occurred: {transcript}")
101
  else:
102
  st.warning("Please enter a URL.")
103
-
 
1
  import streamlit as st
2
+ import requests
3
+ from PIL import Image
4
+ from transformers import BlipProcessor, BlipForConditionalGeneration, pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
5
  import nltk
6
  from youtube_transcript_api import YouTubeTranscriptApi
7
 
8
  # Download NLTK data
9
  nltk.download('punkt')
10
 
11
+ # Initialize the image captioning processor and model
12
+ caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
13
+ caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
14
 
15
+ # Initialize the tokenizer and model for tag generation
16
+ tag_tokenizer = AutoTokenizer.from_pretrained("fabiochiu/t5-base-tag-generation")
17
+ tag_model = AutoModelForSeq2SeqLM.from_pretrained("fabiochiu/t5-base-tag-generation")
18
+
19
+ # Function to generate captions for an image
20
+ def generate_caption(img_url, text="a photography of"):
21
+ try:
22
+ raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
23
+ except Exception as e:
24
+ st.error(f"Error loading image: {e}")
25
+ return None, None
26
+
27
+ # Conditional image captioning
28
+ inputs_conditional = caption_processor(raw_image, text, return_tensors="pt")
29
+ out_conditional = caption_model.generate(**inputs_conditional)
30
+ caption_conditional = caption_processor.decode(out_conditional[0], skip_special_tokens=True)
31
+
32
+ # Unconditional image captioning
33
+ inputs_unconditional = caption_processor(raw_image, return_tensors="pt")
34
+ out_unconditional = caption_model.generate(**inputs_unconditional)
35
+ caption_unconditional = caption_processor.decode(out_unconditional[0], skip_special_tokens=True)
36
+
37
+ return caption_conditional, caption_unconditional
38
 
39
  # Function to fetch YouTube transcript
40
  def fetch_transcript(url):
 
57
  st.header("Image Captioning")
58
 
59
  # Input for image URL
60
+ img_url = st.text_input("Enter Image URL:")
61
 
62
  # If an image URL is provided
63
+ if st.button("Generate Captions", key='caption_button'):
64
+ if img_url:
65
+ caption_conditional, caption_unconditional = generate_caption(img_url)
66
+ if caption_conditional and caption_unconditional:
67
+ st.success("Captions successfully generated!")
68
+ st.image(img_url, caption="Input Image", use_column_width=True)
69
+ st.write("### Conditional Caption")
70
+ st.write(caption_conditional)
71
+ st.write("### Unconditional Caption")
72
+ st.write(caption_unconditional)
73
+ else:
74
+ st.warning("Please enter an image URL.")
 
75
 
76
  # Text Tag Generation Tab
77
  with tab2:
 
81
  text = st.text_area("Enter the text for tag extraction:", height=200)
82
 
83
  # Button to generate tags
84
+ if st.button("Generate Tags", key='tag_button'):
85
  if text:
86
  try:
87
  # Tokenize and encode the input text
88
+ inputs = tag_tokenizer([text], max_length=512, truncation=True, return_tensors="pt")
89
 
90
  # Generate tags
91
+ output = tag_model.generate(**inputs, num_beams=8, do_sample=True, min_length=10, max_length=64)
92
 
93
  # Decode the output
94
+ decoded_output = tag_tokenizer.batch_decode(output, skip_special_tokens=True)[0]
95
 
96
  # Extract unique tags
97
  tags = list(set(decoded_output.strip().split(", ")))
 
112
  youtube_url = st.text_input("Enter YouTube URL:")
113
 
114
  # Button to get transcript
115
+ if st.button("Get Transcript", key='transcript_button'):
116
  if youtube_url:
117
  transcript = fetch_transcript(youtube_url)
118
  if "error" not in transcript.lower():
 
122
  st.error(f"An error occurred: {transcript}")
123
  else:
124
  st.warning("Please enter a URL.")