Spaces:
Running
on
Zero
Running
on
Zero
# import gc | |
# import os | |
# from math import exp | |
# from typing import List, Union | |
# import torch | |
# import transformers | |
# os.environ["OMP_NUM_THREADS"] = "1" | |
# os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
# PAD_TOKEN_LABEL_ID = torch.nn.CrossEntropyLoss().ignore_index | |
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# class PerplexityCalculator: | |
# """ | |
# Calculates perplexity of text using a pre-trained language model. | |
# Adapted from https://github.com/asahi417/lmppl/blob/main/lmppl/ppl_recurrent_lm.py | |
# Parameters | |
# ---------- | |
# model_path : str | |
# Path to the pre-trained language model | |
# load_in_8bit : bool, default=False | |
# Use 8-bit quantization for the model. Requires CUDA. | |
# device_map : str, default="auto" | |
# Device mapping for the model. | |
# """ | |
# def __init__( | |
# self, | |
# model_path: str, | |
# load_in_8bit: bool = False, | |
# device_map: str = "auto", | |
# dtype: torch.dtype = torch.float16, | |
# ): | |
# self.tokenizer = transformers.AutoTokenizer.from_pretrained( | |
# model_path, padding_side="right" | |
# ) | |
# # Configure model loading based on quantization setting and device availability | |
# if load_in_8bit: | |
# if DEVICE.type != "cuda": | |
# raise ValueError("8-bit quantization requires CUDA device") | |
# quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True) | |
# self.model = transformers.AutoModelForCausalLM.from_pretrained( | |
# model_path, | |
# quantization_config=quantization_config, | |
# device_map=device_map, | |
# ) | |
# else: | |
# self.model = transformers.AutoModelForCausalLM.from_pretrained( | |
# model_path, | |
# torch_dtype=dtype, | |
# device_map=device_map, | |
# ) | |
# self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") | |
# self.model.eval() | |
# def get_perplexity( | |
# self, input_texts: Union[str, List[str]], batch_size: int = 1 | |
# ) -> Union[float, List[float]]: | |
# single_input = isinstance(input_texts, str) | |
# input_texts = [input_texts] if single_input else input_texts | |
# loss_list = [] | |
# batches = len(input_texts) // batch_size + (len(input_texts) % batch_size != 0) | |
# for j in range(batches): | |
# a = j * batch_size | |
# b = (j + 1) * batch_size | |
# input_batch = input_texts[a:b] | |
# with torch.no_grad(): | |
# text_with_special = [ | |
# f"{self.tokenizer.bos_token}{text}{self.tokenizer.eos_token}" | |
# for text in input_batch | |
# ] | |
# model_inputs = self.tokenizer( | |
# text_with_special, | |
# return_tensors="pt", | |
# add_special_tokens=False, | |
# padding=True, | |
# ) | |
# if "token_type_ids" in model_inputs: | |
# model_inputs.pop("token_type_ids") | |
# model_inputs = {k: v.to(DEVICE) for k, v in model_inputs.items()} | |
# output = self.model(**model_inputs, use_cache=False) | |
# logits = output["logits"] | |
# label = model_inputs["input_ids"] | |
# label[label == self.tokenizer.pad_token_id] = PAD_TOKEN_LABEL_ID | |
# shift_logits = logits[..., :-1, :].contiguous() | |
# shift_labels = label[..., 1:].contiguous() | |
# loss = self.loss_fct( | |
# shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) | |
# ) | |
# loss = loss.view(len(logits), -1) | |
# valid_length = (shift_labels != PAD_TOKEN_LABEL_ID).sum(dim=-1) | |
# loss = torch.sum(loss, -1) / valid_length | |
# loss_list += loss.cpu().tolist() | |
# ppl = [exp(i) for i in loss_list] | |
# return ppl[0] if single_input else ppl | |
# def clear_gpu_memory(self) -> None: | |
# """Clears GPU memory by deleting references and emptying caches.""" | |
# if not torch.cuda.is_available(): | |
# return | |
# # Delete model and tokenizer if they exist | |
# if hasattr(self, "model"): | |
# del self.model | |
# if hasattr(self, "tokenizer"): | |
# del self.tokenizer | |
# # Run garbage collection | |
# gc.collect() | |
# # Clear CUDA cache and reset memory stats | |
# with DEVICE: | |
# torch.cuda.empty_cache() | |
# torch.cuda.ipc_collect() | |
# torch.cuda.reset_peak_memory_stats() | |
import gc | |
import os | |
from math import exp | |
from typing import List, Union | |
import pandas as pd | |
import torch | |
import transformers | |
from tqdm import tqdm | |
from collections import OrderedDict | |
os.environ['OMP_NUM_THREADS'] = '1' | |
os.environ['TOKENIZERS_PARALLELISM'] = 'false' | |
class LRUCache: | |
def __init__(self, capacity=10**11): | |
self.capacity = capacity | |
self.cache = OrderedDict() | |
def get(self, key): | |
if key in self.cache: | |
self.cache.move_to_end(key) | |
return self.cache[key] | |
return None | |
def set(self, key, value): | |
self.cache[key] = value | |
self.cache.move_to_end(key) | |
if len(self.cache) > self.capacity: | |
self.cache.popitem(last=False) | |
def __len__(self): | |
return len(self.cache) | |
class PerplexityCalculator: | |
model_kwargs = { | |
# "attn_implementation": "sdpa", #γγγγ³γ‘γ³γγ’γ¦γγγͺγγ¨γΉγ³γ’γε€γγγε€ε°ι γγͺγ | |
"device_map": "auto", | |
"torch_dtype": torch.float16, | |
} | |
device = torch.device('cuda') | |
def __init__(self, model_path: str, capacity=10**11): | |
self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_path, padding_side="right") | |
self.model = transformers.AutoModelForCausalLM.from_pretrained(model_path, **self.model_kwargs) | |
self.loss_fct = torch.nn.CrossEntropyLoss(reduction='none') | |
self.pad_token_label_id = self.loss_fct.ignore_index | |
self.model.eval() | |
self.cache = LRUCache(capacity=capacity) | |
def get_perplexity(self, input_texts, batch_size=128, use_cache=True, verbose=False): | |
single_input = isinstance(input_texts, str) | |
input_texts = [input_texts] if single_input else input_texts | |
results = [None] * len(input_texts) | |
if use_cache: | |
text_to_process = [] | |
for i, text in enumerate(input_texts): | |
cached_val = self.cache.get(text) | |
if cached_val is not None: | |
results[i] = cached_val | |
else: | |
text_to_process.append(text) | |
else: | |
text_to_process = input_texts.copy() | |
loss_list = [] | |
batches = len(text_to_process)//batch_size + (len(text_to_process)%batch_size != 0) | |
pbar = range(batches) | |
if verbose and batches: | |
pbar = tqdm(range(batches)) | |
for j in pbar: | |
a = j*batch_size | |
b = (j+1)*batch_size | |
input_batch = text_to_process[a:b] | |
with torch.no_grad(): | |
# Explicitly add sequence boundary tokens to the text | |
text_with_special = [f"{self.tokenizer.bos_token}{text}{self.tokenizer.eos_token}" for text in input_batch] | |
# Tokenize | |
model_inputs = self.tokenizer( | |
text_with_special, | |
return_tensors='pt', | |
add_special_tokens=False, | |
padding=True, | |
) | |
if 'token_type_ids' in model_inputs: | |
model_inputs.pop('token_type_ids') | |
model_inputs = {k: v.to(self.device ) for k, v in model_inputs.items()} | |
# Get model output | |
output = self.model(**model_inputs, use_cache=False) | |
logits = output['logits'] | |
label = model_inputs['input_ids'] | |
label[label == self.tokenizer.pad_token_id] = self.pad_token_label_id | |
# Shift logits and labels for calculating loss | |
shift_logits = logits[..., :-1, :].contiguous() # Drop last prediction | |
shift_labels = label[..., 1:].contiguous() # Drop first input | |
# Calculate token-wise loss | |
loss = self.loss_fct( | |
shift_logits.view(-1, shift_logits.size(-1)), | |
shift_labels.view(-1) | |
) | |
loss = loss.view(len(logits), -1) | |
valid_length = (shift_labels != self.pad_token_label_id).sum(dim=-1) | |
loss = torch.sum(loss, -1) / valid_length | |
loss_list += loss.cpu().tolist() | |
ppl = [exp(i) for i in loss_list] | |
index_ppl = 0 | |
for index_el, el in enumerate(results): | |
if el is None: | |
results[index_el] = ppl[index_ppl] | |
self.cache.set(text_to_process[index_ppl], ppl[index_ppl]) | |
index_ppl += 1 | |
return results[0] if single_input else results | |