Batch inference seems to be done sequentially

#50
by yard1 - opened

When using a batch size larger than 1, the generation time increases almost linearly with the batch size. This is highly unexpected and not something I have seen with other transformers. I would expect a transformer model to handle batched inputs without noticeable impact on latency.

Script:

from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import transformers
import torch
import deepspeed
import time
from deepspeed.accelerator import get_accelerator

model = "tiiuae/falcon-7b"
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    model, trust_remote_code=True, device_map="auto", torch_dtype=torch.bfloat16
)
batch_size = 1

input_prompt = [
    "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:"
] * batch_size
input_tokens = tokenizer.batch_encode_plus(
    input_prompt,
    return_tensors="pt",
)
token_num = input_tokens["input_ids"].size(-1)
for t in input_tokens:
    if torch.is_tensor(input_tokens[t]):
        input_tokens[t] = input_tokens[t].to(model.device)
input_tokens.pop("token_type_ids")

# Warmup
print(f"Batch size {batch_size}")
sequences = model.generate(
    **input_tokens, min_length=512, max_length=512, do_sample=True
)
torch.cuda.synchronize()
st = time.monotonic()
for i in range(2):
    torch.cuda.synchronize()
    sequences = model.generate(
        **input_tokens, min_length=512, max_length=512, do_sample=True
    )
    torch.cuda.synchronize()
tt = time.monotonic() - st
print(f"Time taken {tt/2} time per token {tt/512/2}")

Results with different batch_size:

BS 1: Time taken 20.67650790150003 time per token 0.04038380449511725
BS 2: Time taken 32.592279224000094 time per token 0.06365679535937518
BS 4: Time taken 48.25992262649993 time per token 0.09425766137988267
BS 8: Time taken 86.17116434899981 time per token 0.16830305536914025

That's the time it takes to process the entire batch and increases with the number of samples. What you expect to decrease is the time per sample, which indeed decreases when I look at your numbers

That's the time it takes to process the entire batch and increases with the number of samples. What you expect to decrease is the time per sample, which indeed decreases when I look at your numbers

This is bullshit g-ronimo

be nice.

from transformers import AutoTokenizer, AutoModelForCausalLM
import transformers
import torch
import time

model = "tiiuae/falcon-7b"
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    model, 
    trust_remote_code=True, 
    device_map="auto", 
    torch_dtype=torch.bfloat16
)

for batch_size in [2, 4, 8]:
    input_prompt = [
        "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:"
    ] * batch_size
    input_tokens = tokenizer.batch_encode_plus(
        input_prompt,
        return_tensors="pt",
    ).to("cuda")
    input_tokens_cnt = sum([len(t) for t in input_tokens["input_ids"]])

    # Warmup
    sequences = model.generate(
        **input_tokens, min_length=512, max_length=512, do_sample=True
    )
    
    torch.cuda.synchronize()
    st = time.monotonic()
    generated_tokens_count = []
    num_trials = 2
    for i in range(num_trials):
        torch.cuda.synchronize()
        sequences = model.generate(
            **input_tokens, 
            min_length=512, 
            max_length=512, 
            do_sample=True
        )
        torch.cuda.synchronize()
        sequences_tokens_cnt = sum([len(t) for t in sequences])
        generated_tokens_count.append(sequences_tokens_cnt - input_tokens_cnt)

    tt = time.monotonic() - st
    print(f"batch_size {batch_size}: Avg. time taken {tt/num_trials}, avg. time per token {tt/sum(generated_tokens_count)}")

output

batch_size 2: Avg. time taken 17.3977282285, avg. time per token 0.019287947038248338
batch_size 4: Avg. time taken 17.282605756000066, avg. time per token 0.009580158401330413
batch_size 8: Avg. time taken 18.542240016500045, avg. time per token 0.005139201778409103

Sign up or log in to comment