import gradio as gr import polars as pl from gradio_huggingfacehub_search import HuggingfaceHubSearch import torch import spaces from torch import nn from transformers import AutoModel, AutoTokenizer, AutoConfig from huggingface_hub import PyTorchModelHubMixin import pandas as pd class QualityModel(nn.Module, PyTorchModelHubMixin): def __init__(self, config): super(QualityModel, self).__init__() self.model = AutoModel.from_pretrained(config["base_model"]) self.dropout = nn.Dropout(config["fc_dropout"]) self.fc = nn.Linear(self.model.config.hidden_size, len(config["id2label"])) def forward(self, input_ids, attention_mask): features = self.model( input_ids=input_ids, attention_mask=attention_mask ).last_hidden_state dropped = self.dropout(features) outputs = self.fc(dropped) return torch.softmax(outputs[:, 0, :], dim=1) device = "cuda" if torch.cuda.is_available() else "cpu" config = AutoConfig.from_pretrained("nvidia/quality-classifier-deberta") tokenizer = AutoTokenizer.from_pretrained("nvidia/quality-classifier-deberta") model = QualityModel.from_pretrained("nvidia/quality-classifier-deberta").to(device) model.eval() @spaces.GPU def predict(texts: list[str]): inputs = tokenizer( texts, return_tensors="pt", padding="longest", truncation=True ).to(device) outputs = model(inputs["input_ids"], inputs["attention_mask"]) predicted_classes = torch.argmax(outputs, dim=1) predicted_domains = [ config.id2label[class_idx.item()] for class_idx in predicted_classes.cpu().numpy() ] return predicted_domains def run_quality_check(dataset, column, n_samples): config = "default" data = pl.read_parquet(f"hf://datasets/{dataset}@~parquet/{config}/train/0000.parquet", columns=[column]) texts = data[column].to_list() predictions = predict(texts[:n_samples]) counts = pd.DataFrame({"quality": predictions}).value_counts().to_frame() counts.reset_index(inplace=True) return gr.BarPlot(counts, x="quality", y="count") with gr.Blocks() as demo: gr.Markdown("# 💫 Dataset Quality Checker 💫") dataset_name = HuggingfaceHubSearch( label="Hub Dataset ID", placeholder="Search for dataset id on Huggingface", search_type="dataset", value="fka/awesome-chatgpt-prompts", ) # dataset_name = HuggingfaceHubSearch( # label="Hub Dataset ID", # placeholder="Search for dataset id on Huggingface", # search_type="dataset", # value="HuggingFaceFW/fineweb", # ) # config_name = "default" # TODO: user input @gr.render(inputs=dataset_name) def embed(name): html_code = f""" """ return gr.HTML(value=html_code) text_column = gr.Textbox(placeholder="text", label="Text colum name to check (data must be non-nested, raw texts!)") n_samples = gr.Number(label="Num first samples to run check") gr_check_btn = gr.Button("Check Dataset") plot = gr.BarPlot() # df = gr.DataFrame(visible=False) gr_check_btn.click(run_quality_check, inputs=[dataset_name, text_column, n_samples], outputs=[plot]) demo.launch()