import torch
import torch.nn as nn
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast

class TextRefinementModel(nn.Module):
    def __init__(self, model_name='tirthadagr8/custom-mbart-large-50', max_length=64):
        super(TextRefinementModel, self).__init__()
        self.tokenizer = MBart50TokenizerFast.from_pretrained(model_name)
        self.mbart = MBartForConditionalGeneration.from_pretrained(model_name)
        self.mbart.config.max_length=64
        self.max_length = max_length
        
        # Set the language code for Japanese (ja_XX) or Chinese (zh_CN)
#         self.tokenizer.src_lang = 'ja_XX'  # For Japanese
        # self.tokenizer.src_lang = 'zh_CN'  # Uncomment for Chinese

    def forward(self, input_texts):
        # Tokenize the noisy text inputs
        input_ids = self.tokenizer(input_texts, return_tensors='pt', padding=True, truncation=True, max_length=self.max_length)['input_ids']
        
        # mBART generates output logits
        output_logits = self.mbart(input_ids).logits
        
        return output_logits

    def generate_corrected_text(self, input_texts, temperature=0.7):
        # Tokenize the input noisy text
        input_ids = self.tokenizer(input_texts, return_tensors='pt', padding=True, truncation=True, max_length=self.max_length)['input_ids']
        
        # Generate corrected text using mBART's generate function
        mbart_outputs = self.mbart.generate(input_ids, max_length=self.max_length, temperature=temperature, num_return_sequences=1)
        
        # Decode generated text
        corrected_texts = [self.tokenizer.decode(g, skip_special_tokens=True) for g in mbart_outputs]
        return corrected_texts

# Example usage
model = TextRefinementModel()

noisy_text = ["ใ“ใ‚Œใฏ้–“้•ใฃใŸใƒ†ใ‚ญใ‚นใƒˆใฎไพ‹ใงใ™ใ€‚", "่ฟ™ๆ˜ฏ้”™่ฏฏ็š„ๆ–‡ๆœฌ็คบไพ‹ใ€‚"]  # Japanese and Chinese examples
corrected_text = model.generate_corrected_text(noisy_text)

print(f"Corrected Text: {corrected_text}")

For training:

from transformers import AdamW
import torch.nn.functional as F
from tqdm import tqdm
from torch.utils.data import DataLoader
import numpy as np

# Initialize the mBART model and optimizer
model = TextRefinementModel().cuda()
optimizer = AdamW(model.parameters(), lr=5e-5)

batch_size = 16

# Create a custom dataset class
class TextCorrectionDataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer, max_length=64):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        noisy_text, correct_text = self.data[idx]
        inputs = self.tokenizer(noisy_text, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
        labels = self.tokenizer(correct_text, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
        
        # Adjust label tensors for correct shape
        input_ids = inputs['input_ids'].squeeze()  # Remove extra batch dimension
        labels = labels['input_ids'].squeeze()     # Same for labels
        return input_ids, labels

# Create DataLoader with batching
train_dataset = TextCorrectionDataset(train_data, model.tokenizer)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Define training loop with batches
def train_epoch(model, train_loader, optimizer):
    model.train()
    total_loss = []
    step_iter=0
    for input_ids, labels in tqdm(train_loader):
        # Move tensors to model's device
        input_ids = input_ids.to(model.mbart.device)
        labels = labels.to(model.mbart.device)
        
        # Forward pass
        outputs = model.mbart(input_ids=input_ids, labels=labels)
        loss = outputs.loss
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss.append(loss.item())
        
        if step_iter%100==0:
            print('Loss:',np.mean(total_loss))
        
        step_iter+=1
    return np.mean(total_loss)

# Example training loop
for epoch in range(5):  # Train for 5 epochs (or as needed)
    loss = train_epoch(model, train_loader, optimizer)
    print(f"Epoch {epoch+1}, Loss: {loss:.4f}")
Downloads last month
20
Safetensors
Model size
615M params
Tensor type
F32
ยท
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.