Spaces:
Sleeping
Sleeping
""" | |
Donut | |
Copyright (c) 2022-present NAVER Corp. | |
MIT License | |
""" | |
import math | |
import random | |
import re | |
from pathlib import Path | |
import numpy as np | |
import pytorch_lightning as pl | |
import torch | |
from nltk import edit_distance | |
from pytorch_lightning.utilities import rank_zero_only | |
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |
from torch.nn.utils.rnn import pad_sequence | |
from torch.optim.lr_scheduler import LambdaLR | |
from torch.utils.data import DataLoader | |
from donut import DonutConfig, DonutModel | |
class DonutModelPLModule(pl.LightningModule): | |
def __init__(self, config): | |
super().__init__() | |
self.config = config | |
if self.config.get("pretrained_model_name_or_path", False): | |
self.model = DonutModel.from_pretrained( | |
self.config.pretrained_model_name_or_path, | |
input_size=self.config.input_size, | |
max_length=self.config.max_length, | |
align_long_axis=self.config.align_long_axis, | |
ignore_mismatched_sizes=True, | |
) | |
else: | |
self.model = DonutModel( | |
config=DonutConfig( | |
input_size=self.config.input_size, | |
max_length=self.config.max_length, | |
align_long_axis=self.config.align_long_axis, | |
# with DonutConfig, the architecture customization is available, e.g., | |
# encoder_layer=[2,2,14,2], decoder_layer=4, ... | |
) | |
) | |
self.pytorch_lightning_version_is_1 = int(pl.__version__[0]) < 2 | |
self.num_of_loaders = len(self.config.dataset_name_or_paths) | |
def training_step(self, batch, batch_idx): | |
image_tensors, decoder_input_ids, decoder_labels = list(), list(), list() | |
for batch_data in batch: | |
image_tensors.append(batch_data[0]) | |
decoder_input_ids.append(batch_data[1][:, :-1]) | |
decoder_labels.append(batch_data[2][:, 1:]) | |
image_tensors = torch.cat(image_tensors) | |
decoder_input_ids = torch.cat(decoder_input_ids) | |
decoder_labels = torch.cat(decoder_labels) | |
loss = self.model(image_tensors, decoder_input_ids, decoder_labels)[0] | |
self.log_dict({"train_loss": loss}, sync_dist=True) | |
if not self.pytorch_lightning_version_is_1: | |
self.log('loss', loss, prog_bar=True) | |
return loss | |
def on_validation_epoch_start(self) -> None: | |
super().on_validation_epoch_start() | |
self.validation_step_outputs = [[] for _ in range(self.num_of_loaders)] | |
return | |
def validation_step(self, batch, batch_idx, dataloader_idx=0): | |
image_tensors, decoder_input_ids, prompt_end_idxs, answers = batch | |
decoder_prompts = pad_sequence( | |
[input_id[: end_idx + 1] for input_id, end_idx in zip(decoder_input_ids, prompt_end_idxs)], | |
batch_first=True, | |
) | |
preds = self.model.inference( | |
image_tensors=image_tensors, | |
prompt_tensors=decoder_prompts, | |
return_json=False, | |
return_attentions=False, | |
)["predictions"] | |
scores = list() | |
for pred, answer in zip(preds, answers): | |
pred = re.sub(r"(?:(?<=>) | (?=</s_))", "", pred) | |
answer = re.sub(r"<.*?>", "", answer, count=1) | |
answer = answer.replace(self.model.decoder.tokenizer.eos_token, "") | |
scores.append(edit_distance(pred, answer) / max(len(pred), len(answer))) | |
if self.config.get("verbose", False) and len(scores) == 1: | |
self.print(f"Prediction: {pred}") | |
self.print(f" Answer: {answer}") | |
self.print(f" Normed ED: {scores[0]}") | |
self.validation_step_outputs[dataloader_idx].append(scores) | |
return scores | |
def on_validation_epoch_end(self): | |
assert len(self.validation_step_outputs) == self.num_of_loaders | |
cnt = [0] * self.num_of_loaders | |
total_metric = [0] * self.num_of_loaders | |
val_metric = [0] * self.num_of_loaders | |
for i, results in enumerate(self.validation_step_outputs): | |
for scores in results: | |
cnt[i] += len(scores) | |
total_metric[i] += np.sum(scores) | |
val_metric[i] = total_metric[i] / cnt[i] | |
val_metric_name = f"val_metric_{i}th_dataset" | |
self.log_dict({val_metric_name: val_metric[i]}, sync_dist=True) | |
self.log_dict({"val_metric": np.sum(total_metric) / np.sum(cnt)}, sync_dist=True) | |
def configure_optimizers(self): | |
max_iter = None | |
if int(self.config.get("max_epochs", -1)) > 0: | |
assert len(self.config.train_batch_sizes) == 1, "Set max_epochs only if the number of datasets is 1" | |
max_iter = (self.config.max_epochs * self.config.num_training_samples_per_epoch) / ( | |
self.config.train_batch_sizes[0] * torch.cuda.device_count() * self.config.get("num_nodes", 1) | |
) | |
if int(self.config.get("max_steps", -1)) > 0: | |
max_iter = min(self.config.max_steps, max_iter) if max_iter is not None else self.config.max_steps | |
assert max_iter is not None | |
optimizer = torch.optim.Adam(self.parameters(), lr=self.config.lr) | |
scheduler = { | |
"scheduler": self.cosine_scheduler(optimizer, max_iter, self.config.warmup_steps), | |
"name": "learning_rate", | |
"interval": "step", | |
} | |
return [optimizer], [scheduler] | |
def cosine_scheduler(optimizer, training_steps, warmup_steps): | |
def lr_lambda(current_step): | |
if current_step < warmup_steps: | |
return current_step / max(1, warmup_steps) | |
progress = current_step - warmup_steps | |
progress /= max(1, training_steps - warmup_steps) | |
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) | |
return LambdaLR(optimizer, lr_lambda) | |
def on_save_checkpoint(self, checkpoint): | |
save_path = Path(self.config.result_path) / self.config.exp_name / self.config.exp_version | |
self.model.save_pretrained(save_path) | |
self.model.decoder.tokenizer.save_pretrained(save_path) | |
class DonutDataPLModule(pl.LightningDataModule): | |
def __init__(self, config): | |
super().__init__() | |
self.config = config | |
self.train_batch_sizes = self.config.train_batch_sizes | |
self.val_batch_sizes = self.config.val_batch_sizes | |
self.train_datasets = [] | |
self.val_datasets = [] | |
self.g = torch.Generator() | |
self.g.manual_seed(self.config.seed) | |
def train_dataloader(self): | |
loaders = list() | |
for train_dataset, batch_size in zip(self.train_datasets, self.train_batch_sizes): | |
loaders.append( | |
DataLoader( | |
train_dataset, | |
batch_size=batch_size, | |
num_workers=self.config.num_workers, | |
pin_memory=True, | |
worker_init_fn=self.seed_worker, | |
generator=self.g, | |
shuffle=True, | |
) | |
) | |
return loaders | |
def val_dataloader(self): | |
loaders = list() | |
for val_dataset, batch_size in zip(self.val_datasets, self.val_batch_sizes): | |
loaders.append( | |
DataLoader( | |
val_dataset, | |
batch_size=batch_size, | |
pin_memory=True, | |
shuffle=False, | |
) | |
) | |
return loaders | |
def seed_worker(wordker_id): | |
worker_seed = torch.initial_seed() % 2 ** 32 | |
np.random.seed(worker_seed) | |
random.seed(worker_seed) | |