import os import wandb import fsspec import hydra import lightning as L import omegaconf import rich.syntax import rich.tree import torch import pl_data_loader as dataloader from diffusion import Diffusion import utils from lightning.pytorch.strategies import DDPStrategy from transformers import AutoTokenizer from datasets import load_from_disk, load_dataset #wandb.login(key="2b76a2fa2c1cdfddc5f443602c17b011fefb0a8f") omegaconf.OmegaConf.register_new_resolver( 'cwd', os.getcwd) omegaconf.OmegaConf.register_new_resolver( 'device_count', torch.cuda.device_count) omegaconf.OmegaConf.register_new_resolver( 'eval', eval) omegaconf.OmegaConf.register_new_resolver( 'div_up', lambda x, y: (x + y - 1) // y) def _load_from_checkpoint(config, tokenizer): if 'hf' in config.backbone: return Diffusion( config, tokenizer=tokenizer).to('cuda') else: model= Diffusion.load_from_checkpoint( config.eval.checkpoint_path, tokenizer=tokenizer, config=config) return model @L.pytorch.utilities.rank_zero_only def _print_config( config: omegaconf.DictConfig, resolve: bool = True, save_cfg: bool = True) -> None: """Prints content of DictConfig using Rich library and its tree structure. Args: config (DictConfig): Configuration composed by Hydra. resolve (bool): Whether to resolve reference fields of DictConfig. save_cfg (bool): Whether to save the configuration tree to a file. """ style = 'dim' tree = rich.tree.Tree('CONFIG', style=style, guide_style=style) fields = config.keys() for field in fields: branch = tree.add(field, style=style, guide_style=style) config_section = config.get(field) branch_content = str(config_section) if isinstance(config_section, omegaconf.DictConfig): branch_content = omegaconf.OmegaConf.to_yaml( config_section, resolve=resolve) branch.add(rich.syntax.Syntax(branch_content, 'yaml')) rich.print(tree) if save_cfg: with fsspec.open( '{}/config_tree.txt'.format( config.checkpointing.save_dir), 'w') as fp: rich.print(tree, file=fp) @L.pytorch.utilities.rank_zero_only def _print_batch(train_ds, valid_ds, tokenizer, k=64): #for dl_type, dl in [ #('train', train_ds), ('valid', valid_ds)]: for dl_type, dl in [ ('train', train_ds)]: print(f'Printing {dl_type} dataloader batch.') batch = next(iter(dl)) print('Batch input_ids.shape', batch['input_ids'].shape) first = batch['input_ids'][0, :k] last = batch['input_ids'][0, -k:] print(f'First {k} tokens:', tokenizer.decode(first)) print('ids:', first) print(f'Last {k} tokens:', tokenizer.decode(last)) print('ids:', last) def generate_samples(config, logger, tokenizer): logger.info('Generating samples.') model = _load_from_checkpoint(config=config, tokenizer=tokenizer) model.gen_ppl_metric.reset() if config.eval.disable_ema: logger.info('Disabling EMA.') model.ema = None stride_length = config.sampling.stride_length num_strides = config.sampling.num_strides for _ in range(config.sampling.num_sample_batches): if config.sampling.semi_ar: _, intermediate_samples, _ = model.restore_model_and_semi_ar_sample( stride_length=stride_length, num_strides=num_strides, dt=1 / config.sampling.steps) text_samples = intermediate_samples[-1] # Note: Samples generated using semi-ar method # need to to be processed before computing generative perplexity # since these samples contain numerous <|endoftext|> tokens # and diffusion.compute_generative_perplexity() discards # any text after the first EOS token. else: samples = model.restore_model_and_sample( num_steps=config.sampling.steps) text_samples = model.tokenizer.batch_decode(samples) model.compute_generative_perplexity(text_samples) print('Text samples:', text_samples) if not config.sampling.semi_ar: print('Generative perplexity:', model.gen_ppl_metric.compute()) return text_samples def _ppl_eval(config, logger, tokenizer, data_module): logger.info('Starting Zero Shot Eval.') model = _load_from_checkpoint(config=config, tokenizer=tokenizer) if config.eval.disable_ema: logger.info('Disabling EMA.') model.ema = None wandb_logger = None if config.get('wandb', None) is not None: wandb_logger = L.pytorch.loggers.WandbLogger( config=omegaconf.OmegaConf.to_object(config), ** config.wandb) callbacks = [] if 'callbacks' in config: for _, callback in config.callbacks.items(): callbacks.append(hydra.utils.instantiate(callback)) trainer = hydra.utils.instantiate( config.trainer, default_root_dir=os.getcwd(), callbacks=callbacks, strategy=DDPStrategy(find_unused_parameters=True), logger=wandb_logger) # _, valid_ds = dataloader.get_dataloaders( # config, tokenizer, skip_train=True, valid_seed=config.seed) trainer.test(model, data_module) def _train(config, logger, tokenizer, data_module): logger.info('Starting Training.') wandb_logger = None if config.get('wandb', None) is not None: wandb_logger = L.pytorch.loggers.WandbLogger( config=omegaconf.OmegaConf.to_object(config), ** config.wandb) if (config.checkpointing.resume_from_ckpt and config.checkpointing.resume_ckpt_path is not None and utils.fsspec_exists( config.checkpointing.resume_ckpt_path)): ckpt_path = config.checkpointing.resume_ckpt_path else: ckpt_path = None # Lightning callbacks callbacks = [] if 'callbacks' in config: for _, callback in config.callbacks.items(): callbacks.append(hydra.utils.instantiate(callback)) ''' train_ds, valid_ds = dataloader.get_dataloaders( config, tokenizer) _print_batch(train_ds, valid_ds, tokenizer) model = diffusion.Diffusion( config, tokenizer=valid_ds.tokenizer) ''' trainer = hydra.utils.instantiate( config.trainer, default_root_dir=os.getcwd(), callbacks=callbacks, accelerator='cuda', strategy=DDPStrategy(find_unused_parameters=True), logger=wandb_logger) model = Diffusion( config, tokenizer=tokenizer) trainer.fit(model, datamodule=data_module, ckpt_path=ckpt_path) ''' trainer.fit(model, train_ds, valid_ds, ckpt_path=ckpt_path) ''' @hydra.main(version_base=None, config_path='configs', config_name='config') def main(config): """Main entry point for training.""" L.seed_everything(config.seed) _print_config(config, resolve=True, save_cfg=True) logger = utils.get_logger(__name__) tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") if config.backbone == "vanilla_esm_pretrain": train_dataset = load_dataset('csv', data_files=config.data.train.vanilla_esm_train_path) val_dataset = load_dataset('csv', data_files=config.data.valid.vanilla_esm_valid_path) test_dataset = load_dataset('csv', data_files=config.data.test.vanilla_esm_test_path) elif config.backbone == "membrane_esm_finetune" or config.backbone == "dit": train_dataset = load_dataset('csv', data_files=config.data.train.membrane_esm_train_path) val_dataset = load_dataset('csv', data_files=config.data.valid.membrane_esm_valid_path) test_dataset = load_dataset('csv', data_files=config.data.test.membrane_esm_test_path) lst = [i for i in range(1, 200)] train_dataset = train_dataset['train']#.select(lst) val_dataset = val_dataset['train']#.select(lst) test_dataset = test_dataset['train']#.select(lst) if config.training.focus_mask : collator = dataloader.membrane_collate_fn elif config.data.wrapping: collator = dataloader.wrap_collate_fn else: collator = collate_fn data_module = dataloader.CustomDataModule( train_dataset, val_dataset, test_dataset, tokenizer, batch_size=config.loader.batch_size, collate_fn=collator ) if config.mode == 'sample_eval': generate_samples(config, logger, tokenizer) elif config.mode == 'ppl_eval': _ppl_eval(config, logger, tokenizer, data_module) else: _train(config, logger, tokenizer, data_module) if __name__ == '__main__': main()