gemma2-ppl / metric.py
Prgckwb's picture
Update metric.py
e455656 verified
# import gc
# import os
# from math import exp
# from typing import List, Union
# import torch
# import transformers
# os.environ["OMP_NUM_THREADS"] = "1"
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
# PAD_TOKEN_LABEL_ID = torch.nn.CrossEntropyLoss().ignore_index
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# class PerplexityCalculator:
# """
# Calculates perplexity of text using a pre-trained language model.
# Adapted from https://github.com/asahi417/lmppl/blob/main/lmppl/ppl_recurrent_lm.py
# Parameters
# ----------
# model_path : str
# Path to the pre-trained language model
# load_in_8bit : bool, default=False
# Use 8-bit quantization for the model. Requires CUDA.
# device_map : str, default="auto"
# Device mapping for the model.
# """
# def __init__(
# self,
# model_path: str,
# load_in_8bit: bool = False,
# device_map: str = "auto",
# dtype: torch.dtype = torch.float16,
# ):
# self.tokenizer = transformers.AutoTokenizer.from_pretrained(
# model_path, padding_side="right"
# )
# # Configure model loading based on quantization setting and device availability
# if load_in_8bit:
# if DEVICE.type != "cuda":
# raise ValueError("8-bit quantization requires CUDA device")
# quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True)
# self.model = transformers.AutoModelForCausalLM.from_pretrained(
# model_path,
# quantization_config=quantization_config,
# device_map=device_map,
# )
# else:
# self.model = transformers.AutoModelForCausalLM.from_pretrained(
# model_path,
# torch_dtype=dtype,
# device_map=device_map,
# )
# self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
# self.model.eval()
# def get_perplexity(
# self, input_texts: Union[str, List[str]], batch_size: int = 1
# ) -> Union[float, List[float]]:
# single_input = isinstance(input_texts, str)
# input_texts = [input_texts] if single_input else input_texts
# loss_list = []
# batches = len(input_texts) // batch_size + (len(input_texts) % batch_size != 0)
# for j in range(batches):
# a = j * batch_size
# b = (j + 1) * batch_size
# input_batch = input_texts[a:b]
# with torch.no_grad():
# text_with_special = [
# f"{self.tokenizer.bos_token}{text}{self.tokenizer.eos_token}"
# for text in input_batch
# ]
# model_inputs = self.tokenizer(
# text_with_special,
# return_tensors="pt",
# add_special_tokens=False,
# padding=True,
# )
# if "token_type_ids" in model_inputs:
# model_inputs.pop("token_type_ids")
# model_inputs = {k: v.to(DEVICE) for k, v in model_inputs.items()}
# output = self.model(**model_inputs, use_cache=False)
# logits = output["logits"]
# label = model_inputs["input_ids"]
# label[label == self.tokenizer.pad_token_id] = PAD_TOKEN_LABEL_ID
# shift_logits = logits[..., :-1, :].contiguous()
# shift_labels = label[..., 1:].contiguous()
# loss = self.loss_fct(
# shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
# )
# loss = loss.view(len(logits), -1)
# valid_length = (shift_labels != PAD_TOKEN_LABEL_ID).sum(dim=-1)
# loss = torch.sum(loss, -1) / valid_length
# loss_list += loss.cpu().tolist()
# ppl = [exp(i) for i in loss_list]
# return ppl[0] if single_input else ppl
# def clear_gpu_memory(self) -> None:
# """Clears GPU memory by deleting references and emptying caches."""
# if not torch.cuda.is_available():
# return
# # Delete model and tokenizer if they exist
# if hasattr(self, "model"):
# del self.model
# if hasattr(self, "tokenizer"):
# del self.tokenizer
# # Run garbage collection
# gc.collect()
# # Clear CUDA cache and reset memory stats
# with DEVICE:
# torch.cuda.empty_cache()
# torch.cuda.ipc_collect()
# torch.cuda.reset_peak_memory_stats()
import gc
import os
from math import exp
from typing import List, Union
import pandas as pd
import torch
import transformers
from tqdm import tqdm
from collections import OrderedDict
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
class LRUCache:
def __init__(self, capacity=10**11):
self.capacity = capacity
self.cache = OrderedDict()
def get(self, key):
if key in self.cache:
self.cache.move_to_end(key)
return self.cache[key]
return None
def set(self, key, value):
self.cache[key] = value
self.cache.move_to_end(key)
if len(self.cache) > self.capacity:
self.cache.popitem(last=False)
def __len__(self):
return len(self.cache)
class PerplexityCalculator:
model_kwargs = {
# "attn_implementation": "sdpa", #γ“γ‚Œγ‚’γ‚³γƒ‘γƒ³γƒˆγ‚’γ‚¦γƒˆγ—γͺγ„γ¨γ‚Ήγ‚³γ‚’γŒε€‰γ‚γ‚‹γ€‚ε€šε°‘ι…γγͺγ‚‹
"device_map": "auto",
"torch_dtype": torch.float16,
}
device = torch.device('cuda')
def __init__(self, model_path: str, capacity=10**11):
self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_path, padding_side="right")
self.model = transformers.AutoModelForCausalLM.from_pretrained(model_path, **self.model_kwargs)
self.loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
self.pad_token_label_id = self.loss_fct.ignore_index
self.model.eval()
self.cache = LRUCache(capacity=capacity)
def get_perplexity(self, input_texts, batch_size=128, use_cache=True, verbose=False):
single_input = isinstance(input_texts, str)
input_texts = [input_texts] if single_input else input_texts
results = [None] * len(input_texts)
if use_cache:
text_to_process = []
for i, text in enumerate(input_texts):
cached_val = self.cache.get(text)
if cached_val is not None:
results[i] = cached_val
else:
text_to_process.append(text)
else:
text_to_process = input_texts.copy()
loss_list = []
batches = len(text_to_process)//batch_size + (len(text_to_process)%batch_size != 0)
pbar = range(batches)
if verbose and batches:
pbar = tqdm(range(batches))
for j in pbar:
a = j*batch_size
b = (j+1)*batch_size
input_batch = text_to_process[a:b]
with torch.no_grad():
# Explicitly add sequence boundary tokens to the text
text_with_special = [f"{self.tokenizer.bos_token}{text}{self.tokenizer.eos_token}" for text in input_batch]
# Tokenize
model_inputs = self.tokenizer(
text_with_special,
return_tensors='pt',
add_special_tokens=False,
padding=True,
)
if 'token_type_ids' in model_inputs:
model_inputs.pop('token_type_ids')
model_inputs = {k: v.to(self.device ) for k, v in model_inputs.items()}
# Get model output
output = self.model(**model_inputs, use_cache=False)
logits = output['logits']
label = model_inputs['input_ids']
label[label == self.tokenizer.pad_token_id] = self.pad_token_label_id
# Shift logits and labels for calculating loss
shift_logits = logits[..., :-1, :].contiguous() # Drop last prediction
shift_labels = label[..., 1:].contiguous() # Drop first input
# Calculate token-wise loss
loss = self.loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)
)
loss = loss.view(len(logits), -1)
valid_length = (shift_labels != self.pad_token_label_id).sum(dim=-1)
loss = torch.sum(loss, -1) / valid_length
loss_list += loss.cpu().tolist()
ppl = [exp(i) for i in loss_list]
index_ppl = 0
for index_el, el in enumerate(results):
if el is None:
results[index_el] = ppl[index_ppl]
self.cache.set(text_to_process[index_ppl], ppl[index_ppl])
index_ppl += 1
return results[0] if single_input else results