diff --git "a/Speech_to_speech_translation.ipynb" "b/Speech_to_speech_translation.ipynb"
new file mode 100644--- /dev/null
+++ "b/Speech_to_speech_translation.ipynb"
@@ -0,0 +1,913 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "a8ede6GAyTtl"
+ },
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "n_5dfJzHyjeY"
+ },
+ "source": [
+ "Dengfeng's hands on lab expertice , STS test, Foucs on building a Speech to speech Translation demo"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "3ZICYUc1zkhK"
+ },
+ "source": [
+ "Step0 - Install all the prerequisites"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "aisYQvPBzd9V",
+ "outputId": "2f3d8eab-2452-40f1-914a-073afedd41f1"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (3.2.0)\n",
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets) (3.16.1)\n",
+ "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.26.4)\n",
+ "Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (17.0.0)\n",
+ "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.8)\n",
+ "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.2.2)\n",
+ "Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.32.3)\n",
+ "Requirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.67.1)\n",
+ "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.5.0)\n",
+ "Requirement already satisfied: multiprocess<0.70.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.16)\n",
+ "Requirement already satisfied: fsspec<=2024.9.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets) (2024.9.0)\n",
+ "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.11.10)\n",
+ "Requirement already satisfied: huggingface-hub>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.27.0)\n",
+ "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (24.2)\n",
+ "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.2)\n",
+ "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (2.4.4)\n",
+ "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.2)\n",
+ "Requirement already satisfied: async-timeout<6.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n",
+ "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (24.3.0)\n",
+ "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.5.0)\n",
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.1.0)\n",
+ "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (0.2.1)\n",
+ "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.18.3)\n",
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.23.0->datasets) (4.12.2)\n",
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.4.0)\n",
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.10)\n",
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2.2.3)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2024.12.14)\n",
+ "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n",
+ "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2)\n",
+ "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2)\n",
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.17.0)\n"
+ ]
+ }
+ ],
+ "source": [
+ "pip install datasets"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "kAH8S7npytQB"
+ },
+ "source": [
+ "Step1 - Load Whisper Base(74M)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "w9tvrqQoyb3m",
+ "outputId": "49279cef-6595-4d5c-d410-4808a9df8034"
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: \n",
+ "The secret `HF_TOKEN` does not exist in your Colab secrets.\n",
+ "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co./settings/tokens), set it as secret in your Google Colab and restart your session.\n",
+ "You will be able to reuse this secret in all of your notebooks.\n",
+ "Please note that authentication is recommended but still optional to access public models or datasets.\n",
+ " warnings.warn(\n",
+ "Device set to use cpu\n"
+ ]
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "from transformers import pipeline\n",
+ "\n",
+ "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
+ "pipe = pipeline(\n",
+ " \"automatic-speech-recognition\", model=\"openai/whisper-base\", device=device\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "gELxGFq_zJB2"
+ },
+ "source": [
+ "Step2 - Load Audio Sample in a non-english language, German"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "fUIwqY5LzWV4"
+ },
+ "outputs": [],
+ "source": [
+ "from datasets import load_dataset\n",
+ "\n",
+ "dataset = load_dataset(\"facebook/voxpopuli\", \"de\", split=\"validation\", streaming=True)\n",
+ "sample = next(iter(dataset))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 75
+ },
+ "id": "6RhxOdHF1HjT",
+ "outputId": "e3294f1c-a4e5-4928-e612-22cc8d2f71fe"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from IPython.display import Audio\n",
+ "\n",
+ "Audio(sample[\"audio\"][\"array\"], rate=sample[\"audio\"][\"sampling_rate\"])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "TnKEV0ty1rb0"
+ },
+ "source": [
+ "Step3 - Define a translate function ( from audio to text)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 163
+ },
+ "id": "bn47OSrE11Vo",
+ "outputId": "13923076-7199-4e89-e46c-117a455674bf"
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.10/dist-packages/transformers/pipelines/automatic_speech_recognition.py:312: FutureWarning: `max_new_tokens` is deprecated and will be removed in version 4.49 of Transformers. To remove this warning, pass `max_new_tokens` as a key inside `generate_kwargs` instead.\n",
+ " warnings.warn(\n",
+ "/usr/local/lib/python3.10/dist-packages/transformers/models/whisper/generation_whisper.py:512: FutureWarning: The input name `inputs` is deprecated. Please make sure to use `input_features` instead.\n",
+ " warnings.warn(\n",
+ "You have passed task=translate, but also have set `forced_decoder_ids` to [[1, None], [2, 50359]] which creates a conflict. `forced_decoder_ids` will be ignored in favor of task=translate.\n",
+ "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.google.colaboratory.intrinsic+json": {
+ "type": "string"
+ },
+ "text/plain": [
+ "' The second month is a new president in the field.'"
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "def translate(audio):\n",
+ " outputs = pipe(audio, max_new_tokens=256, generate_kwargs={\"task\": \"translate\"})\n",
+ " return outputs[\"text\"]\n",
+ "\n",
+ "translate(sample[\"audio\"].copy())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "g4zlq3aK3TYO"
+ },
+ "source": [
+ "Compare it with the source text"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 35
+ },
+ "id": "Vmid2fkR3PCS",
+ "outputId": "1db4e956-19bf-488d-83b3-fb0f114b8521"
+ },
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.google.colaboratory.intrinsic+json": {
+ "type": "string"
+ },
+ "text/plain": [
+ "'Und seit zwei Monaten ist ein neuer Präsident im Amt.'"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "sample[\"raw_text\"]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "H4hCOUES3Z1g"
+ },
+ "source": [
+ "Step4 - Text to Speech, TTS, Get the SpeechT5 TTS model to help"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "QFerwTkn3nN_"
+ },
+ "outputs": [],
+ "source": [
+ "from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan\n",
+ "\n",
+ "processor = SpeechT5Processor.from_pretrained(\"microsoft/speecht5_tts\")\n",
+ "\n",
+ "model = SpeechT5ForTextToSpeech.from_pretrained(\"microsoft/speecht5_tts\")\n",
+ "vocoder = SpeechT5HifiGan.from_pretrained(\"microsoft/speecht5_hifigan\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "X6yjka_G3xZk",
+ "outputId": "d6e73cb5-7bf9-463d-ebb3-75d77ffe3f69"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "SpeechT5HifiGan(\n",
+ " (conv_pre): Conv1d(80, 512, kernel_size=(7,), stride=(1,), padding=(3,))\n",
+ " (upsampler): ModuleList(\n",
+ " (0): ConvTranspose1d(512, 256, kernel_size=(8,), stride=(4,), padding=(2,))\n",
+ " (1): ConvTranspose1d(256, 128, kernel_size=(8,), stride=(4,), padding=(2,))\n",
+ " (2): ConvTranspose1d(128, 64, kernel_size=(8,), stride=(4,), padding=(2,))\n",
+ " (3): ConvTranspose1d(64, 32, kernel_size=(8,), stride=(4,), padding=(2,))\n",
+ " )\n",
+ " (resblocks): ModuleList(\n",
+ " (0): HifiGanResidualBlock(\n",
+ " (convs1): ModuleList(\n",
+ " (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n",
+ " (1): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(3,), dilation=(3,))\n",
+ " (2): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(5,), dilation=(5,))\n",
+ " )\n",
+ " (convs2): ModuleList(\n",
+ " (0-2): 3 x Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))\n",
+ " )\n",
+ " )\n",
+ " (1): HifiGanResidualBlock(\n",
+ " (convs1): ModuleList(\n",
+ " (0): Conv1d(256, 256, kernel_size=(7,), stride=(1,), padding=(3,))\n",
+ " (1): Conv1d(256, 256, kernel_size=(7,), stride=(1,), padding=(9,), dilation=(3,))\n",
+ " (2): Conv1d(256, 256, kernel_size=(7,), stride=(1,), padding=(15,), dilation=(5,))\n",
+ " )\n",
+ " (convs2): ModuleList(\n",
+ " (0-2): 3 x Conv1d(256, 256, kernel_size=(7,), stride=(1,), padding=(3,))\n",
+ " )\n",
+ " )\n",
+ " (2): HifiGanResidualBlock(\n",
+ " (convs1): ModuleList(\n",
+ " (0): Conv1d(256, 256, kernel_size=(11,), stride=(1,), padding=(5,))\n",
+ " (1): Conv1d(256, 256, kernel_size=(11,), stride=(1,), padding=(15,), dilation=(3,))\n",
+ " (2): Conv1d(256, 256, kernel_size=(11,), stride=(1,), padding=(25,), dilation=(5,))\n",
+ " )\n",
+ " (convs2): ModuleList(\n",
+ " (0-2): 3 x Conv1d(256, 256, kernel_size=(11,), stride=(1,), padding=(5,))\n",
+ " )\n",
+ " )\n",
+ " (3): HifiGanResidualBlock(\n",
+ " (convs1): ModuleList(\n",
+ " (0): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n",
+ " (1): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(3,), dilation=(3,))\n",
+ " (2): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(5,), dilation=(5,))\n",
+ " )\n",
+ " (convs2): ModuleList(\n",
+ " (0-2): 3 x Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))\n",
+ " )\n",
+ " )\n",
+ " (4): HifiGanResidualBlock(\n",
+ " (convs1): ModuleList(\n",
+ " (0): Conv1d(128, 128, kernel_size=(7,), stride=(1,), padding=(3,))\n",
+ " (1): Conv1d(128, 128, kernel_size=(7,), stride=(1,), padding=(9,), dilation=(3,))\n",
+ " (2): Conv1d(128, 128, kernel_size=(7,), stride=(1,), padding=(15,), dilation=(5,))\n",
+ " )\n",
+ " (convs2): ModuleList(\n",
+ " (0-2): 3 x Conv1d(128, 128, kernel_size=(7,), stride=(1,), padding=(3,))\n",
+ " )\n",
+ " )\n",
+ " (5): HifiGanResidualBlock(\n",
+ " (convs1): ModuleList(\n",
+ " (0): Conv1d(128, 128, kernel_size=(11,), stride=(1,), padding=(5,))\n",
+ " (1): Conv1d(128, 128, kernel_size=(11,), stride=(1,), padding=(15,), dilation=(3,))\n",
+ " (2): Conv1d(128, 128, kernel_size=(11,), stride=(1,), padding=(25,), dilation=(5,))\n",
+ " )\n",
+ " (convs2): ModuleList(\n",
+ " (0-2): 3 x Conv1d(128, 128, kernel_size=(11,), stride=(1,), padding=(5,))\n",
+ " )\n",
+ " )\n",
+ " (6): HifiGanResidualBlock(\n",
+ " (convs1): ModuleList(\n",
+ " (0): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n",
+ " (1): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(3,), dilation=(3,))\n",
+ " (2): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(5,), dilation=(5,))\n",
+ " )\n",
+ " (convs2): ModuleList(\n",
+ " (0-2): 3 x Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))\n",
+ " )\n",
+ " )\n",
+ " (7): HifiGanResidualBlock(\n",
+ " (convs1): ModuleList(\n",
+ " (0): Conv1d(64, 64, kernel_size=(7,), stride=(1,), padding=(3,))\n",
+ " (1): Conv1d(64, 64, kernel_size=(7,), stride=(1,), padding=(9,), dilation=(3,))\n",
+ " (2): Conv1d(64, 64, kernel_size=(7,), stride=(1,), padding=(15,), dilation=(5,))\n",
+ " )\n",
+ " (convs2): ModuleList(\n",
+ " (0-2): 3 x Conv1d(64, 64, kernel_size=(7,), stride=(1,), padding=(3,))\n",
+ " )\n",
+ " )\n",
+ " (8): HifiGanResidualBlock(\n",
+ " (convs1): ModuleList(\n",
+ " (0): Conv1d(64, 64, kernel_size=(11,), stride=(1,), padding=(5,))\n",
+ " (1): Conv1d(64, 64, kernel_size=(11,), stride=(1,), padding=(15,), dilation=(3,))\n",
+ " (2): Conv1d(64, 64, kernel_size=(11,), stride=(1,), padding=(25,), dilation=(5,))\n",
+ " )\n",
+ " (convs2): ModuleList(\n",
+ " (0-2): 3 x Conv1d(64, 64, kernel_size=(11,), stride=(1,), padding=(5,))\n",
+ " )\n",
+ " )\n",
+ " (9): HifiGanResidualBlock(\n",
+ " (convs1): ModuleList(\n",
+ " (0): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))\n",
+ " (1): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(3,), dilation=(3,))\n",
+ " (2): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(5,), dilation=(5,))\n",
+ " )\n",
+ " (convs2): ModuleList(\n",
+ " (0-2): 3 x Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))\n",
+ " )\n",
+ " )\n",
+ " (10): HifiGanResidualBlock(\n",
+ " (convs1): ModuleList(\n",
+ " (0): Conv1d(32, 32, kernel_size=(7,), stride=(1,), padding=(3,))\n",
+ " (1): Conv1d(32, 32, kernel_size=(7,), stride=(1,), padding=(9,), dilation=(3,))\n",
+ " (2): Conv1d(32, 32, kernel_size=(7,), stride=(1,), padding=(15,), dilation=(5,))\n",
+ " )\n",
+ " (convs2): ModuleList(\n",
+ " (0-2): 3 x Conv1d(32, 32, kernel_size=(7,), stride=(1,), padding=(3,))\n",
+ " )\n",
+ " )\n",
+ " (11): HifiGanResidualBlock(\n",
+ " (convs1): ModuleList(\n",
+ " (0): Conv1d(32, 32, kernel_size=(11,), stride=(1,), padding=(5,))\n",
+ " (1): Conv1d(32, 32, kernel_size=(11,), stride=(1,), padding=(15,), dilation=(3,))\n",
+ " (2): Conv1d(32, 32, kernel_size=(11,), stride=(1,), padding=(25,), dilation=(5,))\n",
+ " )\n",
+ " (convs2): ModuleList(\n",
+ " (0-2): 3 x Conv1d(32, 32, kernel_size=(11,), stride=(1,), padding=(5,))\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (conv_post): Conv1d(32, 1, kernel_size=(7,), stride=(1,), padding=(3,))\n",
+ ")"
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model.to(device)\n",
+ "vocoder.to(device)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "SDYLMEor31PJ"
+ },
+ "outputs": [],
+ "source": [
+ "embeddings_dataset = load_dataset(\"Matthijs/cmu-arctic-xvectors\", split=\"validation\")\n",
+ "speaker_embeddings = torch.tensor(embeddings_dataset[7306][\"xvector\"]).unsqueeze(0)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "_GRUnpFR36Tl"
+ },
+ "source": [
+ "Step4.1 - Take the text prompt as input, pre-processing"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "KOzGIeAg4FCD"
+ },
+ "outputs": [],
+ "source": [
+ "def synthesise(text):\n",
+ " inputs = processor(text=text, return_tensors=\"pt\")\n",
+ " speech = model.generate_speech(\n",
+ " inputs[\"input_ids\"].to(device), speaker_embeddings.to(device), vocoder=vocoder\n",
+ " )\n",
+ " return speech.cpu()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "tjEd8vJc4Sn2"
+ },
+ "source": [
+ "Get some test of a dummy text input."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 75
+ },
+ "id": "W2hOO_654I6n",
+ "outputId": "5234c7ef-359e-4626-ead7-3c4f3a6f2d6e"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "speech = synthesise(\"Hey there! This is a test!\")\n",
+ "\n",
+ "Audio(speech, rate=16000)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "nHLMf8va4dAO"
+ },
+ "source": [
+ "Step5 - Go to Demo, test the audio function firstly"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "SXhEozPs4frD"
+ },
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "\n",
+ "target_dtype = np.int16\n",
+ "max_range = np.iinfo(target_dtype).max\n",
+ "\n",
+ "\n",
+ "def speech_to_speech_translation(audio):\n",
+ " translated_text = translate(audio)\n",
+ " synthesised_speech = synthesise(translated_text)\n",
+ " synthesised_speech = (synthesised_speech.numpy() * max_range).astype(np.int16)\n",
+ " return 16000, synthesised_speech"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 75
+ },
+ "id": "JRCwJIc14i4D",
+ "outputId": "8000e100-f85c-47fc-df15-fee2aab91b9e"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 18,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "sampling_rate, synthesised_speech = speech_to_speech_translation(sample[\"audio\"])\n",
+ "\n",
+ "Audio(synthesised_speech, rate=sampling_rate)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "tqtwS7ml41av"
+ },
+ "source": [
+ "Step6 - Gradio demo"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "ZkKjUwlk5FJA",
+ "outputId": "b789b0b0-f685-4958-9b0b-18bec1fd4fb3"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Requirement already satisfied: gradio in /usr/local/lib/python3.10/dist-packages (5.10.0)\n",
+ "Requirement already satisfied: aiofiles<24.0,>=22.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (23.2.1)\n",
+ "Requirement already satisfied: anyio<5.0,>=3.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (3.7.1)\n",
+ "Requirement already satisfied: fastapi<1.0,>=0.115.2 in /usr/local/lib/python3.10/dist-packages (from gradio) (0.115.6)\n",
+ "Requirement already satisfied: ffmpy in /usr/local/lib/python3.10/dist-packages (from gradio) (0.5.0)\n",
+ "Requirement already satisfied: gradio-client==1.5.3 in /usr/local/lib/python3.10/dist-packages (from gradio) (1.5.3)\n",
+ "Requirement already satisfied: httpx>=0.24.1 in /usr/local/lib/python3.10/dist-packages (from gradio) (0.28.1)\n",
+ "Requirement already satisfied: huggingface-hub>=0.25.1 in /usr/local/lib/python3.10/dist-packages (from gradio) (0.27.0)\n",
+ "Requirement already satisfied: jinja2<4.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (3.1.4)\n",
+ "Requirement already satisfied: markupsafe~=2.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (2.1.5)\n",
+ "Requirement already satisfied: numpy<3.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (1.26.4)\n",
+ "Requirement already satisfied: orjson~=3.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (3.10.12)\n",
+ "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from gradio) (24.2)\n",
+ "Requirement already satisfied: pandas<3.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (2.2.2)\n",
+ "Requirement already satisfied: pillow<12.0,>=8.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (11.0.0)\n",
+ "Requirement already satisfied: pydantic>=2.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (2.10.3)\n",
+ "Requirement already satisfied: pydub in /usr/local/lib/python3.10/dist-packages (from gradio) (0.25.1)\n",
+ "Requirement already satisfied: python-multipart>=0.0.18 in /usr/local/lib/python3.10/dist-packages (from gradio) (0.0.20)\n",
+ "Requirement already satisfied: pyyaml<7.0,>=5.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (6.0.2)\n",
+ "Requirement already satisfied: ruff>=0.2.2 in /usr/local/lib/python3.10/dist-packages (from gradio) (0.8.6)\n",
+ "Requirement already satisfied: safehttpx<0.2.0,>=0.1.6 in /usr/local/lib/python3.10/dist-packages (from gradio) (0.1.6)\n",
+ "Requirement already satisfied: semantic-version~=2.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (2.10.0)\n",
+ "Requirement already satisfied: starlette<1.0,>=0.40.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (0.41.3)\n",
+ "Requirement already satisfied: tomlkit<0.14.0,>=0.12.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (0.13.2)\n",
+ "Requirement already satisfied: typer<1.0,>=0.12 in /usr/local/lib/python3.10/dist-packages (from gradio) (0.15.1)\n",
+ "Requirement already satisfied: typing-extensions~=4.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (4.12.2)\n",
+ "Requirement already satisfied: uvicorn>=0.14.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (0.34.0)\n",
+ "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from gradio-client==1.5.3->gradio) (2024.9.0)\n",
+ "Requirement already satisfied: websockets<15.0,>=10.0 in /usr/local/lib/python3.10/dist-packages (from gradio-client==1.5.3->gradio) (14.1)\n",
+ "Requirement already satisfied: idna>=2.8 in /usr/local/lib/python3.10/dist-packages (from anyio<5.0,>=3.0->gradio) (3.10)\n",
+ "Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.10/dist-packages (from anyio<5.0,>=3.0->gradio) (1.3.1)\n",
+ "Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5.0,>=3.0->gradio) (1.2.2)\n",
+ "Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from httpx>=0.24.1->gradio) (2024.12.14)\n",
+ "Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.10/dist-packages (from httpx>=0.24.1->gradio) (1.0.7)\n",
+ "Requirement already satisfied: h11<0.15,>=0.13 in /usr/local/lib/python3.10/dist-packages (from httpcore==1.*->httpx>=0.24.1->gradio) (0.14.0)\n",
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.25.1->gradio) (3.16.1)\n",
+ "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.25.1->gradio) (2.32.3)\n",
+ "Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.25.1->gradio) (4.67.1)\n",
+ "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas<3.0,>=1.0->gradio) (2.8.2)\n",
+ "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas<3.0,>=1.0->gradio) (2024.2)\n",
+ "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas<3.0,>=1.0->gradio) (2024.2)\n",
+ "Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.10/dist-packages (from pydantic>=2.0->gradio) (0.7.0)\n",
+ "Requirement already satisfied: pydantic-core==2.27.1 in /usr/local/lib/python3.10/dist-packages (from pydantic>=2.0->gradio) (2.27.1)\n",
+ "Requirement already satisfied: click>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from typer<1.0,>=0.12->gradio) (8.1.7)\n",
+ "Requirement already satisfied: shellingham>=1.3.0 in /usr/local/lib/python3.10/dist-packages (from typer<1.0,>=0.12->gradio) (1.5.4)\n",
+ "Requirement already satisfied: rich>=10.11.0 in /usr/local/lib/python3.10/dist-packages (from typer<1.0,>=0.12->gradio) (13.9.4)\n",
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas<3.0,>=1.0->gradio) (1.17.0)\n",
+ "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich>=10.11.0->typer<1.0,>=0.12->gradio) (3.0.0)\n",
+ "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich>=10.11.0->typer<1.0,>=0.12->gradio) (2.18.0)\n",
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.25.1->gradio) (3.4.0)\n",
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.25.1->gradio) (2.2.3)\n",
+ "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich>=10.11.0->typer<1.0,>=0.12->gradio) (0.1.2)\n"
+ ]
+ }
+ ],
+ "source": [
+ "pip install gradio"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 1000
+ },
+ "id": "bigLX2Tz44Hb",
+ "outputId": "2312e18c-a787-4c79-a839-602f283ca4ce"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Running Gradio in a Colab notebook requires sharing enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).\n",
+ "\n",
+ "Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().\n",
+ "* Running on public URL: https://a03d3b6dae250fdf76.gradio.live\n",
+ "\n",
+ "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co./spaces)\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ ""
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.10/dist-packages/transformers/pipelines/automatic_speech_recognition.py:312: FutureWarning: `max_new_tokens` is deprecated and will be removed in version 4.49 of Transformers. To remove this warning, pass `max_new_tokens` as a key inside `generate_kwargs` instead.\n",
+ " warnings.warn(\n",
+ "Traceback (most recent call last):\n",
+ " File \"/usr/local/lib/python3.10/dist-packages/gradio/queueing.py\", line 625, in process_events\n",
+ " response = await route_utils.call_process_api(\n",
+ " File \"/usr/local/lib/python3.10/dist-packages/gradio/route_utils.py\", line 322, in call_process_api\n",
+ " output = await app.get_blocks().process_api(\n",
+ " File \"/usr/local/lib/python3.10/dist-packages/gradio/blocks.py\", line 2045, in process_api\n",
+ " result = await self.call_function(\n",
+ " File \"/usr/local/lib/python3.10/dist-packages/gradio/blocks.py\", line 1592, in call_function\n",
+ " prediction = await anyio.to_thread.run_sync( # type: ignore\n",
+ " File \"/usr/local/lib/python3.10/dist-packages/anyio/to_thread.py\", line 33, in run_sync\n",
+ " return await get_asynclib().run_sync_in_worker_thread(\n",
+ " File \"/usr/local/lib/python3.10/dist-packages/anyio/_backends/_asyncio.py\", line 877, in run_sync_in_worker_thread\n",
+ " return await future\n",
+ " File \"/usr/local/lib/python3.10/dist-packages/anyio/_backends/_asyncio.py\", line 807, in run\n",
+ " result = context.run(func, *args)\n",
+ " File \"/usr/local/lib/python3.10/dist-packages/gradio/utils.py\", line 870, in wrapper\n",
+ " response = f(*args, **kwargs)\n",
+ " File \"\", line 8, in speech_to_speech_translation\n",
+ " translated_text = translate(audio)\n",
+ " File \"\", line 2, in translate\n",
+ " outputs = pipe(audio, max_new_tokens=256, generate_kwargs={\"task\": \"translate\"})\n",
+ " File \"/usr/local/lib/python3.10/dist-packages/transformers/pipelines/automatic_speech_recognition.py\", line 283, in __call__\n",
+ " return super().__call__(inputs, **kwargs)\n",
+ " File \"/usr/local/lib/python3.10/dist-packages/transformers/pipelines/base.py\", line 1293, in __call__\n",
+ " return next(\n",
+ " File \"/usr/local/lib/python3.10/dist-packages/transformers/pipelines/pt_utils.py\", line 124, in __next__\n",
+ " item = next(self.iterator)\n",
+ " File \"/usr/local/lib/python3.10/dist-packages/transformers/pipelines/pt_utils.py\", line 269, in __next__\n",
+ " processed = self.infer(next(self.iterator), **self.params)\n",
+ " File \"/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py\", line 701, in __next__\n",
+ " data = self._next_data()\n",
+ " File \"/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py\", line 757, in _next_data\n",
+ " data = self._dataset_fetcher.fetch(index) # may raise StopIteration\n",
+ " File \"/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py\", line 33, in fetch\n",
+ " data.append(next(self.dataset_iter))\n",
+ " File \"/usr/local/lib/python3.10/dist-packages/transformers/pipelines/pt_utils.py\", line 186, in __next__\n",
+ " processed = next(self.subiterator)\n",
+ " File \"/usr/local/lib/python3.10/dist-packages/transformers/pipelines/automatic_speech_recognition.py\", line 412, in preprocess\n",
+ " raise TypeError(f\"We expect a numpy ndarray as input, got `{type(inputs)}`\")\n",
+ "TypeError: We expect a numpy ndarray as input, got ``\n",
+ "Traceback (most recent call last):\n",
+ " File \"/usr/local/lib/python3.10/dist-packages/gradio/queueing.py\", line 625, in process_events\n",
+ " response = await route_utils.call_process_api(\n",
+ " File \"/usr/local/lib/python3.10/dist-packages/gradio/route_utils.py\", line 322, in call_process_api\n",
+ " output = await app.get_blocks().process_api(\n",
+ " File \"/usr/local/lib/python3.10/dist-packages/gradio/blocks.py\", line 2045, in process_api\n",
+ " result = await self.call_function(\n",
+ " File \"/usr/local/lib/python3.10/dist-packages/gradio/blocks.py\", line 1592, in call_function\n",
+ " prediction = await anyio.to_thread.run_sync( # type: ignore\n",
+ " File \"/usr/local/lib/python3.10/dist-packages/anyio/to_thread.py\", line 33, in run_sync\n",
+ " return await get_asynclib().run_sync_in_worker_thread(\n",
+ " File \"/usr/local/lib/python3.10/dist-packages/anyio/_backends/_asyncio.py\", line 877, in run_sync_in_worker_thread\n",
+ " return await future\n",
+ " File \"/usr/local/lib/python3.10/dist-packages/anyio/_backends/_asyncio.py\", line 807, in run\n",
+ " result = context.run(func, *args)\n",
+ " File \"/usr/local/lib/python3.10/dist-packages/gradio/utils.py\", line 870, in wrapper\n",
+ " response = f(*args, **kwargs)\n",
+ " File \"\", line 8, in speech_to_speech_translation\n",
+ " translated_text = translate(audio)\n",
+ " File \"\", line 2, in translate\n",
+ " outputs = pipe(audio, max_new_tokens=256, generate_kwargs={\"task\": \"translate\"})\n",
+ " File \"/usr/local/lib/python3.10/dist-packages/transformers/pipelines/automatic_speech_recognition.py\", line 283, in __call__\n",
+ " return super().__call__(inputs, **kwargs)\n",
+ " File \"/usr/local/lib/python3.10/dist-packages/transformers/pipelines/base.py\", line 1293, in __call__\n",
+ " return next(\n",
+ " File \"/usr/local/lib/python3.10/dist-packages/transformers/pipelines/pt_utils.py\", line 124, in __next__\n",
+ " item = next(self.iterator)\n",
+ " File \"/usr/local/lib/python3.10/dist-packages/transformers/pipelines/pt_utils.py\", line 269, in __next__\n",
+ " processed = self.infer(next(self.iterator), **self.params)\n",
+ " File \"/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py\", line 701, in __next__\n",
+ " data = self._next_data()\n",
+ " File \"/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py\", line 757, in _next_data\n",
+ " data = self._dataset_fetcher.fetch(index) # may raise StopIteration\n",
+ " File \"/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py\", line 33, in fetch\n",
+ " data.append(next(self.dataset_iter))\n",
+ " File \"/usr/local/lib/python3.10/dist-packages/transformers/pipelines/pt_utils.py\", line 186, in __next__\n",
+ " processed = next(self.subiterator)\n",
+ " File \"/usr/local/lib/python3.10/dist-packages/transformers/pipelines/automatic_speech_recognition.py\", line 412, in preprocess\n",
+ " raise TypeError(f\"We expect a numpy ndarray as input, got `{type(inputs)}`\")\n",
+ "TypeError: We expect a numpy ndarray as input, got ``\n",
+ "/usr/local/lib/python3.10/dist-packages/transformers/models/whisper/generation_whisper.py:512: FutureWarning: The input name `inputs` is deprecated. Please make sure to use `input_features` instead.\n",
+ " warnings.warn(\n",
+ "/usr/local/lib/python3.10/dist-packages/transformers/pipelines/automatic_speech_recognition.py:312: FutureWarning: `max_new_tokens` is deprecated and will be removed in version 4.49 of Transformers. To remove this warning, pass `max_new_tokens` as a key inside `generate_kwargs` instead.\n",
+ " warnings.warn(\n",
+ "/usr/local/lib/python3.10/dist-packages/transformers/models/whisper/generation_whisper.py:512: FutureWarning: The input name `inputs` is deprecated. Please make sure to use `input_features` instead.\n",
+ " warnings.warn(\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Keyboard interruption in main thread... closing server.\n",
+ "Killing tunnel 127.0.0.1:7860 <> https://a03d3b6dae250fdf76.gradio.live\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": []
+ },
+ "execution_count": 21,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import gradio as gr\n",
+ "\n",
+ "demo = gr.Blocks()\n",
+ "\n",
+ "mic_translate = gr.Interface(\n",
+ " fn=speech_to_speech_translation,\n",
+ " inputs=gr.Audio(sources=\"microphone\", type=\"filepath\"),\n",
+ " outputs=gr.Audio(label=\"Generated Speech\", type=\"numpy\"),\n",
+ ")\n",
+ "\n",
+ "file_translate = gr.Interface(\n",
+ " fn=speech_to_speech_translation,\n",
+ " inputs=gr.Audio(sources=\"upload\", type=\"filepath\"),\n",
+ " outputs=gr.Audio(label=\"Generated Speech\", type=\"numpy\"),\n",
+ ")\n",
+ "\n",
+ "with demo:\n",
+ " gr.TabbedInterface([mic_translate, file_translate], [\"Microphone\", \"Audio File\"])\n",
+ "\n",
+ "demo.launch(debug=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "CuoE14ma8RFc"
+ },
+ "source": [
+ "Final test we use a german audio input (Und seit zwei Monaten ist ein neuer Präsident im Amt.) , and get a english audio output (And a new president has been in office for two months.), it worked!"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "JNIwhWQ28dMO"
+ },
+ "source": [
+ "Summary:\n",
+ "\n",
+ "Build speech to speech translation function as following:\n",
+ "- tanslate audio to text (translated_text = translate(audio))\n",
+ "- translate text to target audio (synthesised_speech = synthesise(translated_text))\n",
+ "\n",
+ "to build synthesise:\n",
+ "- using model to drive\n",
+ "\n",
+ "then at last import gradio to visualize\n",
+ "\n",
+ "\n",
+ "\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}