medmac01
Added multilingual_clip module
3bd5293
raw
history blame
2.97 kB
import os
import pandas as pd
import torch
import argparse
import shutil
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorWithPadding
from torch.utils.data import Dataset, DataLoader
parser = argparse.ArgumentParser()
parser.add_argument("-f", "--filename", type=str)
parser.add_argument(
"data_folder",
nargs="?",
type=str,
default="data_multi",
)
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("data_translated", exist_ok=True)
if args.filename == "015_ccs_synthetic_en.feather":
# No need to translate English -> English
shutil.copy2(os.path.join(args.data_folder, "015_ccs_synthetic_en.feather"), "data_translated")
os._exit(0)
df = pd.read_feather(os.path.join(args.data_folder, args.filename))
df["opus_mt_url"] = df["opus_mt_url"].str.replace("https://huggingface.co./", "")
print(f"Starting translation of English to {df['multi_language_name'][0]}")
class CaptionDataset(Dataset):
def __init__(self, df, tokenizer_name):
self.df = df
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
def __len__(self):
return len(self.df)
def __getitem__(self, index):
sentence1 = df.loc[index, "caption"]
tokens = self.tokenizer(sentence1, return_tensors="pt")
return tokens
tokenizer = AutoTokenizer.from_pretrained(df["opus_mt_url"][0])
model = AutoModelForSeq2SeqLM.from_pretrained(df["opus_mt_url"][0])
model.to(device)
model.eval()
def custom_collate_fn(data):
"""
Data collator with padding.
"""
tokens = [sample["input_ids"][0] for sample in data]
attention_masks = [sample["attention_mask"][0] for sample in data]
attention_masks = torch.nn.utils.rnn.pad_sequence(attention_masks, batch_first=True)
padded_tokens = torch.nn.utils.rnn.pad_sequence(tokens, batch_first=True)
batch = {"input_ids": padded_tokens, "attention_mask": attention_masks}
return batch
if df["multi_target"][0] == 1:
# If model is a multilingual model we need to concatenate target language code
# in the form '>> CODE >>' in front of string so model outputs correct language.
df["caption"] = ">>" + df["target_code"] + "<<" + df["caption"]
test_data = CaptionDataset(df, df["opus_mt_url"][0])
test_dataloader = DataLoader(
test_data,
batch_size=50,
shuffle=False,
num_workers=4,
collate_fn=custom_collate_fn,
)
with torch.no_grad():
decoded_tokens = []
for i, batch in enumerate(tqdm(test_dataloader)):
batch = {k: v.to(device) for k, v in batch.items()}
output_tokens = model.generate(**batch)
decoded_tokens += tokenizer.batch_decode(output_tokens.to("cpu"), skip_special_tokens=True)
df["caption_multi"] = decoded_tokens
df.to_feather(os.path.join("data_translated", args.filename))
print(f"Finished translating English to {df['multi_language_name'][0]}")