sagawa commited on
Commit
86844fb
1 Parent(s): bbe3baf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -89,12 +89,12 @@ class RegressionModel(nn.Module):
89
  else:
90
  self.config = torch.load(config_path)
91
  if pretrained:
92
- if 't5' in cfg.pretrained_model_name_or_path:
93
  self.model = T5EncoderModel.from_pretrained(CFG.pretrained_model_name_or_path)
94
  else:
95
  self.model = AutoModel.from_pretrained(CFG.pretrained_model_name_or_path)
96
  else:
97
- if 't5' in cfg.model_name_or_path:
98
  self.model = T5EncoderModel.from_pretrained('sagawa/ZINC-t5')
99
  else:
100
  self.model = AutoModel.from_config(self.config)
 
89
  else:
90
  self.config = torch.load(config_path)
91
  if pretrained:
92
+ if 't5' in cfg.model:
93
  self.model = T5EncoderModel.from_pretrained(CFG.pretrained_model_name_or_path)
94
  else:
95
  self.model = AutoModel.from_pretrained(CFG.pretrained_model_name_or_path)
96
  else:
97
+ if 't5' in cfg.model:
98
  self.model = T5EncoderModel.from_pretrained('sagawa/ZINC-t5')
99
  else:
100
  self.model = AutoModel.from_config(self.config)