import json import os import warnings import gradio as gr import librosa import numpy as np from datasets import IterableDatasetDict, load_dataset from gradio_client import Client from loguru import logger warnings.filterwarnings("ignore") NUM_TAR_FILES = 115 HF_PATH_TO_DATASET = "litagin/Galgame_Speech_SER_16kHz" hf_token = os.getenv("HF_TOKEN") client = Client("litagin/ser_record", hf_token=hf_token) id2label = { 0: "Angry", 1: "Disgusted", 2: "Embarrassed", 3: "Fearful", 4: "Happy", 5: "Sad", 6: "Surprised", 7: "Neutral", 8: "Sexual1", 9: "Sexual2", } id2rich_label = { 0: "😠 æ€’ă‚Š (0)", 1: "😒 ć«Œæ‚Ș (1)", 2: "😳 æ„ăšă‹ă—ă•ăƒ»æˆžæƒ‘ă„ (2)", 3: "😹 恐怖 (3)", 4: "😊 ćčžă› (4)", 5: "😱 æ‚Čしみ (5)", 6: "đŸ˜Č é©šă (6)", 7: "😐 äž­ç«‹ (7)", 8: "đŸ„° NSFW1 (8)", 9: "🍭 NSFW2 (9)", } current_item: dict | None = None def _load_dataset( *, streaming: bool = True, use_local_dataset: bool = False, local_dataset_path: str | None = None, data_dir: str = "data", ) -> IterableDatasetDict: data_files = { "train": [ f"galgame-speech-ser-16kHz-train-000{index:03d}.tar" for index in range(0, NUM_TAR_FILES) ], } if use_local_dataset: assert local_dataset_path is not None path = local_dataset_path else: path = HF_PATH_TO_DATASET dataset: IterableDatasetDict = load_dataset( path=path, data_dir=data_dir, data_files=data_files, streaming=streaming ) # type: ignore dataset = dataset.remove_columns(["__url__"]) dataset = dataset.rename_column("ogg", "audio") return dataset logger.info("Start loading dataset") ds = _load_dataset(streaming=True, use_local_dataset=False) logger.info("Dataset loaded") # seed = random.randint(0, 2**32 - 1) # logger.info(f"Seed: {seed}") # ds_iter = iter(ds["train"].shuffle(seed=seed)) ds_iter = iter(ds["train"]) shortcut_js = """ """ def modify_speed( data: tuple[int, np.ndarray], speed: float = 1.0 ) -> tuple[int, np.ndarray]: if speed == 1.0: return data sr, array = data return sr, librosa.effects.time_stretch(array, rate=speed) def parse_item(item, speed: float = 1.0) -> dict: label_id = item["cls"] sampling_rate = item["audio"]["sampling_rate"] array = item["audio"]["array"] return { "key": item["__key__"], "audio": (sampling_rate, array), "text": item["txt"], "label": id2rich_label[label_id], "label_id": label_id, } def get_next_parsed_item(speed: float = 1.0) -> dict: logger.info("Getting next item") next_item = next(ds_iter) parsed = parse_item(next_item, speed=speed) logger.info( f"Next item:\nkey={parsed['key']}\ntext={parsed['text']}\nlabel={parsed['label']}" ) return parsed md = """ # èȘŹæ˜Ž - こぼケプăƒȘは、ă‚ČăƒŒăƒ ăźă‚»ăƒȘăƒ•ă‚’æ„Ÿæƒ…ăƒ©ăƒ™ăƒ«ä»˜ă‘ă—ăŠă€ć€§èŠæšĄăȘæ„Ÿæƒ…éŸłćŁ°ăƒ‡ăƒŒă‚żă‚»ăƒƒăƒˆă‚’äœœæˆă™ă‚‹ăŸă‚ăźă‚‚ăźă§ă™ - **性的ăȘéŸłćŁ°ăŒć«ăŸă‚Œă‚‹ăŸă‚ă€18æ­łæœȘæș€ăźæ–čăŻă”ćˆ©ç”šă‚’ăŠæŽ§ăˆăă ă•ă„** - æ—ąć­˜ăźăƒ©ăƒ™ăƒ«ăŒé©ćˆ‡ă§ă‚ă‚Œă°ă€ăăźăŸăŸă€ŒçŸćœšăźæ„Ÿæƒ…ăƒ©ăƒ™ăƒ«ă§é©ćˆ‡ă€ăƒœă‚żăƒłă‚’æŠŒă—ăŠăă ă•ă„ - ăƒ©ăƒ™ăƒ«ă‚’äżźæ­Łă™ă‚‹ć ŽćˆăŻă€é©ćˆ‡ăȘăƒœă‚żăƒłă‚’æŠŒă—ăŠăă ă•ă„ - ă‚·ăƒ§ăƒŒăƒˆă‚«ăƒƒăƒˆă‚­ăƒŒïŒˆă‚«ăƒƒă‚łć†…ïŒ‰ă‚’äœżă†ă“ăšă‚‚ă§ăăŸă™ # èŁœè¶ł - `đŸ„° NSFW1` ăŻć„łæ€§ăźæ€§çš„èĄŒç‚șäž­ăźéŸłćŁ°ïŒˆć–˜ăŽćŁ°ç­‰ïŒ‰ - `🍭 NSFW2` はキă‚čă‚·ăƒŒăƒłă§ăźăƒȘăƒƒăƒ—éŸłă‚„ăƒ•ă‚§ăƒ©ă‚·ăƒŒăƒłă§ăźă—ă‚ƒă¶ă‚‹éŸłïŒˆăƒăƒ„ăƒ‘éŸłïŒ‰ă‚’èĄšă—ăŸă™ - æ„Ÿæƒ…ăŒéŸłćŁ°ă‹ă‚‰ăŻç‰čにèȘ­ăżć–ă‚ŒăȘい栮搈は `😐 äž­ç«‹` ă‚’éžæŠžă—ăŠăă ă•ă„ """ with gr.Blocks(head=shortcut_js) as app: gr.Markdown(md) with gr.Row(): with gr.Column(): btn_init = gr.Button("ćˆæœŸćŒ–ăƒ»ć†èȘ­ăżèŸŒăż") speed = gr.Slider( minimum=0.5, maximum=5.0, step=0.1, value=1.0, label="憍生速ćșŠ" ) with gr.Column(variant="panel"): key = gr.Textbox(label="Key") audio = gr.Audio() text = gr.Textbox(label="Text") label = gr.Textbox(label="æ„Ÿæƒ…ăƒ©ăƒ™ăƒ«") label_id = gr.Textbox(visible=False) btn_skip = gr.Button("çŸćœšăźæ„Ÿæƒ…ăƒ©ăƒ™ăƒ«ă§é©ćˆ‡ (Enter)", elem_id="btn_skip") with gr.Column(): gr.Markdown("# æ„Ÿæƒ…ăƒ©ăƒ™ăƒ«ă‚’äżźæ­Łă™ă‚‹ć Žćˆ") btn_list = [ gr.Button(id2rich_label[_id], elem_id=f"btn_{_id}") for _id in range(10) ] def update_current_item(data: dict) -> dict: global current_item if current_item is None: speed_value = data[speed] current_item = get_next_parsed_item(speed=speed_value) modified_audio = modify_speed(current_item["audio"], speed=data[speed]) return { key: current_item["key"], audio: gr.Audio(modified_audio, autoplay=True), text: current_item["text"], label: current_item["label"], label_id: current_item["label_id"], } def set_next_item(data: dict) -> dict: global current_item speed_value = data[speed] current_item = get_next_parsed_item(speed=speed_value) return update_current_item(data) def put_unmodified(data: dict) -> dict: logger.info("Putting unmodified") current_key = data[key] current_label_id = data[label_id] _ = client.predict( new_data=json.dumps( { "key": current_key, "cls": int(current_label_id), } ), api_name="/put_data", ) logger.info("Unmodified sent") return set_next_item(data) btn_init.click( update_current_item, inputs={speed}, outputs=[key, audio, text, label, label_id] ) btn_skip.click( put_unmodified, inputs={key, label_id, speed}, outputs=[key, audio, text, label, label_id], ) functions_list = [] for _id in range(10): def put_label(data: dict, _id=_id) -> dict: logger.info(f"Putting label: {id2rich_label[_id]}") current_key = data[key] _ = client.predict( new_data=json.dumps( { "key": current_key, "cls": _id, } ), api_name="/put_data", ) logger.info("Modified sent") return set_next_item(data) functions_list.append(put_label) for _id in range(10): btn_list[_id].click( functions_list[_id], inputs={key, speed}, outputs=[key, audio, text, label, label_id], ) app.launch()