import gradio as gr import huggingface_hub import os import subprocess import threading import shutil import numpy as np import matplotlib.pyplot as plt from scipy.io import wavfile from moviepy.editor import VideoFileClip, AudioFileClip # download model huggingface_hub.snapshot_download( repo_id='ariesssxu/vta-ldm-clip4clip-v-large', local_dir='./ckpt/vta-ldm-clip4clip-v-large' ) def stream_output(pipe): for line in iter(pipe.readline, ''): print(line, end='') def print_directory_contents(path): for root, dirs, files in os.walk(path): level = root.replace(path, '').count(os.sep) indent = ' ' * 4 * (level) print(f"{indent}{os.path.basename(root)}/") subindent = ' ' * 4 * (level + 1) for f in files: print(f"{subindent}{f}") # Print the ckpt directory contents print_directory_contents('./ckpt') def get_wav_files(path): wav_files = [] # Initialize an empty list to store the paths of .wav files for root, dirs, files in os.walk(path): level = root.replace(path, '').count(os.sep) indent = ' ' * 4 * (level) print(f"{indent}{os.path.basename(root)}/") subindent = ' ' * 4 * (level + 1) for f in files: file_path = os.path.join(root, f) if f.lower().endswith('.wav'): wav_files.append(file_path) # Add .wav file paths to the list print(f"{subindent}{file_path}") else: print(f"{subindent}{f}") return wav_files # Return the list of .wav file paths def check_outputs_folder(folder_path): # Check if the folder exists if os.path.exists(folder_path) and os.path.isdir(folder_path): # Delete all contents inside the folder for filename in os.listdir(folder_path): file_path = os.path.join(folder_path, filename) try: if os.path.isfile(file_path) or os.path.islink(file_path): os.unlink(file_path) # Remove file or link elif os.path.isdir(file_path): shutil.rmtree(file_path) # Remove directory except Exception as e: print(f'Failed to delete {file_path}. Reason: {e}') else: print(f'The folder {folder_path} does not exist.') def plot_spectrogram(wav_file, output_image): # Read the WAV file sample_rate, audio_data = wavfile.read(wav_file) # Check if audio_data is stereo (2 channels) and convert it to mono (1 channel) if needed if len(audio_data.shape) == 2: audio_data = audio_data.mean(axis=1) # Create a plot for the spectrogram plt.figure(figsize=(10, 2)) plt.specgram(audio_data, Fs=sample_rate, NFFT=1024, noverlap=512, cmap='gray', aspect='auto') # Remove gridlines and ticks for a cleaner look plt.grid(False) plt.xticks([]) plt.yticks([]) # Save the plot as an image file plt.savefig(output_image, bbox_inches='tight', pad_inches=0, dpi=300) plt.close def merge_audio_to_video(input_vid, input_aud): # Load the video file video = VideoFileClip(input_vid) # Load the new audio file new_audio = AudioFileClip(input_aud) # Set the new audio to the video video_with_new_audio = video.set_audio(new_audio) # Save the result to a new file video_with_new_audio.write_videofile("output_video.mp4", codec='libx264', audio_codec='aac') return "output_video.mp4" def infer(video_in): # check if 'outputs' dir exists and empty it if necessary check_outputs_folder('./outputs/tmp') # Need to find path to gradio temp vid from video input print(f"VIDEO IN PATH: {video_in}") # Get the directory name folder_path = os.path.dirname(video_in) # Path to the input video file input_video_path = video_in # Load the video file video = VideoFileClip(input_video_path) # Get the length of the video in seconds video_duration = int(video.duration) print(f"Video duration: {video_duration} seconds") # Check if the video duration is more than 10 seconds if video_duration > 10: # Cut the video to the first 10 seconds cut_video = video.subclip(0, 10) video_duration = 10 # Extract the directory and filename dir_name = os.path.dirname(input_video_path) base_name = os.path.basename(input_video_path) # Generate the new filename new_base_name = base_name.replace(".mp4", "_10sec_cut.mp4") output_video_path = os.path.join(dir_name, new_base_name) # Save the cut video cut_video.write_videofile(output_video_path, codec='libx264', audio_codec='aac') print(f"Cut video saved as: {output_video_path}") video_in = output_video_path # Delete the original video file os.remove(input_video_path) print(f"Original video file {input_video_path} deleted.") else: print("Video is 10 seconds or shorter; no cutting needed.") # Execute the inference command command = ['python', 'inference_from_video.py', '--original_args', 'ckpt/vta-ldm-clip4clip-v-large/summary.jsonl', '--model', 'ckpt/vta-ldm-clip4clip-v-large/pytorch_model_2.bin', '--data_path', folder_path, '--max_duration', f"{video_duration}" ] process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1) # Create threads to handle stdout and stderr stdout_thread = threading.Thread(target=stream_output, args=(process.stdout,)) stderr_thread = threading.Thread(target=stream_output, args=(process.stderr,)) # Start the threads stdout_thread.start() stderr_thread.start() # Wait for the process to complete and the threads to finish process.wait() stdout_thread.join() stderr_thread.join() print("Inference script finished with return code:", process.returncode) # Need to find where are the results stored, default should be "./outputs/tmp" # Print the outputs directory contents print_directory_contents('./outputs/tmp') wave_files = get_wav_files('./outputs/tmp') print(wave_files) plot_spectrogram(wave_files[0], 'spectrogram.png') final_merged_out = merge_audio_to_video(video_in, wave_files[0]) return wave_files[0], 'spectrogram.png', final_merged_out css = """ #col-container{ max-width: 800px; margin: 0 auto; } """ with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown("# Video-to-Audio Generation with Hidden Alignment") gr.HTML("""
""") with gr.Column(): video_in = gr.Video(label='Video IN') submit_btn = gr.Button("Submit") output_sound = gr.Audio(label="Audio OUT") output_spectrogram = gr.Image(label='Spectrogram') merged_out = gr.Video(label="Merged video + generated audio") submit_btn.click( fn = infer, inputs = [video_in], outputs = [output_sound, output_spectrogram, merged_out], show_api = False ) demo.launch(show_api=False, show_error=True)