Spaces:
Runtime error
Runtime error
import gradio as gr | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from matplotlib.ticker import MultipleLocator | |
INTRO = """# Harm's law | |
The Chinchilla scaling laws focus on optimally scaling training compute but often we also care about inference cost. | |
This tool follows [Harm de Vries' blog post](https://www.harmdevries.com/post/model-size-vs-compute-overhead/) and visualizes the tradeoff between training compute and inference cost (i.e. model size). | |
""" | |
### CHINCHILLA PARAMS: | |
E = 1.62 | |
A = 406.4 | |
B = 410.7 | |
alpha = 0.336 | |
beta = 0.283 | |
Bn = 10**9 | |
G = ((alpha*A)/(beta*B))**(1/(alpha+beta)) | |
### FUNCTIONS | |
def to_flops(N, D): | |
return 6 * N * D | |
def n_opt(C): | |
return G * ((C/6) ** (beta / (alpha+beta))) | |
def d_opt(C): | |
return (1/G) * ((C/6) ** (alpha / (alpha+beta))) | |
def compute_kd(kn): | |
frac = (A/B)*(G**(-alpha-beta)) | |
kd = (1-((kn**-alpha -1)*frac))**(1/(-beta)) | |
return kd | |
def compute_overhead(kn, kd): | |
return kn*kd - 1 | |
### PRECOMPUTE CURVE: | |
kn_min = 0.2 | |
kn_max = 2 | |
kns = np.linspace(0.2, 2, 100) | |
overheads = [] | |
for kn in kns: | |
kd = compute_kd(kn) | |
overheads.append(compute_overhead(kn, kd)*100) | |
def plot_curve(kn, kd): | |
fig, ax = plt.subplots(dpi=200, figsize=(5, 3)) | |
plt.plot(kns, overheads, color="black", zorder=1) | |
plt.scatter([kn], [compute_overhead(kn, kd)*100], s=100, marker="o", c="red", label="You are here!", zorder=2) | |
plt.scatter([1.0], [0.0], marker="o", s=100, c="blue", label="Chinchilla optimal", zorder=2) | |
plt.xlabel("Fraction of Chinchilla optimal model size") | |
plt.ylabel("Compute overhead (%)") | |
plt.legend(loc="best") | |
plt.grid(True, which="both") | |
plt.grid(True, which="minor", alpha=0.5) | |
ax.yaxis.set_minor_locator(MultipleLocator(10)) | |
plt.tight_layout() | |
return fig | |
def compute(N, D): | |
C = to_flops(N * Bn, D * Bn) | |
N_opt = n_opt(C) | |
D_opt = d_opt(C) | |
kn = Bn*N/N_opt | |
kd = compute_kd(kn) | |
fig = plot_curve(kn, kd) | |
text = f"""\ | |
## Compute: | |
Your specificied setting corresponds to the following training compute budget. | |
**Compute budget (TFLOPs): {C:.2E}** | |
## Chinchilla optimal: | |
If you are optimizing for model performance and ignore inference cost this is the optimal setting for training: | |
**Optimal model size: {N_opt/Bn:.2f}B parameters** | |
**Optimal dataset size: {D_opt/Bn:.2f}B tokens** | |
## Your setting trade-off: | |
Compared to the compute optimal model. | |
**Training compute overhead: {100*compute_overhead(kn, kd):.2f}%** | |
**Inference cost savings: {100 - kn*100:.2f}%** """ | |
return text, fig | |
with gr.Blocks() as demo: | |
gr.Markdown(INTRO) | |
with gr.Row(): | |
N = gr.Number(value=7, label="Model size (in B parameters):") | |
D = gr.Number(value=2000, label="Dataset size (in B tokens):") | |
button = gr.Button("Compute!") | |
plot = gr.Plot(value=plt) | |
md = gr.Markdown("") | |
button.click(fn=compute, inputs=[N, D], outputs=[md, plot]) | |
demo.load(fn=compute, inputs=[N, D], outputs=[md, plot]) | |
demo.launch() |