Fine-tuned Longformer for Summarization of Machine Learning Articles

Model Details

  • GitHub: https://github.com/Bakhitovd/led-base-7168-ml
  • Model name: bakhitovd/led-base-7168-ml
  • Model type: Longformer (alenai/led-base-16384)
  • Model description: This Longformer model has been fine-tuned on a focused subset of the arXiv part of the scientific papers dataset, specifically targeting articles about Machine Learning. It aims to generate accurate and consistent summaries of machine learning research papers.

Intended Use

This model is intended to be used for text summarization tasks, specifically for summarizing machine learning research papers.

How to Use

import torch
from transformers import LEDTokenizer, LEDForConditionalGeneration
tokenizer = LEDTokenizer.from_pretrained("bakhitovd/led-base-7168-ml")
model = LEDForConditionalGeneration.from_pretrained("bakhitovd/led-base-7168-ml")

Use the model for summarization

article = "... long document ..."
inputs_dict = tokenizer.encode(article, padding="max_length", max_length=16384, return_tensors="pt", truncation=True)
input_ids = inputs_dict.input_ids.to("cuda")
attention_mask = inputs_dict.attention_mask.to("cuda")
global_attention_mask = torch.zeros_like(attention_mask)
global_attention_mask[:, 0] = 1
predicted_abstract_ids = model.generate(input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask, max_length=512)
summary = tokenizer.decode(predicted_abstract_ids, skip_special_tokens=True)
print(summary)

Training Data

Dataset name: bakhitovd/data_science_arxiv
This dataset is a subset of the 'Scientific papers' dataset, which contains articles semantically, structurally, and meaningfully closest to articles describing machine learning. This subset was obtained using K-means clustering on the embeddings generated by SciBERT.

Evaluation Results

The model's performance was evaluated using ROUGE metrics and it showed improved performance over the baseline models.

image.png

Downloads last month
124
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Dataset used to train bakhitovd/led-base-7168-ml