from torch import nn | |
FC_CLASS_REGISTRY = {"torch": nn.Linear} | |
try: | |
import transformer_engine.pytorch as te | |
FC_CLASS_REGISTRY["te"] = te.Linear | |
except: | |
pass | |
from torch import nn | |
FC_CLASS_REGISTRY = {"torch": nn.Linear} | |
try: | |
import transformer_engine.pytorch as te | |
FC_CLASS_REGISTRY["te"] = te.Linear | |
except: | |
pass | |