sharpenb's picture
Upload folder using huggingface_hub (#1)
a1d0506 verified
raw
history blame contribute delete
170 Bytes
from torch import nn
FC_CLASS_REGISTRY = {"torch": nn.Linear}
try:
import transformer_engine.pytorch as te
FC_CLASS_REGISTRY["te"] = te.Linear
except:
pass