From 4b271fcc6f72741e6e70c6aedc9d0b697a0547be Mon Sep 17 00:00:00 2001 From: wiedersehne Date: Tue, 14 May 2024 10:29:48 +0200 Subject: [PATCH] code upload --- augment.py | 380 +++ cdna_classification.py | 288 +++ config/config.yaml | 103 + config/config_ct.yaml | 165 ++ config/config_fa.yaml | 205 ++ config/config_fa_sweep.yaml | 30 + config/config_gb.yaml | 340 +++ config/config_gue.yaml | 676 ++++++ contrastive_knn.py | 120 + contrastive_pretraining.py | 505 ++++ contrastive_visualization.py | 293 +++ data/generate_pretrain_human.py | 120 + data/genome_process.py | 99 + data/genomic_benchmark.py | 72 + data_utils.py | 43 + dna.sh | 51 + genomic_classification.py | 241 ++ gue_classification.py | 331 +++ mega.txt | 281 +++ models/Other_models/S4_model.py | 83 + models/Other_models/S4_src.py | 133 ++ models/Other_models/s4.py | 1964 ++++++++++++++++ models/SwanDNA.py | 689 ++++++ models/__pycache__/DNASwan.cpython-39.pyc | Bin 0 -> 11685 bytes models/__pycache__/Short_notebook_2.ipynb | 2035 +++++++++++++++++ models/__pycache__/SwanDNA.cpython-39.pyc | Bin 0 -> 17419 bytes models/__pycache__/cdilDNA.cpython-39.pyc | Bin 0 -> 7503 bytes .../pretraining_model.cpython-39.pyc | Bin 0 -> 3158 bytes models/cdilDNA.py | 208 ++ models/deeperdeepsea.py | 55 + models/pretraining_model.py | 134 ++ models/x_formers.py | 165 ++ 32 files changed, 9809 insertions(+) create mode 100644 augment.py create mode 100644 cdna_classification.py create mode 100644 config/config.yaml create mode 100644 config/config_ct.yaml create mode 100644 config/config_fa.yaml create mode 100644 config/config_fa_sweep.yaml create mode 100644 config/config_gb.yaml create mode 100644 config/config_gue.yaml create mode 100644 contrastive_knn.py create mode 100644 contrastive_pretraining.py create mode 100644 contrastive_visualization.py create mode 100644 data/generate_pretrain_human.py create mode 100644 data/genome_process.py create mode 100644 data/genomic_benchmark.py create mode 100644 data_utils.py create mode 100644 dna.sh create mode 100644 genomic_classification.py create mode 100644 gue_classification.py create mode 100644 mega.txt create mode 100644 models/Other_models/S4_model.py create mode 100644 models/Other_models/S4_src.py create mode 100644 models/Other_models/s4.py create mode 100644 models/SwanDNA.py create mode 100644 models/__pycache__/DNASwan.cpython-39.pyc create mode 100644 models/__pycache__/Short_notebook_2.ipynb create mode 100644 models/__pycache__/SwanDNA.cpython-39.pyc create mode 100644 models/__pycache__/cdilDNA.cpython-39.pyc create mode 100644 models/__pycache__/pretraining_model.cpython-39.pyc create mode 100644 models/cdilDNA.py create mode 100644 models/deeperdeepsea.py create mode 100644 models/pretraining_model.py create mode 100644 models/x_formers.py diff --git a/augment.py b/augment.py new file mode 100644 index 0000000..0d7d42f --- /dev/null +++ b/augment.py @@ -0,0 +1,380 @@ +""" +Library of data augmentations for genomic sequence data. + +To contribute a custom augmentation, use the following syntax: + +.. code-block:: python + + class CustomAugmentation(AugmentBase): + def __init__(self, param1, param2): + self.param1 = param1 + self.param2 = param2 + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + # Perform augmentation + return x_aug + +""" + +import torch + + +class AugmentBase: + """ + Base class for EvoAug augmentations for genomic sequences. + """ + def __call__(self, x): + """Return an augmented version of `x`. + + Parameters + ---------- + x : torch.Tensor + Batch of one-hot sequences (shape: (N, A, L)). + + Returns + ------- + torch.Tensor + Batch of one-hot sequences with random augmentation applied. + """ + raise NotImplementedError() + + +class RandomDeletion(AugmentBase): + """Randomly deletes a contiguous stretch of nucleotides from sequences in a training + batch according to a random number between a user-defined delete_min and delete_max. + A different deletion is applied to each sequence. + + Parameters + ---------- + delete_min : int, optional + Minimum size for random deletion (defaults to 0). + delete_max : int, optional + Maximum size for random deletion (defaults to 20). + """ + def __init__(self, delete_min=0, delete_max=20): + self.delete_min = delete_min + self.delete_max = delete_max + + def __call__(self, x): + """Randomly delete segments in a set of one-hot DNA sequences. + + Parameters + ---------- + x : torch.Tensor + Batch of one-hot sequences (shape: (N, A, L)). + + Returns + ------- + torch.Tensor + Sequences with randomly deleted segments (padded to correct shape + with random DNA) + """ + N, A, L = x.shape + + # sample random DNA + a = torch.eye(A) + p = torch.tensor([1/A for _ in range(A)]) + padding = torch.stack([a[p.multinomial(self.delete_max, replacement=True)].transpose(0,1) for _ in range(N)]).to(x.device) + + # sample deletion length for each sequence + delete_lens = torch.randint(self.delete_min, self.delete_max + 1, (N,)) + + # sample locations to delete for each sequence + delete_inds = torch.randint(L - self.delete_max + 1, (N,)) # deletion must be in boundaries of seq. + + # loop over each sequence + x_aug = [] + for seq, pad, delete_len, delete_ind in zip(x, padding, delete_lens, delete_inds): + + # get index of half delete_len (to pad random DNA at beginning of sequence) + pad_begin_index = torch.div(delete_len, 2, rounding_mode='floor').item() + + # index for other half (to pad random DNA at end of sequence) + pad_end_index = delete_len - pad_begin_index + + # removes deletion and pads beginning and end of sequence with random DNA to ensure same length + x_aug.append( torch.cat([pad[:,:pad_begin_index], # random dna padding + seq[:,:delete_ind], # sequence up to deletion start index + seq[:,delete_ind+delete_len:], # sequence after deletion end index + pad[:,self.delete_max-pad_end_index:]], # random dna padding + -1)) # concatenation axis + return torch.stack(x_aug) + + +class RandomInsertion(AugmentBase): + """Randomly inserts a contiguous stretch of nucleotides from sequences in a training + batch according to a random number between a user-defined insert_min and insert_max. + A different insertions is applied to each sequence. Each sequence is padded with random + DNA to ensure same shapes. + + Parameters + ---------- + insert_min : int, optional + Minimum size for random insertion, defaults to 0 + insert_max : int, optional + Maximum size for random insertion, defaults to 20 + """ + def __init__(self, insert_min=0, insert_max=20): + self.insert_min = insert_min + self.insert_max = insert_max + + def __call__(self, x): + """Randomly inserts segments of random DNA to a set of DNA sequences. + + Parameters + ---------- + x : torch.Tensor + Batch of one-hot sequences (shape: (N, A, L)). + + Returns + ------- + torch.Tensor + Sequences with randomly inserts segments of random DNA. All sequences + are padded with random DNA to ensure same shape. + """ + N, A, L = x.shape + + # sample random DNA + a = torch.eye(A) + p = torch.tensor([1/A for _ in range(A)]) + insertions = torch.stack([a[p.multinomial(self.insert_max, replacement=True)].transpose(0,1) for _ in range(N)]).to(x.device) + + # sample insertion length for each sequence + insert_lens = torch.randint(self.insert_min, self.insert_max + 1, (N,)) + + # sample locations to insertion for each sequence + insert_inds = torch.randint(L, (N,)) + + # loop over each sequence + x_aug = [] + for seq, insertion, insert_len, insert_ind in zip(x, insertions, insert_lens, insert_inds): + + # get index of half insert_len (to pad random DNA at beginning of sequence) + insert_beginning_len = torch.div((self.insert_max - insert_len), 2, rounding_mode='floor').item() + + # index for other half (to pad random DNA at end of sequence) + insert_end_len = self.insert_max - insert_len - insert_beginning_len + + # removes deletion and pads beginning and end of sequence with random DNA to ensure same length + x_aug.append( torch.cat([insertion[:,:insert_beginning_len], # random dna padding + seq[:,:insert_ind], # sequence up to insertion start index + insertion[:,insert_beginning_len:insert_beginning_len+insert_len], # random insertion + seq[:,insert_ind:], # sequence after insertion end index + insertion[:,insert_beginning_len+insert_len:self.insert_max]], # random dna padding + -1)) # concatenation axis + return torch.stack(x_aug) + + +class RandomTranslocation(AugmentBase): + """Randomly cuts sequence in two pieces and shifts the order for each in a training + batch. This is implemented with a roll transformation with a user-defined shift_min + and shift_max. A different roll (positive or negative) is applied to each sequence. + Each sequence is padded with random DNA to ensure same shapes. + + Parameters + ---------- + shift_min : int, optional + Minimum size for random shift, defaults to 0. + shift_max : int, optional + Maximum size for random shift, defaults to 20. + """ + def __init__(self, shift_min=0, shift_max=20): + self.shift_min = shift_min + self.shift_max = shift_max + + def __call__(self, x): + """Randomly shifts sequences in a batch, x. + + Parameters + ---------- + x : torch.Tensor + Batch of one-hot sequences (shape: (N, A, L)). + + Returns + ------- + torch.Tensor + Sequences with random translocations. + """ + N = x.shape[0] + + # determine size of shifts for each sequence + shifts = torch.randint(self.shift_min, self.shift_max + 1, (N,)) + + # make some of the shifts negative + ind_neg = torch.rand(N) < 0.5 + shifts[ind_neg] = -1 * shifts[ind_neg] + + # apply random shift to each sequence + x_rolled = [] + for i, shift in enumerate(shifts): + x_rolled.append( torch.roll(x[i], shift.item(), -1) ) + x_rolled = torch.stack(x_rolled).to(x.device) + return x_rolled + + + +class RandomInversion(AugmentBase): + """Randomly inverts a contiguous stretch of nucleotides from sequences in a training + batch according to a user-defined invert_min and invert_max. A different insertions + is applied to each sequence. Each sequence is padded with random DNA to ensure same + shapes. + + Parameters + ---------- + invert_min : int, optional + Minimum size for random insertion, defaults to 0. + invert_max : int, optional + Maximum size for random insertion, defaults to 20. + """ + def __init__(self, invert_min=0, invert_max=20): + self.invert_min = invert_min + self.invert_max = invert_max + + def __call__(self, x): + """Randomly inverts segments of random DNA to a set of one-hot DNA sequences. + + Parameters + ---------- + x : torch.Tensor + Batch of one-hot sequences (shape: (N, A, L)). + + Returns + ------- + torch.Tensor + Sequences with randomly inverted segments of random DNA. + """ + N, A, L = x.shape + + # set random inversion size for each seequence + inversion_lens = torch.randint(self.invert_min, self.invert_max + 1, (N,)) + + # randomly select start location for each inversion + inversion_inds = torch.randint(L - self.invert_max + 1, (N,)) # inversion must be in boundaries of seq. + + # apply random inversion to each sequence + x_aug = [] + for seq, inversion_len, inversion_ind in zip(x, inversion_lens, inversion_inds): + x_aug.append( torch.cat([seq[:,:inversion_ind], # sequence up to inversion start index + torch.flip(seq[:,inversion_ind:inversion_ind+inversion_len], dims=[0,1]), # reverse-complement transformation + seq[:,inversion_ind+inversion_len:]], # sequence after inversion + -1)) # concatenation axis + return torch.stack(x_aug) + + + +class RandomMutation(AugmentBase): + """Randomly mutates sequences in a training batch according to a user-defined + mutate_frac. A different set of mutations is applied to each sequence. + + Parameters + ---------- + mutate_frac : float, optional + Probability of mutation for each nucleotide, defaults to 0.05. + """ + def __init__(self, mutate_frac=0.05): + self.mutate_frac = mutate_frac + + def __call__(self, x): + """Randomly introduces mutations to a set of one-hot DNA sequences. + + Parameters + ---------- + x : torch.Tensor + Batch of one-hot sequences (shape: (N, A, L)). + + Returns + ------- + torch.Tensor + Sequences with randomly mutated DNA. + """ + N, A, L = x.shape + + # determine the number of mutations per sequence + num_mutations = round(self.mutate_frac / 0.75 * L) # num. mutations per sequence (accounting for silent mutations) + + # randomly determine the indices to apply mutations + mutation_inds = torch.argsort(torch.rand(N,L))[:, :num_mutations] # see 0 + + # create random DNA (to serve as random mutations) + a = torch.eye(A) + p = torch.tensor([1/A for _ in range(A)]) + mutations = torch.stack([a[p.multinomial(num_mutations, replacement=True)].transpose(0,1) for _ in range(N)]).to(x.device) + + # make a copy of the batch of sequences + x_aug = torch.clone(x) + + # loop over sequences and apply mutations + for i in range(N): + x_aug[i,:,mutation_inds[i]] = mutations[i] + return x_aug + + + +class RandomRC(AugmentBase): + """Randomly applies a reverse-complement transformation to each sequence in a training + batch according to a user-defined probability, rc_prob. This is applied to each sequence + independently. + + Parameters + ---------- + rc_prob : float, optional + Probability to apply a reverse-complement transformation, defaults to 0.5. + """ + def __init__(self, rc_prob=0.5): + """Creates random reverse-complement object usable by EvoAug. + """ + self.rc_prob = rc_prob + + def __call__(self, x): + """Randomly transforms sequences in a batch with a reverse-complement transformation. + + Parameters + ---------- + x : torch.Tensor + Batch of one-hot sequences (shape: (N, A, L)). + + Returns + ------- + torch.Tensor + Sequences with random reverse-complements applied. + """ + # make a copy of the sequence + x_aug = torch.clone(x) + + # randomly select sequences to apply rc transformation + ind_rc = torch.rand(x_aug.shape[0]) < self.rc_prob + + # apply reverse-complement transformation + x_aug[ind_rc] = torch.flip(x_aug[ind_rc], dims=[1,2]) + return x_aug + + +class RandomNoise(AugmentBase): + """Randomly add Gaussian noise to a batch of sequences with according to a user-defined + noise_mean and noise_std. A different set of noise is applied to each sequence. + + Parameters + ---------- + noise_mean : float, optional + Mean of the Gaussian noise, defaults to 0.0. + noise_std : float, optional + Standard deviation of the Gaussian noise, defaults to 0.2. + """ + def __init__(self, noise_mean=0.0, noise_std=0.2): + self.noise_mean = noise_mean + self.noise_std = noise_std + + def __call__(self, x): + """Randomly adds Gaussian noise to a set of one-hot DNA sequences. + + Parameters + ---------- + x : torch.Tensor + Batch of one-hot sequences (shape: (N, A, L)). + + Returns + ------- + torch.Tensor + Sequences with random noise. + """ + return x + torch.normal(self.noise_mean, self.noise_std, x.shape).to(x.device) \ No newline at end of file diff --git a/cdna_classification.py b/cdna_classification.py new file mode 100644 index 0000000..b3155d4 --- /dev/null +++ b/cdna_classification.py @@ -0,0 +1,288 @@ +import torch +import pandas as pd +import numpy as np +from torch import nn, optim +from omegaconf import OmegaConf +from functools import lru_cache +from sklearn.preprocessing import LabelBinarizer +from torch.utils.data import DataLoader +from torchmetrics import Accuracy, MatthewsCorrCoef, F1Score, AUROC +from torchmetrics.classification import MulticlassMatthewsCorrCoef +from models.SwanDNA import GB_Flash_Classifier, GB_Linear_Classifier +from data_utils import gb_Dataset +# from peft import get_peft_config, get_peft_model, LoraConfig, TaskType +import pytorch_lightning as pl +from transformers import get_cosine_schedule_with_warmup +from pytorch_lightning.strategies import DDPStrategy +from pytorch_lightning.loggers import WandbLogger +from pytorch_lightning.utilities.model_summary import ModelSummary +from ptflops import get_model_complexity_info +from flopth import flopth +from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, StochasticWeightAveraging, TQDMProgressBar +pl.seed_everything(42) + + +class LightningWrapper(pl.LightningModule): + def __init__(self, model, cfg, train_set, val_set, test_set, pretrained, loss, file_name): + super().__init__() + self.save_hyperparameters(cfg) + self.model_config = self.hparams.SwanDNA + self.batch_size = self.hparams.training.batch_size + self.output = self.hparams.SwanDNA.output_size + self.warm_up = self.hparams.training.n_warmup_steps + self.length = self.hparams.SwanDNA.max_len + self.model = model(**self.model_config) + self.save_every = self.hparams.training.save_every + self.train_set = train_set + self.val_set = val_set + self.test_set = test_set + self.loss = loss + self.file_name = file_name + if self.output == 2: + self.train_mcc = Accuracy(task='binary') + self.val_mcc = Accuracy(task='binary') + self.test_mcc = Accuracy(task='binary') + + if pretrained: + pretrained_path = f'./{self.file_name}' + pretrained = torch.load(pretrained_path, map_location='cpu') + pretrained = pretrained["Teacher"] + + from collections import OrderedDict + new_state_dict = OrderedDict() + + for k, v in pretrained.items(): + if k.startswith('encoder') or k.startswith('embedding'): + new_state_dict[k] = v + + net_dict = self.model.state_dict() + pretrained_cm = {k: v for k, v in new_state_dict.items() if k in net_dict} + net_dict.update(pretrained_cm) + self.model.load_state_dict(net_dict) + for k, v in self.model.state_dict().items(): + print(k, v) + print(self.file_name) + print("*************pretrained model loaded***************") + + + def forward(self, x): + # in lightning, forward defines the prediction/inference actions + return self.model(x) + + def _init_weights(self, m): + if isinstance(m, nn.Reear): + nn.init.xavier_uniform_(m.weight) + m.bias.data.fill_(0.01) + + def training_step(self, batch, batch_idx): + seq, label = batch + output = self.model(seq).squeeze() + preds = output.argmax(dim=-1) + train_loss = self.loss(output, label.to(torch.int64)) + self.train_mcc.update(preds, label.int()) + return {"loss":train_loss, "preds":preds, "labels":label} + + def validation_step(self, batch, batch_idx): + seq, label = batch + output = self.model(seq).squeeze() + preds = output.argmax(dim=-1) + val_loss = self.loss(output, label.to(torch.int64)) + self.val_mcc.update(preds, label.int()) + return {"loss":val_loss, "preds":preds, "labels":label} + + def test_step(self, batch, batch_idx): + seq, label = batch + output = self.model(seq).squeeze() + preds = output.argmax(dim=-1) + test_loss = self.loss(output, label.to(torch.int64)) + self.test_mcc.update(preds, label.int()) + return {"loss":test_loss, "preds":preds, "labels":label} + + def training_epoch_end(self, outputs): + train_loss = torch.stack([x["loss"] for x in outputs]).mean() + acc = self.train_mcc.compute().mean() + self.train_mcc.reset() + self.log('train_mcc', acc, sync_dist=True) + self.log('train_loss', train_loss, sync_dist=True) + + def validation_epoch_end(self, outputs): + val_loss = torch.stack([x["loss"] for x in outputs]).mean() + # label = torch.stack([x["labels"] for x in outputs]).reshape((-1,)) + # output = torch.stack([x["preds"] for x in outputs]).reshape((-1,)) + acc = self.val_mcc.compute().mean() + self.val_mcc.reset() + self.log("val_mcc", acc, sync_dist=True) + self.log('val_loss', val_loss, sync_dist=True) + + def test_epoch_end(self, outputs): + test_loss = torch.stack([x["loss"] for x in outputs]).mean() + # label = torch.stack([x["labels"] for x in outputs]).reshape((-1,)) + # output = torch.stack([x["preds"] for x in outputs]).reshape((-1,)) + acc = self.test_mcc.compute().mean() + self.val_mcc.reset() + self.log("test_mcc", acc, sync_dist=True) + self.log('test_loss', test_loss, sync_dist=True) + + def train_dataloader(self): + return DataLoader( + dataset=self.train_set, + num_workers=1, + pin_memory=True, + shuffle=True, + drop_last=True, + batch_size=self.batch_size + ) + + def val_dataloader(self): + return DataLoader( + dataset=self.val_set, + num_workers=1, + pin_memory=True, + shuffle=False, + drop_last=False, + batch_size=self.batch_size + ) + + def test_dataloader(self): + return DataLoader( + dataset=self.test_set, + num_workers=1, + pin_memory=True, + shuffle=False, + drop_last=True, + batch_size=self.batch_size + ) + + @lru_cache + def total_steps(self): + l = len(self.trainer._data_connector._train_dataloader_source.dataloader()) + print('Num devices', self.trainer.num_devices) + max_epochs = self.trainer.max_epochs + accum_batches = self.trainer.accumulate_grad_batches + manual_total_steps = (l // accum_batches * max_epochs)/self.trainer.num_devices + print('MANUAL Total steps', manual_total_steps) + return manual_total_steps + + def configure_optimizers(self): + optimizer = optim.AdamW( + self.parameters(), + lr=self.hparams.training.learning_rate, + weight_decay=self.hparams.training.weight_decay + ) + lr_scheduler = get_cosine_schedule_with_warmup( + optimizer, + num_warmup_steps=int(self.total_steps()*self.warm_up), #hyperparmeter [0.3, 0.4] + num_training_steps=self.total_steps(), + num_cycles=self.hparams.training.n_cycles + ) + return [optimizer], [{"scheduler": lr_scheduler, "interval": "step"}] + + +def sequence2onehot(data_file, lb, length): + ds = pd.read_csv(data_file) + sequences, labels = [],[] + for index, data in ds.iterrows(): + gene_to_number = lb.transform(list(data["sequence"])) + if gene_to_number.shape[0] == length: + sequences.append(gene_to_number) + labels.append(data["label"]) + X = torch.from_numpy(np.array(sequences)).to(torch.int8) + y = torch.from_numpy(np.array(labels)).to(torch.float16) + + return X, y + + +def classify_main(cfg, task, branch): + """ + 1. decide which tack to run + """ + config = cfg.MTcDNA_4k + + + """ + 2. load dataset. + """ + + pretrained = config.training.pretrained + length = config.SwanDNA.max_len + loss = nn.CrossEntropyLoss(reduction='mean') + + lb = LabelBinarizer() + lb.fit(['A', 'T', 'C', 'G', 'N']) + + df = pd.read_csv(f"./data/{task}/{branch}/train.csv") + print(df.describe()) + + train_X, train_y = sequence2onehot(f"./data/{task}/{branch}/train.csv", lb, length) + val_X, val_y = sequence2onehot(f"./data/{task}/{branch}/dev.csv", lb, length) + test_X, test_y = sequence2onehot(f"./data/{task}/{branch}/test.csv", lb, length) + print("***************data******************") + # print(train_X[0]) + print(train_X.size(), test_X.size(), val_X.size()) + + train_set = gb_Dataset(train_X, train_y) + val_set = gb_Dataset(val_X, val_y) + test_set = gb_Dataset(test_X, test_y) + + test_dalaloader = DataLoader( + dataset=test_set, + num_workers=1, + pin_memory=True, + shuffle=False, + drop_last=True, + batch_size=config.training.batch_size + ) + + """ + 3. strat training with ddp mode. + """ + + ddp = DDPStrategy(process_group_backend="nccl", find_unused_parameters=True) + pretrained_model = "model_29_1000_4l_308_512_noiseandTL.pt" + + model = LightningWrapper(GB_Linear_Classifier, config, train_set, val_set, test_set, pretrained, loss, pretrained_model) + summary = ModelSummary(model, max_depth=-1) + + dummy_inputs = torch.rand(1,1024,5).to('cuda') + flops, params = flopth(model, inputs=(dummy_inputs,)) + print("XXX", flops, params) + + """ + 4. init trainer + """ + + wandb_logger = WandbLogger(dir="./wandb/", project="Prom", entity='tonyu', name=f'{pretrained_model}_{length}_{task}') + checkpoint_callback = ModelCheckpoint(monitor="val_mcc", mode="max") + + lr_monitor = LearningRateMonitor(logging_interval='step') + callbacks_for_trainer = [TQDMProgressBar(refresh_rate=10), lr_monitor, checkpoint_callback] + if config.training.patience != -1: + early_stopping = EarlyStopping(monitor="val_mcc", mode="max", min_delta=0., patience=cfg.Fine_tuning.training.patience) + callbacks_for_trainer.append(early_stopping) + if config.training.swa_lrs != -1: + swa = StochasticWeightAveraging(swa_lrs=1e-2) + callbacks_for_trainer.append(swa) + + print(summary) + trainer = pl.Trainer( + check_val_every_n_epoch=1, + enable_progress_bar=True, + accelerator='gpu', + # strategy=ddp, + devices=[0], + max_epochs=config.training.n_epochs, + gradient_clip_val=0.5, + num_sanity_val_steps=0, + precision=16, + logger=wandb_logger, + callbacks=callbacks_for_trainer + ) + trainer.fit(model) + + trainer.test(model, test_dalaloader, "best") + + +if __name__ == "__main__": + cfg = OmegaConf.load('./config/config_gue.yaml') + OmegaConf.set_struct(cfg, False) + classify_main(cfg, "MTcDNA", "4096") diff --git a/config/config.yaml b/config/config.yaml new file mode 100644 index 0000000..3aa0df5 --- /dev/null +++ b/config/config.yaml @@ -0,0 +1,103 @@ +Pretraining: + training: + n_epochs: 30 + n_cores: 28 + device: "cuda" + patience: -1 + swa_lrs: -1 + batch_size: 64 + max_len: 1000 + n_warmup_steps: 40000 + n_cycles: 0.5 + weight_decay: 0.0003 + learning_rate: 0.0003 + save_every: 2500 + SwanDNA: + input_size: 5 + embedding_size: 10 + max_len: 1000 + group_size: 1 + hidden_size: 16 + mlp_dropout: 0 + layer_dropout: 0 + prenorm: None + norm: None + CDIL: + dim: 5 + hdim1: 128 + hdim2: 128 + kernel_size: 3 + n_layers: 9 + dropout: 0.0 +Fine_tuning: + training: + pretrained: True + batch_size: 32 + n_warmup_steps: 50000 + n_cycles: 0.5 + weight_decay: 0.0003 + learning_rate: 0.0003 + save_every: 2500 + n_epochs: 20 + device: "cuda" + patience: -1 + swa_lrs: -1 + Deepsea: + output_size: 49 + Transformer: + name: "transformer" + dim_in: 5 + dim_out: 16 + clf_dim: 16 + layers: 1 + heads: 1 + max_len: 1000 + output_size: 49 + Linformer: + name: "linformer" + dim_in: 5 + dim_out: 16 + clf_dim: 16 + layers: 2 + heads: 2 + max_len: 20000 + output_size: 49 + Mega: + name: "mega" + dim_in: 5 + dim_out: 16 + clf_dim: 16 + layers: 2 + heads: 2 + max_len: 20000 + output_size: 49 + S4: + name: "s4" + dim_in: 5 + dim_out: 16 + clf_dim: 16 + layers: 2 + heads: 2 + max_len: 20000 + output_size: 49 + Nystromformer: + name: "nystromer" + dim_in: 5 + dim_out: 16 + clf_dim: 16 + layers: 2 + heads: 2 + max_len: 20000 + output_size: 49 + SwanDNA: + input_size: 5 + output_size: 49 + embedding_size: 10 + max_len: 1000 + group_size: 1 + hidden_size: 16 + mlp_dropout: 0 + layer_dropout: 0 + prenorm: "None" + norm: "None" + coeff: 2 diff --git a/config/config_ct.yaml b/config/config_ct.yaml new file mode 100644 index 0000000..1d27a14 --- /dev/null +++ b/config/config_ct.yaml @@ -0,0 +1,165 @@ +Pretraining: + training: + n_epochs: 30 + n_cores: 28 + device: "cuda" + patience: -1 + swa_lrs: -1 + batch_size: 64 + max_len: 1000 + cls: 10 + n_warmup_steps: 40000 + n_cycles: 0.5 + weight_decay: 0.0003 + learning_rate: 0.0003 + save_every: 2500 + out_dim: 5 + global_crops_number: 10 + local_crops_number: 10 + warmup_teacher_temp: 0.04 + teacher_temp: 0.04 + warmup_teacher_patch_temp: 0.04 + teacher_patch_temp: 0.07 + warmup_teacher_temp_epochs: 30 + lambda1: 1.0 + lambda2: 1.0 + pred_start_epoch: 0 + momentum_teacher: 0.996 + Revolution: + dim: 5 + hdim1: 8 + hdim2: 16 + kernel_size: 3 + n_layers: 9 + dropout: 0.0 + SwanDNA: + input_size: 5 + embedding_size: 308 + max_len: 1010 + group_size: 28 + hidden_size: 512 + mlp_dropout: 0 + layer_dropout: 0 + prenorm: None + norm: None + Flash: + input_size: 5 + embedding_size: 512 + group_size: 256 + max_len: 1010 + Nystromformer: + seed: 3431 + dim: 5 + hdim1: 16 + hdim2: 8 + kernel_size: 3 + n_layers: 4 + n_heads: 4 + max_len: 1000 + batch_size: 64 + mask_ratio: 0.3 + real_mask: 0.8 + lr: 0.0003 + n_epochs: 50 + train_num: 40000 + n_cores: 28 + device: "cuda" +Fine_tuning: + training: + pretrained: True + batch_size: 64 + n_warmup_steps: 50000 + n_cycles: 0.5 + weight_decay: 0.0003 + learning_rate: 0.0003 + save_every: 2500 + n_epochs: 30 + device: "cuda" + patience: -1 + swa_lrs: -1 + Deepsea: + output_size: 49 + Revolution: + dim_in: 5 + dim_out: 8 + clf_dim: 16 + ks: 3 + layers: 10 + max_len: 1000 + output_size: 49 + dropout: 0.0 + Transformer: + name: "transformer" + dim_in: 5 + dim_out: 16 + clf_dim: 16 + layers: 1 + heads: 1 + max_len: 1000 + output_size: 49 + Linformer: + name: "linformer" + dim_in: 5 + dim_out: 16 + clf_dim: 16 + layers: 2 + heads: 2 + max_len: 20000 + output_size: 49 + Cosformer: + name: "cosformer" + dim_in: 5 + dim_out: 10 + clf_dim: 10 + layers: 2 + heads: 2 + max_len: 1000 + output_size: 49 + Flash: + name: "flash" + dim_in: 5 + dim_out: 64 + clf_dim: 64 + layers: 2 + heads: 8 + max_len: 1000 + output_size: 49 + Mega: + name: "mega" + dim_in: 5 + dim_out: 16 + clf_dim: 16 + layers: 2 + heads: 2 + max_len: 20000 + output_size: 49 + S4: + name: "s4" + dim_in: 5 + dim_out: 15 + clf_dim: 15 + layers: 2 + heads: 2 + max_len: 20000 + output_size: 49 + Nystromformer: + name: "nystromer" + dim_in: 5 + dim_out: 15 + clf_dim: 15 + layers: 2 + heads: 2 + max_len: 20000 + output_size: 49 + Chordmixer: + input_size: 5 + output_size: 49 + embedding_size: 154 + max_len: 1010 + track_size: 1 + hidden_size: 256 + mlp_dropout: 0 + layer_dropout: 0 + prenorm: "None" + norm: "None" + coeff: 1 diff --git a/config/config_fa.yaml b/config/config_fa.yaml new file mode 100644 index 0000000..b2307ce --- /dev/null +++ b/config/config_fa.yaml @@ -0,0 +1,205 @@ +Pretraining: + training: + n_epochs: 30 + n_cores: 28 + device: "cuda" + patience: -1 + swa_lrs: -1 + batch_size: 4 + max_len: 100000 + n_warmup_steps: 40000 + n_cycles: 0.5 + weight_decay: 0.1 + learning_rate: 0.0003 + save_every: 2000 + Flash: + input_size: 5 + embedding_size: 144 + max_len: 100000 + group_size: 8 + hidden_size: 256 + mlp_dropout: 0 + layer_dropout: 0 + prenorm: None + norm: None +Human_Promoter: + training: + name: human_promoter + pretrained: True + batch_size: 1024 + n_warmup_steps: 50000 + n_cycles: 0.5 + weight_decay: 0.1 + learning_rate: 0.003 + save_every: 2500 + n_epochs: 300 + device: "cuda" + patience: -1 + swa_lrs: -1 + Flash: + input_size: 5 + output_size: 2 + max_len: 251 + embedding_size: 128 + group_size: 64 +Human_Enhancers_Cohn: + training: + name: human_cohn + pretrained: True + batch_size: 1024 + n_warmup_steps: 50000 + n_cycles: 0.5 + weight_decay: 0.1 + learning_rate: 0.0003 + save_every: 2500 + n_epochs: 800 + device: "cuda" + patience: -1 + swa_lrs: -1 + Flash: + input_size: 5 + output_size: 2 + max_len: 500 + embedding_size: 512 + group_size: 256 +Demo_Human_Or_Worm: + training: + name: human_worm + pretrained: True + batch_size: 1024 + n_warmup_steps: 50000 + n_cycles: 0.5 + weight_decay: 0.3 + learning_rate: 0.1 + save_every: 2500 + n_epochs: 250 + device: "cuda" + patience: -1 + swa_lrs: -1 + Flash: + input_size: 5 + output_size: 2 + max_len: 200 + embedding_size: 512 + group_size: 256 +Demo_Mouse_Enhancers: + training: + name: mouse_enhancer + pretrained: True + batch_size: 32 + n_warmup_steps: 50000 + n_cycles: 0.5 + weight_decay: 0.001 + learning_rate: 0.001 + save_every: 2500 + n_epochs: 200 + device: "cuda" + patience: -1 + swa_lrs: -1 + Flash: + input_size: 5 + output_size: 2 + max_len: 4776 + embedding_size: 512 + group_size: 256 +Demo_Coding_Inter: + training: + name: coding_inter + pretrained: True + batch_size: 1024 + n_warmup_steps: 50000 + n_cycles: 0.5 + weight_decay: 0.1 + learning_rate: 0.08 + save_every: 2500 + n_epochs: 300 + device: "cuda" + patience: -1 + swa_lrs: -1 + Flash: + input_size: 5 + output_size: 2 + max_len: 200 + embedding_size: 512 + group_size: 256 +Human_Enhancers_Ensembl: + training: + name: human_ensembl + pretrained: True + batch_size: 1024 + n_warmup_steps: 50000 + n_cycles: 0.5 + weight_decay: 0.1 + learning_rate: 0.0001 + save_every: 2500 + n_epochs: 80 + device: "cuda" + patience: -1 + swa_lrs: -1 + Flash: + input_size: 5 + output_size: 2 + max_len: 573 + embedding_size: 512 + group_size: 256 +Human_Regulatory: + training: + name: human_regulatory + pretrained: True + batch_size: 1024 + n_warmup_steps: 50000 + n_cycles: 0.5 + weight_decay: 0.1 + learning_rate: 0.0001 + save_every: 2500 + n_epochs: 400 + device: "cuda" + patience: -1 + swa_lrs: -1 + Flash: + input_size: 5 + output_size: 3 + max_len: 802 + embedding_size: 512 + group_size: 256 +Human_Ocr_Ensembl: + training: + name: human_ocr + pretrained: True + batch_size: 1024 + n_warmup_steps: 50000 + n_cycles: 0.5 + weight_decay: 0.1 + learning_rate: 0.003 + save_every: 2500 + n_epochs: 40 + device: "cuda" + patience: -1 + swa_lrs: -1 + Flash: + input_size: 5 + output_size: 2 + max_len: 593 + embedding_size: 512 + group_size: 256 +Drop_Enhancer_Stark: + training: + name: drop_enhancer + pretrained: True + batch_size: 1024 + n_warmup_steps: 50000 + n_cycles: 0.5 + weight_decay: 0.1 + learning_rate: 0.003 + save_every: 2500 + n_epochs: 40 + device: "cuda" + patience: -1 + swa_lrs: -1 + Flash: + input_size: 5 + output_size: 2 + max_len: 4776 + embedding_size: 512 + group_size: 256 + diff --git a/config/config_fa_sweep.yaml b/config/config_fa_sweep.yaml new file mode 100644 index 0000000..f75c8ea --- /dev/null +++ b/config/config_fa_sweep.yaml @@ -0,0 +1,30 @@ +command: + - python3 + - hyper_search.py +method: bayes +metric: + goal: maximize + name: test_mcc +parameters: + batch_size: + values: + - 128 + - 64 + learning_rate: + values: + - 0.0005 + - 0.0003 + n_epochs: + values: + - 20 + - 25 + block_num: + values: + - 2 + - 3 + - 4 + group_size: + values: + - 40 + - 44 +program: hyper_search.py \ No newline at end of file diff --git a/config/config_gb.yaml b/config/config_gb.yaml new file mode 100644 index 0000000..3e1a445 --- /dev/null +++ b/config/config_gb.yaml @@ -0,0 +1,340 @@ +Pretraining: + training: + n_epochs: 30 + n_cores: 28 + device: "cuda" + patience: -1 + swa_lrs: -1 + batch_size: 4 + max_len: 100000 + n_warmup_steps: 40000 + n_cycles: 0.5 + weight_decay: 0.1 + learning_rate: 0.0003 + save_every: 2000 + SwanDNA: + input_size: 5 + embedding_size: 144 + max_len: 100000 + group_size: 8 + hidden_size: 256 + mlp_dropout: 0 + layer_dropout: 0 + prenorm: None + norm: None +Human_Promoter: + training: + name: human_promoter + pretrained: True + batch_size: 1024 + n_warmup_steps: 0.3 + n_cycles: 0.5 + weight_decay: 0.1 + learning_rate: 0.005 + save_every: 2500 + n_epochs: 200 + device: "cuda" + patience: -1 + swa_lrs: -1 + SwanDNA: + input_size: 5 + output_size: 2 + embedding_size: 308 + max_len: 251 + group_size: 50 + hidden_size: 512 + mlp_dropout: 0 + layer_dropout: 0 + prenorm: "None" + norm: "None" + coeff: 1.2 + block_num: 4 + CDIL: + dim_in: 5 + dim_out: 128 + output_size: 2 + ks: 3 + layers: 7 + dropout: 0.0 + max_len: 251 +Human_Enhancers_Cohn: + training: + name: human_cohn + pretrained: True + batch_size: 1024 + n_warmup_steps: 0.1 + n_cycles: 0.5 + weight_decay: 0.1 + learning_rate: 0.001 + save_every: 2500 + n_epochs: 50 + device: "cuda" + patience: -1 + swa_lrs: -1 + SwanDNA: + input_size: 5 + output_size: 2 + embedding_size: 308 + max_len: 500 + group_size: 40 + hidden_size: 512 + mlp_dropout: 0.3 + layer_dropout: 0.3 + prenorm: "None" + norm: "None" + coeff: 1.3 + block_num: 3 + CDIL: + dim_in: 5 + dim_out: 128 + output_size: 2 + ks: 3 + layers: 8 + dropout: 0.0 + max_len: 500 +Demo_Human_Or_Worm: + training: + name: human_worm + pretrained: True + batch_size: 1024 + n_warmup_steps: 0.3 + n_cycles: 0.5 + weight_decay: 0.1 + learning_rate: 0.0005 + save_every: 2500 + n_epochs: 100 + device: "cuda" + patience: -1 + swa_lrs: -1 + SwanDNA: + input_size: 5 + output_size: 2 + embedding_size: 308 + max_len: 200 + group_size: 50 + hidden_size: 512 + mlp_dropout: 0.2 + layer_dropout: 0.2 + prenorm: "None" + norm: "None" + coeff: 1.2 + block_num: 4 + CDIL: + dim_in: 5 + dim_out: 128 + output_size: 2 + ks: 3 + layers: 7 + dropout: 0.0 + max_len: 200 +Demo_Mouse_Enhancers: + training: + name: mouse_enhancer + pretrained: True + batch_size: 128 + n_warmup_steps: 0.3 + n_cycles: 0.5 + weight_decay: 0.1 + learning_rate: 0.0005 + save_every: 2500 + n_epochs: 150 + device: "cuda" + patience: -1 + swa_lrs: -1 + SwanDNA: + input_size: 5 + output_size: 2 + embedding_size: 308 + max_len: 4776 + group_size: 24 + hidden_size: 512 + mlp_dropout: 0 + layer_dropout: 0 + prenorm: "None" + norm: "None" + coeff: 1.2 + block_num: 2 + CDIL: + dim_in: 5 + dim_out: 128 + output_size: 2 + ks: 3 + layers: 11 + dropout: 0.0 + max_len: 4776 +Demo_Coding_Inter: + training: + name: coding_inter + pretrained: True + batch_size: 1024 + n_warmup_steps: 0.3 + n_cycles: 0.5 + weight_decay: 0.1 + learning_rate: 0.0005 + save_every: 2500 + n_epochs: 30 + device: "cuda" + patience: -1 + swa_lrs: -1 + SwanDNA: + input_size: 5 + output_size: 2 + embedding_size: 308 + max_len: 200 + group_size: 38 + hidden_size: 512 + mlp_dropout: 0.2 + layer_dropout: 0.1 + prenorm: "None" + norm: "None" + coeff: 1.2 + block_num: 2 + CDIL: + dim_in: 5 + dim_out: 128 + output_size: 2 + ks: 3 + layers: 7 + dropout: 0.0 + max_len: 200 +Human_Enhancers_Ensembl: + training: + name: human_ensembl + pretrained: True + batch_size: 512 + n_warmup_steps: 0.2 + n_cycles: 0.5 + weight_decay: 0.1 + learning_rate: 0.001 + save_every: 2500 + n_epochs: 80 + device: "cuda" + patience: -1 + swa_lrs: -1 + SwanDNA: + input_size: 5 + output_size: 2 + embedding_size: 308 + max_len: 573 + group_size: 40 + hidden_size: 512 + mlp_dropout: 0 + layer_dropout: 0 + prenorm: "None" + norm: "None" + coeff: 1.2 + block_num: 3 + CDIL: + dim_in: 5 + dim_out: 128 + output_size: 2 + ks: 3 + layers: 8 + dropout: 0.0 + max_len: 573 +Human_Regulatory: + training: + name: human_regulatory + pretrained: True + batch_size: 512 + n_warmup_steps: 0.2 + n_cycles: 0.5 + weight_decay: 0.1 + learning_rate: 0.0001 + save_every: 2500 + n_epochs: 50 + device: "cuda" + patience: -1 + swa_lrs: -1 + SwanDNA: + input_size: 5 + output_size: 3 + embedding_size: 308 + max_len: 802 + group_size: 36 + hidden_size: 512 + mlp_dropout: 0.1 + layer_dropout: 0.1 + prenorm: "None" + norm: "None" + coeff: 1.2 + block_num: 4 + CDIL: + dim_in: 5 + dim_out: 128 + output_size: 3 + ks: 3 + layers: 9 + dropout: 0.0 + max_len: 802 +Human_Ocr_Ensembl: + training: + name: human_ocr + pretrained: True + batch_size: 256 + n_warmup_steps: 0.3 + n_cycles: 0.5 + weight_decay: 0.1 + learning_rate: 0.0003 + save_every: 2500 + n_epochs: 50 + device: "cuda" + patience: -1 + swa_lrs: -1 + SwanDNA: + input_size: 5 + output_size: 2 + embedding_size: 308 + max_len: 593 + group_size: 30 + hidden_size: 512 + mlp_dropout: 0.3 + layer_dropout: 0.3 + prenorm: "None" + norm: "None" + # coeff: 1.2 + block_num: 2 + CDIL: + dim_in: 5 + dim_out: 128 + output_size: 2 + ks: 3 + layers: 8 + dropout: 0.0 + max_len: 593 +Drop_Enhancer_Stark: + training: + name: drop_enhancer + pretrained: True + batch_size: 1024 + n_warmup_steps: 50000 + n_cycles: 0.5 + weight_decay: 0.1 + learning_rate: 0.003 + save_every: 2500 + n_epochs: 40 + device: "cuda" + patience: -1 + swa_lrs: -1 + SwanDNA: + input_size: 5 + output_size: 2 + embedding_size: 154 + max_len: 593 + group_size: 20 + hidden_size: 256 + mlp_dropout: 0 + layer_dropout: 0 + prenorm: "None" + norm: "None" + coeff: 1.2 + block_num: 4 + CDIL: + dim_in: 5 + dim_out: 128 + output_size: 2 + ks: 3 + layers: 9 + dropout: 0.0 + max_len: 593 + diff --git a/config/config_gue.yaml b/config/config_gue.yaml new file mode 100644 index 0000000..ff4ebd9 --- /dev/null +++ b/config/config_gue.yaml @@ -0,0 +1,676 @@ +Pretraining: + training: + n_epochs: 30 + n_cores: 28 + device: "cuda" + patience: -1 + swa_lrs: -1 + batch_size: 4 + max_len: 100000 + n_warmup_steps: 40000 + n_cycles: 0.5 + weight_decay: 0.1 + learning_rate: 0.0003 + save_every: 2000 + SwanDNA: + input_size: 5 + embedding_size: 144 + max_len: 100000 + group_size: 8 + hidden_size: 256 + mlp_dropout: 0 + layer_dropout: 0 + prenorm: None + norm: None +H3: + training: + name: EMP/H3 + pretrained: True + batch_size: 32 + n_warmup_steps: 0.3 + n_cycles: 0.5 + weight_decay: 0.1 + learning_rate: 0.0005 + save_every: 2500 + n_epochs: 20 + device: "cuda" + patience: -1 + swa_lrs: -1 + SwanDNA: + input_size: 5 + output_size: 2 + embedding_size: 308 + max_len: 500 + group_size: 38 + hidden_size: 512 + mlp_dropout: 0.1 + layer_dropout: 0.1 + prenorm: "None" + norm: "None" + block_num: 4 +H3K4me1: + training: + name: H3K4me1 + pretrained: True + batch_size: 32 + n_warmup_steps: 0.3 + n_cycles: 0.5 + weight_decay: 0.1 + learning_rate: 0.0005 + save_every: 2500 + n_epochs: 20 + device: "cuda" + patience: -1 + swa_lrs: -1 + SwanDNA: + input_size: 5 + output_size: 2 + embedding_size: 308 + max_len: 500 + group_size: 40 + hidden_size: 512 + mlp_dropout: 0 + layer_dropout: 0 + prenorm: "None" + norm: "None" + block_num: 4 +H3K4me2: + training: + name: H3K4me2 + pretrained: True + batch_size: 32 + n_warmup_steps: 0.3 + n_cycles: 0.5 + weight_decay: 0.1 + learning_rate: 0.0005 + save_every: 2500 + n_epochs: 20 + device: "cuda" + patience: -1 + swa_lrs: -1 + SwanDNA: + input_size: 5 + output_size: 2 + embedding_size: 308 + max_len: 500 + group_size: 38 + hidden_size: 512 + mlp_dropout: 0.1 + layer_dropout: 0.1 + prenorm: "None" + norm: "None" + block_num: 4 +H3K4me3: + training: + name: H3K4me3 + pretrained: True + batch_size: 32 + n_warmup_steps: 0.3 + n_cycles: 0.5 + weight_decay: 0.1 + learning_rate: 0.0005 + save_every: 2500 + n_epochs: 20 + device: "cuda" + patience: -1 + swa_lrs: -1 + SwanDNA: + input_size: 5 + output_size: 2 + embedding_size: 308 + max_len: 500 + group_size: 38 + hidden_size: 512 + mlp_dropout: 0.1 + layer_dropout: 0.1 + prenorm: "None" + norm: "None" + block_num: 4 +H3K14ac: + training: + name: H3K14ac + pretrained: True + batch_size: 32 + n_warmup_steps: 0.3 + n_cycles: 0.5 + weight_decay: 0.1 + learning_rate: 0.0005 + save_every: 2500 + n_epochs: 20 + device: "cuda" + patience: -1 + swa_lrs: -1 + SwanDNA: + input_size: 5 + output_size: 2 + embedding_size: 308 + max_len: 500 + group_size: 38 + hidden_size: 512 + mlp_dropout: 0.1 + layer_dropout: 0.1 + prenorm: "None" + norm: "None" + block_num: 4 +H3K36me3: + training: + name: H3K36me3 + pretrained: True + batch_size: 32 + n_warmup_steps: 0.3 + n_cycles: 0.5 + weight_decay: 0.1 + learning_rate: 0.0005 + save_every: 2500 + n_epochs: 20 + device: "cuda" + patience: -1 + swa_lrs: -1 + SwanDNA: + input_size: 5 + output_size: 2 + embedding_size: 308 + max_len: 500 + group_size: 40 + hidden_size: 512 + mlp_dropout: 0 + layer_dropout: 0 + prenorm: "None" + norm: "None" + block_num: 4 +H4: + training: + name: H4 + pretrained: True + batch_size: 32 + n_warmup_steps: 0.3 + n_cycles: 0.5 + weight_decay: 0.1 + learning_rate: 0.0005 + save_every: 2500 + n_epochs: 20 + device: "cuda" + patience: -1 + swa_lrs: -1 + SwanDNA: + input_size: 5 + output_size: 2 + embedding_size: 308 + max_len: 500 + group_size: 40 + hidden_size: 512 + mlp_dropout: 0.1 + layer_dropout: 0.1 + prenorm: "None" + norm: "None" + block_num: 4 +H4ac: + training: + name: H4ac + pretrained: True + batch_size: 32 + n_warmup_steps: 0.3 + n_cycles: 0.5 + weight_decay: 0.1 + learning_rate: 0.0005 + save_every: 2500 + n_epochs: 20 + device: "cuda" + patience: -1 + swa_lrs: -1 + SwanDNA: + input_size: 5 + output_size: 2 + embedding_size: 308 + max_len: 500 + group_size: 40 + hidden_size: 512 + mlp_dropout: 0 + layer_dropout: 0 + prenorm: "None" + norm: "None" + block_num: 4 +H3K79me3: + training: + name: H3K79me3 + pretrained: True + batch_size: 32 + n_warmup_steps: 0.3 + n_cycles: 0.5 + weight_decay: 0.1 + learning_rate: 0.0005 + save_every: 2500 + n_epochs: 20 + device: "cuda" + patience: -1 + swa_lrs: -1 + SwanDNA: + input_size: 5 + output_size: 2 + embedding_size: 308 + max_len: 500 + group_size: 40 + hidden_size: 512 + mlp_dropout: 0 + layer_dropout: 0 + prenorm: "None" + norm: "None" + block_num: 4 +H3K9ac: + training: + name: H3K9ac + pretrained: True + batch_size: 32 + n_warmup_steps: 0.3 + n_cycles: 0.5 + weight_decay: 0.1 + learning_rate: 0.0005 + save_every: 2500 + n_epochs: 20 + device: "cuda" + patience: -1 + swa_lrs: -1 + SwanDNA: + input_size: 5 + output_size: 2 + embedding_size: 308 + max_len: 500 + group_size: 40 + hidden_size: 512 + mlp_dropout: 0 + layer_dropout: 0 + prenorm: "None" + norm: "None" + block_num: 4 +Splice: + training: + name: Splice + pretrained: True + batch_size: 64 + n_warmup_steps: 0.1 + n_cycles: 0.5 + weight_decay: 0.01 + learning_rate: 0.00005 + save_every: 2500 + n_epochs: 20 + device: "cuda" + patience: -1 + swa_lrs: -1 + SwanDNA: + input_size: 5 + output_size: 3 + embedding_size: 308 + max_len: 400 + group_size: 34 + hidden_size: 512 + mlp_dropout: 0.1 + layer_dropout: 0.1 + prenorm: "None" + norm: "None" + block_num: 4 +virus: + training: + name: virus + pretrained: True + batch_size: 256 + n_warmup_steps: 0.3 + n_cycles: 0.5 + weight_decay: 0.01 + learning_rate: 0.001 + save_every: 2500 + n_epochs: 100 + device: "cuda" + patience: -1 + swa_lrs: -1 + SwanDNA: + input_size: 5 + output_size: 9 + embedding_size: 308 + max_len: 999 + group_size: 28 + hidden_size: 512 + mlp_dropout: 0.1 + layer_dropout: 0.1 + prenorm: "None" + norm: "None" + block_num: 4 +Prom_notata: + training: + name: Prom_notata + pretrained: True + batch_size: 512 + n_warmup_steps: 0.3 + n_cycles: 0.5 + weight_decay: 0.1 + learning_rate: 0.0005 + save_every: 2500 + n_epochs: 10 + device: "cuda" + patience: -1 + swa_lrs: -1 + SwanDNA: + input_size: 5 + output_size: 2 + embedding_size: 308 + max_len: 70 + group_size: 44 + hidden_size: 512 + mlp_dropout: 0.1 + layer_dropout: 0.1 + prenorm: "None" + norm: "None" + block_num: 4 +Prom_tata: + training: + name: Prom_tata + pretrained: True + batch_size: 32 + n_warmup_steps: 0.3 + n_cycles: 0.5 + weight_decay: 0.1 + learning_rate: 0.0005 + save_every: 2500 + n_epochs: 10 + device: "cuda" + patience: -1 + swa_lrs: -1 + SwanDNA: + input_size: 5 + output_size: 2 + embedding_size: 308 + max_len: 70 + group_size: 44 + hidden_size: 512 + mlp_dropout: 0.2 + layer_dropout: 0.2 + prenorm: "None" + norm: "None" + block_num: 4 +Prom_all: + training: + name: Prom_all + pretrained: True + batch_size: 32 + n_warmup_steps: 0.3 + n_cycles: 0.5 + weight_decay: 0.1 + learning_rate: 0.0005 + save_every: 2500 + n_epochs: 10 + device: "cuda" + patience: -1 + swa_lrs: -1 + SwanDNA: + input_size: 5 + output_size: 2 + embedding_size: 308 + max_len: 70 + group_size: 44 + hidden_size: 512 + mlp_dropout: 0.1 + layer_dropout: 0.1 + prenorm: "None" + norm: "None" + block_num: 4 +Prom_300_notata: + training: + name: Prom_300_notata + pretrained: True + batch_size: 32 + n_warmup_steps: 0.3 + n_cycles: 0.5 + weight_decay: 0.1 + learning_rate: 0.0005 + save_every: 2500 + n_epochs: 10 + device: "cuda" + patience: -1 + swa_lrs: -1 + SwanDNA: + input_size: 5 + output_size: 2 + embedding_size: 308 + max_len: 300 + group_size: 36 + hidden_size: 512 + mlp_dropout: 0.1 + layer_dropout: 0.1 + prenorm: "None" + norm: "None" + block_num: 4 +Prom_300_tata: + training: + name: Prom_300_tata + pretrained: True + batch_size: 32 + n_warmup_steps: 0.3 + n_cycles: 0.5 + weight_decay: 0.1 + learning_rate: 0.0005 + save_every: 2500 + n_epochs: 10 + device: "cuda" + patience: -1 + swa_lrs: -1 + SwanDNA: + input_size: 5 + output_size: 2 + embedding_size: 308 + max_len: 300 + group_size: 34 + hidden_size: 512 + mlp_dropout: 0 + layer_dropout: 0 + prenorm: "None" + norm: "None" + block_num: 4 +Prom_300_all: + training: + name: Prom_300_all + pretrained: True + batch_size: 256 + n_warmup_steps: 0.3 + n_cycles: 0.5 + weight_decay: 0.1 + learning_rate: 0.0005 + save_every: 2500 + n_epochs: 20 + device: "cuda" + patience: -1 + swa_lrs: -1 + SwanDNA: + input_size: 5 + output_size: 2 + embedding_size: 308 + max_len: 300 + group_size: 38 + hidden_size: 512 + mlp_dropout: 0.1 + layer_dropout: 0.1 + prenorm: "None" + norm: "None" + block_num: 4 +tf1: + training: + name: tf1 + pretrained: True + batch_size: 128 + n_warmup_steps: 0.1 + n_cycles: 0.5 + weight_decay: 0.01 + learning_rate: 0.0005 + save_every: 2500 + n_epochs: 20 + device: "cuda" + patience: -1 + swa_lrs: -1 + SwanDNA: + input_size: 5 + output_size: 2 + embedding_size: 308 + max_len: 101 + group_size: 44 + hidden_size: 512 + mlp_dropout: 0.2 + layer_dropout: 0.2 + prenorm: "None" + norm: "None" + block_num: 2 +tf2: + training: + name: tf2 + pretrained: True + batch_size: 256 + n_warmup_steps: 0.1 + n_cycles: 0.5 + weight_decay: 0.01 + learning_rate: 0.0005 + save_every: 2500 + n_epochs: 20 + device: "cuda" + patience: -1 + swa_lrs: -1 + SwanDNA: + input_size: 5 + output_size: 2 + embedding_size: 308 + max_len: 101 + group_size: 44 + hidden_size: 512 + mlp_dropout: 0.2 + layer_dropout: 0.2 + prenorm: "None" + norm: "None" + block_num: 4 +tf3: + training: + name: tf3 + pretrained: True + batch_size: 256 + n_warmup_steps: 0.1 + n_cycles: 0.5 + weight_decay: 0.01 + learning_rate: 0.0005 + save_every: 2500 + n_epochs: 20 + device: "cuda" + patience: -1 + swa_lrs: -1 + SwanDNA: + input_size: 5 + output_size: 2 + embedding_size: 308 + max_len: 101 + group_size: 38 + hidden_size: 512 + mlp_dropout: 0.2 + layer_dropout: 0.2 + prenorm: "None" + norm: "None" + block_num: 4 +MTcDNA_16k: + training: + name: MTcDNA_16k + pretrained: True + batch_size: 16 + n_warmup_steps: 0.1 + n_cycles: 0.5 + weight_decay: 0.01 + learning_rate: 0.0001 + save_every: 2500 + n_epochs: 50 + device: "cuda" + patience: -1 + swa_lrs: -1 + SwanDNA: + input_size: 5 + output_size: 2 + embedding_size: 308 + max_len: 16384 + group_size: 20 + hidden_size: 512 + mlp_dropout: 0.1 + layer_dropout: 0.1 + prenorm: "None" + norm: "None" + block_num: 4 +MTcDNA_1k: + training: + name: MTcDNA_1k + pretrained: True + batch_size: 16 + n_warmup_steps: 0.2 + n_cycles: 0.5 + weight_decay: 0.1 + learning_rate: 0.0001 + save_every: 2500 + n_epochs: 80 + device: "cuda" + patience: -1 + swa_lrs: -1 + SwanDNA: + input_size: 5 + output_size: 2 + embedding_size: 308 + max_len: 1024 + group_size: 28 + hidden_size: 512 + mlp_dropout: 0.1 + layer_dropout: 0.1 + prenorm: "None" + norm: "None" + block_num: 4 +MTcDNA_4k: + training: + name: MTcDNA_4k + pretrained: True + batch_size: 16 + n_warmup_steps: 0.2 + n_cycles: 0.5 + weight_decay: 0.1 + learning_rate: 0.0001 + save_every: 2500 + n_epochs: 50 + device: "cuda" + patience: -1 + swa_lrs: -1 + SwanDNA: + input_size: 5 + output_size: 2 + embedding_size: 308 + max_len: 4096 + group_size: 24 + hidden_size: 512 + mlp_dropout: 0.1 + layer_dropout: 0.1 + prenorm: "None" + norm: "None" + block_num: 4 +MTcDNA_8k: + training: + name: MTcDNA_8k + pretrained: True + batch_size: 16 + n_warmup_steps: 0.2 + n_cycles: 0.5 + weight_decay: 0.1 + learning_rate: 0.0001 + save_every: 2500 + n_epochs: 50 + device: "cuda" + patience: -1 + swa_lrs: -1 + SwanDNA: + input_size: 5 + output_size: 2 + embedding_size: 308 + max_len: 8192 + group_size: 22 + hidden_size: 512 + mlp_dropout: 0.1 + layer_dropout: 0.1 + prenorm: "None" + norm: "None" + block_num: 4 + + diff --git a/contrastive_knn.py b/contrastive_knn.py new file mode 100644 index 0000000..ad01024 --- /dev/null +++ b/contrastive_knn.py @@ -0,0 +1,120 @@ +from sklearn.neighbors import KNeighborsClassifier +from models.pretraining_model import Model4TSNE +import torch +import torch.nn as nn +import numpy as np +from torch.utils.data import DataLoader, TensorDataset +torch.manual_seed(42) + + +def cls_augment(masked_gene, local_cls_number): + N, L, D = masked_gene.shape + cls_masked = torch.zeros(N, local_cls_number, D) + + masked_gene = torch.cat((masked_gene, cls_masked), 1) + return masked_gene + +class CustomDataset(torch.utils.data.Dataset): + + def __init__(self, X): + super().__init__() + + self.X = X + + def __getitem__(self, idx): + return self.X[idx] # In case you stored your data on a list called instances + + def __len__(self): + return len(self.X) + +task_list = ['demo_coding_vs_intergenomic_seqs', 'demo_human_or_worm', 'dummy_mouse_enhancers_ensembl', 'human_enhancers_cohn', 'human_enhancers_ensembl', 'human_ensembl_regulatory', 'human_nontata_promoters', 'human_ocr_ensembl'] + + +if __name__ == "__main__": + task = "human_nontata_promoters" + pretrained_path = f'./Pretrained_models/model_29_1000_2l_154_256_noaug.pt' + # Step 1: Load the pretrained model + pretrained_model = torch.load(pretrained_path, map_location='cpu')["Teacher"] + + model = Model4TSNE(input_size=5, max_len=1000, embedding_size=154, track_size=14, hidden_size=256, mlp_dropout=0, layer_dropout=0, prenorm='None', norm='None') + + # for k, v in pretrained_model.items(): + # print(k, v) + + # Step 2: Extract the encoder + from collections import OrderedDict + new_state_dict = OrderedDict() + + for k, v in pretrained_model.items(): + if k.startswith('encoder') or k.startswith('embedding'): + print("*********************************************") + new_state_dict[k] = v + + net_dict = model.state_dict() + pretrained_cm = {k: v for k, v in new_state_dict.items() if k in net_dict} + + net_dict.update(pretrained_cm) + model.load_state_dict(net_dict) + + model = model.cuda() + + # print("................", torch.cuda.memory_allocated()) + + # for k, v in model.state_dict().items(): + # print(k, v) + + # Step 4: Perform inference with the encoder + train_X = torch.load(f"./data/{task}_X_train.pt") + train_y = torch.load(f"./data/{task}_y_train.pt") + test_X = torch.load(f"./data/{task}_X_test.pt") + test_y = torch.load(f"./data/{task}_y_test.pt") + + print(train_X.shape) + + # train_X = cls_augment(train_X, 10) + # test_X = cls_augment(test_X, 10) + train_dataset = CustomDataset(train_X) + test_dataset = CustomDataset(test_X) + + # Define batch sizes for training and testing + batch_size = 64 # You can adjust this to your preferred batch size + + # Create DataLoader for training data + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False) + + # Create DataLoader for test data + test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) + + train_output = [] + test_output = [] + + with torch.no_grad(): + model.eval() + for x in train_loader: + output = model(x.cuda()) + train_output.extend(output) + + for x in test_loader: + output = model(x.cuda()) + test_output.extend(output) + + train_X = torch.stack(train_output, dim=0) + test_X = torch.stack(test_output, dim=0) + + # print(train_X.shape) + # with torch.no_grad(): + # model.eval() + # train_X = model(train_X.cuda()) + # test_X = model(test_X.cuda()) + + train_X = train_X.mean(dim=1) + + test_X = test_X.mean(dim=1) + + + print(train_X.shape) + + knn = KNeighborsClassifier(n_neighbors=5) + knn.fit(train_X.cpu(), train_y.cpu()) + accuracy = knn.score(test_X.cpu(), test_y.cpu()) + print(f"KNN Classifier Accuracy: {accuracy}") \ No newline at end of file diff --git a/contrastive_pretraining.py b/contrastive_pretraining.py new file mode 100644 index 0000000..4432d1c --- /dev/null +++ b/contrastive_pretraining.py @@ -0,0 +1,505 @@ +from random import random as rand +from evoaug import evoaug, augment +import torch +import sys +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist +import pandas as pd +from torch.utils.data import Dataset, DataLoader +from omegaconf import OmegaConf +import numpy as np +from functools import lru_cache +import pytorch_lightning as pl +from torch import nn, optim +from augment import RandomDeletion, RandomInsertion, RandomTranslocation, RandomNoise, RandomRC +from transformers import get_cosine_schedule_with_warmup +from models.pretraining_model import Model4Pretrain, Model4PretrainFlash +from pytorch_lightning.strategies import DDPStrategy +from pytorch_lightning.loggers import WandbLogger +from pytorch_lightning.utilities.model_summary import ModelSummary +from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, StochasticWeightAveraging, TQDMProgressBar +pl.seed_everything(42) + + +def pretrain_loss(loss, preds, labels, masks): + masks_new = masks.repeat(5, 1, 1)#.reshape(preds.shape) + # print("losssssss", masks_new.shape, preds.shape, labels.shape) + masks_new = torch.reshape(masks_new, preds.shape) + + print(labels[0][0:10]) + print(preds[0][0:10]) + + labels = labels[masks_new == 1] + preds = preds[masks_new == 1] + + return loss(preds.float(), labels.float()) + + +class iBOTLoss(nn.Module): + def __init__(self, out_dim, patch_out_dim, ngcrops, nlcrops, warmup_teacher_temp, + teacher_temp, warmup_teacher_temp2, teacher_temp2, + warmup_teacher_temp_epochs, nepochs, student_temp=0.1, + center_momentum=0.9, center_momentum2=0.9, + lambda1=1.0, lambda2=1.0, mim_start_epoch=0, length=1000): + super().__init__() + self.student_temp = student_temp + self.center_momentum = center_momentum + self.center_momentum2 = center_momentum2 + self.ngcrops = ngcrops + self.nlcrops = nlcrops + self.ncrops = ngcrops + nlcrops + self.register_buffer("center", torch.zeros(1, out_dim)) + self.register_buffer("center2", torch.zeros(1, 1, out_dim)) + self.lambda1 = lambda1 + self.lambda2 = lambda2 + self.none_cls_length = length + + # we apply a warm up for the teacher temperature because + # a too high temperature makes the training instable at the beginning + self.teacher_temp_schedule = np.concatenate(( + np.linspace(warmup_teacher_temp, + teacher_temp, warmup_teacher_temp_epochs), + np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp + )) + self.teacher_temp2_schedule = np.concatenate(( + np.linspace(warmup_teacher_temp2, + teacher_temp2, warmup_teacher_temp_epochs), + np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp2 + )) if mim_start_epoch == 0 else np.concatenate(( + np.ones(mim_start_epoch) * warmup_teacher_temp2, + np.linspace(warmup_teacher_temp2, + teacher_temp2, warmup_teacher_temp_epochs), + np.ones(nepochs - warmup_teacher_temp_epochs - mim_start_epoch) * teacher_temp2 + )) + + def forward(self, student_output, teacher_output, student_local_cls, student_mask, epoch): + """ + Cross-entropy between softmax outputs of the teacher and student networks. + """ + # print(student_output[0].shape) + student_patch = student_output[0][:,0: self.none_cls_length,:] + student_cls = student_output[0][:,self.none_cls_length:,:] + teacher_patch = teacher_output[:, 0:self.none_cls_length, :] + teacher_cls = teacher_output[:, self.none_cls_length:,:] + + # print("*******", student_cls.shape, student_patch.shape) + + if student_local_cls is not None: + student_cls = torch.cat([student_cls, student_local_cls]) + + # [CLS] and patch for global patches + student_cls = student_cls / self.student_temp + # student_cls_c = student_cls.chunk(self.ncrops) + student_patch = student_patch / self.student_temp + # student_patch_c = student_patch.chunk(self.ngcrops) + + # teacher centering and sharpening + temp = self.teacher_temp_schedule[epoch] + temp2 = self.teacher_temp2_schedule[epoch] + # print(teacher_cls.shape, self.center.shape) + teacher_cls_c = F.softmax((teacher_cls - self.center) / temp, dim=-1) + teacher_cls_c = teacher_cls_c.detach() + teacher_patch_c = F.softmax((teacher_patch - self.center2) / temp2, dim=-1) + teacher_patch_c = teacher_patch_c.detach() + + print(teacher_cls_c.shape, student_cls.shape) + + total_loss1 = 0 + total_loss2 = 0 + + total_loss1 = torch.sum(-teacher_cls_c * F.log_softmax(student_cls, dim=-1), dim=-1).mean() + + loss_func = nn.BCEWithLogitsLoss(reduction='mean') + + total_loss2 = pretrain_loss(loss_func, student_patch, student_output[1], student_output[2]) + + total_loss1 = total_loss1 * self.lambda1 + total_loss2 = total_loss2 * self.lambda2 + print("loss1", total_loss1, "loss2", total_loss2) + total_loss = dict(cls=total_loss1, patch=total_loss2, loss=total_loss1 + total_loss2) + self.update_center(teacher_cls, teacher_patch) + + return total_loss + + @torch.no_grad() + def update_center(self, teacher_cls): + """ + Update center used for teacher output. + """ + cls_center = torch.sum(teacher_cls, dim=0, keepdim=True) + # dist.all_reduce(cls_center) + cls_center = cls_center / len(teacher_cls) # * dist.get_world_size()) + self.center = self.center * self.center_momentum + cls_center * (1 - self.center_momentum) + + +class DatasetCreator(Dataset): + """ + Class to construct a dataset for training/inference + """ + def __init__(self, original_gene, augmented_genes, masked_genes, masks): + self.genes = original_gene + self.augmented_genes = augmented_genes + self.masked_genes = masked_genes + self.masks = masks + + def __getitem__(self, index): + return (self.genes[index], self.augmented_genes[index], self.masked_genes[index], self.masks[index]) + + def __len__(self): + return len(self.genes) + + +def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0): + warmup_schedule = np.array([]) + warmup_iters = warmup_epochs * niter_per_ep + if warmup_epochs > 0: + warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) + + iters = np.arange(epochs * niter_per_ep - warmup_iters) + schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) + + schedule = np.concatenate((warmup_schedule, schedule)) + assert len(schedule) == epochs * niter_per_ep + return schedule + + +class LightningWrapper(pl.LightningModule): + def __init__(self, model, cfg, snapshot_path, train_set, val_set, loss): + super().__init__() + self.save_hyperparameters(cfg) + self.model_config = self.hparams.training + self.arch_config = self.hparams.SwanDNA + self.batch_size = self.hparams.training.batch_size + self.length = self.model_config.max_len + self.student = model(**self.arch_config) + self.teacher = model(**self.arch_config) + self.teacher.load_state_dict(self.student.state_dict(), strict=False) + self.save_every = self.hparams.training.save_every + self.snapshot_path = snapshot_path + self.train_set = train_set + self.val_set = val_set + self.loss = iBOTLoss( + self.model_config.out_dim, + self.model_config.out_dim, + self.model_config.global_crops_number, + self.model_config.local_crops_number, + self.model_config.warmup_teacher_temp, + self.model_config.teacher_temp, + self.model_config.warmup_teacher_patch_temp, + self.model_config.teacher_patch_temp, + self.model_config.warmup_teacher_temp_epochs, + self.model_config.n_epochs, + lambda1=self.model_config.lambda1, + lambda2=self.model_config.lambda2, + mim_start_epoch=self.model_config.pred_start_epoch, + length=self.model_config.max_len + ) + self.momentum_schedule = cosine_scheduler(0.996, 1, self.model_config.n_epochs, len(self.train_dataloader())) + + for p in self.teacher.parameters(): + p.requires_grad = False + + print(self.student, self.teacher) + + def forward(self, x): + # in lightning, forward defines the prediction/inference actions + return self.model(x) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + m.bias.data.fill_(0.01) + + def training_step(self, batch, batch_idx): + # common params + names_q, params_q, names_k, params_k = [], [], [], [] + for name_q, param_q in self.student.state_dict().items(): + names_q.append(name_q) + params_q.append(param_q) + for name_k, param_k in self.teacher.state_dict().items(): + names_k.append(name_k) + params_k.append(param_k) + names_common = list(set(names_q) & set(names_k)) + params_q = [param_q for name_q, param_q in zip(names_q, params_q) if name_q in names_common] + params_k = [param_k for name_k, param_k in zip(names_k, params_k) if name_k in names_common] + + original_gene, augmented_gene, masked_gene, masks = batch + # print("origin", original_gene.shape) + # get global views + teacher_output = self.teacher(augmented_gene) + student_output = [self.student(masked_gene), original_gene, masks] + + # get local views + # self.student.module.backbone.masked_im_modeling = False + # student_local_cls = self.student(masked_gene[self.model_config.global_crops_number:])[0] if len(masked_gene) > self.model_config.global_crops_number else None + student_local_cls = None + # self.student.module.backbone.masked_im_modeling = self.model_config.use_masked_im_modeling + + all_loss = self.loss(student_output, teacher_output, student_local_cls, masks, self.current_epoch) + loss = all_loss.pop('loss') + cls_loss = all_loss.pop('cls') + mlm_loss = all_loss.pop('patch') + + with torch.no_grad(): + # m = self.optimizers().param_groups[0]['lr']#/self.model_config.learning_rate # momentum parameter + m = self.momentum_schedule[self.global_step] + for param_q, param_k in zip(params_q, params_k): + param_k.data.mul_(m).add_((1 - m) * param_q.detach().data) + + self.log('train_loss', loss, sync_dist=True) + self.log('cls_loss', cls_loss, sync_dist=True) + self.log('mlm_loss', mlm_loss, sync_dist=True) + + return {"loss":loss} + + def training_epoch_end(self, outputs): + if self.current_epoch ==9 or self.current_epoch == self.model_config.n_epochs-1: + self._save_snapshot() + + def validation_step(self, batch, batch_idx): + original_gene, augmented_gene, masked_gene, masks = batch + # get global views + teacher_output = self.teacher(augmented_gene) + student_output = [self.student(masked_gene), original_gene, masks] + + # get local views + # student_local_cls = self.student(masked_gene[self.model_config.global_crops_number:])[0] if len(masked_gene) > self.model_config.global_crops_number else None + + student_local_cls = None + + all_loss = self.loss(student_output, teacher_output, student_local_cls, masks, self.current_epoch) + loss = all_loss.pop('loss') + cls_loss = all_loss.pop('cls') + mlm_loss = all_loss.pop('patch') + + return {"loss":loss, "cls_loss":cls_loss, "mlm_loss":mlm_loss} + + def validation_epoch_end(self, outputs): + val_loss = torch.stack([x["loss"] for x in outputs]).mean() + val_cls_loss = torch.stack([x["cls_loss"] for x in outputs]).mean() + val_mlm_loss = torch.stack([x["mlm_loss"] for x in outputs]).mean() + self.log('val_loss', val_loss, sync_dist=True) + self.log('val_cls_loss', val_cls_loss, sync_dist=True) + self.log('val_mlm_loss', val_mlm_loss, sync_dist=True) + + def _save_snapshot(self): + snapshot = { + "Teacher": self.teacher.state_dict(), + "Student": self.student.state_dict(), + "EPOCHS_RUN": self.current_epoch , + } + torch.save(snapshot, f"{self.snapshot_path}/model_{self.current_epoch}_{self.length}_4l_308_512_noiseandTL.pt") + print(f"Epoch {self.current_epoch } | Training snapshot saved at {self.snapshot_path}") + + def _load_snapshot(self, snapshot_path): + loc = f"cuda:0" + snapshot = torch.load(snapshot_path, map_location=loc) + self.model.load_state_dict(snapshot["MODEL_STATE"]) + self.epochs_run = snapshot["EPOCHS_RUN"] + print(f"Resuming training from snapshot at Epoch {self.epochs_run}") + + def train_dataloader(self): + return DataLoader( + dataset=self.train_set, + num_workers=1, + pin_memory=True, + shuffle=True, + drop_last=True, + batch_size=self.batch_size + ) + + def val_dataloader(self): + return DataLoader( + dataset=self.val_set, + num_workers=1, + pin_memory=True, + shuffle=False, + drop_last=True, + batch_size=self.batch_size + ) + + @lru_cache + def total_steps(self): + l = len(self.trainer._data_connector._train_dataloader_source.dataloader()) + print('Num devices', self.trainer.num_devices) + max_epochs = self.trainer.max_epochs + accum_batches = self.trainer.accumulate_grad_batches + manual_total_steps = (l // accum_batches * max_epochs)/self.trainer.num_devices + print('MANUAL Total steps', manual_total_steps) + return manual_total_steps + + def configure_optimizers(self): + optimizer = optim.AdamW( + self.parameters(), + lr=self.hparams.training.learning_rate, + weight_decay=self.hparams.training.weight_decay + ) + lr_scheduler = get_cosine_schedule_with_warmup( + optimizer, + num_warmup_steps=self.total_steps()*0.3, + num_training_steps=self.total_steps(), + num_cycles=self.hparams.training.n_cycles + ) + return [optimizer], [{"scheduler": lr_scheduler, "interval": "step"}] + +# def cls_augment(gene, masked_gene, local_cls_number): +# N, L, D = gene.shape +# # random_masks = torch.zeros(local_cls_number, L) +# # cls_masked = np.eye(D)[np.random.randint(0, D, (N, local_cls_number, 1))].squeeze() +# cls_masked = torch.zeros(N, local_cls_number, D) +# # cls = np.eye(D)[np.random.randint(0, D, (N, local_cls_number, 1))].squeeze() +# cls = torch.zeros(N, local_cls_number, D) + +# gene = torch.cat((cls, gene), 1) +# masked_gene = torch.cat((cls_masked, masked_gene), 1) +# return gene, masked_gene + +def cls_augment(masked_gene, local_cls_number): + N, L, D = masked_gene.shape + cls_masked = torch.zeros(N, local_cls_number, D) + + masked_gene = torch.cat((masked_gene, cls_masked), 1) + return masked_gene + + +def pretrain_main(cgf): + """ + # 1. Load data for pretraining + """ + genes_train = torch.load(f"./data/gene_train_{cfg.Pretraining.training.max_len}_100k.pt") + masked_genes_train = torch.load(f"./data/masked_train_{cfg.Pretraining.training.max_len}_100k.pt") + masks_train = torch.load(f"./data/mask_train_{cfg.Pretraining.training.max_len}_100k.pt") + + genes_val = torch.load(f"./data/gene_val_{cfg.Pretraining.training.max_len}_100k.pt") + masked_genes_val = torch.load(f"./data/masked_val_{cfg.Pretraining.training.max_len}_100k.pt") + masks_val = torch.load(f"./data/mask_val_{cfg.Pretraining.training.max_len}_100k.pt") + + print(genes_train.shape, genes_val.shape) + print(genes_train[0][0:10], masked_genes_train[0][0:10], masks_train[0][0:10]) + original_train = genes_train + original_val = genes_val + + + # 2. Augment the Data + # 2.1 Add CLS tokens + + + print(genes_train.shape, masked_genes_train.shape, masks_train.shape) + + augment_list_1 = [ + RandomDeletion(delete_min=0, delete_max=20), + RandomInsertion(insert_min=0, insert_max=20), + RandomTranslocation(shift_min=0, shift_max=20) + ] + for augment in augment_list_1: + genes_train_aug = torch.permute(augment(torch.permute(genes_train, (0, 2, 1))), (0, 2, 1)) + + for augment in augment_list_1: + genes_val_aug = torch.permute(augment(torch.permute(genes_val, (0, 2, 1))), (0, 2, 1)) + + # genes_train_aug = genes_train + # genes_val_aug = genes_val + + + augment_list_2 = [ + # RandomDeletion(delete_min=0, delete_max=20), + # RandomInsertion(insert_min=0, insert_max=20), + RandomNoise(0, 0.2), + # RandomTranslocation(shift_min=0, shift_max=20) + RandomRC(0.5) + ] + + for augment in augment_list_2: + masked_genes_train = torch.permute(augment(torch.permute(masked_genes_train, (0, 2, 1))), (0, 2, 1)) + + for augment in augment_list_2: + masked_genes_val = torch.permute(augment(torch.permute(masked_genes_val, (0, 2, 1))), (0, 2, 1)) + + print("masked after augmentation", masked_genes_train.shape) + + masked_genes_train = cls_augment(masked_genes_train, 10) + masked_genes_val = cls_augment(masked_genes_val, 10) + + # print("before" ,genes_train_aug.shape) + genes_train_aug = cls_augment(genes_train_aug, 10) + genes_val_aug = cls_augment(genes_val_aug, 10) + + print(genes_train_aug.shape, masked_genes_train.shape) + + # import torch.nn.functional as F + + # dataset1_flat = genes_train_aug.view(-1, 5) + # dataset2_flat = masked_genes_train.view(-1, 5) + + # print(dataset1_flat[:5]) + + + # # Apply a softmax to make sure each row is a valid probability distribution + # dataset1_probs = F.softmax(dataset1_flat, dim=1) + # dataset2_probs = F.softmax(dataset2_flat, dim=1) + + # print(dataset2_probs[:5]) + + # # Calculate KL divergence + # kl_divergence = F.kl_div(torch.log(dataset1_probs), dataset2_probs, reduction='batchmean') + + # print("KL Divergence:", kl_divergence.item()) + + # sys.exit(-1) + + print("after cls augmentation", masked_genes_train.shape, genes_train_aug.shape) + + # genes_train_aug = cls_augment(original_train, 10) + # genes_val_aug = cls_augment(original_val, 10) + + # print("after", genes_train_aug.shape, original_train.shape) + + train_set = DatasetCreator(original_train, genes_train_aug, masked_genes_train, masks_train) + val_set = DatasetCreator(original_val, genes_val_aug, masked_genes_val, masks_val) + + """ + # 3. Prepare model + """ + + ddp = DDPStrategy(process_group_backend="nccl", find_unused_parameters=True) + # profiler = SimpleProfiler() + snapshot_path = "./Pretrained_models/" + + # loss = nn.CrossEntropyLoss(reduce="sum") + loss = torch.nn.BCEWithLogitsLoss(reduction='mean') + MetaArch = Model4Pretrain + model = LightningWrapper(MetaArch, cfg.Pretraining, snapshot_path, train_set, val_set, loss) + print(model) + summary = ModelSummary(model, max_depth=-1) + wandb_logger = WandbLogger(dir="./wandb/", project="Contrastive_Pretrain", entity='tonyu', name=f'Pretraining_{cfg.Pretraining.training.max_len}_4l_{cfg.Pretraining.SwanDNA.embedding_size}_{cfg.Pretraining.SwanDNA.hidden_size}') + checkpoint_callback = ModelCheckpoint(monitor="val_loss", mode="min") + + lr_monitor = LearningRateMonitor(logging_interval='step') + callbacks_for_trainer = [TQDMProgressBar(refresh_rate=10), lr_monitor, checkpoint_callback] + + """ + # 4. init trainer + """ + + print(summary) + trainer = pl.Trainer( + check_val_every_n_epoch=1, + enable_progress_bar=True, + accelerator='gpu', + strategy=ddp, + devices=[0], + max_epochs=cfg.Pretraining.training.n_epochs, + gradient_clip_val=0.5, + num_sanity_val_steps=0, + precision=16, + logger=wandb_logger, + callbacks=callbacks_for_trainer + ) + trainer.fit(model) + + +if __name__ == "__main__": + cfg = OmegaConf.load('./config/config_ct.yaml') #for ve pretraining, chenge it to config.yaml + OmegaConf.set_struct(cfg, False) + pretrain_main(cfg) diff --git a/contrastive_visualization.py b/contrastive_visualization.py new file mode 100644 index 0000000..c661b7a --- /dev/null +++ b/contrastive_visualization.py @@ -0,0 +1,293 @@ +import torch +import numpy as np +from torch import nn, optim +from omegaconf import OmegaConf +from functools import lru_cache +from datetime import datetime +from torch.utils.data import DataLoader +from sklearn.metrics import roc_auc_score +from models.pretraining_model import Model4TSNE +from data_utils import vcf_Dataset +import matplotlib.pyplot as plt +import pytorch_lightning as pl +from sklearn.manifold import TSNE +# from cuml.manifold import TSNE +from augment import RandomDeletion, RandomInsertion, RandomTranslocation +from transformers import get_cosine_schedule_with_warmup +from pytorch_lightning.strategies import DDPStrategy +from pytorch_lightning.loggers import WandbLogger +from pytorch_lightning.utilities.model_summary import ModelSummary +from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, StochasticWeightAveraging, TQDMProgressBar +pl.seed_everything(42) + + +class LightningWrapper(pl.LightningModule): + def __init__(self, model, cfg, snapshot_path, train_set, val_set, pretrained, loss, file_name): + super().__init__() + self.save_hyperparameters(cfg) + self.model_config = self.hparams.SwanDNA + self.batch_size = self.hparams.training.batch_size + self.length = self.hparams.SwanDNA.max_len + self.model = model(**self.model_config)#.apply(self._init_weights) + self.save_every = self.hparams.training.save_every + self.snapshot_path = snapshot_path + self.train_set = train_set + self.val_set = val_set + self.loss = loss + self.file_name = file_name + + print(self.model) + + if pretrained: + pretrained_path = f'./Pretrained_models/{self.file_name}' + pretrained = torch.load(pretrained_path, map_location='cpu') + pretrained = pretrained["MODEL_STATE"] + + from collections import OrderedDict + new_state_dict = OrderedDict() + + for k, v in pretrained.items(): + if k.startswith('encoder') or k.startswith('embedding'): + new_state_dict[k] = v + + net_dict = self.model.state_dict() + pretrained_cm = {k: v for k, v in new_state_dict.items() if k in net_dict} + + net_dict.update(pretrained_cm) + self.model.load_state_dict(net_dict) + for k, v in self.model.state_dict().items(): + print(k, v) + print("*************pretrained model loaded***************") + + def forward(self, x): + # in lightning, forward defines the prediction/inference actions + return self.model(x) + + def _init_weights(self, m): + if isinstance(m, nn.Reear): + nn.init.xavier_uniform_(m.weight) + m.bias.data.fill_(0.01) + + def training_step(self, batch, batch_idx): + ref, alt, tissue, label = batch + output = self.model(ref, alt, tissue).squeeze() + train_loss = self.loss(output, label) + return {"loss":train_loss, "preds":output, "labels":label, "tissue":tissue} + + def validation_step(self, batch, batch_idx): + ref, alt, tissue, label = batch + output = self.model(ref, alt, tissue).squeeze() + val_loss = self.loss(output, label) + return {"loss":val_loss, "preds":output, "labels":label, "tissue":tissue} + + + def training_epoch_end(self, outputs): + train_preds = [[] for _ in range(self.model_config.output_size)] + train_labels = [[] for _ in range(self.model_config.output_size)] + train_loss = torch.stack([x["loss"] for x in outputs]).mean() + tissue = torch.stack([x["tissue"] for x in outputs]).reshape((-1,)) + label = torch.stack([x["labels"] for x in outputs]).reshape((-1,)) + output = torch.stack([x["preds"] for x in outputs]).reshape((-1,)) + + for t, p, l in zip(tissue, output, label): + t = t.to(torch.int8) + train_preds[t.item()].append(p.item()) + train_labels[t.item()].append(l.item()) + train_rocs = [] + for i in range(self.model_config.output_size): + rocauc = roc_auc_score(train_labels[i], train_preds[i]) + train_rocs.append(rocauc) + train_roc = np.average(train_rocs) + self.log('train_roc', train_roc, sync_dist=True) + self.log('train_loss', train_loss, sync_dist=True) + + def validation_epoch_end(self, outputs): + val_preds = [[] for _ in range(self.model_config.output_size)] + val_labels = [[] for _ in range(self.model_config.output_size)] + val_loss = torch.stack([x["loss"] for x in outputs]).mean() + tissue = torch.stack([x["tissue"] for x in outputs]).reshape((-1,)) + label = torch.stack([x["labels"] for x in outputs]).reshape((-1,)) + output = torch.stack([x["preds"] for x in outputs]).reshape((-1,)) + + for t, p, l in zip(tissue, output, label): + t = t.to(torch.int8) + val_preds[t.item()].append(p.item()) + val_labels[t.item()].append(l.item()) + + val_rocs = [] + for i in range(self.model_config.output_size): + if len(val_labels[i]) != 0 and sum(val_labels[i]) != len(val_labels[i]) and sum(val_labels[i]) != 0: + rocauc = roc_auc_score(val_labels[i], val_preds[i]) + val_rocs.append(rocauc) + val_roc = np.average(val_rocs) + self.log("val_auroc", val_roc, sync_dist=True) + self.log('val_loss', val_loss, sync_dist=True) + self.val_preds = [[] for _ in range(self.model_config.output_size)] + self.val_labels = [[] for _ in range(self.model_config.output_size)] + + def train_dataloader(self): + return DataLoader( + dataset=self.train_set, + num_workers=1, + pin_memory=True, + shuffle=True, + drop_last=True, + batch_size=self.batch_size + ) + + def val_dataloader(self): + return DataLoader( + dataset=self.val_set, + num_workers=1, + pin_memory=True, + shuffle=False, + drop_last=False, + batch_size=self.batch_size + ) + + @lru_cache + def total_steps(self): + l = len(self.trainer._data_connector._train_dataloader_source.dataloader()) + print('Num devices', self.trainer.num_devices) + max_epochs = self.trainer.max_epochs + accum_batches = self.trainer.accumulate_grad_batches + manual_total_steps = (l // accum_batches * max_epochs)/self.trainer.num_devices + print('MANUAL Total steps', manual_total_steps) + return manual_total_steps + + def configure_optimizers(self): + optimizer = optim.AdamW( + self.parameters(), + lr=self.hparams.training.learning_rate, + weight_decay=self.hparams.training.weight_decay + ) + lr_scheduler = get_cosine_schedule_with_warmup( + optimizer, + num_warmup_steps=int(self.total_steps()*0.3), + num_training_steps=self.total_steps(), + num_cycles=self.hparams.training.n_cycles + ) + return [optimizer], [{"scheduler": lr_scheduler, "interval": "step"}] + + +def classify_main(cfg): + pretrained = cfg.Fine_tuning.training.pretrained + length = cfg.Fine_tuning.SwanDNA.max_len + + loss = nn.BCEWithLogitsLoss() + + train_ref = torch.load(f"./data/ref_{length}_train.pt") + train_alt = torch.load(f"./data/alt_{length}_train.pt") + train_tissue = torch.load(f"./data/tissue_{length}_train.pt") + train_label = torch.load(f"./data/label_{length}_train.pt") + + train_set = vcf_Dataset(train_ref, train_alt, train_tissue, train_label) + val_set = vcf_Dataset(torch.load(f"./data/ref_{length}_chr11_test.pt"), torch.load(f"./data/alt_{length}_chr11_test.pt"), torch.load(f"./data/tissue_{length}_chr11_test.pt"), torch.load(f"./data/label_{length}_chr11_test.pt")) + + ddp = DDPStrategy(process_group_backend="nccl", find_unused_parameters=True) + snapshot_path = "test.pt" + file_name = "SwanDNA_VE_10_16.pt" + + model = LightningWrapper(Classifier, cfg.Fine_tuning, snapshot_path, train_set, val_set, pretrained, loss, file_name) + summary = ModelSummary(model, max_depth=-1) + + + # ------------ + # init trainer + # ------------ + + wandb_logger = WandbLogger(dir="./wandb/", project="VE_classification", entity='', name=f'{file_name}_{length}_{pretrained}') + checkpoint_callback = ModelCheckpoint(monitor="val_auroc", mode="max") + + print(len(train_set), len(val_set)) + + lr_monitor = LearningRateMonitor(logging_interval='step') + callbacks_for_trainer = [TQDMProgressBar(refresh_rate=10), lr_monitor, checkpoint_callback] + if cfg.Fine_tuning.training.patience != -1: + early_stopping = EarlyStopping(monitor="val_auroc", mode="max", min_delta=0., patience=cfg.Fine_tuning.training.patience) + callbacks_for_trainer.append(early_stopping) + if cfg.Fine_tuning.training.swa_lrs != -1: + swa = StochasticWeightAveraging(swa_lrs=1e-2) + callbacks_for_trainer.append(swa) + + print(summary) + trainer = pl.Trainer( + check_val_every_n_epoch=1, + enable_progress_bar=True, + accelerator='gpu', + strategy=ddp, + devices=[0], + max_epochs=cfg.Fine_tuning.training.n_epochs, + gradient_clip_val=0.5, + num_sanity_val_steps=0, + # profiler=profiler, + precision=16, + logger=wandb_logger + ) + trainer.fit(model) + +def cls_augment(masked_gene, local_cls_number): + N, L, D = masked_gene.shape + cls_masked = torch.zeros(N, local_cls_number, D) + + masked_gene = torch.cat((masked_gene, cls_masked), 1) + return masked_gene + + +if __name__ == "__main__": + pretrained_path = f'./Pretrained_models/model_10_1000_2l_154_256.pt' + # Step 1: Load the pretrained model + pretrained_model = torch.load(pretrained_path, map_location='cpu')["Teacher"] + + model = Model4TSNE(input_size=5, max_len=1000, embedding_size=154, track_size=14, hidden_size=256, mlp_dropout=0, layer_dropout=0, prenorm='None', norm='None') + + for k, v in pretrained_model.items(): + print(k, v) + + # Step 2: Extract the encoder + from collections import OrderedDict + new_state_dict = OrderedDict() + + for k, v in pretrained_model.items(): + if k.startswith('encoder') or k.startswith('embedding'): + print("**************************************************************************************") + new_state_dict[k] = v + + net_dict = model.state_dict() + pretrained_cm = {k: v for k, v in new_state_dict.items() if k in net_dict} + + net_dict.update(pretrained_cm) + model.load_state_dict(net_dict) + + for k, v in model.state_dict().items(): + print(k, v) + + # Step 4: Perform inference with the encoder + genes_train = torch.load(f"./data/gene_valid_1000_10k.pt") + augment_list = [ + RandomDeletion(delete_min=0, delete_max=20), + RandomInsertion(insert_min=0, insert_max=20), + RandomTranslocation(shift_min=0, shift_max=20) + ] + for augment in augment_list: + genes_train_aug = torch.permute(augment(torch.permute(genes_train, (0, 2, 1))), (0, 2, 1)) + + genes_train_aug = cls_augment(genes_train_aug, 10) + + with torch.no_grad(): + x = model(genes_train_aug) + print(x.shape) + + x = x[:, 1000:, :] + x = x.view(x.shape[0], -1) + + # x = x[:, :1000, :] + # x = x.mean(dim=1).view(x.shape[0], -1) + print(x.shape) + + tsne = TSNE(n_components=2, random_state=42) + X_tsne = tsne.fit_transform(x) + fig = plt.scatter(x=X_tsne[:, 0], y=X_tsne[:, 1], marker='.', s=2) + plt.savefig("teacher_tsne_cls_10k_154_2l_E20.png") + plt.show() + diff --git a/data/generate_pretrain_human.py b/data/generate_pretrain_human.py new file mode 100644 index 0000000..4e7e73c --- /dev/null +++ b/data/generate_pretrain_human.py @@ -0,0 +1,120 @@ +import torch +import numpy as np +from Bio import SeqIO +from sklearn.preprocessing import LabelBinarizer, LabelEncoder +from joblib import Parallel, delayed +import pysam +import multiprocessing +from functools import partial +from numpy.random import default_rng +rng = default_rng() + + +class UniformMasking(): + """ Pre-processing steps for pretraining revolution """ + def __init__(self, mask_prob): + super().__init__() + self.mask_ratio = mask_prob + + def __call__(self, instance): + uniform_vec = np.random.uniform(size=len(instance)) + masked_vec = (uniform_vec <= self.mask_ratio).astype(int) + + uniform_vec2 = np.random.uniform(size=len(instance)) + random_vec = np.zeros(len(instance)) + same_vec = np.zeros(len(instance)) + random_vec[(masked_vec == 1) & (uniform_vec2 <= 0.1)] = 1 + same_vec[(masked_vec == 1) & (uniform_vec2 >= 0.9)] = 1 + real_vec = abs(masked_vec - random_vec - same_vec) + random_vec = np.array(random_vec).astype(bool) + real_vec = np.array(real_vec).astype(bool) + + instance[real_vec, :] = [0, 0, 0, 0, 0] + instance[random_vec, :] = np.eye(5)[np.random.choice(5, sum(random_vec))] + + return instance, masked_vec + +def generate_pairs(num): + import pandas as pd + import random + + # Read the CSV file containing chromosome lengths + chromosome_lengths_df = pd.read_csv("./data/chromosomes.csv") + + chroms = [] + poses = [] + + for _ in range(num): + # Randomly select a chromosome + random_chromosome = random.choice(chromosome_lengths_df["name"]) + + # Get the length of the selected chromosome + chromosome_length = chromosome_lengths_df.loc[ + chromosome_lengths_df["name"] == random_chromosome, "length" + ].values[0] + + # Randomly select a position within the length range of the chromosome + random_position = random.randint(1, chromosome_length) + chroms.append(random_chromosome) + poses.append(random_position) + + print(chroms[:10]) + print(poses[:10]) + return chroms, poses + + +def fetch_and_transform(position_and_chrom, length, lb, masking): + """ + fetch one cunk of DNA sequences at given positions and chromosomes. + """ + position, chrom = position_and_chrom + genome = pysam.FastaFile('./data/hg38.fa') + sequence = genome.fetch(chrom, position, position + length).lower() + if len(sequence) == 0: + print(f"Empty sequence at position {position}") + return None + # "transform sequence to one-hot encoding" + gene_to_number = lb.transform(list(sequence)).astype("int8") + # "Masking" + masked_gene, mask = masking(np.array(gene_to_number)) + return masked_gene.astype("int8"), mask.astype("int8"), gene_to_number + +def mask_chr_sequences(num, length, chroms, positions, split): + """ + Parallelize the masking over 200k sequence. + """ + masked_genes_train, genes_train, masks_train = [], [], [] + chunksize = int(num / multiprocessing.cpu_count()) # Or any other suitable value + print(multiprocessing.cpu_count(), chunksize) + with multiprocessing.Pool() as pool: + results = pool.map(partial(fetch_and_transform, length=length, lb=lb, masking=masking), zip(positions, chroms), chunksize=chunksize) + + for masked_gene, mask, gene in results: + if len(masked_gene) == length: + masked_genes_train.append(masked_gene) + masks_train.append(mask) + genes_train.append(gene) + + X_train = torch.from_numpy(np.stack(masked_genes_train)) + M_train = torch.from_numpy(np.stack(masks_train)) + O_train = torch.from_numpy(np.stack(genes_train)) + + print(X_train.shape, M_train.shape, O_train.shape) + + torch.save(X_train, f"./data/masked_{split}_{length}_10k.pt") + torch.save(M_train, f"./data/mask_{split}_{length}_10k.pt") + torch.save(O_train, f"./data/gene_{split}_{length}_10k.pt") + + + +hg38_dict = SeqIO.to_dict(SeqIO.parse("./data/hg38.fa", "fasta")) +lb = LabelBinarizer() +lb.fit(['a', 't', 'c', 'g', 'n']) +chromosomes = [f"chr{i}" for i in range(1, 23)] +chromosomes.append("chrX") +chromosomes.append("chrY") +print(chromosomes) + +masking = UniformMasking(0.3) +chroms, positions = generate_pairs(10000) +mask_chr_sequences(10000, 1000, chroms, positions, "valid") diff --git a/data/genome_process.py b/data/genome_process.py new file mode 100644 index 0000000..ca2db94 --- /dev/null +++ b/data/genome_process.py @@ -0,0 +1,99 @@ +import torch +import numpy as np +import pandas as pd +from sklearn.model_selection import train_test_split +import tensorflow as tf +from Bio import SeqIO + +bases = {"a": 0, "g":1, "c":2, "t":3, "n": 4, "":5} + +def merge(): + input_file_desert = "./data/MTcDNA/Chelonoidis_turtle" + fasta_sequences = SeqIO.parse(open(input_file_desert), 'fasta') + desert_sequences = [] + for fasta in fasta_sequences: + name, sequence = fasta.id, str(fasta.seq) + desert_sequences.append(sequence) + + desert_df = pd.DataFrame(desert_sequences, columns=["sequence"]) + desert_df["length"] = desert_df['sequence'].apply(lambda x: len(x)) + desert_df["sequence"] = desert_df['sequence'].apply(lambda x: x.lower()) + desert_df = desert_df[desert_df["length"] > 5000] + desert_df["class"] = 0 + + input_file_island = "./data/MTcDNA/Gopherus_turtle" + fasta_sequences_ff = SeqIO.parse(open(input_file_island), 'fasta') + island_sequences = [] + for fasta in fasta_sequences_ff: + name, sequence = fasta.id, str(fasta.seq) + island_sequences.append(sequence) + + island_df = pd.DataFrame(island_sequences, columns=["sequence"]) + island_df["length"] = island_df['sequence'].apply(lambda x: len(x)) + island_df["sequence"] = island_df['sequence'].apply(lambda x: x.lower()) + island_df = island_df[island_df["length"] > 5000] + island_df["class"] = 0 + + input_file_Musculus = "./data/MTcDNA/Mus_musculus" + fasta_sequences_ff = SeqIO.parse(open(input_file_Musculus), 'fasta') + Musculus_sequences = [] + for fasta in fasta_sequences_ff: + name, sequence = fasta.id, str(fasta.seq) + Musculus_sequences.append(sequence) + + Musculus_df = pd.DataFrame(Musculus_sequences, columns=["sequence"]) + Musculus_df["length"] = Musculus_df['sequence'].apply(lambda x: len(x)) + Musculus_df["sequence"] = Musculus_df['sequence'].apply(lambda x: x.lower()) + Musculus_df = Musculus_df[Musculus_df["length"] > 5000] + Musculus_df["class"] = 1 + + input_file_Spretus = "./data/MTcDNA/Mus_spretus" + fasta_sequences_ff = SeqIO.parse(open(input_file_Spretus), 'fasta') + Spretus_sequences = [] + for fasta in fasta_sequences_ff: + name, sequence = fasta.id, str(fasta.seq) + Spretus_sequences.append(sequence) + + Spretus_df = pd.DataFrame(Spretus_sequences, columns=["sequence"]) + Spretus_df["length"] = Spretus_df['sequence'].apply(lambda x: len(x)) + Spretus_df["sequence"] = Spretus_df['sequence'].apply(lambda x: x.lower()) + Spretus_df = Spretus_df[Spretus_df["length"] > 5000] + Spretus_df["class"] = 1 + + df = pd.concat((desert_df, island_df, Musculus_df, Spretus_df), axis=0) + print(df.head(5)) + print(df.shape) + return df + + +def padding(sequence, desired_length): + + if len(sequence) > desired_length: + return sequence[:desired_length] + else: + padding_character = 'N' + + # Pad the sequence + padding_length = desired_length - len(sequence) + padded_sequence = sequence + padding_character * padding_length + return padded_sequence + + +def get_dna_csv(max_len): + data_df = merge() + X = [padding(s.upper(), max_len) for s in data_df.sequence.values] + print(X[0]) + y = data_df["class"].values + + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1) + X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.25, random_state=1) + + train_df = pd.DataFrame({"sequence": X_train, "label": y_train}) + test_df = pd.DataFrame({"sequence": X_test, "label": y_test}) + val_df = pd.DataFrame({"sequence": X_val, "label": y_val}) + + train_df.to_csv(f"./data/MTcDNA/{max_len}/train.csv", index=None) + val_df.to_csv(f"./data/MTcDNA/{max_len}/dev.csv", index=None) + test_df.to_csv(f"./data/MTcDNA/{max_len}/test.csv", index=None) + +get_dna_csv(8192) \ No newline at end of file diff --git a/data/genomic_benchmark.py b/data/genomic_benchmark.py new file mode 100644 index 0000000..ac71233 --- /dev/null +++ b/data/genomic_benchmark.py @@ -0,0 +1,72 @@ +from genomic_benchmarks.data_check import list_datasets +from genomic_benchmarks.data_check import info +import torch +import numpy as np +from sklearn.preprocessing import LabelBinarizer +from genomic_benchmarks.dataset_getters.pytorch_datasets import get_dataset, HumanNontataPromoters, HumanEnhancersCohn, DemoHumanOrWorm, DemoMouseEnhancers, DemoCodingVsIntergenomicSeqs, DrosophilaEnhancersStark, HumanEnhancersEnsembl + + +def encode_sequence(d, ds, dt): + """ + First get the datasets with fixed length, and one-hot encode each sequence. + Then save them as torch tensors: X, y. + """ + sequences = [] + labels = [] + for data in ds: + gene_to_number = lb.transform(list(data[0])) + sequences.append(gene_to_number) + labels.append(data[1]) + X = torch.from_numpy(np.array(sequences)).to(torch.int8) + y = torch.from_numpy(np.array(labels)).to(torch.float16) + + torch.save(X, f"./data/{d}_X_{dt}.pt") + torch.save(y, f"./data/{d}_y_{dt}.pt") + +def encode_sequence_varied(d, ds, dt, length): + """ + First get the datasets with varied length, and one-hot encode each sequence. + Then pad the sequence to maximum lengths. [mouse_enhancer:4776; Human_Enhancer_ensembl:573 + Human Regulatory: 802; Human_OCR:593] + Lastly save them as torch tensors: X, y. + """ + print(len(ds)) + sequences = [] + labels = [] + for data in ds: + gene_to_number = lb.transform(list(data[0])) + # print(gene_to_number) + padded_seq = np.pad(gene_to_number, ((0, length-len(gene_to_number)), (0, 0))) + # print(padded_seq) + # print(len(gene_to_number)) + # print(len(padded_seq)) + sequences.append(padded_seq) + labels.append(data[1]) + + X = torch.from_numpy(np.array(sequences)).to(torch.int8) + y = torch.from_numpy(np.array(labels)).to(torch.float16) + + torch.save(X, f"./data/{d}_X_{dt}.pt") + torch.save(y, f"./data/{d}_y_{dt}.pt") + +datasets = list_datasets() +lb = LabelBinarizer() +lb.fit(['A', 'T', 'C', 'G', 'N']) +varied_lengths = { + "dummy_mouse_enhancers_ensembl":4776, + "human_enhancers_ensembl": 573, + "human_ensembl_regulatory": 802, + "human_ocr_ensembl": 593 +} +for d in datasets: + print(info(d, version=0)) + train_dset = get_dataset(d, split='train', version=0) + test_dset = get_dataset(d, split='test', version=0) + if d in ["demo_coding_vs_intergenomic_seqs", "demo_human_or_worm", "human_enhancers_cohn", "human_nontata_promoters"]: + encode_sequence(d, train_dset, "train") + encode_sequence(d, test_dset, "test") + elif d in ["dummy_mouse_enhancers_ensembl", "human_enhancers_ensembl", "human_ensembl_regulatory", "human_ocr_ensembl"]: + max_length = varied_lengths[d] + encode_sequence_varied(d, train_dset, "train", max_length) + encode_sequence_varied(d, test_dset, "test", max_length) + diff --git a/data_utils.py b/data_utils.py new file mode 100644 index 0000000..c0ed5ef --- /dev/null +++ b/data_utils.py @@ -0,0 +1,43 @@ +from torch.utils.data import Dataset + + +class vcf_Dataset(Dataset): + def __init__(self, ref, alt, tissue, label): + self.ref, self.alt, self.tissue, self.label = ref, alt, tissue, label + + def __getitem__(self, index): + ref = self.ref[index] + alt = self.alt[index] + tissue = self.tissue[index] + label = self.label[index].float() + return ref, alt, tissue, label + + def __len__(self): + return len(self.label) + + +class plant_Dataset(Dataset): + def __init__(self, data, labels): + self.data = data + self.labels = labels + + def __getitem__(self, index): + X = self.data[index] + Y = self.labels[index] + return X, Y + + def __len__(self): + return len(self.labels) + + +class gb_Dataset(Dataset): + def __init__(self, seq, label): + self.X, self.y = seq, label + + def __getitem__(self, index): + X = self.X[index] + y = self.y[index].float() + return X, y + + def __len__(self): + return len(self.y) diff --git a/dna.sh b/dna.sh new file mode 100644 index 0000000..1998c55 --- /dev/null +++ b/dna.sh @@ -0,0 +1,51 @@ +#!/bin/sh +#SBATCH --partition=GPUQ +#SBATCH --gres=gpu:1 +#SBATCH --account=share-ie-idi +#SBATCH --time=120:00:00 +#SBATCH --nodes=1 +#SBATCH --mem=128000 +#SBATCH --ntasks-per-node=1 +#SBATCH --job-name="contrast" +#SBATCH --output=mega.txt +#SBATCH --mail-user=tong.yu@ntnu.no +#SBATCH --mail-type=ALL + + +WORKDIR=${SLURM_SUBMIT_DIR} +cd ${WORKDIR} +echo "we are running from this directory: $SLURM_SUBMIT_DIR" +echo " the name of the job is: $SLURM_JOB_NAME" +echo "Th job ID is $SLURM_JOB_ID" +echo "The job was run on these nodes: $SLURM_JOB_NODELIST" +echo "Number of nodes: $SLURM_JOB_NUM_NODES" +echo "We are using $SLURM_CPUS_ON_NODE cores" +echo "We are using $SLURM_CPUS_ON_NODE cores per node" +echo "Total of $SLURM_NTASKS cores" + +module purge +module load Anaconda3/2023.09-0 +source activate monica +conda activate monica + +# python3 -u enformer_veclf.py +# echo "python3" +# wandb agent tonyu/VE_20000/cvgt4a2q +# python3 -u pretraining_hg38.py --problem Pretraining --model Revolution +# python3 -u pretraining.py +#python3 pretraining_large.py +python3 genomic_classification.py +# python3 genomic_benchmark.py +# python3 hyper_search.py +# python3 cdna_classification.py +# wandb agent tonyu/VE_20000/o7qq8ejx# +# python3 genomic_classification.py +# wandb agent tonyu/Human_promoter/nsdo9xmd +# wandb agent tonyu/Human_cohn/5bu0evqi +# wandb agent tonyu/Human_worm/b0ixhb15 +# wandb agent tonyu/Human_worm/a37jprb4 +# python3 generate_pretrain_1m.py +# wandb agent tonyu/Human_cohn/x02jvqgd +# python contrastive_pretraining.py +# python contrastive_classification.py +#python pretraining.py \ No newline at end of file diff --git a/genomic_classification.py b/genomic_classification.py new file mode 100644 index 0000000..f5f17d7 --- /dev/null +++ b/genomic_classification.py @@ -0,0 +1,241 @@ +import torch +from torch import nn, optim +from omegaconf import OmegaConf +from functools import lru_cache +from torch.utils.data import DataLoader +from torchmetrics import Accuracy +from models.SwanDNA import GB_Flash_Classifier, GB_Linear_Classifier +from data_utils import gb_Dataset +# from peft import get_peft_config, get_peft_model, LoraConfig, TaskType +import pytorch_lightning as pl +from transformers import get_cosine_schedule_with_warmup +from pytorch_lightning.strategies import DDPStrategy +from pytorch_lightning.loggers import WandbLogger +from pytorch_lightning.utilities.model_summary import ModelSummary +from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, StochasticWeightAveraging, TQDMProgressBar +pl.seed_everything(42) + + +class LightningWrapper(pl.LightningModule): + def __init__(self, model, cfg, train_set, val_set, pretrained, loss, file_name): + super().__init__() + self.save_hyperparameters(cfg) + self.model_config = self.hparams.SwanDNA + self.batch_size = self.hparams.training.batch_size + self.output = self.hparams.SwanDNA.output_size + self.warm_up = self.hparams.training.n_warmup_steps + self.length = self.hparams.SwanDNA.max_len + self.model = model(**self.model_config) + self.save_every = self.hparams.training.save_every + self.train_set = train_set + self.val_set = val_set + self.loss = loss + self.file_name = file_name + if self.output == 2: + self.train_acc = Accuracy(task='binary', top_k=1) + self.val_acc = Accuracy(task='binary', top_k=1) + else: + self.train_acc = Accuracy(task='multiclass', num_classes=self.model_config.output_size, top_k=1) + self.val_acc = Accuracy(task='multiclass', num_classes=self.model_config.output_size, top_k=1) + print(self.model) + + if pretrained: + pretrained_path = f'./{self.file_name}' + pretrained = torch.load(pretrained_path, map_location='cpu') + pretrained = pretrained["Teacher"] + + from collections import OrderedDict + new_state_dict = OrderedDict() + + for k, v in pretrained.items(): + if k.startswith('encoder') or k.startswith('embedding'): + new_state_dict[k] = v + + net_dict = self.model.state_dict() + pretrained_cm = {k: v for k, v in new_state_dict.items() if k in net_dict} + net_dict.update(pretrained_cm) + self.model.load_state_dict(net_dict) + for k, v in self.model.state_dict().items(): + print(k, v) + print(self.file_name) + print("*************pretrained model loaded***************") + + + def forward(self, x): + # in lightning, forward defines the prediction/inference actions + return self.model(x) + + def _init_weights(self, m): + if isinstance(m, nn.Reear): + nn.init.xavier_uniform_(m.weight) + m.bias.data.fill_(0.01) + + def training_step(self, batch, batch_idx): + seq, label = batch + output = self.model(seq).squeeze() + preds = output.argmax(dim=-1) + train_loss = self.loss(output, label.to(torch.int64)) + self.train_acc.update(preds, label.int()) + return {"loss":train_loss, "preds":preds, "labels":label} + + def validation_step(self, batch, batch_idx): + seq, label = batch + output = self.model(seq).squeeze() + preds = output.argmax(dim=-1) + val_loss = self.loss(output, label.to(torch.int64)) + self.val_acc.update(preds, label.int()) + return {"loss":val_loss, "preds":preds, "labels":label} + + def training_epoch_end(self, outputs): + train_loss = torch.stack([x["loss"] for x in outputs]).mean() + acc = self.train_acc.compute().mean() + self.train_acc.reset() + self.log('train_acc', acc, sync_dist=True) + self.log('train_loss', train_loss, sync_dist=True) + + # def validation_step_end(self, outputs): + # acc = self.val_acc(outputs["preds"], outputs["labels"]) + # self.log("val_acc", acc, sync_dist=True) + # self.log('val_loss', outputs["loss"], sync_dist=True) + + def validation_epoch_end(self, outputs): + val_loss = torch.stack([x["loss"] for x in outputs]).mean() + # label = torch.stack([x["labels"] for x in outputs]).reshape((-1,)) + # output = torch.stack([x["preds"] for x in outputs]).reshape((-1,)) + acc = self.val_acc.compute().mean() + self.val_acc.reset() + self.log("val_acc", acc, sync_dist=True) + self.log('val_loss', val_loss, sync_dist=True) + + + def train_dataloader(self): + return DataLoader( + dataset=self.train_set, + num_workers=1, + pin_memory=True, + shuffle=True, + drop_last=False, + batch_size=self.batch_size + ) + + def val_dataloader(self): + return DataLoader( + dataset=self.val_set, + num_workers=1, + pin_memory=True, + shuffle=False, + drop_last=False, + batch_size=self.batch_size + ) + + @lru_cache + def total_steps(self): + l = len(self.trainer._data_connector._train_dataloader_source.dataloader()) + print('Num devices', self.trainer.num_devices) + max_epochs = self.trainer.max_epochs + accum_batches = self.trainer.accumulate_grad_batches + manual_total_steps = (l // accum_batches * max_epochs)/self.trainer.num_devices + print('MANUAL Total steps', manual_total_steps) + return manual_total_steps + + def configure_optimizers(self): + optimizer = optim.AdamW( + self.parameters(), + lr=self.hparams.training.learning_rate, + weight_decay=self.hparams.training.weight_decay + ) + lr_scheduler = get_cosine_schedule_with_warmup( + optimizer, + num_warmup_steps=int(self.total_steps()*self.warm_up), #hyperparmeter [0.3, 0.4] + num_training_steps=self.total_steps(), + num_cycles=self.hparams.training.n_cycles + ) + return [optimizer], [{"scheduler": lr_scheduler, "interval": "step"}] + + +def classify_main(cfg, task): + """ + 1. decide which tack to run + """ + if task == "human_nontata_promoters": + config = cfg.Human_Promoter + elif task == "human_enhancers_cohn": + config = cfg.Human_Enhancers_Cohn + elif task == "demo_human_or_worm": + config = cfg.Demo_Human_Or_Worm + elif task == "dummy_mouse_enhancers_ensembl": + config = cfg.Demo_Mouse_Enhancers + elif task == "demo_coding_vs_intergenomic_seqs": + config = cfg.Demo_Coding_Inter + elif task == "drosophila_enhancers_stark": + config = cfg.Drop_Enhancer_Stark + elif task == "human_enhancers_ensembl": + config = cfg.Human_Enhancers_Ensembl + elif task == "human_ensembl_regulatory": + config = cfg.Human_Regulatory + elif task == "human_ocr_ensembl": + config = cfg.Human_Ocr_Ensembl + + """ + 2. load dataset. + """ + + pretrained = config.training.pretrained + length = config.SwanDNA.max_len + loss = nn.CrossEntropyLoss(reduction='mean') + + train_X = torch.load(f"./data/{task}_X_train.pt") + train_y = torch.load(f"./data/{task}_y_train.pt") + test_X = torch.load(f"./data/{task}_X_test.pt") + test_y = torch.load(f"./data/{task}_y_test.pt") + print(train_X.shape) + + train_set = gb_Dataset(train_X, train_y) + val_set = gb_Dataset(test_X, test_y) + + """ + 3. strat training with ddp mode. + """ + + ddp = DDPStrategy(process_group_backend="nccl", find_unused_parameters=True) + pretrained_model = "model_29_1000_4l_308_512_noiseandTL.pt" + + model = LightningWrapper(GB_Linear_Classifier, config, train_set, val_set, pretrained, loss, pretrained_model) + summary = ModelSummary(model, max_depth=-1) + + """ + 4. init trainer + """ + + wandb_logger = WandbLogger(dir="./wandb/", project="Mouse_Enhancers", entity='tonyu', name=f'{pretrained_model}_{length}_{task}') + checkpoint_callback = ModelCheckpoint(monitor="val_acc", mode="max") + + lr_monitor = LearningRateMonitor(logging_interval='step') + callbacks_for_trainer = [TQDMProgressBar(refresh_rate=10), lr_monitor, checkpoint_callback] + if config.training.patience != -1: + early_stopping = EarlyStopping(monitor="val_acc", mode="max", min_delta=0., patience=cfg.Fine_tuning.training.patience) + callbacks_for_trainer.append(early_stopping) + if config.training.swa_lrs != -1: + swa = StochasticWeightAveraging(swa_lrs=1e-2) + callbacks_for_trainer.append(swa) + + print(summary) + trainer = pl.Trainer( + check_val_every_n_epoch=1, + enable_progress_bar=True, + accelerator='gpu', + strategy=ddp, + devices=[0], + max_epochs=config.training.n_epochs, + gradient_clip_val=0.5, + num_sanity_val_steps=0, + precision=16, + logger=wandb_logger + ) + trainer.fit(model) + + +if __name__ == "__main__": + cfg = OmegaConf.load('./config/config_gb.yaml') + OmegaConf.set_struct(cfg, False) + classify_main(cfg, "human_ocr_ensembl") diff --git a/gue_classification.py b/gue_classification.py new file mode 100644 index 0000000..67cd27c --- /dev/null +++ b/gue_classification.py @@ -0,0 +1,331 @@ +import torch +import pandas as pd +import numpy as np +from torch import nn, optim +from omegaconf import OmegaConf +from functools import lru_cache +from sklearn.preprocessing import LabelBinarizer +from torch.utils.data import DataLoader +from torchmetrics import Accuracy, MatthewsCorrCoef, F1Score +from torchmetrics.classification import MulticlassMatthewsCorrCoef +from models.SwanDNA import GB_Flash_Classifier, GB_Linear_Classifier +from data_utils import gb_Dataset +# from peft import get_peft_config, get_peft_model, LoraConfig, TaskType +import pytorch_lightning as pl +from transformers import get_cosine_schedule_with_warmup +from pytorch_lightning.strategies import DDPStrategy +from pytorch_lightning.loggers import WandbLogger +from pytorch_lightning.utilities.model_summary import ModelSummary +from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, StochasticWeightAveraging, TQDMProgressBar +pl.seed_everything(42) + + +class LightningWrapper(pl.LightningModule): + def __init__(self, model, cfg, train_set, val_set, test_set, pretrained, loss, file_name): + super().__init__() + self.save_hyperparameters(cfg) + self.model_config = self.hparams.SwanDNA + self.batch_size = self.hparams.training.batch_size + self.output = self.hparams.SwanDNA.output_size + self.warm_up = self.hparams.training.n_warmup_steps + self.length = self.hparams.SwanDNA.max_len + self.model = model(**self.model_config) + self.save_every = self.hparams.training.save_every + self.train_set = train_set + self.val_set = val_set + self.test_set = test_set + self.loss = loss + self.file_name = file_name + # if self.output == 2: + # self.train_mcc = MatthewsCorrCoef(task='binary') + # self.val_mcc = MatthewsCorrCoef(task='binary') + # self.test_mcc = MatthewsCorrCoef(task='binary') + # else: + # self.train_mcc = MulticlassMatthewsCorrCoef(num_classes=3) + # self.val_mcc = MulticlassMatthewsCorrCoef(num_classes=3) + # self.test_mcc = MulticlassMatthewsCorrCoef(num_classes=3) + + if self.hparams.training.name == "virus": + self.train_mcc = F1Score(task="multiclass", num_classes=9) + self.val_mcc = F1Score(task="multiclass", num_classes=9) + self.test_mcc = F1Score(task="multiclass", num_classes=9) + print(self.model) + + if pretrained: + pretrained_path = f'./{self.file_name}' + pretrained = torch.load(pretrained_path, map_location='cpu') + pretrained = pretrained["Teacher"] + + from collections import OrderedDict + new_state_dict = OrderedDict() + + for k, v in pretrained.items(): + if k.startswith('encoder') or k.startswith('embedding'): + new_state_dict[k] = v + + net_dict = self.model.state_dict() + pretrained_cm = {k: v for k, v in new_state_dict.items() if k in net_dict} + net_dict.update(pretrained_cm) + self.model.load_state_dict(net_dict) + for k, v in self.model.state_dict().items(): + print(k, v) + print(self.file_name) + print("*************pretrained model loaded***************") + + + def forward(self, x): + # in lightning, forward defines the prediction/inference actions + return self.model(x) + + def _init_weights(self, m): + if isinstance(m, nn.Reear): + nn.init.xavier_uniform_(m.weight) + m.bias.data.fill_(0.01) + + def training_step(self, batch, batch_idx): + seq, label = batch + output = self.model(seq).squeeze() + preds = output.argmax(dim=-1) + train_loss = self.loss(output, label.to(torch.int64)) + self.train_mcc.update(preds, label.int()) + return {"loss":train_loss, "preds":preds, "labels":label} + + def validation_step(self, batch, batch_idx): + seq, label = batch + output = self.model(seq).squeeze() + preds = output.argmax(dim=-1) + val_loss = self.loss(output, label.to(torch.int64)) + self.val_mcc.update(preds, label.int()) + return {"loss":val_loss, "preds":preds, "labels":label} + + def test_step(self, batch, batch_idx): + seq, label = batch + output = self.model(seq).squeeze() + preds = output.argmax(dim=-1) + test_loss = self.loss(output, label.to(torch.int64)) + self.test_mcc.update(preds, label.int()) + return {"loss":test_loss, "preds":preds, "labels":label} + + def training_epoch_end(self, outputs): + train_loss = torch.stack([x["loss"] for x in outputs]).mean() + acc = self.train_mcc.compute().mean() + self.train_mcc.reset() + self.log('train_mcc', acc, sync_dist=True) + self.log('train_loss', train_loss, sync_dist=True) + + def validation_epoch_end(self, outputs): + val_loss = torch.stack([x["loss"] for x in outputs]).mean() + # label = torch.stack([x["labels"] for x in outputs]).reshape((-1,)) + # output = torch.stack([x["preds"] for x in outputs]).reshape((-1,)) + acc = self.val_mcc.compute().mean() + self.val_mcc.reset() + self.log("val_mcc", acc, sync_dist=True) + self.log('val_loss', val_loss, sync_dist=True) + + def test_epoch_end(self, outputs): + test_loss = torch.stack([x["loss"] for x in outputs]).mean() + # label = torch.stack([x["labels"] for x in outputs]).reshape((-1,)) + # output = torch.stack([x["preds"] for x in outputs]).reshape((-1,)) + acc = self.test_mcc.compute().mean() + self.val_mcc.reset() + self.log("test_mcc", acc, sync_dist=True) + self.log('test_loss', test_loss, sync_dist=True) + + def train_dataloader(self): + return DataLoader( + dataset=self.train_set, + num_workers=1, + pin_memory=True, + shuffle=True, + drop_last=True, + batch_size=self.batch_size + ) + + def val_dataloader(self): + return DataLoader( + dataset=self.val_set, + num_workers=1, + pin_memory=True, + shuffle=False, + drop_last=False, + batch_size=self.batch_size + ) + + def test_dataloader(self): + return DataLoader( + dataset=self.test_set, + num_workers=1, + pin_memory=True, + shuffle=False, + drop_last=False, + batch_size=self.batch_size + ) + + @lru_cache + def total_steps(self): + l = len(self.trainer._data_connector._train_dataloader_source.dataloader()) + print('Num devices', self.trainer.num_devices) + max_epochs = self.trainer.max_epochs + accum_batches = self.trainer.accumulate_grad_batches + manual_total_steps = (l // accum_batches * max_epochs)/self.trainer.num_devices + print('MANUAL Total steps', manual_total_steps) + return manual_total_steps + + def configure_optimizers(self): + optimizer = optim.AdamW( + self.parameters(), + lr=self.hparams.training.learning_rate, + weight_decay=self.hparams.training.weight_decay + ) + lr_scheduler = get_cosine_schedule_with_warmup( + optimizer, + num_warmup_steps=int(self.total_steps()*self.warm_up), #hyperparmeter [0.3, 0.4] + num_training_steps=self.total_steps(), + num_cycles=self.hparams.training.n_cycles + ) + return [optimizer], [{"scheduler": lr_scheduler, "interval": "step"}] + + +def sequence2onehot(data_file, lb, length): + ds = pd.read_csv(data_file) + sequences, labels = [],[] + for index, data in ds.iterrows(): + gene_to_number = lb.transform(list(data["sequence"])) + if gene_to_number.shape[0] == length: + sequences.append(gene_to_number) + labels.append(data["label"]) + X = torch.from_numpy(np.array(sequences)).to(torch.int8) + y = torch.from_numpy(np.array(labels)).to(torch.float16) + + return X, y + + +def classify_main(cfg, task, branch): + """ + 1. decide which tack to run + """ + if task == "H3": + config = cfg.H3 + elif task == "H3K4me1": + config = cfg.H3K4me1 + elif task == "H3K4me2": + config = cfg.H3K4me2 + elif task == "H3K4me3": + config = cfg.H3K4me3 + elif task == "H3K36me3": + config = cfg.H3K36me3 + elif task == "H3K14ac": + config = cfg.H3K14ac + elif task == "H4": + config = cfg.H4 + elif task == "H3K79me3": + config = cfg.H3K79me3 + elif task == "H3K9ac": + config = cfg.H3K9ac + elif task == "H4ac": + config = cfg.H4ac + elif task == "prom_core_notata": + config = cfg.Prom_notata + elif task == "prom_core_tata": + config = cfg.Prom_tata + elif task == "prom_core_all": + config = cfg.Prom_all + elif task == "prom_300_notata": + config = cfg.Prom_300_notata + elif task == "prom_300_tata": + config = cfg.Prom_300_tata + elif task == "prom_300_all": + config = cfg.Prom_300_all + elif task == "tf1": + config = cfg.tf1 + elif task == "tf3": + config = cfg.tf3 + elif task == "splice": + config = cfg.Splice + elif task == "virus": + config = cfg.virus + + + """ + 2. load dataset. + """ + + pretrained = config.training.pretrained + length = config.SwanDNA.max_len + loss = nn.CrossEntropyLoss(reduction='mean') + + lb = LabelBinarizer() + lb.fit(['A', 'T', 'C', 'G', 'N']) + + df = pd.read_csv(f"./data/GUE/GUE/virus/{branch}/train.csv") + print(df.describe()) + + train_X, train_y = sequence2onehot(f"./data/GUE/GUE/virus/{branch}/train.csv", lb, length) + val_X, val_y = sequence2onehot(f"./data/GUE/GUE/virus/{branch}/dev.csv", lb, length) + test_X, test_y = sequence2onehot(f"./data/GUE/GUE/virus/{branch}/test.csv", lb, length) + print("***************data******************") + # print(train_X[0]) + print(train_X.size(), test_X.size(), val_X.size()) + + train_set = gb_Dataset(train_X, train_y) + val_set = gb_Dataset(val_X, val_y) + test_set = gb_Dataset(test_X, test_y) + + test_dalaloader = DataLoader( + dataset=test_set, + num_workers=1, + pin_memory=True, + shuffle=False, + drop_last=False, + batch_size=config.training.batch_size + ) + + """ + 3. strat training with ddp mode. + """ + + ddp = DDPStrategy(process_group_backend="nccl", find_unused_parameters=True) + pretrained_model = "model_29_1000_4l_308_512_noiseandTL.pt" + + model = LightningWrapper(GB_Linear_Classifier, config, train_set, val_set, test_set, pretrained, loss, pretrained_model) + summary = ModelSummary(model, max_depth=-1) + + """ + 4. init trainer + """ + + wandb_logger = WandbLogger(dir="./wandb/", project="Prom", entity='tonyu', name=f'{pretrained_model}_{length}_{branch}') + checkpoint_callback = ModelCheckpoint(monitor="val_mcc", mode="max") + + lr_monitor = LearningRateMonitor(logging_interval='step') + callbacks_for_trainer = [TQDMProgressBar(refresh_rate=10), lr_monitor, checkpoint_callback] + if config.training.patience != -1: + early_stopping = EarlyStopping(monitor="val_mcc", mode="max", min_delta=0., patience=cfg.Fine_tuning.training.patience) + callbacks_for_trainer.append(early_stopping) + if config.training.swa_lrs != -1: + swa = StochasticWeightAveraging(swa_lrs=1e-2) + callbacks_for_trainer.append(swa) + + print(summary) + trainer = pl.Trainer( + check_val_every_n_epoch=1, + enable_progress_bar=True, + accelerator='gpu', + strategy=ddp, + devices=[0], + max_epochs=config.training.n_epochs, + gradient_clip_val=0.5, + num_sanity_val_steps=0, + precision=16, + logger=wandb_logger, + callbacks=callbacks_for_trainer + ) + trainer.fit(model) + + trainer.test(model, test_dalaloader, "best") + + +if __name__ == "__main__": + cfg = OmegaConf.load('./config/config_gue.yaml') + OmegaConf.set_struct(cfg, False) + classify_main(cfg, "virus", "covid") diff --git a/mega.txt b/mega.txt new file mode 100644 index 0000000..fde6574 --- /dev/null +++ b/mega.txt @@ -0,0 +1,281 @@ +we are running from this directory: /cluster/home/tonyu/VE_Pretraining/FinDNA + the name of the job is: contrast +Th job ID is 18931222 +The job was run on these nodes: idun-04-09 +Number of nodes: 1 +We are using 1 cores +We are using 1 cores per node +Total of 1 cores +2024-02-01 18:11:39.192857: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. +[rank: 0] Global seed set to 42 +torch.Size([139804, 593, 5]) +GB_Linear_Classifier( + (embedding): Linear(in_features=5, out_features=308, bias=True) + (encoder): SwanDNANetwork( + (SwanDNA_blocks): ModuleList( + (0): SwanDNAEncoder( + (SwanDNA_blocks): ModuleList( + (0): SwanDNABlock( + (prenorm): Identity() + (norm): Identity() + (column_transform): Mlp( + (fc1): Linear(in_features=308, out_features=512, bias=True) + (act): GELU(approximate='none') + (fc2): Linear(in_features=512, out_features=308, bias=True) + (drop): Dropout(p=0.3, inplace=False) + ) + (dropout): Dropout(p=0.3, inplace=False) + (shift): CircularShift() + ) + (1): SwanDNABlock( + (prenorm): Identity() + (norm): Identity() + (column_transform): Mlp( + (fc1): Linear(in_features=308, out_features=512, bias=True) + (act): GELU(approximate='none') + (fc2): Linear(in_features=512, out_features=308, bias=True) + (drop): Dropout(p=0.3, inplace=False) + ) + (dropout): Dropout(p=0.3, inplace=False) + (shift): CircularShift() + ) + (2): SwanDNABlock( + (prenorm): Identity() + (norm): Identity() + (column_transform): Mlp( + (fc1): Linear(in_features=308, out_features=512, bias=True) + (act): GELU(approximate='none') + (fc2): Linear(in_features=512, out_features=308, bias=True) + (drop): Dropout(p=0.3, inplace=False) + ) + (dropout): Dropout(p=0.3, inplace=False) + (shift): CircularShift() + ) + (3): SwanDNABlock( + (prenorm): Identity() + (norm): Identity() + (column_transform): Mlp( + (fc1): Linear(in_features=308, out_features=512, bias=True) + (act): GELU(approximate='none') + (fc2): Linear(in_features=512, out_features=308, bias=True) + (drop): Dropout(p=0.3, inplace=False) + ) + (dropout): Dropout(p=0.3, inplace=False) + (shift): CircularShift() + ) + (4): SwanDNABlock( + (prenorm): Identity() + (norm): Identity() + (column_transform): Mlp( + (fc1): Linear(in_features=308, out_features=512, bias=True) + (act): GELU(approximate='none') + (fc2): Linear(in_features=512, out_features=308, bias=True) + (drop): Dropout(p=0.3, inplace=False) + ) + (dropout): Dropout(p=0.3, inplace=False) + (shift): CircularShift() + ) + (5): SwanDNABlock( + (prenorm): Identity() + (norm): Identity() + (column_transform): Mlp( + (fc1): Linear(in_features=308, out_features=512, bias=True) + (act): GELU(approximate='none') + (fc2): Linear(in_features=512, out_features=308, bias=True) + (drop): Dropout(p=0.3, inplace=False) + ) + (dropout): Dropout(p=0.3, inplace=False) + (shift): CircularShift() + ) + (6): SwanDNABlock( + (prenorm): Identity() + (norm): Identity() + (column_transform): Mlp( + (fc1): Linear(in_features=308, out_features=512, bias=True) + (act): GELU(approximate='none') + (fc2): Linear(in_features=512, out_features=308, bias=True) + (drop): Dropout(p=0.3, inplace=False) + ) + (dropout): Dropout(p=0.3, inplace=False) + (shift): CircularShift() + ) + (7): SwanDNABlock( + (prenorm): Identity() + (norm): Identity() + (column_transform): Mlp( + (fc1): Linear(in_features=308, out_features=512, bias=True) + (act): GELU(approximate='none') + (fc2): Linear(in_features=512, out_features=308, bias=True) + (drop): Dropout(p=0.3, inplace=False) + ) + (dropout): Dropout(p=0.3, inplace=False) + (shift): CircularShift() + ) + (8): SwanDNABlock( + (prenorm): Identity() + (norm): Identity() + (column_transform): Mlp( + (fc1): Linear(in_features=308, out_features=512, bias=True) + (act): GELU(approximate='none') + (fc2): Linear(in_features=512, out_features=308, bias=True) + (drop): Dropout(p=0.3, inplace=False) + ) + (dropout): Dropout(p=0.3, inplace=False) + (shift): CircularShift() + ) + (9): SwanDNABlock( + (prenorm): Identity() + (norm): Identity() + (column_transform): Mlp( + (fc1): Linear(in_features=308, out_features=512, bias=True) + (act): GELU(approximate='none') + (fc2): Linear(in_features=512, out_features=308, bias=True) + (drop): Dropout(p=0.3, inplace=False) + ) + (dropout): Dropout(p=0.3, inplace=False) + (shift): CircularShift() + ) + ) + ) + (1): SwanDNAEncoder( + (SwanDNA_blocks): ModuleList( + (0): SwanDNABlock( + (prenorm): Identity() + (norm): Identity() + (column_transform): Mlp( + (fc1): Linear(in_features=308, out_features=512, bias=True) + (act): GELU(approximate='none') + (fc2): Linear(in_features=512, out_features=308, bias=True) + (drop): Dropout(p=0.3, inplace=False) + ) + (dropout): Dropout(p=0.3, inplace=False) + (shift): CircularShift() + ) + (1): SwanDNABlock( + (prenorm): Identity() + (norm): Identity() + (column_transform): Mlp( + (fc1): Linear(in_features=308, out_features=512, bias=True) + (act): GELU(approximate='none') + (fc2): Linear(in_features=512, out_features=308, bias=True) + (drop): Dropout(p=0.3, inplace=False) + ) + (dropout): Dropout(p=0.3, inplace=False) + (shift): CircularShift() + ) + (2): SwanDNABlock( + (prenorm): Identity() + (norm): Identity() + (column_transform): Mlp( + (fc1): Linear(in_features=308, out_features=512, bias=True) + (act): GELU(approximate='none') + (fc2): Linear(in_features=512, out_features=308, bias=True) + (drop): Dropout(p=0.3, inplace=False) + ) + (dropout): Dropout(p=0.3, inplace=False) + (shift): CircularShift() + ) + (3): SwanDNABlock( + (prenorm): Identity() + (norm): Identity() + (column_transform): Mlp( + (fc1): Linear(in_features=308, out_features=512, bias=True) + (act): GELU(approximate='none') + (fc2): Linear(in_features=512, out_features=308, bias=True) + (drop): Dropout(p=0.3, inplace=False) + ) + (dropout): Dropout(p=0.3, inplace=False) + (shift): CircularShift() + ) + (4): SwanDNABlock( + (prenorm): Identity() + (norm): Identity() + (column_transform): Mlp( + (fc1): Linear(in_features=308, out_features=512, bias=True) + (act): GELU(approximate='none') + (fc2): Linear(in_features=512, out_features=308, bias=True) + (drop): Dropout(p=0.3, inplace=False) + ) + (dropout): Dropout(p=0.3, inplace=False) + (shift): CircularShift() + ) + (5): SwanDNABlock( + (prenorm): Identity() + (norm): Identity() + (column_transform): Mlp( + (fc1): Linear(in_features=308, out_features=512, bias=True) + (act): GELU(approximate='none') + (fc2): Linear(in_features=512, out_features=308, bias=True) + (drop): Dropout(p=0.3, inplace=False) + ) + (dropout): Dropout(p=0.3, inplace=False) + (shift): CircularShift() + ) + (6): SwanDNABlock( + (prenorm): Identity() + (norm): Identity() + (column_transform): Mlp( + (fc1): Linear(in_features=308, out_features=512, bias=True) + (act): GELU(approximate='none') + (fc2): Linear(in_features=512, out_features=308, bias=True) + (drop): Dropout(p=0.3, inplace=False) + ) + (dropout): Dropout(p=0.3, inplace=False) + (shift): CircularShift() + ) + (7): SwanDNABlock( + (prenorm): Identity() + (norm): Identity() + (column_transform): Mlp( + (fc1): Linear(in_features=308, out_features=512, bias=True) + (act): GELU(approximate='none') + (fc2): Linear(in_features=512, out_features=308, bias=True) + (drop): Dropout(p=0.3, inplace=False) + ) + (dropout): Dropout(p=0.3, inplace=False) + (shift): CircularShift() + ) + (8): SwanDNABlock( + (prenorm): Identity() + (norm): Identity() + (column_transform): Mlp( + (fc1): Linear(in_features=308, out_features=512, bias=True) + (act): GELU(approximate='none') + (fc2): Linear(in_features=512, out_features=308, bias=True) + (drop): Dropout(p=0.3, inplace=False) + ) + (dropout): Dropout(p=0.3, inplace=False) + (shift): CircularShift() + ) + (9): SwanDNABlock( + (prenorm): Identity() + (norm): Identity() + (column_transform): Mlp( + (fc1): Linear(in_features=308, out_features=512, bias=True) + (act): GELU(approximate='none') + (fc2): Linear(in_features=512, out_features=308, bias=True) + (drop): Dropout(p=0.3, inplace=False) + ) + (dropout): Dropout(p=0.3, inplace=False) + (shift): CircularShift() + ) + ) + ) + ) + ) + (decoder): Linear(in_features=308, out_features=2, bias=True) +) +Traceback (most recent call last): + File "/cluster/home/tonyu/VE_Pretraining/FinDNA/genomic_classification.py", line 241, in + classify_main(cfg, "human_ocr_ensembl") + File "/cluster/home/tonyu/VE_Pretraining/FinDNA/genomic_classification.py", line 203, in classify_main + model = LightningWrapper(GB_Linear_Classifier, config, train_set, val_set, pretrained, loss, pretrained_model) + File "/cluster/home/tonyu/VE_Pretraining/FinDNA/genomic_classification.py", line 44, in __init__ + pretrained = torch.load(pretrained_path, map_location='cpu') + File "/cluster/home/tonyu/.conda/envs/monica/lib/python3.9/site-packages/torch/serialization.py", line 771, in load + with _open_file_like(f, 'rb') as opened_file: + File "/cluster/home/tonyu/.conda/envs/monica/lib/python3.9/site-packages/torch/serialization.py", line 270, in _open_file_like + return _open_file(name_or_buffer, mode) + File "/cluster/home/tonyu/.conda/envs/monica/lib/python3.9/site-packages/torch/serialization.py", line 251, in __init__ + super(_open_file, self).__init__(open(name, mode)) +FileNotFoundError: [Errno 2] No such file or directory: './Pretrained_models/model_29_1000_4l_308_512_noiseandTL.pt' diff --git a/models/Other_models/S4_model.py b/models/Other_models/S4_model.py new file mode 100644 index 0000000..6cd39c4 --- /dev/null +++ b/models/Other_models/S4_model.py @@ -0,0 +1,83 @@ +import torch +import torch.nn as nn +from s4 import S4Block as S4 # Can use full version instead of minimal S4D standalone below +from S4_src import S4D + +# Dropout broke in PyTorch 1.11 +if tuple(map(int, torch.__version__.split('.')[:2])) == (1, 11): + print("WARNING: Dropout is bugged in PyTorch 1.11. Results may be worse.") + dropout_fn = nn.Dropout +if tuple(map(int, torch.__version__.split('.')[:2])) >= (1, 12): + dropout_fn = nn.Dropout1d +else: + dropout_fn = nn.Dropout2d + + +class S4Model(nn.Module): + + def __init__( + self, + d_input, + d_model=256, + n_layers=4, + dropout=0.2, + prenorm=False, + ): + super().__init__() + + self.prenorm = prenorm + + # Linear encoder (d_input = 1 for grayscale and 3 for RGB) + self.encoder = nn.Linear(d_input, d_model) + + # Stack S4 layers as residual blocks + self.s4_layers = nn.ModuleList() + self.norms = nn.ModuleList() + self.dropouts = nn.ModuleList() + for _ in range(n_layers): + self.s4_layers.append( + S4D(d_model, dropout=dropout, transposed=True, lr=min(0.001, 0.0003)) + ) + self.norms.append(nn.LayerNorm(d_model)) + self.dropouts.append(dropout_fn(dropout)) + + # Linear decoder + # self.decoder = nn.Linear(d_model, d_output) + + def forward(self, x): + """ + Input x is shape (B, L, d_input) + """ + # x = self.encoder(x) # (B, L, d_input) -> (B, L, d_model) + + x = x.transpose(-1, -2) # (B, L, d_model) -> (B, d_model, L) + for layer, norm, dropout in zip(self.s4_layers, self.norms, self.dropouts): + # Each iteration of this loop will map (B, d_model, L) -> (B, d_model, L) + + z = x + if self.prenorm: + # Prenorm + z = norm(z.transpose(-1, -2)).transpose(-1, -2) + + # Apply S4 block: we ignore the state input and output + z, _ = layer(z) + + # Dropout on the output of the S4 block + z = dropout(z) + + # Residual connection + x = z + x + + if not self.prenorm: + # Postnorm + x = norm(x.transpose(-1, -2)).transpose(-1, -2) + + x = x.transpose(-1, -2) + + # # Pooling: average pooling over the sequence length + # x = x.mean(dim=1) + + # # Decode the outputs + # x = self.decoder(x) # (B, d_model) -> (B, d_output) + + return x \ No newline at end of file diff --git a/models/Other_models/S4_src.py b/models/Other_models/S4_src.py new file mode 100644 index 0000000..0bc3a96 --- /dev/null +++ b/models/Other_models/S4_src.py @@ -0,0 +1,133 @@ +"""Minimal version of S4D with extra options and features stripped out, for pedagogical purposes.""" + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat + + +class DropoutNd(nn.Module): + def __init__(self, p: float = 0.5, tie=True, transposed=True): + """ + tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d) + """ + super().__init__() + if p < 0 or p >= 1: + raise ValueError("dropout probability has to be in [0, 1), " "but got {}".format(p)) + self.p = p + self.tie = tie + self.transposed = transposed + self.binomial = torch.distributions.binomial.Binomial(probs=1-self.p) + + def forward(self, X): + """X: (batch, dim, lengths...).""" + if self.training: + if not self.transposed: X = rearrange(X, 'b ... d -> b d ...') + # binomial = torch.distributions.binomial.Binomial(probs=1-self.p) # This is incredibly slow because of CPU -> GPU copying + mask_shape = X.shape[:2] + (1,)*(X.ndim-2) if self.tie else X.shape + # mask = self.binomial.sample(mask_shape) + mask = torch.rand(*mask_shape, device=X.device) < 1.-self.p + X = X * mask * (1.0/(1-self.p)) + if not self.transposed: X = rearrange(X, 'b d ... -> b ... d') + return X + return X + + +class S4DKernel(nn.Module): + """Generate convolution kernel from diagonal SSM parameters.""" + + def __init__(self, d_model, N=64, dt_min=0.001, dt_max=0.1, lr=None): + super().__init__() + # Generate dt + H = d_model + log_dt = torch.rand(H) * ( + math.log(dt_max) - math.log(dt_min) + ) + math.log(dt_min) + + C = torch.randn(H, N // 2, dtype=torch.cfloat) + self.C = nn.Parameter(torch.view_as_real(C)) + self.register("log_dt", log_dt, lr) + + log_A_real = torch.log(0.5 * torch.ones(H, N//2)) + A_imag = math.pi * repeat(torch.arange(N//2), 'n -> h n', h=H) + self.register("log_A_real", log_A_real, lr) + self.register("A_imag", A_imag, lr) + + def forward(self, L): + """ + returns: (..., c, L) where c is number of channels (default 1) + """ + + # Materialize parameters + dt = torch.exp(self.log_dt) # (H) + C = torch.view_as_complex(self.C) # (H N) + A = -torch.exp(self.log_A_real) + 1j * self.A_imag # (H N) + + # Vandermonde multiplication + dtA = A * dt.unsqueeze(-1) # (H N) + K = dtA.unsqueeze(-1) * torch.arange(L, device=A.device) # (H N L) + C = C * (torch.exp(dtA)-1.) / A + K = 2 * torch.einsum('hn, hnl -> hl', C, torch.exp(K)).real + + return K + + def register(self, name, tensor, lr=None): + """Register a tensor with a configurable learning rate and 0 weight decay""" + + if lr == 0.0: + self.register_buffer(name, tensor) + else: + self.register_parameter(name, nn.Parameter(tensor)) + + optim = {"weight_decay": 0.0} + if lr is not None: optim["lr"] = lr + setattr(getattr(self, name), "_optim", optim) + + +class S4D(nn.Module): + def __init__(self, d_model, d_state=64, dropout=0.0, transposed=True, **kernel_args): + super().__init__() + + self.h = d_model + self.n = d_state + self.d_output = self.h + self.transposed = transposed + + self.D = nn.Parameter(torch.randn(self.h)) + + # SSM Kernel + self.kernel = S4DKernel(self.h, N=self.n, **kernel_args) + + # Pointwise + self.activation = nn.GELU() + # dropout_fn = nn.Dropout2d # NOTE: bugged in PyTorch 1.11 + dropout_fn = DropoutNd + self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity() + + # position-wise output transform to mix features + self.output_linear = nn.Sequential( + nn.Conv1d(self.h, 2*self.h, kernel_size=1), + nn.GLU(dim=-2), + ) + + def forward(self, u, **kwargs): # absorbs return_output and transformer src mask + """ Input and output shape (B, H, L) """ + if not self.transposed: u = u.transpose(-1, -2) + L = u.size(-1) + + # Compute SSM Kernel + k = self.kernel(L=L) # (H L) + + # Convolution + k_f = torch.fft.rfft(k, n=2*L) # (H L) + u_f = torch.fft.rfft(u, n=2*L) # (B H L) + y = torch.fft.irfft(u_f*k_f, n=2*L)[..., :L] # (B H L) + + # Compute D term in state space equation - essentially a skip connection + y = y + u * self.D.unsqueeze(-1) + + y = self.dropout(self.activation(y)) + y = self.output_linear(y) + if not self.transposed: y = y.transpose(-1, -2) + return y, None # Return a dummy state to satisfy this repo's interface, but this can be modified \ No newline at end of file diff --git a/models/Other_models/s4.py b/models/Other_models/s4.py new file mode 100644 index 0000000..bdb6c0b --- /dev/null +++ b/models/Other_models/s4.py @@ -0,0 +1,1964 @@ +"""Standalone version of Structured State Space sequence model (S4).""" + +from collections import defaultdict +from typing import Optional, Mapping, Tuple, Union +import logging +from functools import partial +import math +import numpy as np +from scipy import special as ss +import torch +import torch.nn as nn +import torch.nn.functional as F +from pytorch_lightning.utilities import rank_zero_only +from einops import rearrange, repeat + +# Function aliases +contract = torch.einsum + +_conj = lambda x: torch.cat([x, x.conj()], dim=-1) +_c2r = torch.view_as_real +_r2c = torch.view_as_complex +if tuple(map(int, torch.__version__.split('.')[:2])) >= (1, 10): + _resolve_conj = lambda x: x.conj().resolve_conj() +else: + _resolve_conj = lambda x: x.conj() + + +def get_logger(name=__name__, level=logging.INFO) -> logging.Logger: + """Initializes multi-GPU-friendly python logger.""" + + logger = logging.getLogger(name) + logger.setLevel(level) + + # this ensures all logging levels get marked with the rank zero decorator + # otherwise logs would get multiplied for each GPU process in multi-GPU setup + for level in ("debug", "info", "warning", "error", "exception", "fatal", "critical"): + setattr(logger, level, rank_zero_only(getattr(logger, level))) + + return logger +log = get_logger(__name__) + +"""Structured matrix kernels""" + +# Try CUDA extension +try: + from extensions.kernels.cauchy import cauchy_mult as cauchy_cuda + from extensions.kernels.vandermonde import log_vandermonde_cuda + has_cuda_extension = True + log.info("CUDA extension for structured kernels (Cauchy and Vandermonde multiplication) found.") +except: + log.warning( + "CUDA extension for structured kernels (Cauchy and Vandermonde multiplication) not found. Install by going to extensions/kernels/ and running `python setup.py install`, for improved speed and memory efficiency. Note that the kernel changed for state-spaces 4.0 and must be recompiled." + ) + has_cuda_extension = False + +# Try pykeops +try: + import pykeops + from pykeops.torch import Genred + has_pykeops = True + log.info("Pykeops installation found.") + + def _broadcast_dims(*tensors): + max_dim = max([len(tensor.shape) for tensor in tensors]) + tensors = [tensor.view((1,)*(max_dim-len(tensor.shape))+tensor.shape) for tensor in tensors] + return tensors + + def cauchy_keops(v, z, w): + expr_num = 'z * ComplexReal(v) - Real2Complex(Sum(v * w))' + expr_denom = 'ComplexMult(z-w, z-Conj(w))' + + cauchy_mult = Genred( + f'ComplexDivide({expr_num}, {expr_denom})', + [ + 'v = Vj(2)', + 'z = Vi(2)', + 'w = Vj(2)', + ], + reduction_op='Sum', + axis=1, + ) + + v, z, w = _broadcast_dims(v, z, w) + v = _c2r(v) + z = _c2r(z) + w = _c2r(w) + + r = 2*cauchy_mult(v, z, w, backend='GPU') + return _r2c(r) + + def log_vandermonde_keops(v, x, L): + expr = 'ComplexMult(v, ComplexExp(ComplexMult(x, l)))' + vandermonde_mult = Genred( + expr, + [ + 'v = Vj(2)', + 'x = Vj(2)', + 'l = Vi(2)', + ], + reduction_op='Sum', + axis=1, + ) + + l = torch.arange(L).to(x) + v, x, l = _broadcast_dims(v, x, l) + v = _c2r(v) + x = _c2r(x) + l = _c2r(l) + + r = vandermonde_mult(v, x, l, backend='GPU') + return 2*_r2c(r).real + + def log_vandermonde_transpose_keops(u, v, x, L): + """ + u: ... H L + v: ... H N + x: ... H N + Returns: ... H N + + V = Vandermonde(a, L) : (H N L) + contract_L(V * u * v) + """ + expr = 'ComplexMult(ComplexMult(v, u), ComplexExp(ComplexMult(x, l)))' + vandermonde_mult = Genred( + expr, + [ + 'u = Vj(2)', + 'v = Vi(2)', + 'x = Vi(2)', + 'l = Vj(2)', + ], + reduction_op='Sum', + axis=1, + ) + + l = torch.arange(L).to(x) + u, v, x, l = _broadcast_dims(u, v, x, l) + u = _c2r(u) + v = _c2r(v) + x = _c2r(x) + l = _c2r(l) + + r = vandermonde_mult(u, v, x, l, backend='GPU') + return _r2c(r) + +except ImportError: + has_pykeops = False + if not has_cuda_extension: + log.warning( + "Falling back on slow Cauchy and Vandermonde kernel. Install at least one of pykeops or the CUDA extension for better speed and memory efficiency." + ) + +# Fallback versions +def cauchy_naive(v, z, w): + """ + v: (..., N) + z: (..., L) + w: (..., N) + returns: (..., L) \sum v/(z-w) + """ + v = _conj(v) + w = _conj(w) + cauchy_matrix = v.unsqueeze(-1) / (z.unsqueeze(-2) - w.unsqueeze(-1)) # (... N L) + return torch.sum(cauchy_matrix, dim=-2) + +def log_vandermonde_naive(v, x, L, conj=True): + """ + v: (..., N) + x: (..., N) + returns: (..., L) \sum v x^l + """ + vandermonde_matrix = torch.exp(x.unsqueeze(-1) * torch.arange(L).to(x)) # (... N L) + vandermonde_prod = contract('... n, ... n l -> ... l', v, vandermonde_matrix) # (... L) + return 2*vandermonde_prod.real + +def log_vandermonde_transpose_naive(u, v, x, L): + vandermonde_matrix = torch.exp(x.unsqueeze(-1) * torch.arange(L).to(x)) # (... N L) + vandermonde_prod = contract('... l, ... n, ... n l -> ... n', u.to(x), v.to(x), vandermonde_matrix) # (... L) + return vandermonde_prod + + + +""" Simple nn.Module components """ + +def Activation(activation=None, dim=-1): + if activation in [ None, 'id', 'identity', 'linear' ]: + return nn.Identity() + elif activation == 'tanh': + return nn.Tanh() + elif activation == 'relu': + return nn.ReLU() + elif activation == 'gelu': + return nn.GELU() + elif activation == 'elu': + return nn.ELU() + elif activation in ['swish', 'silu']: + return nn.SiLU() + elif activation == 'glu': + return nn.GLU(dim=dim) + elif activation == 'sigmoid': + return nn.Sigmoid() + elif activation == 'softplus': + return nn.Softplus() + else: + raise NotImplementedError("hidden activation '{}' is not implemented".format(activation)) + +def LinearActivation( + d_input, d_output, bias=True, + transposed=False, + activation=None, + activate=False, # Apply activation as part of this module + **kwargs, + ): + """Returns a linear nn.Module with control over axes order, initialization, and activation.""" + + # Construct core module + linear_cls = partial(nn.Conv1d, kernel_size=1) if transposed else nn.Linear + if activation is not None and activation == 'glu': d_output *= 2 + linear = linear_cls(d_input, d_output, bias=bias, **kwargs) + + if activate and activation is not None: + activation = Activation(activation, dim=-2 if transposed else -1) + linear = nn.Sequential(linear, activation) + return linear + +class DropoutNd(nn.Module): + def __init__(self, p: float = 0.5, tie=True, transposed=True): + """ + tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d) + """ + super().__init__() + if p < 0 or p >= 1: + raise ValueError("dropout probability has to be in [0, 1), " "but got {}".format(p)) + self.p = p + self.tie = tie + self.transposed = transposed + self.binomial = torch.distributions.binomial.Binomial(probs=1-self.p) + + def forward(self, X): + """X: (batch, dim, lengths...).""" + if self.training: + if not self.transposed: X = rearrange(X, 'b ... d -> b d ...') + mask_shape = X.shape[:2] + (1,)*(X.ndim-2) if self.tie else X.shape + mask = torch.rand(*mask_shape, device=X.device) < 1.-self.p + X = X * mask * (1.0/(1-self.p)) + if not self.transposed: X = rearrange(X, 'b d ... -> b ... d') + return X + return X + +"""Misc functional utilities""" + +def power(L, A, v=None): + """Compute A^L and the scan sum_i A^i v_i. + + A: (..., N, N) + v: (..., N, L) + """ + + I = torch.eye(A.shape[-1]).to(A) # , dtype=A.dtype, device=A.device) + + powers = [A] + l = 1 + while True: + if L % 2 == 1: I = powers[-1] @ I + L //= 2 + if L == 0: break + l *= 2 + if v is None: + powers = [powers[-1] @ powers[-1]] + else: + powers.append(powers[-1] @ powers[-1]) + + if v is None: return I + + # Invariants: + # powers[-1] := A^l + # l := largest po2 at most L + + # Note that an alternative divide and conquer to compute the reduction is possible and can be embedded into the above loop without caching intermediate powers of A + # We do this reverse divide-and-conquer for efficiency reasons: + # 1) it involves fewer padding steps for non-po2 L + # 2) it involves more contiguous arrays + + # Take care of edge case for non-po2 arrays + # Note that this initial step is a no-op for the case of power of 2 (l == L) + k = v.size(-1) - l + v_ = powers.pop() @ v[..., l:] + v = v[..., :l] + v[..., :k] = v[..., :k] + v_ + + # Handle reduction for power of 2 + while v.size(-1) > 1: + v = rearrange(v, '... (z l) -> ... z l', z=2) + v = v[..., 0, :] + powers.pop() @ v[..., 1, :] + return I, v.squeeze(-1) + + +"""HiPPO utilities""" + +def transition(measure, N, **measure_args): + """A, B transition matrices for different measures. + + measure: the type of measure + legt - Legendre (translated) + legs - Legendre (scaled) + glagt - generalized Laguerre (translated) + lagt, tlagt - previous versions of (tilted) Laguerre with slightly different normalization + """ + # Legendre (translated) + if measure == 'legt': + Q = np.arange(N, dtype=np.float64) + R = (2*Q + 1) ** .5 + j, i = np.meshgrid(Q, Q) + A = R[:, None] * np.where(i < j, (-1.)**(i-j), 1) * R[None, :] + B = R[:, None] + A = -A + + # Halve again for timescale correctness + A *= 0.5 + B *= 0.5 + # Legendre (scaled) + elif measure == 'legs': + q = np.arange(N, dtype=np.float64) + col, row = np.meshgrid(q, q) + r = 2 * q + 1 + M = -(np.where(row >= col, r, 0) - np.diag(q)) + T = np.sqrt(np.diag(2 * q + 1)) + A = T @ M @ np.linalg.inv(T) + B = np.diag(T)[:, None] + B = B.copy() # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B) + elif measure in ['fourier', 'fout']: + freqs = np.arange(N//2) + d = np.stack([np.zeros(N//2), freqs], axis=-1).reshape(-1)[1:] + A = np.pi*(-np.diag(d, 1) + np.diag(d, -1)) + B = np.zeros(N) + B[0::2] = 2**.5 + B[0] = 1 + + # Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case + A = A - B[:, None] * B[None, :] + B = B[:, None] + else: + raise NotImplementedError + + return A, B + +def rank_correction(measure, N, rank=1, dtype=torch.float): + """Return low-rank matrix L such that A + L is normal.""" + + if measure == 'legs': + assert rank >= 1 + P = torch.sqrt(.5+torch.arange(N, dtype=dtype)).unsqueeze(0) # (1 N) + elif measure == 'legt': + assert rank >= 2 + P = torch.sqrt(1+2*torch.arange(N, dtype=dtype)) # (N) + P0 = P.clone() + P0[0::2] = 0. + P1 = P.clone() + P1[1::2] = 0. + P = torch.stack([P0, P1], dim=0) # (2 N) + P *= 2**(-0.5) # Halve the rank correct just like the original matrix was halved + elif measure in ['fourier', 'fout']: + P = torch.zeros(N) + P[0::2] = 2**.5 + P[0] = 1 + P = P.unsqueeze(0) + else: raise NotImplementedError + + d = P.size(0) + if rank > d: + P = torch.cat([P, torch.zeros(rank-d, N, dtype=dtype)], dim=0) # (rank N) + return P + +def nplr(measure, N, rank=1, dtype=torch.float, diagonalize_precision=True, B_clip=2.0): + """Constructs NPLR form of HiPPO matrices. + + Returns w, p, q, V, B such that + (w - p q^*, B) is unitarily equivalent to the original HiPPO A, B by the matrix V + i.e. A = V[w - p q^*]V^*, B = V B + + measure: Name of HiPPO method. + N: Size of recurrent A matrix (also known as `d_state` elsewhere). + dtype: Single or double precision. + diagonalize_precision: Calculate diagonalization in double precision. + B_clip: Clip values of B, can help with stability. None for no clipping. + """ + + assert dtype == torch.float or dtype == torch.double + cdtype = torch.cfloat if dtype == torch.float else torch.cdouble + + A, B = transition(measure, N) + A = torch.as_tensor(A, dtype=dtype) # (N, N) + B = torch.as_tensor(B, dtype=dtype)[:, 0] # (N,) + + P = rank_correction(measure, N, rank=rank, dtype=dtype) # (r N) + AP = A + torch.sum(P.unsqueeze(-2)*P.unsqueeze(-1), dim=-3) + + # We require AP to be nearly skew-symmetric + _A = AP + AP.transpose(-1, -2) + if (err := torch.sum((_A - _A[0,0]*torch.eye(N))**2) / N) > 1e-5: # if not torch.allclose(_A - _A[0,0]*torch.eye(N), torch.zeros(N, N), atol=1e-5): + print("WARNING: HiPPO matrix not skew symmetric", err) + + + # Take advantage of identity + skew-symmetric form to calculate real and imaginary parts separately + # Imaginary part can use eigh instead of eig + W_re = torch.mean(torch.diagonal(AP), -1, keepdim=True) + + # Diagonalize in double precision + if diagonalize_precision: AP = AP.to(torch.double) + # w, V = torch.linalg.eig(AP) # (..., N) (..., N, N) + W_im, V = torch.linalg.eigh(AP*-1j) # (..., N) (..., N, N) + if diagonalize_precision: W_im, V = W_im.to(cdtype), V.to(cdtype) + W = W_re + 1j * W_im + # Check: V W V^{-1} = A + # print("check", V @ torch.diag_embed(W) @ V.conj().transpose(-1, -2)) + + + # Only keep half of each conjugate pair + _, idx = torch.sort(W.imag) + W_sorted = W[idx] + V_sorted = V[:, idx] + + # There is an edge case when eigenvalues can be 0, which requires some machinery to handle + # We use a huge hack here: Assume only one pair is 0, and that it is the first row/column of A (only happens in Fourier case) + V = V_sorted[:, :N//2] + W = W_sorted[:N//2] # Only keep negative imaginary components + assert W[-2].abs() > 1e-4, "Only 1 zero eigenvalue allowed in diagonal part of A" + if W[-1].abs() < 1e-4: + V[:, -1] = 0. + V[0, -1] = 2**-0.5 + V[1, -1] = 2**-0.5 * 1j + + _AP = V @ torch.diag_embed(W) @ V.conj().transpose(-1, -2) + if ((err := torch.sum((2*_AP.real-AP)**2)/N) > 1e-5): + print("Warning: Diagonalization of A matrix not numerically precise - error", err) + # print("check", V @ torch.diag_embed(W) @ V.conj().transpose(-1, -2)) + + V_inv = V.conj().transpose(-1, -2) + + # C = initial_C(measure, N, dtype=dtype) + B = contract('ij, j -> i', V_inv, B.to(V)) # V^* B + # C = contract('ij, j -> i', V_inv, C.to(V)) # V^* C + P = contract('ij, ...j -> ...i', V_inv, P.to(V)) # V^* P + + if B_clip is not None: + B = B.real + 1j*torch.clamp(B.imag, min=-B_clip, max=B_clip) + + # W represents the imaginary part of the DPLR form: A = W - PP^* + # Downstream classes just call this A for simplicity, + # which is also more consistent with the diagonal case + return W, P, B, V + +def dplr( + init='hippo', + N=64, rank=1, H=1, + dtype=torch.float, + real_random=False, + real_scale=1.0, + imag_random=False, + imag_scale=1.0, + B_random=False, + B_init='constant', + B_scale=1.0, + P_scale=1.0, + normalize=False, +): + """Directly construct a DPLR matrix. + + Args: + - init: (str) ['rand', 'lin', inv', 'real', 'hippo'] Choices for initialization of A. + Most of these affect the imaginary part of A, except for 'real'. + - real_random: (bool) Initialize A.real in -U[0, 1]. Otherwise, initialize to -1/2. + - real_scale: (float) Scaling factor of real part of A. + - imag_random: (bool) Initialize A.imag randomly. + - imag_scale: (bool) Scaling factor of imaginary part of A. + - B_init: (str) ['constant' | 'random' | 'alternating' | 'unit-cw' | 'unit-ccw' | 'hippo'] + Choices for initialization of B. + - B_scale: (float) Scaling factor for B + - P_scale: (float) Scaling factor for P + - normalize: (bool) Apply an automatic normalization factor on B + """ + assert dtype == torch.float or dtype == torch.double + dtype = torch.cfloat if dtype == torch.float else torch.cdouble + + pi = torch.tensor(math.pi) + + # Construct real part of diagonal A (must be non-negative) + if real_random: + real_part = torch.rand(H, N//2) + else: + real_part = .5 * torch.ones(H, N//2) + real_part = real_scale * real_part + + # Construct imaginary part of diagonal A (must be non-negative) + if imag_random: + imag_part = N//2 * torch.rand(H, N//2) + else: + imag_part = repeat(torch.arange(N//2), 'n -> h n', h=H) + + if init in ['random', 'rand']: + imag_part = torch.exp(torch.randn(H, N//2)) + elif init == 'real': + imag_part = 0 * imag_part + if real_random: + real_part = torch.rand(H, N//2) * N//2 + else: + # This is the S4D-Real method described in the S4D paper + # The A matrix is diag(-1, -2, ..., -N), which are the eigenvalues of the HiPPO matrix + real_part = 1 + repeat(torch.arange(N//2), 'n -> h n', h=H) + elif init in ['linear', 'lin']: + imag_part = pi * imag_part + elif init in ['inverse', 'inv']: # Based on asymptotics of the default HiPPO matrix + imag_part = 1/pi * N * (N/(1+2*imag_part)-1) + elif init in ['inverse2', 'inv2']: + imag_part = 1/pi * N * (N/(1+imag_part)-1) + elif init in ['quadratic', 'quad']: + imag_part = 1/pi * (1+2*imag_part)**2 + elif init in ['legs', 'hippo']: + A, _, _, _ = nplr('legs', N) + imag_part = -A.imag # Positive + else: raise NotImplementedError + imag_part = imag_scale * imag_part + + # Construct diagonal A + A = -real_part - 1j * imag_part # Force negative real and imag + assert torch.all(A.real < 1e-4) and torch.all(A.imag <= 0.0) # Allow some tolerance for numerical precision on real part + + # Initialize B + if B_random: + log.warning("'B_random' is deprecated in favor of B_init='random' and will be deprecated in a future version.") + if init in ['legs', 'hippo']: + log.info(f'Initializing with S4D-LegS and ignoring argument {B_init=}') + # Special initialization using the HiPPO B matrix + # Note that theory (from S4D paper) says that B should be halved + # to match DPLR but we drop this 0.5 factor for simplicity + _, P, B, _ = nplr('legs', N, B_clip=2.0) + B = repeat(B, 'n -> h n', h=H).clone().contiguous() + elif B_init == 'constant': + B = torch.ones(H, N//2, dtype=dtype) + elif B_init == 'random': + B = torch.randn(H, N//2, dtype=dtype) + elif B_init == 'alternating': # Seems to track 'constant' exactly for some reason + B = torch.ones(H, N//4, 2, dtype=dtype) + B[:, :, 1] *= -1 + B = B.view(H, N//2) + elif B_init == 'unit-cw': + z = torch.tensor(torch.exp(-2j * pi / N), dtype=dtype) + B = z ** torch.arange(0, N // 2) + B = repeat(B, 'n -> h n', h=H).clone().contiguous() + elif B_init == 'unit-ccw': + z = torch.tensor(torch.exp(2j * pi / N), dtype=dtype) + B = z ** torch.arange(0, N // 2) + B = repeat(B, 'n -> h n', h=H).clone().contiguous() + else: raise NotImplementedError + B *= B_scale + + # Experimental feature that appeared in earlier versions of HTTYH (not extensively tested) + # Seems more principled for normalization theoretically, but seemed to hurt on PathX + if normalize: + norm = -B/A # (H, N) # Result if you integrate the kernel with constant 1 function + zeta = 2*torch.sum(torch.abs(norm)**2, dim=-1, keepdim=True) # Variance with a random C vector + B = B / zeta**.5 + + # Initialize P + if B_init in ['legs', 'hippo']: + # P constructed earlier + P = repeat(P, 'r n -> r h n', h=H).clone().contiguous() + else: + P = torch.randn(rank, H, N//2, dtype=dtype) + P = P * P_scale + + # Initialize V (only used in testing) + V = torch.eye(N, dtype=dtype)[:, :N//2] + V = repeat(V, 'n m -> h n m', h=H) + + return A, P, B, V + +def ssm(init, N, R, H, **ssm_args): + """Dispatcher to create single SSM initialization + + N: state size + R: rank (for DPLR parameterization) + H: number of independent SSM copies + """ + + if init.startswith("diag") or init.startswith("dplr"): + if init.startswith("diag"): + ssm_args["P_scale"] = 0.0 + args = init[4:].split("-") + assert args[0] == "" + if len(args) > 1: + ssm_args["init"] = args[1] + A, P, B, V = dplr(N=N, rank=R, H=H, **ssm_args) + else: + A, P, B, V = nplr(init, N, R, **ssm_args) + A = repeat(A, 'n -> s n', s=H) + P = repeat(P, 'r n -> r s n', s=H) + B = repeat(B, 'n -> s n', s=H) + V = repeat(V, 'n m -> s n m', s=H) + return A, P, B, V + +combinations = { + 'hippo': ['legs', 'fourier'], + 'diag': ['diag-inv', 'diag-lin'], + 'all': ['legs', 'fourier', 'diag-inv', 'diag-lin'], +} + +def combination(inits, N, R, S, **ssm_args): + if isinstance(inits, str): + inits = combinations[inits] if inits in combinations else [inits] + + assert S % len(inits) == 0, f"{S} independent trainable SSM copies must be multiple of {len(inits)} different inits" + A, P, B, V = zip( + *[ssm(init, N, R, S // len(inits), **ssm_args) for init in inits] + ) + A = torch.cat(A, dim=0) # (S N) + P = torch.cat(P, dim=1) # (R S N) + B = torch.cat(B, dim=0) # (S N) + V = torch.cat(V, dim=0) # (S N N) + return A, P, B, V + + +"""SSM convolution kernels""" + +def inv_transform(param, transform='none'): + """Initialize a (positive) parameter under a transform.""" + param = torch.clamp(param, min=1e-4) + if transform == 'none': + return param + elif transform == 'exp': + return torch.log(param) # Some of the HiPPO methods have real part 0 + elif transform == 'relu': + return param + elif transform == 'sigmoid': + return torch.logit(param) + elif transform == 'softplus': + return torch.log(torch.exp(param)-1) + else: raise NotImplementedError + +def param_transform(param, transform='none'): + """Get a (positive) parameter under a transform.""" + if transform == 'none': + p = param + elif transform == 'exp': + p = torch.exp(param) + elif transform == 'relu': + # JAX version seems to NaN if you allow 0's, although this code was fine without it + p = F.relu(param)+1e-4 + elif transform == 'sigmoid': + p = F.sigmoid(param) + elif transform == 'softplus': + p = F.softplus(param) + else: raise NotImplementedError + return p + +class Kernel(nn.Module): + """Interface for modules that produce convolution kernels. + + A main distinction between these and normal Modules is that the forward pass + does not take inputs. It is a mapping from parameters to a tensor that can + be used in other modules, in particular as a convolution kernel. + + Because of the unusual parameterization, these kernels may often want special + hyperparameter settings on their parameters. The `register` method provides + an easy interface for controlling this, and is intended to be used with an + optimizer hook that can be found in train.py or example.py. + + This class also defines an interface for interacting with kernels *statefully*, + in particular for state space models (SSMs). This interface handles the setting + when a model can be converted from a "CNN" into an "RNN". + _setup_step() + step() + default_state() + forward_state() + + See ConvKernel for the simplest instantiation of this interface. + """ + + def __init__( + self, + d_model: int = 0, + channels: int = 1, + l_max: Optional[int] = None, + lr: Union[float, Optional[Mapping]] = None, + wd: Union[float, Optional[Mapping]] = 0.0, + verbose: bool = True, + **kwargs, + ): + """General interface. + + d_model (H): Model dimension, or number of independent convolution kernels created. + channels (C): Extra dimension in the returned output (see .forward()). + - One interpretation is that it expands the input dimension giving it C separate "heads" per feature. + That is convolving by this kernel maps shape (B L D) -> (B L C D) + - This is also used to implement a particular form of bidirectionality in an efficient way. + - In general for making a more powerful model, instead of increasing C + it is recommended to set channels=1 and adjust H to control parameters instead. + l_max (L): Maximum kernel length (optional). If unspecified, most Kernel instantiations + will return kernels of arbitrary length as passed into .forward(). + lr: Optional dictionary specifying special hyperparameters for .register(). + Passing in a number (e.g. 0.001) sets attributes of SSM parameters (A, B, dt). + A custom optimizer hook is needed to configure the optimizer to set the learning rates appropriately for these parameters. + wd: Same as lr, but for weight decay. + """ + super().__init__() + assert d_model > 0 + self.H = self.d_model = d_model + self.L = self.l_max = l_max + self.channels = channels + self.lr = lr + self.wd = wd + self.verbose = verbose + + # Add a catch-all **kwargs to make it easier to change kernels + # without manually moving other options passed in the config. + # Good to log these just so it's explicit. + if self.verbose and len(kwargs) > 0: + log.info(f"{type(self)} extra kwargs: {kwargs}") + + # Logic for registering parameters + # Case 1: lr: None | float + # All params should have this lr (None means inherit from global lr) + # Case 2: lr: dict + # Specified params should have that lr, all others should be None + if self.lr is None or isinstance(self.lr, float): + self.lr_dict = defaultdict(lambda: self.lr) + else: + self.lr_dict = defaultdict(lambda: None) + self.lr_dict.update(self.lr) + + # Same logic for weight decay + # (but is always just set to 0.0 and hasn't been ablated) + if self.wd is None or isinstance(self.wd, float): + self.wd_dict = defaultdict(lambda: self.wd) + else: + self.wd_dict = defaultdict(lambda: None) + self.wd_dict.update(self.wd) + + def forward(self, state=None, rate=1.0, L=None): + """General interface to generate a global convolution kernel. + + state: Initial state for recurrent updates. + E.g. for SSMs, this should have shape (B, H, N) (batch, d_model, d_state). + rate: Relative sampling rate. + L: Target kernel length. + + Returns: + - (C, H, L) (channels, d_model, l_kernel) The convolution kernel. + - (B, H, L) (batch, d_model, l_kernel) + Extra information for how the state affects the output of convolving by kernel. + """ + raise NotImplementedError + + def register(self, name, tensor, lr=None, wd=0.0): + """Register a tensor with a configurable learning rate and 0 weight decay""" + + if lr == 0.0: + self.register_buffer(name, tensor) + else: + self.register_parameter(name, nn.Parameter(tensor)) + + optim = {} + if lr is not None: optim["lr"] = lr + if wd is not None: optim["weight_decay"] = wd + setattr(getattr(self, name), "_optim", optim) + + def _setup_step(self, **kwargs): + """Convert a model into a recurrent mode for autoregressive inference.""" + raise NotImplementedError + + def step(self, x, state, **kwargs): + """Step the model for one timestep with input x and recurrent state.""" + raise NotImplementedError + + def default_state(self, *args, **kwargs): + """Return a default initial state.""" + raise NotImplementedError + + @torch.no_grad() + def forward_state(self, u, state): + """Forward the state through a sequence, i.e. computes the state after passing chunk through the kernel.""" + raise NotImplementedError + + @property + def d_state(self): + """Implement this for interfaces that want to interact with a stateful layer (i.e. SSMs). + + Currently the only codepath that might use this is the StateDecoder, which is not used. + """ + raise NotImplementedError + + @property + def state_to_tensor(self): + """Same as d_state, only needed for niche codepaths involving recurrent state.""" + raise NotImplementedError + +class SSMKernel(Kernel): + """Parent class for different SSM parameterizations. + + This class is abstract and only defines some initializations and flags that are common to all SSM variants. + It is instantiated by subclasses SSMKernel{Dense,Real,Diag,DPLR}. + + Options: + d_state (N): State size (dimensionality of parameters A, B, C). Generally shouldn't need to be adjusted and doens't affect speed much for most kernels (e.g. S4, S4D). + deterministic: Use a deterministic initialization for dt, A, B, C. + Useful for debugging as well as constructing a simple exponential decay kernel (e.g. used in S4ND image->video inflation). + + dt_min, dt_max: min and max values for the step size dt + dt_tie: Keep dt tied across the N dimensions of the state. Although this theoretically makes more sense, models such as S5 and Mega have found slightly improvements by setting it to False. + dt_transform: Transform function for parameterization of dt (default 'softplus', used to be 'exp') + + rank: Rank of low-rank correction for DPLR mode. Needs to be increased for init "legt". + n_ssm: Number of independent trainable (A, B) SSMs, e.g. + `n_ssm=1` means all A/B parameters are tied across the H different instantiations of C. + `n_ssm=None` means all H SSMs are completely independent. + Generally, changing this option can save parameters but doesn't affect performance or speed much. + This parameter must divide H. + init: Options for initialization of (A, B). For DPLR mode, recommendations are "legs", "fout", "hippo" (combination of both). For Diag mode, recommendations are "diag-inv", "diag-lin", "diag-legs", and "diag" (combination of diag-inv and diag-lin). + init_args: Extra arguments passed into initialization function (see dplr.py for options). + """ + + def init_dt(self): + # Generate dt + if self.deterministic: # Meant for debugging + assert self.dt_tie, "Deterministic dt initialization is tied" + assert self.dt_transform == 'exp', "Deterministic dt transform should be 'exp' for simplicity" + inv_dt = torch.exp(torch.linspace(math.log(self.dt_min), math.log(self.dt_max), self.H)).unsqueeze(-1) # (H 1) + else: + shape = (self.H, 1) if self.dt_tie else (self.H, self.N//2) + # Initialize log dt + inv_dt = torch.rand(*shape, dtype=self.dtype) * ( + math.log(self.dt_max) - math.log(self.dt_min) + ) + math.log(self.dt_min) + if self.dt_transform != 'exp': + inv_dt = inv_transform(torch.exp(inv_dt), self.dt_transform) + + return inv_dt + + def init_ssm_real(self): + """Returns (dense, real) (A, B, C) parameters for init options.""" + # Generate A, B + A, B = transition(self.init, self.N) + A = torch.as_tensor(A, dtype=self.dtype) + B = torch.as_tensor(B, dtype=self.dtype)[:, 0] + B = repeat(B, 'n -> v n', v=self.n_ssm).clone().contiguous() + A = repeat(A, 'n m -> v n m', v=self.n_ssm).clone().contiguous() + + # Generate C + if self.deterministic: + C = torch.zeros(self.channels, self.H, self.N, dtype=self.dtype) + C[..., :1] = 1.0 + else: + C = torch.randn(self.channels, self.H, self.N, dtype=self.dtype) + + return A, B, C + + def init_ssm_dplr(self): + """Returns DPLR (A, P, B, C) parameters for init options.""" + A, P, B, V = combination(self.init, self.N, self.rank, self.n_ssm, **self.init_args) + + # Broadcast C to have H channels + if self.deterministic: + C = torch.zeros(self.channels, self.n_ssm, self.N, dtype=self.cdtype) + C[:, :, :1] = 1. + C = contract('hmn, chn -> chm', V.conj().transpose(-1, -2), C) # V^* C + C = repeat(C, 'c t n -> c (v t) n', v=self.H // C.size(-2)).clone().contiguous() + else: + C = torch.randn(self.channels, self.H, self.N//2, dtype=self.cdtype) + + # Broadcast other parameters to have n_ssm copies + assert self.n_ssm % B.size(-2) == 0 \ + and self.n_ssm % P.size(-2) == 0 \ + and self.n_ssm % A.size(-2) == 0 + + # Broadcast tensors to n_ssm copies + # These will be the parameters, so make sure tensors are materialized and contiguous + B = repeat(B, 't n -> (v t) n', v=self.n_ssm // B.size(-2)).clone().contiguous() + P = repeat(P, 'r t n -> r (v t) n', v=self.n_ssm // P.size(-2)).clone().contiguous() + A = repeat(A, 't n -> (v t) n', v=self.n_ssm // A.size(-2)).clone().contiguous() + + # Because these complex parameterizations assume conjugate symmetry, + # halve the value of self.N for convenience + self.N //= 2 + + return A, P, B, C + + def __init__( + self, + # General Kernel arguments for parent class + d_model: int = 0, + channels: int = 1, + l_max: Optional[int] = None, + lr: Union[float, Optional[Mapping]] = None, + wd: Union[float, Optional[Mapping]] = 0.0, + verbose: bool = True, + # SSM arguments + d_state: int = 64, + deterministic: bool = False, + # dt options + dt_min: float = 0.001, + dt_max: float = 0.1, + dt_tie: bool = True, + dt_transform: str = 'exp', + # (A, B, C) options + rank: int = 1, + n_ssm: Optional[int] = None, + measure: Optional[str] = None, + init: Optional[str] = "legs", + # Extra hyperparameters for initialization + **init_args, + ): + super().__init__(d_model=d_model, channels=channels, l_max=l_max, lr=lr, wd=wd, verbose=verbose) + self.N = d_state + self.dtype, self.cdtype = torch.float, torch.cfloat + self.deterministic = deterministic + # dt options + self.dt_min = dt_min + self.dt_max = dt_max + self.dt_tie = dt_tie + self.dt_transform = dt_transform + # SSM options (A, B, C) + self.rank = rank + self.n_ssm = n_ssm if n_ssm is not None else self.H + if measure is not None: + log.warning("Warning: 'measure' option changed to 'init' and will be removed in a future version.") + assert init is None, "'measure' and 'init' cannot both be passed into SSMKernel" + init, measure = measure, init + self.init = init + self.init_args = init_args + + @torch.no_grad() + def forward_state(self, u, state): + """Forward the state through a sequence, i.e. computes the state after passing chunk through SSM + + This is a generic version of this functionality that works for SSMs. + It is currently used by SSMKernelDense and SSMKernelDPLR. + This is a suboptimal implementation; it is recommended to use SSMKernelDiag + if this functionality is desired. + + state: (B, H, N) + u: (B, H, L) + + Returns: (B, H, N) + """ + + # Construct dA, dB matrices + dA, dB = self._setup_state() # (H N N) (H N) + + conj = state.size(-1) != dA.size(-1) + if conj: state = _conj(state) + + v = contract('h n, b h l -> b h n l', dB, u.flip(-1)) + AL, v = power(u.size(-1), dA, v) + next_state = contract("h m n, b h n -> b h m", AL, state) + next_state = next_state + v + + if conj: next_state = next_state[..., : next_state.size(-1) // 2] + return next_state + + def _setup_state(self): + """Register dA and dB to module.""" + raise NotImplementedError + + @property + def d_state(self): + """d_state and state_to_tensor are used by specific decoders. + + These were used in earlier versions and should not be needed in general. + """ + return self.H * self.N + + @property + def state_to_tensor(self): + return lambda state: rearrange('... h n -> ... (h n)', state) + + +class SSMKernelDiag(SSMKernel): + """SSM kernel using diagonal state matrix (S4D model). + + Options: + disc: ['zoh' | 'bilinear' | 'dss'] Discretization options. + dt_fast: (experimental) Parameterize inv_dt under sinh function. + (Ohno et al. "Fast Saturating Gate for Learning Long Time Scales with RNNs") + real_transform, imag_transform: ['none' | 'exp' | 'relu' | 'sigmoid' | 'softplus'] + Parameterize the real/imag parts of the diagonal of A under this function. + bandlimit: Mask high frequencies of the kernel (indices corresponding to + diagonal elements with large imaginary part). Introduced in S4ND paper. + backend: ['cuda' | 'keops' | 'naive'] Options for Vandermonde/Cauchy kernel (in order of efficiency). + is_real : Real-valued SSM; can be interpreted as EMA. + """ + + def __init__( + self, + disc: str = 'zoh', # Change to 'bilinear' to match S4, but should make little difference either way + dt_fast: bool = False, + real_transform: str = 'exp', + imag_transform: str = 'none', + bandlimit: Optional[float] = None, + backend: str = 'cuda', + is_real: bool = False, + **kwargs, + ): + # Special case: for real-valued, d_state semantics change + if is_real and 'd_state' in kwargs: + kwargs['d_state'] = kwargs['d_state'] * 2 + super().__init__(**kwargs) + self.disc = disc + self.dt_fast = dt_fast + self.real_transform = real_transform + self.imag_transform = imag_transform + self.bandlimit = bandlimit + self.backend = backend + self.is_real = is_real + + # Initialize dt, A, B, C + inv_dt = self.init_dt() + A, P, B, C = self.init_ssm_dplr() + # Note that in the Diag case, P will be ignored + # The DPLR case subclasses this and uses P + self.register_params(A, B, C, inv_dt, P) + + def register_params(self, A, B, C, inv_dt, P): + """Process the initialization into form of trainable parameters. + + A: (S, N) diagonal matrix + B: (S, N) + C: (C, H, N) + dt: (H) timescale per feature + + Dimensions: + N (or d_state): state size + H (or d_model): total SSM copies + S (or n_ssm): number of trainable copies of (A, B, dt); must divide H + C (or channels): system is 1-dim to C-dim + + The forward pass of this Module returns a tensor of shape (C, H, L) + + Note: tensor shape N here denotes half the true state size, because of conjugate symmetry + """ + + assert self.backend in ['cuda', 'keops', 'naive'] + + if self.dt_fast: inv_dt = torch.asinh(inv_dt) + + # Rank of low-rank correction + assert self.H == inv_dt.size(0) + assert self.N == A.size(-1) == B.size(-1) == C.size(-1) + assert self.n_ssm == A.size(-2) == B.size(-2) # Number of independent SSMs trained + self.repeat = self.H // A.size(0) + + # Check that diagonal part has negative real and imag part + # (allow some tolerance for numerical precision on real part + # since it may be constructed by a diagonalization) + assert torch.all(A.real < 1e-4) and torch.all(A.imag <= 0.0) + + # Broadcast everything to correct shapes + C = C.expand(torch.broadcast_shapes(C.shape, (1, self.H, self.N))) # (C, H, N) # TODO originally this was only in DPLR, check safe for Diag + B = B.unsqueeze(0) # (1, H, N) + assert self.channels == C.shape[0] + + # Register dt + self.register("inv_dt", inv_dt, self.lr_dict['dt'], self.wd_dict['dt']) + # Register ABC + if self.is_real: + self.register("C", C.real, self.lr_dict['C'], None) + self.register("B", B.real, self.lr_dict['B'], self.wd_dict['B']) + self.register("A_real", inv_transform(-A.real, self.real_transform), self.lr_dict['A'], self.wd_dict['A']) + else: + self.register("C", _c2r(_resolve_conj(C)), self.lr_dict['C'], None) + self.register("B", _c2r(B), self.lr_dict['B'], self.wd_dict['B']) + self.register("A_real", inv_transform(-A.real, self.real_transform), self.lr_dict['A'], self.wd_dict['A']) + self.register("A_imag", inv_transform(-A.imag, self.imag_transform), self.lr_dict['A'], self.wd_dict['A']) + + def _get_params(self, rate=1.0): + """Process the internal parameters.""" + + # (S N) where S=n_ssm + if self.is_real: + A = -param_transform(self.A_real, self.real_transform) + B = self.B # (1 S N) + C = self.C # (C H N) + else: + A = -param_transform(self.A_real, self.real_transform) - 1j * param_transform(self.A_imag, self.imag_transform) + B = _r2c(self.B) # (1 S N) + C = _r2c(self.C) # (C H N) + + if self.dt_fast: inv_dt = torch.sinh(self.inv_dt) + else: inv_dt = self.inv_dt + dt = param_transform(inv_dt, self.dt_transform) * rate # (H N) + + if self.bandlimit is not None: + freqs = dt / rate * A.imag.abs() / (2*math.pi) # (H N) + mask = torch.where(freqs < self.bandlimit * .5, 1, 0) + C = C * mask + + # Incorporate dt into A and B + A = repeat(A, 't n -> (v t) n', v=self.repeat) # (H N) + B = repeat(B, 'b t n -> b (v t) n', v=self.repeat) # (1 H N) + + # TODO: The downstream algorithm should only need to access dt*A + # However the current DPLR kernel still uses dt and A separately + # Once that is fixed, this should return dtA instead of dt and A + dtA = dt * A # (H N) + + return dt, A, B, C + + def forward(self, L, state=None, rate=1.0): + """See Kernel.forward() for argument documentation.""" + + dt, A, B, C = self._get_params(rate) + dtA = dt * A + + # Augment B with state + if state is not None: + s = state / dt + if self.disc == 'bilinear': + s = s * (1. + dtA/2) + elif self.disc == 'zoh': + s = s * dtA * dtA.exp() / (dtA.exp() - 1.) + B = torch.cat([s, B], dim=-3) # (1+B H N) + + + # Combine B and C + C = (B[:, None, :, :] * C).view(-1, self.H, self.N) + + # Dispatch which Vandermonde kernel to use + if has_cuda_extension and C.dtype == torch.cfloat and C.device.type == 'cuda' and self.backend == 'cuda': + log_vandermonde = log_vandermonde_cuda + elif has_pykeops and self.backend in ['cuda', 'keops']: + log_vandermonde = log_vandermonde_keops + else: + log_vandermonde = log_vandermonde_naive + + # Main kernel + if self.disc == 'zoh': + # Power up + C = C * (torch.exp(dtA)-1.) / A + K = log_vandermonde(C, dtA, L) # (H L) + elif self.disc == 'bilinear': + C = C * (1. - dtA/2).reciprocal() * dt # or * dtA / A + dA = (1. + dtA/2) / (1. - dtA/2) + K = log_vandermonde(C, dA.log(), L) + elif self.disc == 'dss': + # Implementation from DSS meant for case when real eigenvalues can be positive + P = dtA.unsqueeze(-1) * torch.arange(L, device=C.device) # [H N L] + A_gt_0 = A.real > 0 # [N] + if A_gt_0.any(): + with torch.no_grad(): + P_max = dtA * (A_gt_0 * (L-1)) # [H N] + P = P - P_max.unsqueeze(-1) # [H N L] + S = P.exp() # [H N L] + + dtA_neg = dtA * (1 - 2*A_gt_0) # [H N] + num = dtA_neg.exp() - 1 # [H N] + den = (dtA_neg * L).exp() - 1 # [H N] + + # Inline reciprocal function for DSS logic + x = den * A + x_conj = _resolve_conj(x) + r = x_conj / (x*x_conj + 1e-7) + + C = C * num * r # [C H N] + K = contract('chn,hnl->chl', C, S).float() + else: raise ValueError(f"Discretization {self.disc} not supported") + + K = K.view(-1, self.channels, self.H, L) # (1+B C H L) + + if state is not None: + K_state = K[:-1, :, :, :] # (B C H L) + else: + K_state = None + K = K[-1, :, :, :] # (C H L) + + return K, K_state + + def _setup_step(self): + """Set up dA, dB, dC discretized parameters for stepping.""" + + dt, A, B, C, = self._get_params() + # Incorporate dt into A + dtA = dt * A # (H N) + + if self.disc == 'zoh': + self.dA = torch.exp(dtA) # (H N) + self.dB = B * (torch.exp(dtA)-1.) / A # (C H N) + elif self.disc == 'bilinear': + self.dA = (1. + dtA/2) / (1. - dtA/2) + self.dB = B * (1. - dtA/2).reciprocal() * dt # or * dtA / A + self.dB = rearrange(self.dB, '1 h n -> h n') + self.dC = C + + def default_state(self, *batch_shape): + C = _r2c(self.C) + state = torch.zeros(*batch_shape, self.H, self.N, dtype=C.dtype, device=C.device) + return state + + def step(self, u, state): + next_state = contract("h n, b h n -> b h n", self.dA, state) \ + + contract("h n, b h -> b h n", self.dB, u) + y = contract("c h n, b h n -> b c h", self.dC, next_state) + return 2*y.real, next_state + + def forward_state(self, u, state): + """Pass the state forward through an entire sequence.""" + self._setup_step() + AL = self.dA ** u.size(-1) + u = u.flip(-1).to(self.dA).contiguous() # (B H L) + # Dispatch which Vandermonde kernel to use + if has_pykeops and self.backend in ['cuda', 'keops']: + log_vandermonde_transpose = log_vandermonde_transpose_keops + else: + log_vandermonde_transpose = log_vandermonde_transpose_naive + v = log_vandermonde_transpose(u, self.dB, self.dA.log(), u.size(-1)) + next_state = AL * state + v + return next_state + + +class SSMKernelDPLR(SSMKernelDiag): + """SSM kernel for diagonal + low rank (DPLR) state matrices, corresponding to the original S4 model.""" + + @torch.no_grad() + def _setup_C(self, L): + """Construct C~ from C. + + Two modes are supported: go directly to length L if self.l_kernel is 1, or length is doubled + """ + + if self.l_kernel.item() == 0: + if self.verbose: log.info(f"S4: Initializing kernel to length {L}") + double_length = False + elif L > self.l_kernel.item(): # 2*int(self.l_kernel) == L: + if self.verbose: log.info(f"S4: Doubling length from L = {self.l_kernel.item()} to {2*self.l_kernel.item()}") + double_length = True + L = self.l_kernel.item() # Convenience for the math below + else: return + + C = _r2c(self.C) + dA, _ = self._setup_state() + dA_L = power(L, dA) + # Multiply C by I - dA_L + C_ = _conj(C) + prod = contract("h m n, c h n -> c h m", dA_L.transpose(-1, -2), C_) + if double_length: prod = -prod # Multiply by I + dA_L instead + C_ = C_ - prod + C_ = C_[..., :self.N] # Take conjugate pairs again + self.C.copy_(_c2r(C_)) + + self.l_kernel = 2*self.l_kernel if double_length else self.l_kernel+L # Preserve type/device + + def _omega(self, L, dtype, device, cache=True): + """Calculate (and cache) FFT nodes. + + This also caches a version of the nodes "unprocessed" with the bilinear transform. + This method should be called everytime the internal length self.l_kernel changes. + """ + + # Use cached if available + if cache and hasattr(self, 'omega') and self.omega.size(-1) == L//2+1: + return self.omega, self.z + + omega = torch.tensor( + np.exp(-2j * np.pi / (L)), dtype=dtype, device=device + ) # \omega_{2L} + omega = omega ** torch.arange(0, L // 2 + 1, device=device) + z = 2 * (1 - omega) / (1 + omega) + + # Cache if necessary + if cache: + self.omega = omega + self.z = z + return omega, z + + + def register_params(self, A, B, C, inv_dt, P): + """Process the initialization into form of trainable parameters. + + The SSM state matrix is represented by diag_embed(A) - PP^* + Note that the A notation here is slightly overloaded: + normally A refers to the full SSM state matrix (DPLR in this case) + but here we're using it to refer to the diagonal part of the matrix. + This is to make variable names compatible with the SSMKernelDiag class (DSS/S4D) + and is a much simpler variable name (e.g. as opposed to Lambda). + + A: (S, N) diagonal part + P: (R, S, N) low-rank part + B: (S, N) + C: (C, H, N) + dt: (H) timescale per feature + + Dimensions: + N (or d_state): state size + H (or d_model): total SSM copies + S (or n_ssm): number of trainable copies of (A, B, dt); must divide H + R (or rank): rank of low-rank part + C (or channels): system is 1-dim to C-dim + + The forward pass of this Module returns a tensor of shape (C, H, L) + + Note: tensor shape N here denotes half the true state size, because of conjugate symmetry + """ + + # Print out kernel lengths; it can be tricky to make sure the length logic is correct + if self.verbose: + log.info(f"Constructing S4 (H, N, L) = ({self.H}, {self.N}, {self.l_max})") + + # Register the basic params for diagonal SSM (A, B, C, dt) + super().register_params(A, B, C, inv_dt, P) + + # Check shapes + assert self.rank == P.shape[-3] + assert self.N == P.size(-1) + assert self.n_ssm == P.size(-2) + + self.register('P', _c2r(P), self.lr_dict['A'], self.wd_dict['A']) + + # Track the current kernel length this is "attuned" to + self.register_buffer('l_kernel', torch.tensor(0)) + + def _get_params(self, rate=1.0): + dt, A, B, C = super()._get_params(rate=rate) + P = _r2c(self.P) # (R S N) + P = repeat(P, 'r t n -> r (v t) n', v=self.repeat) # (R H N) + Q = P.conj() + + return dt, A, B, C, P, Q + + def forward(self, state=None, rate=1.0, L=None): + """See Kernel.forward() for argument documentation.""" + + # Initialize C~ if necessary (done in forward pass so it's on the correct device) + if self.l_kernel.item() == 0 and self.l_max is not None and self.l_max > 0: + self._setup_C(self.l_max) + + # Handle sampling rate logic + # The idea is that this kernel's length (in continuous units) is self.l_kernel, while we are asked to provide a kernel of length L at (relative) frequency rate + if L is None: + L = round(self.l_kernel.item() / rate) + + # Increase the internal length if needed + continuous_L = round(rate*L) + while continuous_L > self.l_kernel.item(): + self._setup_C(continuous_L) + discrete_L = round(self.l_kernel.item()/rate) + + dt, A, B, C, P, Q = self._get_params(rate) + + # Get FFT nodes of right length + omega, z = self._omega(discrete_L, dtype=A.dtype, device=A.device, cache=(rate==1.0)) + + # Augment B + if state is not None: + # Have to "unbilinear" the state to put it into the same "type" as B + # Compute 1/dt * (I + dt/2 A) @ state + + # Can do this without expanding (maybe minor speedup using conj symmetry in theory), but it's easier to read this way + s = _conj(state) if state.size(-1) == self.N else state # (B H N) + sA = ( + s * _conj(A) # (B H N) + - contract('bhm, rhm, rhn -> bhn', s, _conj(Q), _conj(P)) + ) + s = s / dt + sA / 2 + s = s[..., :self.N] + + B = torch.cat([s, B], dim=-3) # (B+1, H, N) + + # Incorporate dt into A + A = A * dt # (H N) + + # Stack B and p, C and q for convenient batching + B = torch.cat([B, P], dim=-3) # (B+1+R, H, N) + C = torch.cat([C, Q], dim=-3) # (C+R, H, N) + + # Incorporate B and C batch dimensions + v = B.unsqueeze(-3) * C.unsqueeze(-4) # (B+1+R, C+R, H, N) + v = v * dt # Incorporate dt into B + + # Dispatch which Cauchy kernel to use + if has_cuda_extension and z.dtype == torch.cfloat and z.device.type == 'cuda' and self.kernel == 'cuda': + cauchy_mult = cauchy_cuda + elif has_pykeops and self.kernel in ['cuda', 'keops']: + cauchy_mult = cauchy_keops + else: + cauchy_mult = cauchy_naive + # Calculate resolvent at omega + r = cauchy_mult(v, z, A) + + # Low-rank Woodbury correction + if self.rank == 1: + k_f = r[:-1, :-1, :, :] - r[:-1, -1:, :, :] * r[-1:, :-1, :, :] / (1 + r[-1:, -1:, :, :]) + elif self.rank == 2: + r00 = r[: -self.rank, : -self.rank, :, :] + r01 = r[: -self.rank, -self.rank :, :, :] + r10 = r[-self.rank :, : -self.rank, :, :] + r11 = r[-self.rank :, -self.rank :, :, :] + det = (1 + r11[:1, :1, :, :]) * (1 + r11[1:, 1:, :, :]) - r11[:1, 1:, :, :] * r11[1:, :1, :, :] + s = ( + r01[:, :1, :, :] * (1 + r11[1:, 1:, :, :]) * r10[:1, :, :, :] + + r01[:, 1:, :, :] * (1 + r11[:1, :1, :, :]) * r10[1:, :, :, :] + - r01[:, :1, :, :] * (r11[:1, 1:, :, :]) * r10[1:, :, :, :] + - r01[:, 1:, :, :] * (r11[1:, :1, :, :]) * r10[:1, :, :, :] + ) + s = s / det + k_f = r00 - s + else: + r00 = r[:-self.rank, :-self.rank, :, :] + r01 = r[:-self.rank, -self.rank:, :, :] + r10 = r[-self.rank:, :-self.rank, :, :] + r11 = r[-self.rank:, -self.rank:, :, :] + r11 = rearrange(r11, "a b h n -> h n a b") + r11 = torch.linalg.inv(torch.eye(self.rank, device=r.device) + r11) + r11 = rearrange(r11, "h n a b -> a b h n") + k_f = r00 - torch.einsum("i j h n, j k h n, k l h n -> i l h n", r01, r11, r10) + + # Final correction for the bilinear transform + k_f = k_f * 2 / (1 + omega) + + # Move from frequency to coefficients + k = torch.fft.irfft(k_f, n=discrete_L) # (B+1, C, H, L) + + # # Truncate to target length + k = k[..., :L] + + if state is not None: + k_state = k[:-1, :, :, :] # (B, C, H, L) + else: + k_state = None + k_B = k[-1, :, :, :] # (C H L) + + return k_B, k_state + + @torch.no_grad() + def double_length(self): + self._setup_C(2*self.l_kernel) + + @torch.no_grad() + def _check(self): + """Check if A, B, C parameters and vanilla SSMKernel construction can be recovered""" + + # assert self.l_kernel > 0, "Set up module first" + + K = self.forward(L=self.l_max)[0] + + self._setup_step() + K_ = krylov(self.l_max, self.dA, self.dB, self.dC) + + diff = K - K_ + print("checking DPLR Kernel construction", torch.sum(diff ** 2)) + + @torch.no_grad() + def _setup_linear(self): + """Preprocessing that allows fast linear-time (in state dimension) stepping.""" + dt, A, B, C, P, Q = self._get_params() + + # Prepare Linear stepping + D = (2.0 / dt - A).reciprocal() # (H, N) + R = (torch.eye(self.rank, dtype=A.dtype, device=A.device) + 2*contract('r h n, h n, s h n -> h r s', Q, D, P).real) # (H R R) + Q_D = rearrange(Q*D, 'r h n -> h r n') + try: + R = torch.linalg.solve(R, Q_D) # (H R N) + except: + R = torch.tensor(np.linalg.solve(R.to(Q_D).contiguous().detach().cpu(), Q_D.contiguous().detach().cpu())).to(Q_D) + R = rearrange(R, 'h r n -> r h n') + + self.step_params = { + "D": D, # (H N) + "R": R, # (R H N) + "P": P, # (R H N) + "Q": Q, # (R H N) + "B": B, # (1 H N) + "E": 2.0 / dt + A, # (H N) + } + + def _step_state_linear(self, u=None, state=None): + """ + Version of the step function that has time O(N) instead of O(N^2) per step, which takes advantage of the DPLR form and bilinear discretization. + + Unfortunately, as currently implemented it's about 2x slower because it calls several sequential operations. + Perhaps a fused CUDA kernel implementation would be much faster. + + u: (H) Input + state: (H, N/2) State with conjugate pairs. Optionally, the state can have last dimension N. + + Returns: same shape as state + """ + C = _r2c(self.C) # View used for dtype/device + + if u is None: # Special case used to find dA + u = torch.zeros(self.H, dtype=C.dtype, device=C.device) + if state is None: # Special case used to find dB + state = torch.zeros(self.H, self.N, dtype=C.dtype, device=C.device) + + step_params = self.step_params.copy() + if state.size(-1) == self.N: # Only store half of the conjugate pairs; should be true by default + # There should be a slightly faster way using conjugate symmetry + contract_fn = lambda p, x, y: contract('r h n, r h m, ... h m -> ... h n', _conj(p), _conj(x), _conj(y))[..., :self.N] # inner outer product + else: + assert state.size(-1) == 2*self.N + step_params = {k: _conj(v) for k, v in step_params.items()} + contract_fn = lambda p, x, y: contract('r h n, r h m, ... h m -> ... h n', p, x, y) # inner outer product + D = step_params["D"] # (H N) + E = step_params["E"] # (H N) + R = step_params["R"] # (R H N) + P = step_params["P"] # (R H N) + Q = step_params["Q"] # (R H N) + B = step_params["B"] # (1 H N) + + new_state = E * state - contract_fn(P, Q, state) # (B H N) + new_state = new_state + 2.0 * B * u.unsqueeze(-1) # (B H N) + new_state = D * (new_state - contract_fn(P, R, new_state)) + + return new_state + + def _setup_state(self): + """Construct dA and dB for discretized state equation.""" + + # Construct dA and dB by using the stepping + self._setup_linear() + C = _r2c(self.C) # Just returns a view that we use for finding dtype/device + + state = torch.eye(2*self.N, dtype=C.dtype, device=C.device).unsqueeze(-2) # (N 1 N) + dA = self._step_state_linear(state=state) + dA = rearrange(dA, "n h m -> h m n") + + u = C.new_ones(self.H) + dB = self._step_state_linear(u=u) + dB = _conj(dB) + dB = rearrange(dB, '1 h n -> h n') # (H N) + return dA, dB + + def _step_state(self, u, state): + """Must be called after self.default_state() is used to construct an initial state!""" + next_state = (torch.einsum(self.state_contraction, self.dA, state) + + torch.einsum(self.input_contraction, self.dB, u)) + return next_state + + def _setup_step(self, mode='dense'): + """Set up dA, dB, dC discretized parameters for stepping.""" + self.dA, self.dB = self._setup_state() + + # Calculate original C + C = _conj(_r2c(self.C)) # (H C N) + if self.l_kernel.item() == 0: + dC = C + else: + # self.C represents C_tilde + dA_L = power(self.l_kernel.item(), self.dA) + I = torch.eye(self.dA.size(-1)).to(dA_L) + + dC = torch.linalg.solve( + I - dA_L.transpose(-1, -2), + C.unsqueeze(-1), + ).squeeze(-1) + self.dC = dC + + # Do special preprocessing for different step modes + + self._step_mode = mode + if mode == 'linear': + # Linear case: special step function for the state, we need to handle output + # use conjugate symmetry by default, which affects the output projection + self.dC = 2*self.dC[:, :, :self.N] + elif mode == 'diagonal': + # Eigendecomposition of the A matrix + L, V = torch.linalg.eig(self.dA) + V_inv = torch.linalg.inv(V) + # Check that the eigendedecomposition is correct + if self.verbose: + print("Diagonalization error:", torch.dist(V @ torch.diag_embed(L) @ V_inv, self.dA)) + + # Change the parameterization to diagonalize + self.dA = L + self.dB = contract('h n m, h m -> h n', V_inv, self.dB) + self.dC = contract('h n m, c h n -> c h m', V, self.dC) + + elif mode == 'dense': + pass + else: raise NotImplementedError("DPLR Kernel step mode must be {'dense' | 'linear' | 'diagonal'}") + + def default_state(self, *batch_shape): + C = _r2c(self.C) + N = C.size(-1) + H = C.size(-2) + + # Cache the tensor contractions we will later do, for efficiency + # These are put in this function because they depend on the batch size + step_mode = getattr(self, "_step_mode", "dense") # Used in default_state, which is called without _setup_step() in forward_state() + if step_mode != 'linear': + N *= 2 + + if step_mode == 'diagonal': + self.state_contraction = "h n, ... h n -> ... h n" + else: + # Dense (quadratic) case: expand all terms + self.state_contraction = "h m n, ... h n -> ... h m" + + self.input_contraction = "h n, ... h -> ... h n" + + self.output_contraction = "c h n, ... h n -> ... c h" + + state = torch.zeros(*batch_shape, H, N, dtype=C.dtype, device=C.device) + return state + + def step(self, u, state): + """Must have called self._setup_step() and created state with self.default_state() before calling this.""" + + if self._step_mode == 'linear': + new_state = self._step_state_linear(u, state) + else: + new_state = self._step_state(u, state) + y = torch.einsum(self.output_contraction, self.dC, new_state) + return y.real, new_state + + def forward_state(self, *args, **kwargs): + # Dispatch directly to generic state forwarding + # instead of using the Diag version + + # TODO design pattern is ugly. Can be fixed with an intermediate + # subclass above Diag/DPLR that has the shared logic (parameter construction) + # but not the state/step logic. + # Fine to keep like this for now since we want Diag to be the standard + # instead of having too many layers of subclassing. + + return SSMKernel.forward_state(self, *args, **kwargs) + +kernel_registry = { + 's4d': SSMKernelDiag, + 'diag': SSMKernelDiag, + 's4': SSMKernelDPLR, + 'nplr': SSMKernelDPLR, + 'dplr': SSMKernelDPLR, +} + +class FFTConv(nn.Module): + """Implements an FFT Convolution around a convolution kernel. + + d_model (H): Model dimension (in CNN terminology, this would be "channels"). + l_max (L): The maximum kernel length. Set l_max=None to always use a global kernel. + channels: Can be interpreted as a number of "heads"; the SSM is a map from a 1-dim to C-dim sequence. It's not recommended to change this; instead, increase d_model for larger models. + bidirectional: If True, convolution kernel will be two-sided. + activation: Activation after the full convolution. + transposed, dropout, tie_dropout: More general model options, see SequenceModule. + mode: Which kernel algorithm to use. 'nplr' is the full S4 model; 'diag' is the simpler S4D. Other options can be found in the kernel registry. + + kernel_args: See the class .kernel.SSMKernel for the kernel constructor which accepts kernel_args. Relevant options that are worth considering and tuning include "mode", "init", "dt_min", "dt_max", "lr" + """ + + def __init__( + self, + d_model, + l_max=None, + channels=1, + swap_channels=False, + bidirectional=False, + activation='gelu', # Activation after layer + transposed=True, + dropout=0.0, + tie_dropout=False, + drop_kernel=0.0, + mode='dplr', + kernel=None, + **kernel_args, # Arguments passed into inner convolution kernel + ): + super().__init__() + self.d_model = d_model + self.L = self.l_max = l_max + self.bidirectional = bidirectional + self.channels = channels + self.transposed = transposed + self.swap_channels = swap_channels + + + if activation is not None and activation.startswith('glu'): + channels *= 2 + self.activation = Activation(activation, dim=1 if self.transposed else -1) + + self.D = nn.Parameter(torch.randn(channels, self.d_model)) + + if self.bidirectional: + channels *= 2 + + # Inner convolution kernel + if mode is not None: + assert kernel is None, "Pass either mode or kernel but not both" + # log.info( + # "Argument 'mode' is deprecated and renamed to 'kernel'," + # "and will be removed in a future version." + # ) + kernel, mode = mode, kernel + kernel_cls = kernel_registry[kernel] + self.kernel = kernel_cls( + d_model=self.d_model, + l_max=self.l_max, + channels=channels, + **kernel_args, + ) + + dropout_fn = DropoutNd if tie_dropout else nn.Dropout + self.drop = dropout_fn(dropout) if dropout > 0.0 else nn.Identity() + self.drop_kernel = nn.Dropout(drop_kernel) if drop_kernel > 0.0 else nn.Identity() + + def forward(self, x, state=None, rate=1.0, **kwargs): # absorbs return_output and transformer src mask + """ + x: (B D L) if self.transposed else (B L D) + """ + + # Always work with (B D L) dimension in this module + if not self.transposed: x = x.transpose(-1, -2) + L = x.size(-1) + + # Compute SS Kernel + l_kernel = L if self.L is None else min(L, round(self.L / rate)) + k, k_state = self.kernel(L=l_kernel, rate=rate, state=state) # (C H L) (B C H L) + + # Convolution + if self.bidirectional: + k0, k1 = rearrange(k, '(s c) h l -> s c h l', s=2) + k = F.pad(k0, (0, L)) \ + + F.pad(k1.flip(-1), (L, 0)) + # The above has an off-by-one in the reverse direction + # This is a deliberate choice since the off-by-one should not affect any applications + # This can be amended which may be very slightly slower + # k = F.pad(k0, (0, L)) \ + # + F.pad(k1[..., 1:].flip(-1), (L+1, 0)) \ + # + F.pad(k1[..., :1], (0, l_kernel+L-1)) + + # Kernel dropout + k = self.drop_kernel(k) + + # In principle, we could pad to l_kernel+L-1 instead of l_kernel+L, but we choose the latter for + # equational simplicity. Additionally, we have not experimented to compare the efficiency of the two. + k_f = torch.fft.rfft(k, n=l_kernel+L) # (C H L) + x_f = torch.fft.rfft(x, n=l_kernel+L) # (B H L) + y_f = contract('bhl,chl->bchl', x_f, k_f) + y = torch.fft.irfft(y_f, n=l_kernel+L)[..., :L] # (B C H L) + + + # Compute D term in state space equation - essentially a skip connection + y = y + contract('bhl,ch->bchl', x, self.D) + + # Compute state update + if state is not None: + assert not self.bidirectional, "Bidirectional not supported with state forwarding" + y = y + k_state # + next_state = self.kernel.forward_state(x, state) + else: + next_state = None + + + # Reshape to flatten channels + if self.swap_channels: + y = rearrange(y, 'b c h l -> b (h c) l') + else: + y = rearrange(y, 'b c h l -> b (c h) l') + + y = self.drop(y) # DropoutNd better with transposed=True + + if not self.transposed: y = y.transpose(-1, -2) + y = self.activation(y) + + return y, next_state + + + def setup_step(self, **kwargs): + self.kernel._setup_step(**kwargs) + + def step(self, x, state): + """ Step one time step as a recurrent model. Intended to be used during validation. + + x: (B H) + state: (B H N) + Returns: output (B H), state (B H N) + """ + + y, next_state = self.kernel.step(x, state) # (B C H) + y = y + x.unsqueeze(-2) * self.D + y = rearrange(y, 'b c h -> b (c h)') + y = self.activation(y) + return y, next_state + + def default_state(self, *batch_shape, device=None): + # kernel is not a SequenceModule so it doesn't need to adhere to same interface + # the kernel will know the device of its own parameters + return self.kernel.default_state(*batch_shape) + + @property + def d_output(self): + return self.d_model * self.channels + + +class S4Block(nn.Module): + """General block design wrapping an inner layer. Currently only layer=FFTConv is supported, but easy to incorporate others. + + Arguments: + - bottleneck: Reduce dimension of inner layer (e.g. used in GSS). + - gate: Add multiplicative gating (e.g. used in GSS), which is essentially a multiplicative instead of additive residual branch. + - gate_act: Activation function to apply on the gate residual branch. + - mult_act: Activation function to apply after gate multiplication (e.g. GELU in GSS). + - final_act: Activation function to apply after final linear layer. 'id' for no activation, None for no linear layer at all. + + - initializer: Initializer on final linear layer. + - weight_norm: Weight normalization on final linear layer. + - dropout: standard dropout argument. tie_dropout=True ties the dropout mask across the sequence length, emulating nn.Dropout1d + + - transposed: Choose backbone axis ordering of (B, L, H) (if False) or (B, H, L) (if True) [B=batch size, L=sequence length, H=model dimension] + + Other options are all experimental and should not need to be configured. + """ + + def __init__( + self, + d_model, + bottleneck=None, + gate=None, + gate_act=None, + mult_act=None, + final_act='glu', + postact=None, + initializer=None, + weight_norm=False, + dropout=0.0, + tie_dropout=False, + transposed=True, + **layer_args, # Arguments into inner layer (e.g. FFTConv) + ): + super().__init__() + + self.d_model = d_model + self.transposed = transposed + + self.gate = gate + self.bottleneck = bottleneck + + if bottleneck is not None: + self.d_model = self.d_model // bottleneck + self.input_linear = LinearActivation( + self.d_model, + self.d_model, + transposed=False, + activation=None, + activate=False, + ) + + if gate is not None: + self.input_gate = LinearActivation( + self.d_model, + self.d_model * gate, + transposed=False, + activation=gate_act, + activate=True, + ) + if self.layer.d_output != self.d_model * gate: + self.output_gate = LinearActivation( + self.d_model*self.channels, + self.d_model * gate, + transposed=False, + activation=None, + activate=False, + ) + + # Currently this module only uses FFTConv for its inner module + # But the options here are all agnostic to the inner block + # If other types of inner layers are desired, it is easy + # to add an option to swap a different module in + self.layer = FFTConv(d_model, transposed=False, dropout=dropout, tie_dropout=tie_dropout, **layer_args) + + # Pointwise operations + + # Activation after (optional) multiplication by gate branch + self.mult_activation = Activation(mult_act) + # dropout_fn = nn.Dropout2d if self.transposed else nn.Dropout # Broken in torch==1.11 + dropout_fn = partial(DropoutNd, transposed=False) if tie_dropout else nn.Dropout + self.drop = dropout_fn(dropout) if dropout > 0.0 else nn.Identity() + + # position-wise output transform to mix features + if postact is not None: + assert final_act is None + log.warning("Warning: 'postact' option changed to 'final_act' and will be removed in a future version.") + final_act, postact = postact, final_act + if final_act is None: + self.output_linear = nn.Identity() + else: + self.output_linear = LinearActivation( + self.d_model*gate if gate is not None else self.layer.d_output, + self.d_model, + transposed=False, + activation=final_act, + activate=True, + ) + + + + def forward(self, x, lengths=None, **kwargs): # absorbs return_output and transformer src mask + """ + x: (B H L) if self.transposed else (B L H) + state: (H N) never needed unless you know what you're doing + + Returns: same shape as x + """ + if self.transposed: x = rearrange(x, 'b d ... -> b ... d') + L = x.size(1) + + # Mask out padding tokens + # TODO handle option for mask - instead of lengths, which assumes suffix padding + if isinstance(lengths, int): + if lengths != L: + lengths = torch.tensor(lengths, dtype=torch.long, device=x.device) + else: + lengths = None + if lengths is not None: + assert isinstance(lengths, torch.Tensor) and lengths.ndim == 1 and lengths.size(0) in [1, x.size(0)] + mask = torch.where(torch.arange(L, device=lengths.device)[:, None] < lengths[:, None, None], 1., 0.) + x = x * mask + + if self.gate is not None: + v = self.input_gate(x) + if self.bottleneck is not None: + x = self.input_linear(x) + + y, state = self.layer(x, **kwargs) + + + if self.gate is not None: + y = self.output_gate(y) + y = y * v + y = self.mult_activation(y) + y = self.drop(y) + y = self.output_linear(y) + + if self.transposed: y = rearrange(y, 'b d ... -> b ... d') + + return y, state + + def setup_step(self, **kwargs): + self.layer.setup_step(**kwargs) + + def step(self, x, state): + """Step one time step as a recurrent model. Intended to be used during validation. + + x: (B H) + state: (B H N) + Returns: output (B H), state (B H N) + """ + + if self.gate is not None: + v = self.input_gate(x) + if self.bottleneck is not None: + x = self.input_linear(x) + y, next_state = self.layer.step(x, state) # (B C H) + if self.gate is not None: + y = self.output_gate(y) + y = y * v + y = self.mult_activation(y) + y = self.drop(y) + y = self.output_linear(y) + return y, next_state + + def default_state(self, *batch_shape, device=None): + # kernel is not a SequenceModule so it doesn't need to adhere to same interface + # the kernel will know the device of its own parameters + return self.layer.default_state(*batch_shape) + + @property + def d_output(self): + return self.d_model \ No newline at end of file diff --git a/models/SwanDNA.py b/models/SwanDNA.py new file mode 100644 index 0000000..589c737 --- /dev/null +++ b/models/SwanDNA.py @@ -0,0 +1,689 @@ +import math +import torch +import numpy as np +import torch.nn as nn +from flash_pytorch import FLASH, FLASHTransformer +from torch import Tensor + +class GEGLU(nn.Module): + """ + References: + Shazeer et al., "GLU Variants Improve Transformer," 2020. + https://arxiv.org/abs/2002.05202 + """ + + def geglu(self, x: Tensor) -> Tensor: + assert x.shape[-1] % 2 == 0 + a, b = x.chunk(2, dim=-1) + return a * F.gelu(b) + + def forward(self, x: Tensor) -> Tensor: + + return self.geglu(x) + +class Mlp(nn.Module): + """ + A simple Multi-Layer Perceptron (MLP) module in PyTorch. + + Args: + in_features (int): Number of input features. + hidden_features (int, optional): Hidden layer size. Defaults to in_features. + out_features (int, optional): Output size. Defaults to in_features. + act_layer (nn.Module, optional): Activation function. Defaults to nn.GELU. + drop (float, optional): Dropout probability. Defaults to 0. + """ + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class BatchNorm(nn.Module): + """ + A PyTorch module implementing 1D Batch Normalization for token embeddings. + + Args: + embedding_size (int): The size of the token embeddings. + """ + def __init__(self, embedding_size): + super().__init__() + self.bn = nn.BatchNorm1d(embedding_size) + + def forward(self, x): + x = torch.permute(x, (0, 2, 1)) + x = self.bn(x) + x = torch.permute(x, (0, 2, 1)) + return x + + +class GroupNorm(nn.Module): + """ + A PyTorch module implementing Group Normalization for token embeddings. + + Args: + embedding_size (int): The size of the token embeddings. + n_groups (int): The number of groups to divide the channels into. + """ + def __init__(self, embedding_size, n_groups): + super().__init__() + self.gn = nn.GroupNorm(n_groups, embedding_size) + + def forward(self, x): + x = torch.permute(x, (0, 2, 1)) + x = self.gn(x) + x = torch.permute(x, (0, 2, 1)) + return x + + +def map_norm(norm_type, embedding_size, group_size=None): + """ + Maps the given normalization type to the corresponding PyTorch module. + + Args: + norm_type (str): The normalization type ('LN', 'BN', 'GN', or None). + embedding_size (int): The size of the token embeddings. + group_size (int, optional): The number of groups for Group Normalization. + + Returns: + nn.Module: The corresponding normalization module. + """ + if norm_type == 'LN': + norm = nn.LayerNorm(embedding_size) + elif norm_type == 'BN': + norm = BatchNorm(embedding_size) + elif norm_type == 'GN': + norm = GroupNorm(embedding_size, group_size) + elif norm_type == 'None': + norm = nn.Identity() + return norm + + +class CircularShift(nn.Module): + """ + A PyTorch module that performs a parameter-free shift of groups within token embeddings. + + This module can be used to augment or modify the input data in a data-driven manner. The shift is + performed jointly for all sequences in a batch and is based on powers of 2. + + Args: + group_size (int): The size of groups to be shifted. + """ + def __init__(self, group_size): + super().__init__() + self.group_size = group_size + + def forward(self, x): + y = torch.split( + tensor=x, + split_size_or_sections=self.group_size, + dim=-1 + ) + + # Roll sequences in a batch jointly + # The first group remains unchanged + z = [y[0]] + for i in range(1, len(y)): + offset = - 2 ** (i - 1) + z.append(torch.roll(y[i], shifts=offset, dims=1)) + + z = torch.cat(z, -1) + return z + + +class SwanDNABlock(nn.Module): + """ + A PyTorch module implementing the SwanDNABlock. + + This module combines two main steps in the SwanDNA layer: circular-shift and column_transform. + The dropout between too is added. + + Args: + embedding_size (int): The size of the token embeddings. + group_size (int): The size of groups to be shifted. + hidden_size (int): The hidden layer size for the MLP. + mlp_dropout (float): The dropout probability for the MLP. + layer_dropout (float): The dropout probability for the SwanDNABlock. + prenorm (str): The type of normalization for the pre-normalization step. + norm (str): The type of normalization for the post-normalization step. + """ + + def __init__( + self, + embedding_size, + group_size, + hidden_size, + mlp_dropout, + layer_dropout, + prenorm, + norm + ): + super().__init__() + self.prenorm = map_norm(prenorm, embedding_size, group_size) + self.norm = map_norm(norm, embedding_size, group_size) + + self.column_transform = Mlp( + embedding_size, + hidden_size, + embedding_size, + act_layer=nn.GELU, + drop=mlp_dropout + ) + + self.dropout = nn.Dropout(layer_dropout) + self.shift = CircularShift(group_size) + + def forward(self, x): + res_con = x + x = self.prenorm(x) + x = self.column_transform(x) + x = self.dropout(x) + x = self.shift(x) + x = x + res_con + # x = self.norm(self.shift(self.dropout(self.column_transform(self.prenorm(x)))) + res_con) + return x + + +class SwanDNAEncoder(nn.Module): + """ + A PyTorch module implementing a SwanDNA Encoder as a stack of SwanDNA layers. + The number of layers in the stack is determined by the maximum sequence length in the batch. + The number of layers is fixed for the equal lengths mode. + + Args: + max_len (int): The maximum sequence length of the input tensor. + group_size (int): The size of groups to be shifted. + hidden_size (int): The hidden layer size for the MLP. + mlp_dropout (float): The dropout probability for the MLP. + layer_dropout (float): The dropout probability for the SwanDNABlock. + prenorm (str): The type of normalization for the pre-normalization step. + norm (str): The type of normalization for the post-normalization step. + """ + def __init__( + self, + max_len, + embedding_size, + group_size, + hidden_size, + mlp_dropout, + layer_dropout, + prenorm, + norm + ): + super().__init__() + self.max_len = max_len + self.max_n_layers = math.ceil(np.log2(max_len)) + self.SwanDNA_blocks = nn.ModuleList( + [ + SwanDNABlock( + embedding_size, + group_size, + hidden_size, + mlp_dropout, + layer_dropout, + prenorm, + norm + ) + for _ in range(self.max_n_layers) + ] + ) + + def forward(self, x): + # If var_len, use a variable number of layers + + for layer in range(self.max_n_layers): + x = self.SwanDNA_blocks[layer](x) + return x + + +class SwanDNANetwork(nn.Module): + """ + A PyTorch module implementing a SwanDNA Encoder as a stack of SwanDNA layers. + The number of layers in the stack is determined by the maximum sequence length in the batch. + The number of layers is fixed for the equal lengths mode. + + Args: + max_len (int): The maximum sequence length of the input tensor. + group_size (int): The size of groups to be shifted. + hidden_size (int): The hidden layer size for the MLP. + mlp_dropout (float): The dropout probability for the MLP. + layer_dropout (float): The dropout probability for the SwanDNABlock. + prenorm (str): The type of normalization for the pre-normalization step. + norm (str): The type of normalization for the post-normalization step. + """ + def __init__( + self, + max_len, + embedding_size, + group_size, + hidden_size, + mlp_dropout, + layer_dropout, + prenorm, + norm, + block_num + ): + super().__init__() + self.blocks = block_num + self.SwanDNA_blocks = nn.ModuleList( + [ + SwanDNAEncoder(max_len, + embedding_size, + group_size, + hidden_size, + mlp_dropout, + layer_dropout, + prenorm, + norm) + for _ in range(self.blocks) + ] + ) + + def forward(self, x): + # If var_len, use a variable number of layers + + for block in range(self.blocks): + x = self.SwanDNA_blocks[block](x) + return x + + +class Classifier(nn.Module): + """ + The SwanDNA model. Encoder is a stack of SwanDNA blocks. Decoder a global average pooling, followed by a linear layer. + + Args: + input_size (int): The input size of the embedding layer. + output_size (int): The output size of the decoder layer. + decoder (str): The type of decoder layer. We use 'linear'. + max_len (int): The maximum sequence length in the data. + group_size (int): The size of groups to be shifted. + hidden_size (int): The hidden layer size for the MLPs. + mlp_dropout (float): The dropout probability for the MLPs. + layer_dropout (float): The dropout probability for the SwanDNABlock. + prenorm (str): The type of normalization for the pre-normalization step. + norm (str): The type of normalization for the post-normalization step. + """ + def __init__(self, + input_size, + output_size, + max_len, + embedding_size, + group_size, + hidden_size, + mlp_dropout, + layer_dropout, + prenorm, + norm, + coeff + ): + super().__init__() + self.max_n_layers = math.ceil(np.log2(max_len)) + self.embedding = nn.Linear( + input_size, + embedding_size + ).apply(self._init_weights) + + self.encoder = SwanDNAEncoder( + max_len, + embedding_size, + group_size, + hidden_size, + mlp_dropout, + layer_dropout, + prenorm, + norm + ).apply(self._init_weights) + + self.cm_clf = SwanDNAEncoder( + max_len, + embedding_size, + group_size, + int(hidden_size*coeff), + mlp_dropout, + layer_dropout, + prenorm, + norm + ).apply(self._init_weights) + self.decoder = nn.Linear(embedding_size, output_size) + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=1.0) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def freeze_encoder(self): + for param in self.encoder.parameters(): + param.requires_grad = False + + def forward(self, x1, x2, idx_linear): + x1, x2 = x1.float(), x2.float() + x1, x2 = x1.permute(0, 2, 1), x2.permute(0, 2, 1) + x1, x2 = self.embedding(x1), self.embedding(x2) + y1 = self.encoder(x1) + y2 = self.encoder(x2) + y = y1 - y2 + y = self.cm_clf(y) + y = torch.mean(y, dim=1) + y = self.decoder(y) + idx_linear = idx_linear.unsqueeze(0).t().type(torch.int64) + y = torch.gather(y, 1, idx_linear) + return y + + +class GB_Classifier(nn.Module): + """ + The SwanDNA model. Encoder is a stack of SwanDNA blocks. Decoder a global average pooling, followed by a linear layer. + + Args: + input_size (int): The input size of the embedding layer. + output_size (int): The output size of the decoder layer. + max_len (int): The maximum sequence length in the data. + group_size (int): The size of groups to be shifted. + hidden_size (int): The hidden layer size for the MLPs. + mlp_dropout (float): The dropout probability for the MLPs. + layer_dropout (float): The dropout probability for the SwanDNABlock. + prenorm (str): The type of normalization for the pre-normalization step. + norm (str): The type of normalization for the post-normalization step. + """ + def __init__(self, + input_size, + output_size, + max_len, + embedding_size, + group_size, + hidden_size, + mlp_dropout, + layer_dropout, + prenorm, + norm, + coeff + ): + super().__init__() + self.max_n_layers = math.ceil(np.log2(max_len)) + self.group_size = group_size + self.embedding_size = embedding_size + self.embedding = nn.Linear( + input_size, + self.embedding_size + ).apply(self._init_weights) + + self.encoder = SwanDNAEncoder( + max_len, + self.embedding_size, + group_size, + hidden_size, + mlp_dropout, + layer_dropout, + prenorm, + norm + )#.apply(self._init_weights) + + self.cm_clf = SwanDNAEncoder( + max_len, + self.embedding_size, + group_size, + int(hidden_size*coeff), + mlp_dropout, + layer_dropout, + prenorm, + norm + ).apply(self._init_weights) + + self.decoder = nn.Linear(self.embedding_size, output_size) + self.freeze_encoder() + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=1.0) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def freeze_encoder(self): + for param in self.encoder.parameters(): + param.requires_grad = False + + def forward(self, x): + x = x.float() + x = self.embedding(x) + x = self.encoder(x) + x = self.cm_clf(x) + x = torch.mean(x, dim=1) + x = self.decoder(x) + return x + + +class GB_Linear_Classifier(nn.Module): + """ + The SwanDNA model. Encoder is a stack of SwanDNA blocks. Decoder a global average pooling, followed by a linear layer. + + Args: + input_size (int): The input size of the embedding layer. + output_size (int): The output size of the decoder layer. + max_len (int): The maximum sequence length in the data. + group_size (int): The size of groups to be shifted. + hidden_size (int): The hidden layer size for the MLPs. + mlp_dropout (float): The dropout probability for the MLPs. + layer_dropout (float): The dropout probability for the SwanDNABlock. + prenorm (str): The type of normalization for the pre-normalization step. + norm (str): The type of normalization for the post-normalization step. + """ + def __init__(self, + input_size, + output_size, + max_len, + embedding_size, + group_size, + hidden_size, + mlp_dropout, + layer_dropout, + prenorm, + norm, + block_num + ): + super().__init__() + self.max_n_layers = math.ceil(np.log2(max_len)) + self.group_size = group_size + self.embedding_size = embedding_size + self.embedding = nn.Linear( + input_size, + self.embedding_size + ) + + self.encoder = SwanDNANetwork( + max_len, + self.embedding_size, + group_size, + hidden_size, + mlp_dropout, + layer_dropout, + prenorm, + norm, + block_num=block_num + ) + + self.decoder = nn.Linear(self.embedding_size, output_size).apply(self._init_weights) + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=1.0) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def freeze_encoder(self): + for param in self.encoder.parameters(): + param.requires_grad = False + + def forward(self, x): + x = x.float() + x = self.embedding(x) + x = self.encoder(x) + x = torch.mean(x, dim=1) + x = self.decoder(x) + return x + + +class GB_CM_Classifier(nn.Module): + """ + The SwanDNA model. Encoder is a stack of SwanDNA blocks. Decoder a global average pooling, followed by a linear layer. + + Args: + input_size (int): The input size of the embedding layer. + output_size (int): The output size of the decoder layer. + max_len (int): The maximum sequence length in the data. + group_size (int): The size of groups to be shifted. + hidden_size (int): The hidden layer size for the MLPs. + mlp_dropout (float): The dropout probability for the MLPs. + layer_dropout (float): The dropout probability for the SwanDNABlock. + prenorm (str): The type of normalization for the pre-normalization step. + norm (str): The type of normalization for the post-normalization step. + """ + def __init__(self, + input_size, + output_size, + max_len, + embedding_size, + group_size, + hidden_size, + mlp_dropout, + layer_dropout, + prenorm, + norm, + coeff + ): + super().__init__() + self.max_n_layers = math.ceil(np.log2(max_len)) + self.group_size = group_size + self.embedding_size = embedding_size + self.embedding = nn.Linear( + input_size, + self.embedding_size + ).apply(self._init_weights) + + self.encoder = SwanDNANetwork( + max_len, + self.embedding_size, + group_size, + hidden_size, + mlp_dropout, + layer_dropout, + prenorm, + norm, + block_num=5 + ) + + self.decoder = nn.Linear(self.embedding_size, output_size) + self.freeze_encoder() + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=1.0) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def freeze_encoder(self): + for param in self.encoder.parameters(): + param.requires_grad = False + + def forward(self, x): + x = x.float() + x = self.embedding(x) + x = self.encoder(x) + x = torch.mean(x, dim=1) + x = self.decoder(x) + return x + + +class GB_Flash_Classifier(nn.Module): + def __init__(self, + input_size, + output_size, + max_len, + embedding_size, + group_size + ): + super().__init__() + self.max_n_layers = 8 + self.group_size = group_size + self.embedding_size = embedding_size + self.embedding = nn.Linear( + input_size, + self.embedding_size + ) + self.max_len = max_len + + self.pos_enc = nn.Embedding(max_len, embedding_size) + + self.encoder = nn.ModuleList( + [ + FLASH( + dim = embedding_size, + group_size = group_size, # group size + causal = True, # autoregressive or not + query_key_dim = int(embedding_size/4), # query / key dimension + expansion_factor = 2., # hidden dimension = dim * expansion_factor + laplace_attn_fn = True # new Mega paper claims this is more stable than relu squared as attention function + ) + for _ in range(self.max_n_layers) + ] + ) + + self.decoder = nn.Linear(self.embedding_size, output_size) + # self.freeze_encoder() + + def freeze_encoder(self): + for param in self.encoder.parameters(): + param.requires_grad = False + + def forward(self, x): + x = x.float() + print(x.shape) + positions = torch.arange(0, self.max_len).expand(x.size(0), self.max_len).cuda() + x = self.embedding(x) + pos_enc = self.pos_enc(positions) + x = pos_enc + x + print(x.shape) + for layer in range(self.max_n_layers): + x = self.encoder[layer](x) + x = torch.mean(x, dim=1) + x = self.decoder(x) + return x \ No newline at end of file diff --git a/models/__pycache__/DNASwan.cpython-39.pyc b/models/__pycache__/DNASwan.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02fe3c65bf4e65ab5bc3aa0884481e98d51d0c29 GIT binary patch literal 11685 zcmeHN+m9SqTCZDIUuSHO9mi&q-69)e_Y#jcCToEe0={IOEOryt*$o;9m71>VnHqOj zwNF*so}sN3j3rwu{IC)NY4?Gi7oK51@TO^#(|w@Wd0#Z3(~MIn~u&<92LA zvh!7aob%PWeCPLFPI*g94Gq`EU;dT; z*L3-8teN;X`uf0HvvlpM=9|8CU-PZN+N$iDYc?nq-v-6bDOFIaehrjbPO14X_;tT= zUtg=Etm(H<*2>Eope*^zpe*NDp>^7NW<*h!`T2{Z<(0gNB+xj*HN)32Lc=%6$2WV{1Ffg~mS4GFU$cGNuin?z zst+_vTdRS8Eiq~A(DTsE#7SfDEMBM>#-p^L$S#z) z{pu)XAD)4a*G*mZ($i72dR>mZ#Oh)<4R3qwr_&onc!;BN*U;?hJJ-IYy5);_=$!8L zV=panxdJ-6>Dt&F z9v%fqx@4CgxZl8+*oBbFHzptKq*Qsm9opR2xs#ab-_ONzK=8X#0AnUC%1X zXc&mB=DJ}NrmiatyePA;g;C&%%i82xG&w-gL~h>@>`|uQ z6~|a0lNU&-_F{3z6aG)(A!#9DucyDACi|S%z*=FEB5x45Zq{&JwO^5Mx$doz*Uw9; zuItC$={i~zQaLT+1Yft5&viHG_fcoonH7`@_DFChD0vB!nr<~t)FrKp$I``$ypq#M z=7CCrli{29YipJSDQ&GH!N~?V*;)OPmv%Qhu^8<9QI1bJ;uIJ}v>F6aiq&@BxZ=nL z4qJHra3|lEu>eHe3?e501o=LyCXXQ&(XS|W2@F%<`d~AVKtbqAk$;>O+b$FnjyvIc z&9ViYnW8GMB=m8t>Kv9nSz zda^{b@@2Mq4M{5H*w^y-yA(@oU$lYLP3uR21jtYgG zmqCsC;qB0eh@-f>=|vGh6RI_hmGbFMr;5|Xe4%om!V@|k>CHP#0TYMRBL7%uC#3?4+Q%{uRizqxG$FqLrax92^3+0q^ z@HKMCL9_}vC9O2s(}}*!iX9|l*tRzfaZK$?Xc#j(51IkmYbmCFmy}mQF+sVIN{zp3 zfOZ%(3$$;brM0KQE-(SL6>k?iSW5AA4`M6&Gs-6 zx;JMyUoyrP?hY)ZBbT+jhx3tsGk~!h<<>9x7sjF*X16q1e3h|+(;Nxd#CmutmYKPV%y5<#r&#$)R+x1vsgAKe$t?*4ZNYtf_EuZ zL{6<$0}R&y#_GSOi#I{-vS|L|-sknakhyd@6x~tZ6E`-)Ug~{3mAQ)t#0T=_!4?h# zCVh}No-_0WtQjbjGd&SN9kaRA$6Uq4xf7cSP0+}3H+I9}vC=+?T&p1^k)#&-gda7^bGZl2j!C%YB+ zIc8HmPcP8fK>S~Lpns+~_h#RYrCPJ<|kMRq(H!sU{1 z;KssD0_n#kndyfEY2}EYK?lXFNZLkb%Nb7K+Ju9oU6XnO-Vs^I`N=B6i#CGH><3Y1 zd&6N6`I#l+em^t2UfQ;mwtErWGJQMKcQQT9?6}uU;JM6fnZi&SpkhKRO4BXf;QeE} z*3ehrue@ZyY@Os{T-nw$Ja|4dnrGDJK>3s^3 zdj42{X>74+2FO3mY0Q4X%ft{qfUxYUZ@Sn2+%GIHCJm_dVaYOUn zx!-)yOHt&x7>R@+&J4mfGu1ih58df-RVX5V`II;pgjmc$21K5pM|pKQ5MDzOPzgD6 zMu9XEFbkv~@Z&HTG&oZ(p%l!<_nh3~B>hZww>{xvBRi+pto$DP2mc5eLPg zx84uVpw33R30ahpr77m;C@kF^wV5>L>A*-7{x&i*6o%<$GdQzr0*y^PRf{JptsoCW zF15F+fIj_Eh2fdhp49VDO(Sg*AVF;r$J=nx48a(Oj4rg7<`A7)9bTH6dGT|s!`?_s zL%htw!w0vDV8|h><}xDj0+>fd3Ym z5xuFX&eL`h?=krX6OLu-%@E($s<4sWZX7Mb^*=#{ucv&X<$zU^sj@( zcgQRU^dqg4g0wE~(#4tJ^}K+-ocHIQsEbIN@GgMSGeEjG1#~BG=y(W&y zAU$im)en=j8xMwW{}2_W|6wq@oV z;K<}?<)Gw{sDY76llctI=FWLxl(;PEixqu*+b3;kru0f}En6S0P5UMpo4iK~lr(%r4!6X*0{_g zp|BSQ;;(-QriB)nY8QC+L4UPSLA2^0S3zpMdFCcH9I9;eaTM2gyxY)s8?ZWI{UOlx zIy6MT4`*F!9?t;{r(a5gEz}Sa?`MIg+|@EbdZF+qro1G2A%Rx{i}EoJid;IG z@bkXQ#hvMvg7ag+x2c_7)X05@)29+dT$11v10$!@5Kqko-=;5GV6sn(R3DrVWvr#( zG3(TGD9+rwV-#lA{wJ+4bL)hQGH1``N4d}DKh3d_v-5w_fhq-BpW;e!pMqhcwVUfl1k)!G17P$ck+ECP>Fh~ZyCFWcr!J(^esw& zzGXA~hySXtjdbw`%xz(nxkAx;-ll;zL)P7@Z8d2_`Q{@-wt^3xw?Y|ev`%Vy+t>28 zOK4l-s`yNneS>eN6zp3a!LDxL#`TZC1nbEf1YVSxN$T$(!m1=;1h)p^VdZ>N4pczV zq7Y<h7wPV}eW8Q9ol_?YLQDZubx8#ZGec5J~Epby@{@;3Bi_=op>Vt%3KCH666 z6IxlV7LReU%jBmiFkl+@bRQvK*SOV?vt@L?zHbVTy=)XA|0G)y@Ww(q(i_6_r#-L8 z8V5Tbpz(QjoYStM%R49Mx>?>DU0H|qa53UQyGKgZpCU8XC)%DiHkcX+0>IePC+II_ zWDlRy%<_4uHj#(rKmnL!dwM}tpMY4PqZN+o5V#&!PH6XRP^>Ye^q!qo5W2ScDcU1L z(yNM&uyzfBYn_reX`oh#T$?(%bkj^N_ME*{iIA$SFQm_wYQX^@HTp)V4NRJ;xhqb<>i%? Rm7|RpSDcmeD_>i2{}WMJP`dyC literal 0 HcmV?d00001 diff --git a/models/__pycache__/Short_notebook_2.ipynb b/models/__pycache__/Short_notebook_2.ipynb new file mode 100644 index 0000000..35c6517 --- /dev/null +++ b/models/__pycache__/Short_notebook_2.ipynb @@ -0,0 +1,2035 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Short 2\n", + "\n", + "### Ensamble model of AutoGluon and Catboost\n", + "\n", + "Name: Erlend Lokna, Student ID: 528564\n", + "\n", + "Name: Johan Vik Mathisen, Student ID: 508258\n", + "\n", + "\n", + "Team name: Shaky Warriors" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "from autogluon.tabular import TabularPredictor\n", + "import catboost as cb\n", + "%matplotlib inline\n", + "\n", + "pd.set_option('display.max_rows', 200)\n", + "pd.set_option('display.max_columns', 200)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + " \n", + "class DataSet:\n", + " def __init__(self):\n", + " \"\"\"\n", + " kind: observerd, estimated, train\n", + " \"\"\"\n", + "\n", + " train_a = pd.read_parquet('data/A/train_targets.parquet')\n", + " train_b = pd.read_parquet('data/B/train_targets.parquet')\n", + " train_c = pd.read_parquet('data/C/train_targets.parquet')\n", + "\n", + " # Estimated training data for each location\n", + " X_train_estimated_a = pd.read_parquet('data/A/X_train_estimated.parquet')\n", + " X_train_estimated_b = pd.read_parquet('data/B/X_train_estimated.parquet')\n", + " X_train_estimated_c = pd.read_parquet('data/C/X_train_estimated.parquet')\n", + "\n", + " # Observed training data for each location\n", + " X_train_observed_b = pd.read_parquet('data/B/X_train_observed.parquet')\n", + " X_train_observed_a = pd.read_parquet('data/A/X_train_observed.parquet')\n", + " X_train_observed_c = pd.read_parquet('data/C/X_train_observed.parquet')\n", + "\n", + " # Estimated test data for each location\n", + " X_test_estimated_b = pd.read_parquet('data/B/X_test_estimated.parquet')\n", + " X_test_estimated_a = pd.read_parquet('data/A/X_test_estimated.parquet')\n", + " X_test_estimated_c = pd.read_parquet('data/C/X_test_estimated.parquet')\n", + "\n", + " Y_train = {\n", + " 'a': train_a, \n", + " 'b':train_b, \n", + " 'c':train_c\n", + " }\n", + " X_train_estimated = {\n", + " 'a':X_train_estimated_a,\n", + " 'b':X_train_estimated_b,\n", + " 'c':X_train_estimated_c\n", + " }\n", + " X_train_observed = {\n", + " 'a':X_train_observed_a,\n", + " 'b':X_train_observed_b,\n", + " 'c':X_train_observed_c\n", + " }\n", + " X_test_estimated = {\n", + " 'a':X_test_estimated_a,\n", + " 'b':X_test_estimated_b,\n", + " 'c':X_test_estimated_c\n", + " }\n", + " self.X_train_observed = X_train_observed\n", + " self.X_train_estimated = X_train_estimated\n", + " self.X_test_estimated = X_test_estimated\n", + " self.Y_train = Y_train\n", + "\n", + " def resample_to_hourly(self):\n", + " for loc in ['a','b','c']:\n", + " self.X_train_observed[loc] = to_hourly(self.X_train_observed[loc])\n", + " self.X_train_estimated[loc] = to_hourly(self.X_train_estimated[loc])\n", + " self.X_test_estimated[loc] = to_hourly(self.X_test_estimated[loc])\n", + "\n", + "\n", + " def select_features(self, features):\n", + " \"\"\" \n", + " Reduces dim by selecting only features from \"features\"\n", + " This will remove \"date_calc\" from est.\n", + " \"\"\"\n", + " for loc in ['a','b','c']:\n", + " self.X_train_observed[loc] = self.X_train_observed[loc][features]\n", + " self.X_train_estimated[loc] = self.X_train_estimated[loc][features]\n", + " self.X_test_estimated[loc] = self.X_test_estimated[loc][features]\n", + "\n", + " def add_type(self):\n", + " \"\"\"\n", + " 0: Estimated data\n", + " 1: Observed data\n", + " \"\"\"\n", + " for loc in ['a','b','c']:\n", + " type_vec_X_tr = [1] * len(self.X_train_observed[loc])\n", + " self.X_train_observed[loc]['type'] = type_vec_X_tr\n", + "\n", + " type_vec_X_tr_e = [0] * len(self.X_train_estimated[loc])\n", + " self.X_train_estimated[loc]['type'] = type_vec_X_tr_e\n", + "\n", + " type_vec_X_te = [0] * len(self.X_test_estimated[loc])\n", + " self.X_test_estimated[loc]['type'] = type_vec_X_te\n", + "\n", + "\n", + " def add_location(self):\n", + " \"\"\"\n", + " Adds a categorical feature \"location\" equal to the input string location.\n", + " \"\"\"\n", + " for loc in ['a','b','c']:\n", + " loc_vec_X_tr = [loc] * len(self.X_train_observed[loc])\n", + " self.X_train_observed[loc]['location'] = loc_vec_X_tr\n", + "\n", + " loc_vec_X_tr_e = [loc] * len(self.X_train_estimated[loc])\n", + " self.X_train_estimated[loc]['location'] = loc_vec_X_tr_e\n", + "\n", + " loc_vec_X_te = [loc] * len(self.X_test_estimated[loc])\n", + " self.X_test_estimated[loc]['location'] = loc_vec_X_te\n", + "\n", + " def remove_nans(self, feature):\n", + " for loc in ['a','b','c']:\n", + " cols = self.X_train_observed['a'].columns\n", + " if feature in cols:\n", + " self.X_train_observed[loc] = self.X_train_observed[loc].dropna(subset = [feature], how = 'all')\n", + " self.X_train_estimated[loc] = self.X_train_estimated[loc].dropna(subset = [feature], how = 'all')\n", + " self.X_test_estimated[loc] = self.X_test_estimated[loc].dropna(subset = [feature], how = 'all')\n", + " else:\n", + " print(\"Feature not in data frame.\")\n", + "\n", + " def combine_obs_est(self):\n", + " \"\"\"\n", + " Concatinates the estimated and observed data. \n", + " Removes data_calc from est.\n", + " \"\"\"\n", + "\n", + " obs_a = self.X_train_observed['a']\n", + " est_a = self.X_train_estimated['a']\n", + "\n", + " obs_b = self.X_train_observed['b']\n", + " est_b = self.X_train_estimated['b']\n", + "\n", + " obs_c = self.X_train_observed['c']\n", + " est_c = self.X_train_estimated['c']\n", + "\n", + " self.X_train = {\n", + " 'a':pd.concat([obs_a, est_a]),\n", + " 'b':pd.concat([obs_b, est_b]),\n", + " 'c':pd.concat([obs_c, est_c])\n", + " }\n", + "\n", + " self.X_train['a'] = self.X_train['a'].reset_index(drop=True)\n", + " self.X_train['b'] = self.X_train['b'].reset_index(drop=True)\n", + " self.X_train['c'] = self.X_train['c'].reset_index(drop=True)\n", + "\n", + " self.X_train['a'], self.Y_train['a'] = match_X_Y(self.X_train['a'], self.Y_train['a'])\n", + " self.X_train['b'], self.Y_train['b'] = match_X_Y(self.X_train['b'], self.Y_train['b'])\n", + " self.X_train['c'], self.Y_train['c'] = match_X_Y(self.X_train['c'], self.Y_train['c'])\n", + " \n", + " def train_test(self):\n", + " \"\"\"\n", + " Vanilla split. \n", + " \"\"\"\n", + " X_a = self.X_train['a']\n", + " X_b = self.X_train['b']\n", + " X_c = self.X_train['c']\n", + "\n", + " y_a = self.Y_train['a']\n", + " y_b = self.Y_train['b']\n", + " y_c = self.Y_train['c']\n", + "\n", + " y_train = pd.concat([y_a, y_b, y_c])\n", + " y_train = y_train.reset_index(drop=True)\n", + "\n", + " X_train = pd.concat([X_a, X_b, X_c])\n", + " X_test = pd.concat([self.X_test_estimated['a'], self.X_test_estimated['b'],self.X_test_estimated['c']])\n", + " \n", + " return X_train, X_test, y_train\n", + "\n", + " def scale_y_train(self, k_b = 5, k_c = 6):\n", + "\n", + " self.Y_train['b'] = self.Y_train['b'] * k_b \n", + " self.Y_train['c'] = self.Y_train['c']* k_c\n", + "\n", + " def drop_bad_data(self):\n", + " for loc in ['a', 'b', 'c']:\n", + " y_ind = get_constant_indices(self.Y_train[loc])\n", + " self.Y_train[loc].drop(y_ind, errors='ignore')\n", + " self.X_train[loc].drop(y_ind, errors='ignore')\n", + "\n", + "\n", + " def cyclic_time_encoding(self):\n", + " for loc in ['a', 'b', 'c']:\n", + " for time_feature in [\"time\", \"date_forecast\"]:\n", + " if time_feature in self.X_train[loc].columns:\n", + " self.X_train[loc]['sin_hour'] = np.sin(2*np.pi*self.X_train[loc][time_feature].dt.hour/24)\n", + " self.X_train[loc]['sin_month'] = np.sin(2*np.pi*self.X_train[loc][time_feature].dt.month/12)\n", + "\n", + " self.X_train[loc]['cos_hour'] = np.cos(2*np.pi*self.X_train[loc][time_feature].dt.hour/24)\n", + " self.X_train[loc]['cos_month'] = np.cos(2*np.pi*self.X_train[loc][time_feature].dt.month/12)\n", + " if time_feature in self.X_test_estimated[loc].columns: \n", + " self.X_test_estimated[loc]['sin_hour'] = np.sin(2*np.pi*self.X_test_estimated[loc][time_feature].dt.hour/24)\n", + " self.X_test_estimated[loc]['sin_month'] = np.sin(2*np.pi*self.X_test_estimated[loc][time_feature].dt.month/12)\n", + "\n", + " self.X_test_estimated[loc]['cos_hour'] = np.cos(2*np.pi*self.X_test_estimated[loc][time_feature].dt.hour/24)\n", + " self.X_test_estimated[loc]['cos_month'] = np.cos(2*np.pi*self.X_test_estimated[loc][time_feature].dt.month/12)\n", + "\n", + "#Helper functions\n", + "\n", + "def match_X_Y(X,Y):\n", + " \"\"\" \n", + " date_forecast and time must be unique!\n", + " Matches the timestamps of X to the timestamps of Y. \n", + " Makes sure that the length of X and Y are equal.\n", + " \"\"\"\n", + " Y = Y.dropna()\n", + " X = X.rename(columns={'date_forecast': 'time'})\n", + " merge_df = Y.merge(X, on=\"time\", how='inner')\n", + " Y = merge_df['pv_measurement']\n", + " X = merge_df.drop(columns = ['pv_measurement'])\n", + " return X,Y\n", + "\n", + "def to_hourly(df):\n", + " df['date_forecast']\n", + " df.set_index('date_forecast', inplace=True)\n", + " df = df.resample('H').mean()\n", + " df.reset_index(inplace=True)\n", + " return df\n", + "\n", + "def make_categorical(data, feature_list):\n", + " for feature in feature_list:\n", + " data[feature] = data[feature].astype('category')\n", + "\n", + "\n", + "def ReLU(x):\n", + " return x * (x > 0)\n", + "\n", + "def remap(x):\n", + " if x<0.5:\n", + " return 0\n", + " else:\n", + " return 1\n", + "\n", + "\n", + "def get_constant_indices(ser):\n", + " mask = (ser != 0)\n", + " constant_periods = ser[mask].groupby((ser[mask] != ser[mask].shift()).cumsum()).cumcount().add(1)\n", + " \n", + " drop_mask = constant_periods >= 12\n", + " return constant_periods[drop_mask].index" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "selected_features = ['date_forecast', 'absolute_humidity_2m:gm3',\n", + " 'clear_sky_energy_1h:J', 'clear_sky_rad:W',\n", + " 'cloud_base_agl:m', 'dew_or_rime:idx', 'dew_point_2m:K',\n", + " 'diffuse_rad:W', 'diffuse_rad_1h:J', 'direct_rad:W', 'direct_rad_1h:J',\n", + " 'effective_cloud_cover:p', 'elevation:m', 'fresh_snow_12h:cm',\n", + " 'fresh_snow_1h:cm', 'fresh_snow_24h:cm', 'fresh_snow_3h:cm',\n", + " 'fresh_snow_6h:cm', 'is_in_shadow:idx', 'is_day:idx', \n", + " 'msl_pressure:hPa', 'precip_5min:mm', 'precip_type_5min:idx',\n", + " 'pressure_100m:hPa', 'pressure_50m:hPa', 'prob_rime:p',\n", + " 'rain_water:kgm2', 'relative_humidity_1000hPa:p', 'sfc_pressure:hPa',\n", + " 'snow_depth:cm', 'snow_drift:idx',\n", + " 'snow_melt_10min:mm', 'snow_water:kgm2', 'sun_azimuth:d',\n", + " 'sun_elevation:d', 'super_cooled_liquid_water:kgm2', 't_1000hPa:K',\n", + " 'total_cloud_cover:p', 'visibility:m', 'wind_speed_10m:ms',\n", + " 'wind_speed_u_10m:ms', 'wind_speed_v_10m:ms', 'wind_speed_w_1000hPa:ms']\n", + "\n", + "made_features = ['location', 'type', 'is_day:idx', 'is_in_shadow:idx', 'dew_or_rime:idx']\n", + "\n", + "drop_feature = 'diffuse_rad:W'" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "data_collection = DataSet()\n", + "data_collection.select_features(selected_features)\n", + "data_collection.resample_to_hourly()\n", + "data_collection.remove_nans(drop_feature)\n", + "data_collection.add_location()\n", + "data_collection.add_type()\n", + "data_collection.combine_obs_est()\n", + "data_collection.drop_bad_data()\n", + "data_collection.cyclic_time_encoding()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "X_a = data_collection.X_train['a']\n", + "X_b = data_collection.X_train['b']\n", + "X_c = data_collection.X_train['c']\n", + "\n", + "y_a = data_collection.Y_train['a']\n", + "y_b = data_collection.Y_train['b']\n", + "y_c = data_collection.Y_train['c']\n", + "\n", + "for f in made_features:\n", + " if f not in ['location', 'type']:\n", + " X_a[f] = X_a[f].map(remap)\n", + " X_b[f] = X_b[f].map(remap)\n", + " X_c[f] = X_c[f].map(remap)\n", + "\n", + "make_categorical(X_a,made_features)\n", + "make_categorical(X_b,made_features)\n", + "make_categorical(X_c,made_features)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "drop_cols = ['location', 'time']\n", + "\n", + "df_a = pd.concat([X_a, y_a], axis=1).drop(columns=drop_cols)\n", + "df_b = pd.concat([X_b, y_b], axis=1).drop(columns=drop_cols)\n", + "df_c = pd.concat([X_c, y_c], axis=1).drop(columns=drop_cols)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "seed = 246\n", + "\n", + "data = dict()\n", + "\n", + "# sample 50% of the data for each building with type = 0\n", + "df_a_tune = df_a[df_a['type'] == 0].sample(frac=0.5, random_state=seed)\n", + "df_b_tune = df_b[df_b['type'] == 0].sample(frac=0.5, random_state=seed) \n", + "df_c_tune = df_c[df_c['type'] == 0].sample(frac=0.5, random_state=seed)\n", + "\n", + "# drop these rows from the original data\n", + "df_a_train = df_a.drop(df_a_tune.index)\n", + "df_b_train = df_b.drop(df_b_tune.index)\n", + "df_c_train = df_c.drop(df_c_tune.index)\n", + "\n", + "data['a'] = [df_a_train, df_a_tune]\n", + "data['b'] = [df_b_train, df_b_tune]\n", + "data['c'] = [df_c_train, df_c_tune]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Model" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "#3 hours (per model)\n", + "time_in_sek = 60*60*2\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "No path specified. Models will be saved in: \"AutogluonModels/ag-20231116_082447/\"\n", + "Presets specified: ['best_quality']\n", + "Stack configuration (auto_stack=True): num_stack_levels=0, num_bag_folds=8, num_bag_sets=20\n", + "Beginning AutoGluon training ... Time limit = 7200s\n", + "AutoGluon will save models to \"AutogluonModels/ag-20231116_082447/\"\n", + "AutoGluon Version: 0.8.2\n", + "Python Version: 3.8.8\n", + "Operating System: Linux\n", + "Platform Machine: x86_64\n", + "Platform Version: #98~20.04.1-Ubuntu SMP Mon Oct 9 16:43:45 UTC 2023\n", + "Disk Space Avail: 35.80 GB / 339.99 GB (10.5%)\n", + "Train Data Rows: 31864\n", + "Train Data Columns: 47\n", + "Tuning Data Rows: 2197\n", + "Tuning Data Columns: 47\n", + "Label Column: pv_measurement\n", + "Preprocessing data ...\n", + "AutoGluon infers your prediction problem is: 'regression' (because dtype of label-column == float and many unique label-values observed).\n", + "\tLabel info (max, min, mean, stddev): (5733.42, 0.0, 650.65332, 1179.83452)\n", + "\tIf 'regression' is not the correct problem_type, please manually specify the problem_type parameter during predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression'])\n", + "Using Feature Generators to preprocess the data ...\n", + "Fitting AutoMLPipelineFeatureGenerator...\n", + "\tAvailable Memory: 20282.92 MB\n", + "\tTrain Data (Original) Memory Usage: 6.54 MB (0.0% of available memory)\n", + "\tInferring data type of each feature based on column values. Set feature_metadata_in to manually specify special dtypes of the features.\n", + "\tStage 1 Generators:\n", + "\t\tFitting AsTypeFeatureGenerator...\n", + "\t\t\tNote: Converting 4 features to boolean dtype as they only contain 2 unique values.\n", + "\tStage 2 Generators:\n", + "\t\tFitting FillNaFeatureGenerator...\n", + "\tStage 3 Generators:\n", + "\t\tFitting IdentityFeatureGenerator...\n", + "\tStage 4 Generators:\n", + "\t\tFitting DropUniqueFeatureGenerator...\n", + "\tStage 5 Generators:\n", + "\t\tFitting DropDuplicatesFeatureGenerator...\n", + "\tUseless Original Features (Count: 2): ['elevation:m', 'snow_drift:idx']\n", + "\t\tThese features carry no predictive signal and should be manually investigated.\n", + "\t\tThis is typically a feature which has the same value for all rows.\n", + "\t\tThese features do not need to be present at inference time.\n", + "\tTypes of features in original data (raw dtype, special dtypes):\n", + "\t\t('category', []) : 4 | ['dew_or_rime:idx', 'is_in_shadow:idx', 'is_day:idx', 'type']\n", + "\t\t('float', []) : 41 | ['absolute_humidity_2m:gm3', 'clear_sky_energy_1h:J', 'clear_sky_rad:W', 'cloud_base_agl:m', 'dew_point_2m:K', ...]\n", + "\tTypes of features in processed data (raw dtype, special dtypes):\n", + "\t\t('float', []) : 41 | ['absolute_humidity_2m:gm3', 'clear_sky_energy_1h:J', 'clear_sky_rad:W', 'cloud_base_agl:m', 'dew_point_2m:K', ...]\n", + "\t\t('int', ['bool']) : 4 | ['dew_or_rime:idx', 'is_in_shadow:idx', 'is_day:idx', 'type']\n", + "\t0.1s = Fit runtime\n", + "\t45 features in original data used to generate 45 features in processed data.\n", + "\tTrain Data (Processed) Memory Usage: 6.27 MB (0.0% of available memory)\n", + "Data preprocessing and feature engineering runtime = 0.15s ...\n", + "AutoGluon will gauge predictive performance using evaluation metric: 'mean_absolute_error'\n", + "\tThis metric's sign has been flipped to adhere to being higher_is_better. The metric score can be multiplied by -1 to get the metric value.\n", + "\tTo change this, specify the eval_metric parameter of Predictor()\n", + "use_bag_holdout=True, will use tuning_data as holdout (will not be used for early stopping).\n", + "User-specified model hyperparameters to be fit:\n", + "{\n", + "\t'NN_TORCH': {},\n", + "\t'GBM': [{'extra_trees': True, 'ag_args': {'name_suffix': 'XT'}}, {}, 'GBMLarge'],\n", + "\t'CAT': {},\n", + "\t'XGB': {},\n", + "\t'FASTAI': {},\n", + "\t'RF': [{'criterion': 'gini', 'ag_args': {'name_suffix': 'Gini', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'entropy', 'ag_args': {'name_suffix': 'Entr', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'squared_error', 'ag_args': {'name_suffix': 'MSE', 'problem_types': ['regression', 'quantile']}}],\n", + "\t'XT': [{'criterion': 'gini', 'ag_args': {'name_suffix': 'Gini', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'entropy', 'ag_args': {'name_suffix': 'Entr', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'squared_error', 'ag_args': {'name_suffix': 'MSE', 'problem_types': ['regression', 'quantile']}}],\n", + "\t'KNN': [{'weights': 'uniform', 'ag_args': {'name_suffix': 'Unif'}}, {'weights': 'distance', 'ag_args': {'name_suffix': 'Dist'}}],\n", + "}\n", + "Fitting 11 L1 models ...\n", + "Fitting model: KNeighborsUnif_BAG_L1 ... Training model for up to 7199.85s of the 7199.84s of remaining time.\n", + "\t-170.5398\t = Validation score (-mean_absolute_error)\n", + "\t0.03s\t = Training runtime\n", + "\t17.73s\t = Validation runtime\n", + "Fitting model: KNeighborsDist_BAG_L1 ... Training model for up to 7180.33s of the 7180.33s of remaining time.\n", + "\t-170.2365\t = Validation score (-mean_absolute_error)\n", + "\t0.03s\t = Training runtime\n", + "\t18.01s\t = Validation runtime\n", + "Fitting model: LightGBMXT_BAG_L1 ... Training model for up to 7160.75s of the 7160.75s of remaining time.\n", + "\tFitting 8 child models (S1F1 - S1F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "/home/dashuo/anaconda3/lib/python3.8/site-packages/dask/dataframe/utils.py:369: FutureWarning: pandas.Int64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.\n", + " _numeric_index_types = (pd.Int64Index, pd.Float64Index, pd.UInt64Index)\n", + "/home/dashuo/anaconda3/lib/python3.8/site-packages/dask/dataframe/utils.py:369: FutureWarning: pandas.Float64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.\n", + " _numeric_index_types = (pd.Int64Index, pd.Float64Index, pd.UInt64Index)\n", + "/home/dashuo/anaconda3/lib/python3.8/site-packages/dask/dataframe/utils.py:369: FutureWarning: pandas.UInt64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.\n", + " _numeric_index_types = (pd.Int64Index, pd.Float64Index, pd.UInt64Index)\n", + "\t-91.6585\t = Validation score (-mean_absolute_error)\n", + "\t54.59s\t = Training runtime\n", + "\t63.89s\t = Validation runtime\n", + "Fitting model: LightGBM_BAG_L1 ... Training model for up to 7087.44s of the 7087.44s of remaining time.\n", + "\tFitting 8 child models (S1F1 - S1F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-99.8131\t = Validation score (-mean_absolute_error)\n", + "\t64.32s\t = Training runtime\n", + "\t38.5s\t = Validation runtime\n", + "Fitting model: RandomForestMSE_BAG_L1 ... Training model for up to 7013.5s of the 7013.49s of remaining time.\n", + "\t-112.6177\t = Validation score (-mean_absolute_error)\n", + "\t20.96s\t = Training runtime\n", + "\t0.9s\t = Validation runtime\n", + "Fitting model: CatBoost_BAG_L1 ... Training model for up to 6990.88s of the 6990.88s of remaining time.\n", + "\tFitting 8 child models (S1F1 - S1F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-105.1783\t = Validation score (-mean_absolute_error)\n", + "\t334.18s\t = Training runtime\n", + "\t0.12s\t = Validation runtime\n", + "Fitting model: ExtraTreesMSE_BAG_L1 ... Training model for up to 6654.4s of the 6654.4s of remaining time.\n", + "\t-112.2127\t = Validation score (-mean_absolute_error)\n", + "\t5.35s\t = Training runtime\n", + "\t0.89s\t = Validation runtime\n", + "Fitting model: NeuralNetFastAI_BAG_L1 ... Training model for up to 6647.39s of the 6647.39s of remaining time.\n", + "\tFitting 8 child models (S1F1 - S1F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-109.9873\t = Validation score (-mean_absolute_error)\n", + "\t51.92s\t = Training runtime\n", + "\t0.59s\t = Validation runtime\n", + "Fitting model: XGBoost_BAG_L1 ... Training model for up to 6592.2s of the 6592.19s of remaining time.\n", + "\tFitting 8 child models (S1F1 - S1F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-105.5894\t = Validation score (-mean_absolute_error)\n", + "\t89.45s\t = Training runtime\n", + "\t2.95s\t = Validation runtime\n", + "Fitting model: NeuralNetTorch_BAG_L1 ... Training model for up to 6497.97s of the 6497.97s of remaining time.\n", + "\tFitting 8 child models (S1F1 - S1F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-98.0603\t = Validation score (-mean_absolute_error)\n", + "\t124.26s\t = Training runtime\n", + "\t0.39s\t = Validation runtime\n", + "Fitting model: LightGBMLarge_BAG_L1 ... Training model for up to 6371.66s of the 6371.65s of remaining time.\n", + "\tFitting 8 child models (S1F1 - S1F8) | Fitting with ParallelLocalFoldFittingStrategy\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\t-96.7658\t = Validation score (-mean_absolute_error)\n", + "\t161.68s\t = Training runtime\n", + "\t84.26s\t = Validation runtime\n", + "Repeating k-fold bagging: 2/20\n", + "Fitting model: LightGBMXT_BAG_L1 ... Training model for up to 6184.33s of the 6184.33s of remaining time.\n", + "\tFitting 8 child models (S2F1 - S2F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-91.5613\t = Validation score (-mean_absolute_error)\n", + "\t109.18s\t = Training runtime\n", + "\t127.3s\t = Validation runtime\n", + "Fitting model: LightGBM_BAG_L1 ... Training model for up to 6112.06s of the 6112.05s of remaining time.\n", + "\tFitting 8 child models (S2F1 - S2F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-99.4679\t = Validation score (-mean_absolute_error)\n", + "\t121.69s\t = Training runtime\n", + "\t68.31s\t = Validation runtime\n", + "Fitting model: CatBoost_BAG_L1 ... Training model for up to 6040.9s of the 6040.9s of remaining time.\n", + "\tFitting 8 child models (S2F1 - S2F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-104.7195\t = Validation score (-mean_absolute_error)\n", + "\t671.84s\t = Training runtime\n", + "\t0.23s\t = Validation runtime\n", + "Fitting model: NeuralNetFastAI_BAG_L1 ... Training model for up to 5701.18s of the 5701.18s of remaining time.\n", + "\tFitting 8 child models (S2F1 - S2F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-109.2595\t = Validation score (-mean_absolute_error)\n", + "\t101.17s\t = Training runtime\n", + "\t1.12s\t = Validation runtime\n", + "Fitting model: XGBoost_BAG_L1 ... Training model for up to 5649.27s of the 5649.27s of remaining time.\n", + "\tFitting 8 child models (S2F1 - S2F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-104.9263\t = Validation score (-mean_absolute_error)\n", + "\t192.25s\t = Training runtime\n", + "\t9.44s\t = Validation runtime\n", + "Fitting model: NeuralNetTorch_BAG_L1 ... Training model for up to 5539.97s of the 5539.97s of remaining time.\n", + "\tFitting 8 child models (S2F1 - S2F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-97.3511\t = Validation score (-mean_absolute_error)\n", + "\t256.21s\t = Training runtime\n", + "\t0.8s\t = Validation runtime\n", + "Fitting model: LightGBMLarge_BAG_L1 ... Training model for up to 5405.88s of the 5405.87s of remaining time.\n", + "\tFitting 8 child models (S2F1 - S2F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-96.7628\t = Validation score (-mean_absolute_error)\n", + "\t327.75s\t = Training runtime\n", + "\t183.65s\t = Validation runtime\n", + "Repeating k-fold bagging: 3/20\n", + "Fitting model: LightGBMXT_BAG_L1 ... Training model for up to 5205.24s of the 5205.24s of remaining time.\n", + "\tFitting 8 child models (S3F1 - S3F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-91.3144\t = Validation score (-mean_absolute_error)\n", + "\t166.54s\t = Training runtime\n", + "\t200.24s\t = Validation runtime\n", + "Fitting model: LightGBM_BAG_L1 ... Training model for up to 5126.24s of the 5126.24s of remaining time.\n", + "\tFitting 8 child models (S3F1 - S3F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-99.183\t = Validation score (-mean_absolute_error)\n", + "\t189.13s\t = Training runtime\n", + "\t114.32s\t = Validation runtime\n", + "Fitting model: CatBoost_BAG_L1 ... Training model for up to 5042.86s of the 5042.86s of remaining time.\n", + "\tFitting 8 child models (S3F1 - S3F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-104.652\t = Validation score (-mean_absolute_error)\n", + "\t1004.69s\t = Training runtime\n", + "\t0.34s\t = Validation runtime\n", + "Fitting model: NeuralNetFastAI_BAG_L1 ... Training model for up to 4707.7s of the 4707.7s of remaining time.\n", + "\tFitting 8 child models (S3F1 - S3F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-109.0019\t = Validation score (-mean_absolute_error)\n", + "\t150.13s\t = Training runtime\n", + "\t1.65s\t = Validation runtime\n", + "Fitting model: XGBoost_BAG_L1 ... Training model for up to 4655.55s of the 4655.55s of remaining time.\n", + "\tFitting 8 child models (S3F1 - S3F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-104.6682\t = Validation score (-mean_absolute_error)\n", + "\t278.94s\t = Training runtime\n", + "\t12.64s\t = Validation runtime\n", + "Fitting model: NeuralNetTorch_BAG_L1 ... Training model for up to 4562.22s of the 4562.21s of remaining time.\n", + "\tFitting 8 child models (S3F1 - S3F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-97.2953\t = Validation score (-mean_absolute_error)\n", + "\t385.99s\t = Training runtime\n", + "\t1.18s\t = Validation runtime\n", + "Fitting model: LightGBMLarge_BAG_L1 ... Training model for up to 4430.12s of the 4430.12s of remaining time.\n", + "\tFitting 8 child models (S3F1 - S3F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-96.8981\t = Validation score (-mean_absolute_error)\n", + "\t480.79s\t = Training runtime\n", + "\t264.39s\t = Validation runtime\n", + "Repeating k-fold bagging: 4/20\n", + "Fitting model: LightGBMXT_BAG_L1 ... Training model for up to 4240.05s of the 4240.05s of remaining time.\n", + "\tFitting 8 child models (S4F1 - S4F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-91.2554\t = Validation score (-mean_absolute_error)\n", + "\t217.48s\t = Training runtime\n", + "\t256.44s\t = Validation runtime\n", + "Fitting model: LightGBM_BAG_L1 ... Training model for up to 4166.9s of the 4166.89s of remaining time.\n", + "\tFitting 8 child models (S4F1 - S4F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-99.3312\t = Validation score (-mean_absolute_error)\n", + "\t248.49s\t = Training runtime\n", + "\t154.18s\t = Validation runtime\n", + "Fitting model: CatBoost_BAG_L1 ... Training model for up to 4090.09s of the 4090.09s of remaining time.\n", + "\tFitting 8 child models (S4F1 - S4F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-104.6171\t = Validation score (-mean_absolute_error)\n", + "\t1326.25s\t = Training runtime\n", + "\t0.45s\t = Validation runtime\n", + "Fitting model: NeuralNetFastAI_BAG_L1 ... Training model for up to 3766.47s of the 3766.46s of remaining time.\n", + "\tFitting 8 child models (S4F1 - S4F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-108.872\t = Validation score (-mean_absolute_error)\n", + "\t197.91s\t = Training runtime\n", + "\t2.15s\t = Validation runtime\n", + "Fitting model: XGBoost_BAG_L1 ... Training model for up to 3715.34s of the 3715.34s of remaining time.\n", + "\tFitting 8 child models (S4F1 - S4F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-103.8981\t = Validation score (-mean_absolute_error)\n", + "\t384.72s\t = Training runtime\n", + "\t24.53s\t = Validation runtime\n", + "Fitting model: NeuralNetTorch_BAG_L1 ... Training model for up to 3600.63s of the 3600.63s of remaining time.\n", + "\tFitting 8 child models (S4F1 - S4F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-97.3585\t = Validation score (-mean_absolute_error)\n", + "\t503.8s\t = Training runtime\n", + "\t1.55s\t = Validation runtime\n", + "Fitting model: LightGBMLarge_BAG_L1 ... Training model for up to 3480.03s of the 3480.03s of remaining time.\n", + "\tFitting 8 child models (S4F1 - S4F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-96.9215\t = Validation score (-mean_absolute_error)\n", + "\t634.79s\t = Training runtime\n", + "\t330.64s\t = Validation runtime\n", + "Repeating k-fold bagging: 5/20\n", + "Fitting model: LightGBMXT_BAG_L1 ... Training model for up to 3287.36s of the 3287.35s of remaining time.\n", + "\tFitting 8 child models (S5F1 - S5F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-91.2525\t = Validation score (-mean_absolute_error)\n", + "\t267.33s\t = Training runtime\n", + "\t317.37s\t = Validation runtime\n", + "Fitting model: LightGBM_BAG_L1 ... Training model for up to 3212.27s of the 3212.27s of remaining time.\n", + "\tFitting 8 child models (S5F1 - S5F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-99.2688\t = Validation score (-mean_absolute_error)\n", + "\t306.23s\t = Training runtime\n", + "\t213.86s\t = Validation runtime\n", + "Fitting model: CatBoost_BAG_L1 ... Training model for up to 3130.89s of the 3130.89s of remaining time.\n", + "\tFitting 8 child models (S5F1 - S5F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-104.6233\t = Validation score (-mean_absolute_error)\n", + "\t1646.05s\t = Training runtime\n", + "\t0.56s\t = Validation runtime\n", + "Fitting model: NeuralNetFastAI_BAG_L1 ... Training model for up to 2808.58s of the 2808.57s of remaining time.\n", + "\tFitting 8 child models (S5F1 - S5F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-108.7237\t = Validation score (-mean_absolute_error)\n", + "\t246.04s\t = Training runtime\n", + "\t2.72s\t = Validation runtime\n", + "Fitting model: XGBoost_BAG_L1 ... Training model for up to 2756.91s of the 2756.91s of remaining time.\n", + "\tFitting 8 child models (S5F1 - S5F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-103.8802\t = Validation score (-mean_absolute_error)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\t479.57s\t = Training runtime\n", + "\t31.7s\t = Validation runtime\n", + "Fitting model: NeuralNetTorch_BAG_L1 ... Training model for up to 2651.99s of the 2651.99s of remaining time.\n", + "\tFitting 8 child models (S5F1 - S5F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-97.541\t = Validation score (-mean_absolute_error)\n", + "\t620.89s\t = Training runtime\n", + "\t1.9s\t = Validation runtime\n", + "Fitting model: LightGBMLarge_BAG_L1 ... Training model for up to 2531.82s of the 2531.81s of remaining time.\n", + "\tFitting 8 child models (S5F1 - S5F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-96.8442\t = Validation score (-mean_absolute_error)\n", + "\t786.3s\t = Training runtime\n", + "\t416.08s\t = Validation runtime\n", + "Repeating k-fold bagging: 6/20\n", + "Fitting model: LightGBMXT_BAG_L1 ... Training model for up to 2331.17s of the 2331.17s of remaining time.\n", + "\tFitting 8 child models (S6F1 - S6F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-91.1536\t = Validation score (-mean_absolute_error)\n", + "\t317.54s\t = Training runtime\n", + "\t379.38s\t = Validation runtime\n", + "Fitting model: LightGBM_BAG_L1 ... Training model for up to 2252.57s of the 2252.57s of remaining time.\n", + "\tFitting 8 child models (S6F1 - S6F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-99.2429\t = Validation score (-mean_absolute_error)\n", + "\t358.49s\t = Training runtime\n", + "\t232.25s\t = Validation runtime\n", + "Fitting model: CatBoost_BAG_L1 ... Training model for up to 2180.21s of the 2180.21s of remaining time.\n", + "\tFitting 8 child models (S6F1 - S6F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-104.6122\t = Validation score (-mean_absolute_error)\n", + "\t1965.84s\t = Training runtime\n", + "\t0.66s\t = Validation runtime\n", + "Fitting model: NeuralNetFastAI_BAG_L1 ... Training model for up to 1857.79s of the 1857.79s of remaining time.\n", + "\tFitting 8 child models (S6F1 - S6F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-108.8813\t = Validation score (-mean_absolute_error)\n", + "\t294.21s\t = Training runtime\n", + "\t3.23s\t = Validation runtime\n", + "Fitting model: XGBoost_BAG_L1 ... Training model for up to 1805.89s of the 1805.89s of remaining time.\n", + "\tFitting 8 child models (S6F1 - S6F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-104.07\t = Validation score (-mean_absolute_error)\n", + "\t562.96s\t = Training runtime\n", + "\t34.54s\t = Validation runtime\n", + "Fitting model: NeuralNetTorch_BAG_L1 ... Training model for up to 1713.44s of the 1713.44s of remaining time.\n", + "\tFitting 8 child models (S6F1 - S6F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-97.6575\t = Validation score (-mean_absolute_error)\n", + "\t750.29s\t = Training runtime\n", + "\t2.27s\t = Validation runtime\n", + "Fitting model: LightGBMLarge_BAG_L1 ... Training model for up to 1580.89s of the 1580.89s of remaining time.\n", + "\tFitting 8 child models (S6F1 - S6F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-96.7983\t = Validation score (-mean_absolute_error)\n", + "\t936.88s\t = Training runtime\n", + "\t487.22s\t = Validation runtime\n", + "Repeating k-fold bagging: 7/20\n", + "Fitting model: LightGBMXT_BAG_L1 ... Training model for up to 1378.02s of the 1378.02s of remaining time.\n", + "\tFitting 8 child models (S7F1 - S7F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-91.1458\t = Validation score (-mean_absolute_error)\n", + "\t368.14s\t = Training runtime\n", + "\t438.82s\t = Validation runtime\n", + "Fitting model: LightGBM_BAG_L1 ... Training model for up to 1297.2s of the 1297.2s of remaining time.\n", + "\tFitting 8 child models (S7F1 - S7F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-99.1905\t = Validation score (-mean_absolute_error)\n", + "\t416.21s\t = Training runtime\n", + "\t287.19s\t = Validation runtime\n", + "Fitting model: CatBoost_BAG_L1 ... Training model for up to 1211.71s of the 1211.71s of remaining time.\n", + "\tFitting 8 child models (S7F1 - S7F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-104.5463\t = Validation score (-mean_absolute_error)\n", + "\t2287.62s\t = Training runtime\n", + "\t0.77s\t = Validation runtime\n", + "Fitting model: NeuralNetFastAI_BAG_L1 ... Training model for up to 887.66s of the 887.66s of remaining time.\n", + "\tFitting 8 child models (S7F1 - S7F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-108.7659\t = Validation score (-mean_absolute_error)\n", + "\t342.38s\t = Training runtime\n", + "\t3.75s\t = Validation runtime\n", + "Fitting model: XGBoost_BAG_L1 ... Training model for up to 835.48s of the 835.48s of remaining time.\n", + "\tFitting 8 child models (S7F1 - S7F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-104.0103\t = Validation score (-mean_absolute_error)\n", + "\t647.28s\t = Training runtime\n", + "\t37.33s\t = Validation runtime\n", + "Fitting model: NeuralNetTorch_BAG_L1 ... Training model for up to 741.58s of the 741.57s of remaining time.\n", + "\tFitting 8 child models (S7F1 - S7F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-97.8036\t = Validation score (-mean_absolute_error)\n", + "\t868.81s\t = Training runtime\n", + "\t2.66s\t = Validation runtime\n", + "Fitting model: LightGBMLarge_BAG_L1 ... Training model for up to 619.56s of the 619.56s of remaining time.\n", + "\tFitting 8 child models (S7F1 - S7F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-96.8534\t = Validation score (-mean_absolute_error)\n", + "\t1085.66s\t = Training runtime\n", + "\t586.59s\t = Validation runtime\n", + "Completed 7/20 k-fold bagging repeats ...\n", + "Fitting model: WeightedEnsemble_L2 ... Training model for up to 719.98s of the 403.44s of remaining time.\n", + "\t-90.2284\t = Validation score (-mean_absolute_error)\n", + "\t0.27s\t = Training runtime\n", + "\t0.01s\t = Validation runtime\n", + "AutoGluon training complete, total runtime = 6796.86s ... Best model: \"WeightedEnsemble_L2\"\n", + "TabularPredictor saved. To load, use: predictor = TabularPredictor.load(\"AutogluonModels/ag-20231116_082447/\")\n" + ] + } + ], + "source": [ + "label = 'pv_measurement'\n", + "predictor_a = TabularPredictor(label=label, eval_metric='mae').fit(\n", + " train_data = data['a'][0], \n", + " time_limit = time_in_sek,\n", + " presets='best_quality',\n", + " num_bag_folds=8,\n", + " num_stack_levels=0,\n", + " tuning_data = data['a'][1],\n", + " use_bag_holdout= True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "No path specified. Models will be saved in: \"AutogluonModels/ag-20231116_101804/\"\n", + "Presets specified: ['best_quality']\n", + "Stack configuration (auto_stack=True): num_stack_levels=0, num_bag_folds=8, num_bag_sets=20\n", + "Beginning AutoGluon training ... Time limit = 7200s\n", + "AutoGluon will save models to \"AutogluonModels/ag-20231116_101804/\"\n", + "AutoGluon Version: 0.8.2\n", + "Python Version: 3.8.8\n", + "Operating System: Linux\n", + "Platform Machine: x86_64\n", + "Platform Version: #98~20.04.1-Ubuntu SMP Mon Oct 9 16:43:45 UTC 2023\n", + "Disk Space Avail: 25.30 GB / 339.99 GB (7.4%)\n", + "Train Data Rows: 31019\n", + "Train Data Columns: 47\n", + "Tuning Data Rows: 1800\n", + "Tuning Data Columns: 47\n", + "Label Column: pv_measurement\n", + "Preprocessing data ...\n", + "AutoGluon infers your prediction problem is: 'regression' (because dtype of label-column == float and many unique label-values observed).\n", + "\tLabel info (max, min, mean, stddev): (1152.3, -0.0, 99.69624, 196.54802)\n", + "\tIf 'regression' is not the correct problem_type, please manually specify the problem_type parameter during predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression'])\n", + "Using Feature Generators to preprocess the data ...\n", + "Fitting AutoMLPipelineFeatureGenerator...\n", + "\tAvailable Memory: 17326.02 MB\n", + "\tTrain Data (Original) Memory Usage: 6.3 MB (0.0% of available memory)\n", + "\tInferring data type of each feature based on column values. Set feature_metadata_in to manually specify special dtypes of the features.\n", + "\tStage 1 Generators:\n", + "\t\tFitting AsTypeFeatureGenerator...\n", + "\t\t\tNote: Converting 4 features to boolean dtype as they only contain 2 unique values.\n", + "\tStage 2 Generators:\n", + "\t\tFitting FillNaFeatureGenerator...\n", + "\tStage 3 Generators:\n", + "\t\tFitting IdentityFeatureGenerator...\n", + "\tStage 4 Generators:\n", + "\t\tFitting DropUniqueFeatureGenerator...\n", + "\tStage 5 Generators:\n", + "\t\tFitting DropDuplicatesFeatureGenerator...\n", + "\tUseless Original Features (Count: 1): ['elevation:m']\n", + "\t\tThese features carry no predictive signal and should be manually investigated.\n", + "\t\tThis is typically a feature which has the same value for all rows.\n", + "\t\tThese features do not need to be present at inference time.\n", + "\tTypes of features in original data (raw dtype, special dtypes):\n", + "\t\t('category', []) : 4 | ['dew_or_rime:idx', 'is_in_shadow:idx', 'is_day:idx', 'type']\n", + "\t\t('float', []) : 42 | ['absolute_humidity_2m:gm3', 'clear_sky_energy_1h:J', 'clear_sky_rad:W', 'cloud_base_agl:m', 'dew_point_2m:K', ...]\n", + "\tTypes of features in processed data (raw dtype, special dtypes):\n", + "\t\t('float', []) : 42 | ['absolute_humidity_2m:gm3', 'clear_sky_energy_1h:J', 'clear_sky_rad:W', 'cloud_base_agl:m', 'dew_point_2m:K', ...]\n", + "\t\t('int', ['bool']) : 4 | ['dew_or_rime:idx', 'is_in_shadow:idx', 'is_day:idx', 'type']\n", + "\t0.2s = Fit runtime\n", + "\t46 features in original data used to generate 46 features in processed data.\n", + "\tTrain Data (Processed) Memory Usage: 6.17 MB (0.0% of available memory)\n", + "Data preprocessing and feature engineering runtime = 0.18s ...\n", + "AutoGluon will gauge predictive performance using evaluation metric: 'mean_absolute_error'\n", + "\tThis metric's sign has been flipped to adhere to being higher_is_better. The metric score can be multiplied by -1 to get the metric value.\n", + "\tTo change this, specify the eval_metric parameter of Predictor()\n", + "use_bag_holdout=True, will use tuning_data as holdout (will not be used for early stopping).\n", + "User-specified model hyperparameters to be fit:\n", + "{\n", + "\t'NN_TORCH': {},\n", + "\t'GBM': [{'extra_trees': True, 'ag_args': {'name_suffix': 'XT'}}, {}, 'GBMLarge'],\n", + "\t'CAT': {},\n", + "\t'XGB': {},\n", + "\t'FASTAI': {},\n", + "\t'RF': [{'criterion': 'gini', 'ag_args': {'name_suffix': 'Gini', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'entropy', 'ag_args': {'name_suffix': 'Entr', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'squared_error', 'ag_args': {'name_suffix': 'MSE', 'problem_types': ['regression', 'quantile']}}],\n", + "\t'XT': [{'criterion': 'gini', 'ag_args': {'name_suffix': 'Gini', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'entropy', 'ag_args': {'name_suffix': 'Entr', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'squared_error', 'ag_args': {'name_suffix': 'MSE', 'problem_types': ['regression', 'quantile']}}],\n", + "\t'KNN': [{'weights': 'uniform', 'ag_args': {'name_suffix': 'Unif'}}, {'weights': 'distance', 'ag_args': {'name_suffix': 'Dist'}}],\n", + "}\n", + "Fitting 11 L1 models ...\n", + "Fitting model: KNeighborsUnif_BAG_L1 ... Training model for up to 7199.82s of the 7199.82s of remaining time.\n", + "\t-30.5762\t = Validation score (-mean_absolute_error)\n", + "\t0.03s\t = Training runtime\n", + "\t17.16s\t = Validation runtime\n", + "Fitting model: KNeighborsDist_BAG_L1 ... Training model for up to 7181.33s of the 7181.33s of remaining time.\n", + "\t-30.6245\t = Validation score (-mean_absolute_error)\n", + "\t0.03s\t = Training runtime\n", + "\t19.88s\t = Validation runtime\n", + "Fitting model: LightGBMXT_BAG_L1 ... Training model for up to 7160.23s of the 7160.23s of remaining time.\n", + "\tFitting 8 child models (S1F1 - S1F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-15.5998\t = Validation score (-mean_absolute_error)\n", + "\t47.9s\t = Training runtime\n", + "\t59.59s\t = Validation runtime\n", + "Fitting model: LightGBM_BAG_L1 ... Training model for up to 7098.66s of the 7098.66s of remaining time.\n", + "\tFitting 8 child models (S1F1 - S1F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-15.4243\t = Validation score (-mean_absolute_error)\n", + "\t60.54s\t = Training runtime\n", + "\t45.93s\t = Validation runtime\n", + "Fitting model: RandomForestMSE_BAG_L1 ... Training model for up to 7028.85s of the 7028.84s of remaining time.\n", + "\t-16.8377\t = Validation score (-mean_absolute_error)\n", + "\t21.1s\t = Training runtime\n", + "\t0.75s\t = Validation runtime\n", + "Fitting model: CatBoost_BAG_L1 ... Training model for up to 7006.45s of the 7006.45s of remaining time.\n", + "\tFitting 8 child models (S1F1 - S1F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-16.506\t = Validation score (-mean_absolute_error)\n", + "\t315.8s\t = Training runtime\n", + "\t0.09s\t = Validation runtime\n", + "Fitting model: ExtraTreesMSE_BAG_L1 ... Training model for up to 6688.78s of the 6688.78s of remaining time.\n", + "\t-16.2079\t = Validation score (-mean_absolute_error)\n", + "\t4.55s\t = Training runtime\n", + "\t0.82s\t = Validation runtime\n", + "Fitting model: NeuralNetFastAI_BAG_L1 ... Training model for up to 6682.83s of the 6682.83s of remaining time.\n", + "\tFitting 8 child models (S1F1 - S1F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-15.436\t = Validation score (-mean_absolute_error)\n", + "\t47.4s\t = Training runtime\n", + "\t0.52s\t = Validation runtime\n", + "Fitting model: XGBoost_BAG_L1 ... Training model for up to 6633.14s of the 6633.14s of remaining time.\n", + "\tFitting 8 child models (S1F1 - S1F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-15.6074\t = Validation score (-mean_absolute_error)\n", + "\t192.36s\t = Training runtime\n", + "\t68.13s\t = Validation runtime\n", + "Fitting model: NeuralNetTorch_BAG_L1 ... Training model for up to 6426.98s of the 6426.98s of remaining time.\n", + "\tFitting 8 child models (S1F1 - S1F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-12.7558\t = Validation score (-mean_absolute_error)\n", + "\t195.22s\t = Training runtime\n", + "\t0.39s\t = Validation runtime\n", + "Fitting model: LightGBMLarge_BAG_L1 ... Training model for up to 6229.67s of the 6229.67s of remaining time.\n", + "\tFitting 8 child models (S1F1 - S1F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-14.5535\t = Validation score (-mean_absolute_error)\n", + "\t154.47s\t = Training runtime\n", + "\t59.74s\t = Validation runtime\n", + "Repeating k-fold bagging: 2/20\n", + "Fitting model: LightGBMXT_BAG_L1 ... Training model for up to 6057.83s of the 6057.83s of remaining time.\n", + "\tFitting 8 child models (S2F1 - S2F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-15.5251\t = Validation score (-mean_absolute_error)\n", + "\t97.16s\t = Training runtime\n", + "\t121.16s\t = Validation runtime\n", + "Fitting model: LightGBM_BAG_L1 ... Training model for up to 5991.76s of the 5991.76s of remaining time.\n", + "\tFitting 8 child models (S2F1 - S2F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-15.3098\t = Validation score (-mean_absolute_error)\n", + "\t115.53s\t = Training runtime\n", + "\t103.05s\t = Validation runtime\n", + "Fitting model: CatBoost_BAG_L1 ... Training model for up to 5920.37s of the 5920.37s of remaining time.\n", + "\tFitting 8 child models (S2F1 - S2F8) | Fitting with ParallelLocalFoldFittingStrategy\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\t-16.4631\t = Validation score (-mean_absolute_error)\n", + "\t632.73s\t = Training runtime\n", + "\t0.18s\t = Validation runtime\n", + "Fitting model: NeuralNetFastAI_BAG_L1 ... Training model for up to 5601.69s of the 5601.69s of remaining time.\n", + "\tFitting 8 child models (S2F1 - S2F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-15.3561\t = Validation score (-mean_absolute_error)\n", + "\t95.51s\t = Training runtime\n", + "\t1.08s\t = Validation runtime\n", + "Fitting model: XGBoost_BAG_L1 ... Training model for up to 5551.11s of the 5551.11s of remaining time.\n", + "\tFitting 8 child models (S2F1 - S2F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-15.3887\t = Validation score (-mean_absolute_error)\n", + "\t383.59s\t = Training runtime\n", + "\t131.42s\t = Validation runtime\n", + "Fitting model: NeuralNetTorch_BAG_L1 ... Training model for up to 5343.84s of the 5343.84s of remaining time.\n", + "\tFitting 8 child models (S2F1 - S2F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-12.686\t = Validation score (-mean_absolute_error)\n", + "\t422.13s\t = Training runtime\n", + "\t0.76s\t = Validation runtime\n", + "Fitting model: LightGBMLarge_BAG_L1 ... Training model for up to 5114.84s of the 5114.84s of remaining time.\n", + "\tFitting 8 child models (S2F1 - S2F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-14.5084\t = Validation score (-mean_absolute_error)\n", + "\t302.79s\t = Training runtime\n", + "\t131.11s\t = Validation runtime\n", + "Repeating k-fold bagging: 3/20\n", + "Fitting model: LightGBMXT_BAG_L1 ... Training model for up to 4940.11s of the 4940.11s of remaining time.\n", + "\tFitting 8 child models (S3F1 - S3F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-15.5462\t = Validation score (-mean_absolute_error)\n", + "\t145.94s\t = Training runtime\n", + "\t181.63s\t = Validation runtime\n", + "Fitting model: LightGBM_BAG_L1 ... Training model for up to 4872.58s of the 4872.57s of remaining time.\n", + "\tFitting 8 child models (S3F1 - S3F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-15.3066\t = Validation score (-mean_absolute_error)\n", + "\t171.31s\t = Training runtime\n", + "\t161.9s\t = Validation runtime\n", + "Fitting model: CatBoost_BAG_L1 ... Training model for up to 4798.32s of the 4798.32s of remaining time.\n", + "\tFitting 8 child models (S3F1 - S3F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-16.4738\t = Validation score (-mean_absolute_error)\n", + "\t947.13s\t = Training runtime\n", + "\t0.28s\t = Validation runtime\n", + "Fitting model: NeuralNetFastAI_BAG_L1 ... Training model for up to 4481.78s of the 4481.78s of remaining time.\n", + "\tFitting 8 child models (S3F1 - S3F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-15.2079\t = Validation score (-mean_absolute_error)\n", + "\t143.69s\t = Training runtime\n", + "\t1.6s\t = Validation runtime\n", + "Fitting model: XGBoost_BAG_L1 ... Training model for up to 4430.85s of the 4430.85s of remaining time.\n", + "\tFitting 8 child models (S3F1 - S3F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-15.3376\t = Validation score (-mean_absolute_error)\n", + "\t575.0s\t = Training runtime\n", + "\t195.8s\t = Validation runtime\n", + "Fitting model: NeuralNetTorch_BAG_L1 ... Training model for up to 4219.55s of the 4219.55s of remaining time.\n", + "\tFitting 8 child models (S3F1 - S3F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-12.8074\t = Validation score (-mean_absolute_error)\n", + "\t589.97s\t = Training runtime\n", + "\t1.16s\t = Validation runtime\n", + "Fitting model: LightGBMLarge_BAG_L1 ... Training model for up to 4049.39s of the 4049.39s of remaining time.\n", + "\tFitting 8 child models (S3F1 - S3F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-14.4566\t = Validation score (-mean_absolute_error)\n", + "\t451.61s\t = Training runtime\n", + "\t201.47s\t = Validation runtime\n", + "Repeating k-fold bagging: 4/20\n", + "Fitting model: LightGBMXT_BAG_L1 ... Training model for up to 3868.07s of the 3868.06s of remaining time.\n", + "\tFitting 8 child models (S4F1 - S4F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-15.5483\t = Validation score (-mean_absolute_error)\n", + "\t194.77s\t = Training runtime\n", + "\t240.86s\t = Validation runtime\n", + "Fitting model: LightGBM_BAG_L1 ... Training model for up to 3798.24s of the 3798.24s of remaining time.\n", + "\tFitting 8 child models (S4F1 - S4F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-15.3225\t = Validation score (-mean_absolute_error)\n", + "\t226.79s\t = Training runtime\n", + "\t219.18s\t = Validation runtime\n", + "Fitting model: CatBoost_BAG_L1 ... Training model for up to 3721.72s of the 3721.71s of remaining time.\n", + "\tFitting 8 child models (S4F1 - S4F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-16.4439\t = Validation score (-mean_absolute_error)\n", + "\t1263.45s\t = Training runtime\n", + "\t0.38s\t = Validation runtime\n", + "Fitting model: NeuralNetFastAI_BAG_L1 ... Training model for up to 3403.08s of the 3403.08s of remaining time.\n", + "\tFitting 8 child models (S4F1 - S4F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-15.1587\t = Validation score (-mean_absolute_error)\n", + "\t190.91s\t = Training runtime\n", + "\t2.12s\t = Validation runtime\n", + "Fitting model: XGBoost_BAG_L1 ... Training model for up to 3352.69s of the 3352.69s of remaining time.\n", + "\tFitting 8 child models (S4F1 - S4F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-15.3024\t = Validation score (-mean_absolute_error)\n", + "\t764.72s\t = Training runtime\n", + "\t260.44s\t = Validation runtime\n", + "Fitting model: NeuralNetTorch_BAG_L1 ... Training model for up to 3140.84s of the 3140.83s of remaining time.\n", + "\tFitting 8 child models (S4F1 - S4F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-12.8824\t = Validation score (-mean_absolute_error)\n", + "\t774.4s\t = Training runtime\n", + "\t1.53s\t = Validation runtime\n", + "Fitting model: LightGBMLarge_BAG_L1 ... Training model for up to 2953.53s of the 2953.52s of remaining time.\n", + "\tFitting 8 child models (S4F1 - S4F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-14.4728\t = Validation score (-mean_absolute_error)\n", + "\t606.76s\t = Training runtime\n", + "\t267.65s\t = Validation runtime\n", + "Repeating k-fold bagging: 5/20\n", + "Fitting model: LightGBMXT_BAG_L1 ... Training model for up to 2764.43s of the 2764.42s of remaining time.\n", + "\tFitting 8 child models (S5F1 - S5F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-15.4936\t = Validation score (-mean_absolute_error)\n", + "\t242.66s\t = Training runtime\n", + "\t299.65s\t = Validation runtime\n", + "Fitting model: LightGBM_BAG_L1 ... Training model for up to 2692.8s of the 2692.79s of remaining time.\n", + "\tFitting 8 child models (S5F1 - S5F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-15.3229\t = Validation score (-mean_absolute_error)\n", + "\t283.99s\t = Training runtime\n", + "\t271.5s\t = Validation runtime\n", + "Fitting model: CatBoost_BAG_L1 ... Training model for up to 2614.12s of the 2614.12s of remaining time.\n", + "\tFitting 8 child models (S5F1 - S5F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-16.4525\t = Validation score (-mean_absolute_error)\n", + "\t1579.92s\t = Training runtime\n", + "\t0.48s\t = Validation runtime\n", + "Fitting model: NeuralNetFastAI_BAG_L1 ... Training model for up to 2295.57s of the 2295.56s of remaining time.\n", + "\tFitting 8 child models (S5F1 - S5F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-15.1589\t = Validation score (-mean_absolute_error)\n", + "\t237.61s\t = Training runtime\n", + "\t2.63s\t = Validation runtime\n", + "Fitting model: XGBoost_BAG_L1 ... Training model for up to 2245.6s of the 2245.6s of remaining time.\n", + "\tFitting 8 child models (S5F1 - S5F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-15.2693\t = Validation score (-mean_absolute_error)\n", + "\t954.25s\t = Training runtime\n", + "\t324.65s\t = Validation runtime\n", + "Fitting model: NeuralNetTorch_BAG_L1 ... Training model for up to 2030.22s of the 2030.21s of remaining time.\n", + "\tFitting 8 child models (S5F1 - S5F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-12.9306\t = Validation score (-mean_absolute_error)\n", + "\t984.88s\t = Training runtime\n", + "\t1.92s\t = Validation runtime\n", + "Fitting model: LightGBMLarge_BAG_L1 ... Training model for up to 1816.62s of the 1816.62s of remaining time.\n", + "\tFitting 8 child models (S5F1 - S5F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-14.4551\t = Validation score (-mean_absolute_error)\n", + "\t754.5s\t = Training runtime\n", + "\t338.8s\t = Validation runtime\n", + "Repeating k-fold bagging: 6/20\n", + "Fitting model: LightGBMXT_BAG_L1 ... Training model for up to 1621.33s of the 1621.33s of remaining time.\n", + "\tFitting 8 child models (S6F1 - S6F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-15.4713\t = Validation score (-mean_absolute_error)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\t297.53s\t = Training runtime\n", + "\t371.61s\t = Validation runtime\n", + "Fitting model: LightGBM_BAG_L1 ... Training model for up to 1537.14s of the 1537.14s of remaining time.\n", + "\tFitting 8 child models (S6F1 - S6F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-15.329\t = Validation score (-mean_absolute_error)\n", + "\t344.02s\t = Training runtime\n", + "\t343.78s\t = Validation runtime\n", + "Fitting model: CatBoost_BAG_L1 ... Training model for up to 1447.95s of the 1447.95s of remaining time.\n", + "\tFitting 8 child models (S6F1 - S6F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-16.4714\t = Validation score (-mean_absolute_error)\n", + "\t1920.33s\t = Training runtime\n", + "\t0.59s\t = Validation runtime\n", + "Fitting model: NeuralNetFastAI_BAG_L1 ... Training model for up to 1104.92s of the 1104.91s of remaining time.\n", + "\tFitting 8 child models (S6F1 - S6F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-15.2013\t = Validation score (-mean_absolute_error)\n", + "\t288.37s\t = Training runtime\n", + "\t3.17s\t = Validation runtime\n", + "Fitting model: XGBoost_BAG_L1 ... Training model for up to 1049.88s of the 1049.88s of remaining time.\n", + "\tFitting 8 child models (S6F1 - S6F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-15.2445\t = Validation score (-mean_absolute_error)\n", + "\t1159.18s\t = Training runtime\n", + "\t397.99s\t = Validation runtime\n", + "Fitting model: NeuralNetTorch_BAG_L1 ... Training model for up to 813.38s of the 813.38s of remaining time.\n", + "\tFitting 8 child models (S6F1 - S6F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-12.9073\t = Validation score (-mean_absolute_error)\n", + "\t1205.42s\t = Training runtime\n", + "\t2.3s\t = Validation runtime\n", + "Fitting model: LightGBMLarge_BAG_L1 ... Training model for up to 589.72s of the 589.72s of remaining time.\n", + "\tFitting 8 child models (S6F1 - S6F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-14.4539\t = Validation score (-mean_absolute_error)\n", + "\t915.29s\t = Training runtime\n", + "\t416.78s\t = Validation runtime\n", + "Completed 6/20 k-fold bagging repeats ...\n", + "Fitting model: WeightedEnsemble_L2 ... Training model for up to 719.98s of the 373.7s of remaining time.\n", + "\t-12.7457\t = Validation score (-mean_absolute_error)\n", + "\t0.28s\t = Training runtime\n", + "\t0.01s\t = Validation runtime\n", + "AutoGluon training complete, total runtime = 6826.61s ... Best model: \"WeightedEnsemble_L2\"\n", + "TabularPredictor saved. To load, use: predictor = TabularPredictor.load(\"AutogluonModels/ag-20231116_101804/\")\n" + ] + } + ], + "source": [ + "predictor_b = TabularPredictor(label=label, eval_metric='mae').fit(\n", + " train_data = data['b'][0], \n", + " time_limit = time_in_sek,\n", + " presets='best_quality',\n", + " num_bag_folds=8,\n", + " num_stack_levels=0,\n", + " tuning_data = data['b'][1],\n", + " use_bag_holdout=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "No path specified. Models will be saved in: \"AutogluonModels/ag-20231116_121151/\"\n", + "Presets specified: ['best_quality']\n", + "Stack configuration (auto_stack=True): num_stack_levels=0, num_bag_folds=8, num_bag_sets=20\n", + "Beginning AutoGluon training ... Time limit = 7200s\n", + "AutoGluon will save models to \"AutogluonModels/ag-20231116_121151/\"\n", + "AutoGluon Version: 0.8.2\n", + "Python Version: 3.8.8\n", + "Operating System: Linux\n", + "Platform Machine: x86_64\n", + "Platform Version: #98~20.04.1-Ubuntu SMP Mon Oct 9 16:43:45 UTC 2023\n", + "Disk Space Avail: 13.29 GB / 339.99 GB (3.9%)\n", + "Train Data Rows: 24606\n", + "Train Data Columns: 47\n", + "Tuning Data Rows: 1465\n", + "Tuning Data Columns: 47\n", + "Label Column: pv_measurement\n", + "Preprocessing data ...\n", + "AutoGluon infers your prediction problem is: 'regression' (because dtype of label-column == float and label-values can't be converted to int).\n", + "\tLabel info (max, min, mean, stddev): (999.6, 0.0, 79.70535, 168.37633)\n", + "\tIf 'regression' is not the correct problem_type, please manually specify the problem_type parameter during predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression'])\n", + "Using Feature Generators to preprocess the data ...\n", + "Fitting AutoMLPipelineFeatureGenerator...\n", + "\tAvailable Memory: 14196.73 MB\n", + "\tTrain Data (Original) Memory Usage: 5.01 MB (0.0% of available memory)\n", + "\tInferring data type of each feature based on column values. Set feature_metadata_in to manually specify special dtypes of the features.\n", + "\tStage 1 Generators:\n", + "\t\tFitting AsTypeFeatureGenerator...\n", + "\t\t\tNote: Converting 4 features to boolean dtype as they only contain 2 unique values.\n", + "\tStage 2 Generators:\n", + "\t\tFitting FillNaFeatureGenerator...\n", + "\tStage 3 Generators:\n", + "\t\tFitting IdentityFeatureGenerator...\n", + "\tStage 4 Generators:\n", + "\t\tFitting DropUniqueFeatureGenerator...\n", + "\tStage 5 Generators:\n", + "\t\tFitting DropDuplicatesFeatureGenerator...\n", + "\tUseless Original Features (Count: 2): ['elevation:m', 'snow_drift:idx']\n", + "\t\tThese features carry no predictive signal and should be manually investigated.\n", + "\t\tThis is typically a feature which has the same value for all rows.\n", + "\t\tThese features do not need to be present at inference time.\n", + "\tTypes of features in original data (raw dtype, special dtypes):\n", + "\t\t('category', []) : 4 | ['dew_or_rime:idx', 'is_in_shadow:idx', 'is_day:idx', 'type']\n", + "\t\t('float', []) : 41 | ['absolute_humidity_2m:gm3', 'clear_sky_energy_1h:J', 'clear_sky_rad:W', 'cloud_base_agl:m', 'dew_point_2m:K', ...]\n", + "\tTypes of features in processed data (raw dtype, special dtypes):\n", + "\t\t('float', []) : 41 | ['absolute_humidity_2m:gm3', 'clear_sky_energy_1h:J', 'clear_sky_rad:W', 'cloud_base_agl:m', 'dew_point_2m:K', ...]\n", + "\t\t('int', ['bool']) : 4 | ['dew_or_rime:idx', 'is_in_shadow:idx', 'is_day:idx', 'type']\n", + "\t0.1s = Fit runtime\n", + "\t45 features in original data used to generate 45 features in processed data.\n", + "\tTrain Data (Processed) Memory Usage: 4.8 MB (0.0% of available memory)\n", + "Data preprocessing and feature engineering runtime = 0.13s ...\n", + "AutoGluon will gauge predictive performance using evaluation metric: 'mean_absolute_error'\n", + "\tThis metric's sign has been flipped to adhere to being higher_is_better. The metric score can be multiplied by -1 to get the metric value.\n", + "\tTo change this, specify the eval_metric parameter of Predictor()\n", + "use_bag_holdout=True, will use tuning_data as holdout (will not be used for early stopping).\n", + "User-specified model hyperparameters to be fit:\n", + "{\n", + "\t'NN_TORCH': {},\n", + "\t'GBM': [{'extra_trees': True, 'ag_args': {'name_suffix': 'XT'}}, {}, 'GBMLarge'],\n", + "\t'CAT': {},\n", + "\t'XGB': {},\n", + "\t'FASTAI': {},\n", + "\t'RF': [{'criterion': 'gini', 'ag_args': {'name_suffix': 'Gini', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'entropy', 'ag_args': {'name_suffix': 'Entr', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'squared_error', 'ag_args': {'name_suffix': 'MSE', 'problem_types': ['regression', 'quantile']}}],\n", + "\t'XT': [{'criterion': 'gini', 'ag_args': {'name_suffix': 'Gini', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'entropy', 'ag_args': {'name_suffix': 'Entr', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'squared_error', 'ag_args': {'name_suffix': 'MSE', 'problem_types': ['regression', 'quantile']}}],\n", + "\t'KNN': [{'weights': 'uniform', 'ag_args': {'name_suffix': 'Unif'}}, {'weights': 'distance', 'ag_args': {'name_suffix': 'Dist'}}],\n", + "}\n", + "Fitting 11 L1 models ...\n", + "Fitting model: KNeighborsUnif_BAG_L1 ... Training model for up to 7199.87s of the 7199.87s of remaining time.\n", + "\t-20.8779\t = Validation score (-mean_absolute_error)\n", + "\t0.02s\t = Training runtime\n", + "\t12.51s\t = Validation runtime\n", + "Fitting model: KNeighborsDist_BAG_L1 ... Training model for up to 7186.34s of the 7186.34s of remaining time.\n", + "\t-20.7787\t = Validation score (-mean_absolute_error)\n", + "\t0.02s\t = Training runtime\n", + "\t11.17s\t = Validation runtime\n", + "Fitting model: LightGBMXT_BAG_L1 ... Training model for up to 7174.16s of the 7174.16s of remaining time.\n", + "\tFitting 8 child models (S1F1 - S1F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-11.6025\t = Validation score (-mean_absolute_error)\n", + "\t46.66s\t = Training runtime\n", + "\t42.45s\t = Validation runtime\n", + "Fitting model: LightGBM_BAG_L1 ... Training model for up to 7115.92s of the 7115.92s of remaining time.\n", + "\tFitting 8 child models (S1F1 - S1F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-13.2\t = Validation score (-mean_absolute_error)\n", + "\t61.44s\t = Training runtime\n", + "\t42.36s\t = Validation runtime\n", + "Fitting model: RandomForestMSE_BAG_L1 ... Training model for up to 7042.75s of the 7042.75s of remaining time.\n", + "\t-16.8133\t = Validation score (-mean_absolute_error)\n", + "\t15.48s\t = Training runtime\n", + "\t0.58s\t = Validation runtime\n", + "Fitting model: CatBoost_BAG_L1 ... Training model for up to 7026.35s of the 7026.34s of remaining time.\n", + "\tFitting 8 child models (S1F1 - S1F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-13.4489\t = Validation score (-mean_absolute_error)\n", + "\t329.41s\t = Training runtime\n", + "\t0.09s\t = Validation runtime\n", + "Fitting model: ExtraTreesMSE_BAG_L1 ... Training model for up to 6694.53s of the 6694.53s of remaining time.\n", + "\t-15.6165\t = Validation score (-mean_absolute_error)\n", + "\t3.16s\t = Training runtime\n", + "\t0.57s\t = Validation runtime\n", + "Fitting model: NeuralNetFastAI_BAG_L1 ... Training model for up to 6690.43s of the 6690.43s of remaining time.\n", + "\tFitting 8 child models (S1F1 - S1F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-13.5703\t = Validation score (-mean_absolute_error)\n", + "\t42.1s\t = Training runtime\n", + "\t0.48s\t = Validation runtime\n", + "Fitting model: XGBoost_BAG_L1 ... Training model for up to 6645.83s of the 6645.83s of remaining time.\n", + "\tFitting 8 child models (S1F1 - S1F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-13.7789\t = Validation score (-mean_absolute_error)\n", + "\t186.51s\t = Training runtime\n", + "\t43.18s\t = Validation runtime\n", + "Fitting model: NeuralNetTorch_BAG_L1 ... Training model for up to 6448.19s of the 6448.18s of remaining time.\n", + "\tFitting 8 child models (S1F1 - S1F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-13.1541\t = Validation score (-mean_absolute_error)\n", + "\t103.16s\t = Training runtime\n", + "\t0.36s\t = Validation runtime\n", + "Fitting model: LightGBMLarge_BAG_L1 ... Training model for up to 6342.93s of the 6342.93s of remaining time.\n", + "\tFitting 8 child models (S1F1 - S1F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-13.4775\t = Validation score (-mean_absolute_error)\n", + "\t169.76s\t = Training runtime\n", + "\t48.27s\t = Validation runtime\n", + "Repeating k-fold bagging: 2/20\n", + "Fitting model: LightGBMXT_BAG_L1 ... Training model for up to 6153.04s of the 6153.04s of remaining time.\n", + "\tFitting 8 child models (S2F1 - S2F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-11.5668\t = Validation score (-mean_absolute_error)\n", + "\t95.44s\t = Training runtime\n", + "\t99.29s\t = Validation runtime\n", + "Fitting model: LightGBM_BAG_L1 ... Training model for up to 6088.12s of the 6088.12s of remaining time.\n", + "\tFitting 8 child models (S2F1 - S2F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-13.1276\t = Validation score (-mean_absolute_error)\n", + "\t122.41s\t = Training runtime\n", + "\t75.62s\t = Validation runtime\n", + "Fitting model: CatBoost_BAG_L1 ... Training model for up to 6017.15s of the 6017.15s of remaining time.\n", + "\tFitting 8 child models (S2F1 - S2F8) | Fitting with ParallelLocalFoldFittingStrategy\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\t-13.3955\t = Validation score (-mean_absolute_error)\n", + "\t648.07s\t = Training runtime\n", + "\t0.17s\t = Validation runtime\n", + "Fitting model: NeuralNetFastAI_BAG_L1 ... Training model for up to 5696.47s of the 5696.47s of remaining time.\n", + "\tFitting 8 child models (S2F1 - S2F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-13.5631\t = Validation score (-mean_absolute_error)\n", + "\t82.16s\t = Training runtime\n", + "\t0.89s\t = Validation runtime\n", + "Fitting model: XGBoost_BAG_L1 ... Training model for up to 5653.43s of the 5653.43s of remaining time.\n", + "\tFitting 8 child models (S2F1 - S2F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-13.8469\t = Validation score (-mean_absolute_error)\n", + "\t293.35s\t = Training runtime\n", + "\t54.1s\t = Validation runtime\n", + "Fitting model: NeuralNetTorch_BAG_L1 ... Training model for up to 5537.34s of the 5537.34s of remaining time.\n", + "\tFitting 8 child models (S2F1 - S2F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-13.0076\t = Validation score (-mean_absolute_error)\n", + "\t208.92s\t = Training runtime\n", + "\t0.7s\t = Validation runtime\n", + "Fitting model: LightGBMLarge_BAG_L1 ... Training model for up to 5429.07s of the 5429.07s of remaining time.\n", + "\tFitting 8 child models (S2F1 - S2F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-13.4505\t = Validation score (-mean_absolute_error)\n", + "\t344.22s\t = Training runtime\n", + "\t102.15s\t = Validation runtime\n", + "Repeating k-fold bagging: 3/20\n", + "Fitting model: LightGBMXT_BAG_L1 ... Training model for up to 5231.09s of the 5231.09s of remaining time.\n", + "\tFitting 8 child models (S3F1 - S3F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-11.5904\t = Validation score (-mean_absolute_error)\n", + "\t148.87s\t = Training runtime\n", + "\t147.71s\t = Validation runtime\n", + "Fitting model: LightGBM_BAG_L1 ... Training model for up to 5158.8s of the 5158.8s of remaining time.\n", + "\tFitting 8 child models (S3F1 - S3F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-13.0893\t = Validation score (-mean_absolute_error)\n", + "\t181.61s\t = Training runtime\n", + "\t109.31s\t = Validation runtime\n", + "Fitting model: CatBoost_BAG_L1 ... Training model for up to 5083.85s of the 5083.85s of remaining time.\n", + "\tFitting 8 child models (S3F1 - S3F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-13.3713\t = Validation score (-mean_absolute_error)\n", + "\t984.68s\t = Training runtime\n", + "\t0.27s\t = Validation runtime\n", + "Fitting model: NeuralNetFastAI_BAG_L1 ... Training model for up to 4744.58s of the 4744.58s of remaining time.\n", + "\tFitting 8 child models (S3F1 - S3F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-13.5302\t = Validation score (-mean_absolute_error)\n", + "\t126.32s\t = Training runtime\n", + "\t1.45s\t = Validation runtime\n", + "Fitting model: XGBoost_BAG_L1 ... Training model for up to 4696.13s of the 4696.12s of remaining time.\n", + "\tFitting 8 child models (S3F1 - S3F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-13.794\t = Validation score (-mean_absolute_error)\n", + "\t480.14s\t = Training runtime\n", + "\t88.45s\t = Validation runtime\n", + "Fitting model: NeuralNetTorch_BAG_L1 ... Training model for up to 4496.86s of the 4496.86s of remaining time.\n", + "\tFitting 8 child models (S3F1 - S3F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-12.9877\t = Validation score (-mean_absolute_error)\n", + "\t321.64s\t = Training runtime\n", + "\t1.07s\t = Validation runtime\n", + "Fitting model: LightGBMLarge_BAG_L1 ... Training model for up to 4381.42s of the 4381.42s of remaining time.\n", + "\tFitting 8 child models (S3F1 - S3F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-13.4855\t = Validation score (-mean_absolute_error)\n", + "\t522.58s\t = Training runtime\n", + "\t153.72s\t = Validation runtime\n", + "Repeating k-fold bagging: 4/20\n", + "Fitting model: LightGBMXT_BAG_L1 ... Training model for up to 4171.8s of the 4171.8s of remaining time.\n", + "\tFitting 8 child models (S4F1 - S4F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-11.6009\t = Validation score (-mean_absolute_error)\n", + "\t199.24s\t = Training runtime\n", + "\t194.81s\t = Validation runtime\n", + "Fitting model: LightGBM_BAG_L1 ... Training model for up to 4102.06s of the 4102.05s of remaining time.\n", + "\tFitting 8 child models (S4F1 - S4F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-13.0673\t = Validation score (-mean_absolute_error)\n", + "\t238.52s\t = Training runtime\n", + "\t149.22s\t = Validation runtime\n", + "Fitting model: CatBoost_BAG_L1 ... Training model for up to 4028.27s of the 4028.27s of remaining time.\n", + "\tFitting 8 child models (S4F1 - S4F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-13.3635\t = Validation score (-mean_absolute_error)\n", + "\t1295.92s\t = Training runtime\n", + "\t0.35s\t = Validation runtime\n", + "Fitting model: NeuralNetFastAI_BAG_L1 ... Training model for up to 3714.88s of the 3714.88s of remaining time.\n", + "\tFitting 8 child models (S4F1 - S4F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-13.46\t = Validation score (-mean_absolute_error)\n", + "\t165.66s\t = Training runtime\n", + "\t1.87s\t = Validation runtime\n", + "Fitting model: XGBoost_BAG_L1 ... Training model for up to 3672.37s of the 3672.37s of remaining time.\n", + "\tFitting 8 child models (S4F1 - S4F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-13.7027\t = Validation score (-mean_absolute_error)\n", + "\t654.02s\t = Training runtime\n", + "\t124.14s\t = Validation runtime\n", + "Fitting model: NeuralNetTorch_BAG_L1 ... Training model for up to 3484.21s of the 3484.21s of remaining time.\n", + "\tFitting 8 child models (S4F1 - S4F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-13.054\t = Validation score (-mean_absolute_error)\n", + "\t423.54s\t = Training runtime\n", + "\t1.39s\t = Validation runtime\n", + "Fitting model: LightGBMLarge_BAG_L1 ... Training model for up to 3379.34s of the 3379.34s of remaining time.\n", + "\tFitting 8 child models (S4F1 - S4F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-13.4396\t = Validation score (-mean_absolute_error)\n", + "\t676.59s\t = Training runtime\n", + "\t195.0s\t = Validation runtime\n", + "Repeating k-fold bagging: 5/20\n", + "Fitting model: LightGBMXT_BAG_L1 ... Training model for up to 3194.97s of the 3194.97s of remaining time.\n", + "\tFitting 8 child models (S5F1 - S5F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-11.5539\t = Validation score (-mean_absolute_error)\n", + "\t244.34s\t = Training runtime\n", + "\t244.21s\t = Validation runtime\n", + "Fitting model: LightGBM_BAG_L1 ... Training model for up to 3128.71s of the 3128.71s of remaining time.\n", + "\tFitting 8 child models (S5F1 - S5F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-13.0811\t = Validation score (-mean_absolute_error)\n", + "\t295.11s\t = Training runtime\n", + "\t184.52s\t = Validation runtime\n", + "Fitting model: CatBoost_BAG_L1 ... Training model for up to 3055.59s of the 3055.59s of remaining time.\n", + "\tFitting 8 child models (S5F1 - S5F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-13.3485\t = Validation score (-mean_absolute_error)\n", + "\t1601.67s\t = Training runtime\n", + "\t0.43s\t = Validation runtime\n", + "Fitting model: NeuralNetFastAI_BAG_L1 ... Training model for up to 2747.61s of the 2747.61s of remaining time.\n", + "\tFitting 8 child models (S5F1 - S5F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-13.4013\t = Validation score (-mean_absolute_error)\n", + "\t203.86s\t = Training runtime\n", + "\t2.29s\t = Validation runtime\n", + "Fitting model: XGBoost_BAG_L1 ... Training model for up to 2705.94s of the 2705.94s of remaining time.\n", + "\tFitting 8 child models (S5F1 - S5F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-13.6805\t = Validation score (-mean_absolute_error)\n", + "\t827.11s\t = Training runtime\n", + "\t158.2s\t = Validation runtime\n", + "Fitting model: NeuralNetTorch_BAG_L1 ... Training model for up to 2517.12s of the 2517.12s of remaining time.\n", + "\tFitting 8 child models (S5F1 - S5F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-12.9132\t = Validation score (-mean_absolute_error)\n", + "\t535.82s\t = Training runtime\n", + "\t1.7s\t = Validation runtime\n", + "Fitting model: LightGBMLarge_BAG_L1 ... Training model for up to 2402.07s of the 2402.07s of remaining time.\n", + "\tFitting 8 child models (S5F1 - S5F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-13.4832\t = Validation score (-mean_absolute_error)\n", + "\t830.69s\t = Training runtime\n", + "\t248.48s\t = Validation runtime\n", + "Repeating k-fold bagging: 6/20\n", + "Fitting model: LightGBMXT_BAG_L1 ... Training model for up to 2205.61s of the 2205.61s of remaining time.\n", + "\tFitting 8 child models (S6F1 - S6F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-11.5847\t = Validation score (-mean_absolute_error)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\t289.1s\t = Training runtime\n", + "\t293.21s\t = Validation runtime\n", + "Fitting model: LightGBM_BAG_L1 ... Training model for up to 2137.92s of the 2137.92s of remaining time.\n", + "\tFitting 8 child models (S6F1 - S6F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-13.0548\t = Validation score (-mean_absolute_error)\n", + "\t349.32s\t = Training runtime\n", + "\t221.89s\t = Validation runtime\n", + "Fitting model: CatBoost_BAG_L1 ... Training model for up to 2063.47s of the 2063.47s of remaining time.\n", + "\tFitting 8 child models (S6F1 - S6F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-13.3504\t = Validation score (-mean_absolute_error)\n", + "\t1905.03s\t = Training runtime\n", + "\t0.51s\t = Validation runtime\n", + "Fitting model: NeuralNetFastAI_BAG_L1 ... Training model for up to 1757.64s of the 1757.63s of remaining time.\n", + "\tFitting 8 child models (S6F1 - S6F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-13.4056\t = Validation score (-mean_absolute_error)\n", + "\t241.46s\t = Training runtime\n", + "\t2.72s\t = Validation runtime\n", + "Fitting model: XGBoost_BAG_L1 ... Training model for up to 1716.4s of the 1716.4s of remaining time.\n", + "\tFitting 8 child models (S6F1 - S6F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-13.6622\t = Validation score (-mean_absolute_error)\n", + "\t986.42s\t = Training runtime\n", + "\t177.56s\t = Validation runtime\n", + "Fitting model: NeuralNetTorch_BAG_L1 ... Training model for up to 1540.37s of the 1540.37s of remaining time.\n", + "\tFitting 8 child models (S6F1 - S6F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-12.8645\t = Validation score (-mean_absolute_error)\n", + "\t670.29s\t = Training runtime\n", + "\t2.04s\t = Validation runtime\n", + "Fitting model: LightGBMLarge_BAG_L1 ... Training model for up to 1403.04s of the 1403.03s of remaining time.\n", + "\tFitting 8 child models (S6F1 - S6F8) | Fitting with ParallelLocalFoldFittingStrategy\n", + "\t-13.4564\t = Validation score (-mean_absolute_error)\n", + "\t987.48s\t = Training runtime\n", + "\t296.62s\t = Validation runtime\n", + "Completed 6/20 k-fold bagging repeats ...\n", + "Fitting model: WeightedEnsemble_L2 ... Training model for up to 719.99s of the 1205.81s of remaining time.\n", + "\t-11.2468\t = Validation score (-mean_absolute_error)\n", + "\t0.24s\t = Training runtime\n", + "\t0.0s\t = Validation runtime\n", + "AutoGluon training complete, total runtime = 5994.46s ... Best model: \"WeightedEnsemble_L2\"\n", + "TabularPredictor saved. To load, use: predictor = TabularPredictor.load(\"AutogluonModels/ag-20231116_121151/\")\n" + ] + } + ], + "source": [ + "predictor_c = TabularPredictor(label=label, eval_metric='mae').fit(\n", + " train_data = data['c'][0], \n", + " time_limit = time_in_sek,\n", + " presets='best_quality',\n", + " num_bag_folds=8,\n", + " num_stack_levels=0,\n", + " tuning_data = data['c'][1],\n", + " use_bag_holdout=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Refitting models via `predictor.refit_full` using all of the data (combined train and validation)...\n", + "\tModels trained in this way will have the suffix \"_FULL\" and have NaN validation score.\n", + "\tThis process is not bound by time_limit, but should take less time than the original `predictor.fit` call.\n", + "\tTo learn more, refer to the `.refit_full` method docstring which explains how \"_FULL\" models differ from normal models.\n", + "Fitting 1 L1 models ...\n", + "Fitting model: KNeighborsUnif_BAG_L1_FULL ...\n", + "\t0.02s\t = Training runtime\n", + "Fitting 1 L1 models ...\n", + "Fitting model: KNeighborsDist_BAG_L1_FULL ...\n", + "\t0.02s\t = Training runtime\n", + "Fitting 1 L1 models ...\n", + "Fitting model: LightGBMXT_BAG_L1_FULL ...\n", + "\t12.98s\t = Training runtime\n", + "Fitting 1 L1 models ...\n", + "Fitting model: LightGBM_BAG_L1_FULL ...\n", + "\t12.65s\t = Training runtime\n", + "Fitting 1 L1 models ...\n", + "Fitting model: RandomForestMSE_BAG_L1_FULL ...\n", + "\t20.98s\t = Training runtime\n", + "Fitting 1 L1 models ...\n", + "Fitting model: CatBoost_BAG_L1_FULL ...\n", + "\t64.94s\t = Training runtime\n", + "Fitting 1 L1 models ...\n", + "Fitting model: ExtraTreesMSE_BAG_L1_FULL ...\n", + "\t5.06s\t = Training runtime\n", + "Fitting 1 L1 models ...\n", + "Fitting model: NeuralNetFastAI_BAG_L1_FULL ...\n", + "\tStopping at the best epoch learned earlier - 26.\n", + "\t22.94s\t = Training runtime\n", + "Fitting 1 L1 models ...\n", + "Fitting model: XGBoost_BAG_L1_FULL ...\n", + "\t6.77s\t = Training runtime\n", + "Fitting 1 L1 models ...\n", + "Fitting model: NeuralNetTorch_BAG_L1_FULL ...\n", + "\t57.36s\t = Training runtime\n", + "Fitting 1 L1 models ...\n", + "Fitting model: LightGBMLarge_BAG_L1_FULL ...\n", + "\t42.41s\t = Training runtime\n", + "Fitting model: WeightedEnsemble_L2_FULL | Skipping fit via cloning parent ...\n", + "\t0.27s\t = Training runtime\n", + "Updated best model to \"WeightedEnsemble_L2_FULL\" (Previously \"WeightedEnsemble_L2\"). AutoGluon will default to using \"WeightedEnsemble_L2_FULL\" for predict() and predict_proba().\n", + "Refit complete, total runtime = 296.26s\n", + "Refitting models via `predictor.refit_full` using all of the data (combined train and validation)...\n", + "\tModels trained in this way will have the suffix \"_FULL\" and have NaN validation score.\n", + "\tThis process is not bound by time_limit, but should take less time than the original `predictor.fit` call.\n", + "\tTo learn more, refer to the `.refit_full` method docstring which explains how \"_FULL\" models differ from normal models.\n", + "Fitting 1 L1 models ...\n", + "Fitting model: KNeighborsUnif_BAG_L1_FULL ...\n", + "\t0.02s\t = Training runtime\n", + "Fitting 1 L1 models ...\n", + "Fitting model: KNeighborsDist_BAG_L1_FULL ...\n", + "\t0.02s\t = Training runtime\n", + "Fitting 1 L1 models ...\n", + "Fitting model: LightGBMXT_BAG_L1_FULL ...\n", + "\t12.08s\t = Training runtime\n", + "Fitting 1 L1 models ...\n", + "Fitting model: LightGBM_BAG_L1_FULL ...\n", + "\t13.57s\t = Training runtime\n", + "Fitting 1 L1 models ...\n", + "Fitting model: RandomForestMSE_BAG_L1_FULL ...\n", + "\t21.97s\t = Training runtime\n", + "Fitting 1 L1 models ...\n", + "Fitting model: CatBoost_BAG_L1_FULL ...\n", + "\t66.24s\t = Training runtime\n", + "Fitting 1 L1 models ...\n", + "Fitting model: ExtraTreesMSE_BAG_L1_FULL ...\n", + "\t4.56s\t = Training runtime\n", + "Fitting 1 L1 models ...\n", + "Fitting model: NeuralNetFastAI_BAG_L1_FULL ...\n", + "\tStopping at the best epoch learned earlier - 26.\n", + "\t23.52s\t = Training runtime\n", + "Fitting 1 L1 models ...\n", + "Fitting model: XGBoost_BAG_L1_FULL ...\n", + "\t28.69s\t = Training runtime\n", + "Fitting 1 L1 models ...\n", + "Fitting model: NeuralNetTorch_BAG_L1_FULL ...\n", + "\t108.24s\t = Training runtime\n", + "Fitting 1 L1 models ...\n", + "Fitting model: LightGBMLarge_BAG_L1_FULL ...\n", + "\t43.19s\t = Training runtime\n", + "Fitting model: WeightedEnsemble_L2_FULL | Skipping fit via cloning parent ...\n", + "\t0.28s\t = Training runtime\n", + "Updated best model to \"WeightedEnsemble_L2_FULL\" (Previously \"WeightedEnsemble_L2\"). AutoGluon will default to using \"WeightedEnsemble_L2_FULL\" for predict() and predict_proba().\n", + "Refit complete, total runtime = 375.4s\n", + "Refitting models via `predictor.refit_full` using all of the data (combined train and validation)...\n", + "\tModels trained in this way will have the suffix \"_FULL\" and have NaN validation score.\n", + "\tThis process is not bound by time_limit, but should take less time than the original `predictor.fit` call.\n", + "\tTo learn more, refer to the `.refit_full` method docstring which explains how \"_FULL\" models differ from normal models.\n", + "Fitting 1 L1 models ...\n", + "Fitting model: KNeighborsUnif_BAG_L1_FULL ...\n", + "\t0.02s\t = Training runtime\n", + "Fitting 1 L1 models ...\n", + "Fitting model: KNeighborsDist_BAG_L1_FULL ...\n", + "\t0.02s\t = Training runtime\n", + "Fitting 1 L1 models ...\n", + "Fitting model: LightGBMXT_BAG_L1_FULL ...\n", + "\t14.71s\t = Training runtime\n", + "Fitting 1 L1 models ...\n", + "Fitting model: LightGBM_BAG_L1_FULL ...\n", + "\t13.49s\t = Training runtime\n", + "Fitting 1 L1 models ...\n", + "Fitting model: RandomForestMSE_BAG_L1_FULL ...\n", + "\t15.17s\t = Training runtime\n", + "Fitting 1 L1 models ...\n", + "Fitting model: CatBoost_BAG_L1_FULL ...\n", + "\tWarning: Exception caused CatBoost_BAG_L1_FULL to fail during training... Skipping this model.\n", + "\t\t[Errno 28] No space left on device\n", + "Detailed Traceback:\n", + "Traceback (most recent call last):\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/core/trainer/abstract_trainer.py\", line 1733, in _train_and_save\n", + " model = self._train_single(X, y, model, X_val, y_val, total_resources=total_resources, **model_fit_kwargs)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/core/trainer/abstract_trainer.py\", line 1684, in _train_single\n", + " model = model.fit(X=X, y=y, X_val=X_val, y_val=y_val, total_resources=total_resources, **model_fit_kwargs)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/core/models/abstract/abstract_model.py\", line 829, in fit\n", + " out = self._fit(**kwargs)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/core/models/ensemble/stacker_ensemble_model.py\", line 169, in _fit\n", + " return super()._fit(X=X, y=y, time_limit=time_limit, **kwargs)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/core/models/ensemble/bagged_ensemble_model.py\", line 250, in _fit\n", + " self._fit_single(X=X, y=y, model_base=model_base, use_child_oof=use_child_oof, skip_oof=_skip_oof, **kwargs)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/core/models/ensemble/bagged_ensemble_model.py\", line 442, in _fit_single\n", + " self.save_child(model_base)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/core/models/ensemble/bagged_ensemble_model.py\", line 792, in save_child\n", + " child.save(verbose=verbose)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/core/models/abstract/abstract_model.py\", line 1026, in save\n", + " save_pkl.save(path=file_path, object=self, verbose=verbose)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/common/savers/save_pkl.py\", line 27, in save\n", + " save_with_fn(validated_path, object, pickle_fn, format=format, verbose=verbose, compression_fn=compression_fn, compression_fn_kwargs=compression_fn_kwargs)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/common/savers/save_pkl.py\", line 47, in save_with_fn\n", + " pickle_fn(object, fout)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/common/savers/save_pkl.py\", line 25, in pickle_fn\n", + " return pickle.dump(o, buffer, protocol=4)\n", + "OSError: [Errno 28] No space left on device\n", + "Fitting 1 L1 models ...\n", + "Fitting model: ExtraTreesMSE_BAG_L1_FULL ...\n", + "\tWarning: Exception caused ExtraTreesMSE_BAG_L1_FULL to fail during training... Skipping this model.\n", + "\t\t[Errno 28] No space left on device: 'AutogluonModels/ag-20231116_121151/models/ExtraTreesMSE_BAG_L1_FULL'\n", + "Detailed Traceback:\n", + "Traceback (most recent call last):\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/core/trainer/abstract_trainer.py\", line 1733, in _train_and_save\n", + " model = self._train_single(X, y, model, X_val, y_val, total_resources=total_resources, **model_fit_kwargs)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/core/trainer/abstract_trainer.py\", line 1684, in _train_single\n", + " model = model.fit(X=X, y=y, X_val=X_val, y_val=y_val, total_resources=total_resources, **model_fit_kwargs)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/core/models/abstract/abstract_model.py\", line 829, in fit\n", + " out = self._fit(**kwargs)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/core/models/ensemble/stacker_ensemble_model.py\", line 169, in _fit\n", + " return super()._fit(X=X, y=y, time_limit=time_limit, **kwargs)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/core/models/ensemble/bagged_ensemble_model.py\", line 228, in _fit\n", + " self.save_model_base(self.model_base)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/core/models/ensemble/bagged_ensemble_model.py\", line 998, in save_model_base\n", + " save_pkl.save(path=os.path.join(self.path + \"utils\", \"model_template.pkl\"), object=model_base)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/common/savers/save_pkl.py\", line 27, in save\n", + " save_with_fn(validated_path, object, pickle_fn, format=format, verbose=verbose, compression_fn=compression_fn, compression_fn_kwargs=compression_fn_kwargs)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/common/savers/save_pkl.py\", line 41, in save_with_fn\n", + " os.makedirs(path_parent, exist_ok=True)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/os.py\", line 213, in makedirs\n", + " makedirs(head, exist_ok=exist_ok)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/os.py\", line 223, in makedirs\n", + " mkdir(name, mode)\n", + "OSError: [Errno 28] No space left on device: 'AutogluonModels/ag-20231116_121151/models/ExtraTreesMSE_BAG_L1_FULL'\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Fitting 1 L1 models ...\n", + "Fitting model: NeuralNetFastAI_BAG_L1_FULL ...\n", + "\tWarning: Exception caused NeuralNetFastAI_BAG_L1_FULL to fail during training... Skipping this model.\n", + "\t\t[Errno 28] No space left on device: 'AutogluonModels/ag-20231116_121151/models/NeuralNetFastAI_BAG_L1_FULL'\n", + "Detailed Traceback:\n", + "Traceback (most recent call last):\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/core/trainer/abstract_trainer.py\", line 1733, in _train_and_save\n", + " model = self._train_single(X, y, model, X_val, y_val, total_resources=total_resources, **model_fit_kwargs)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/core/trainer/abstract_trainer.py\", line 1684, in _train_single\n", + " model = model.fit(X=X, y=y, X_val=X_val, y_val=y_val, total_resources=total_resources, **model_fit_kwargs)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/core/models/abstract/abstract_model.py\", line 829, in fit\n", + " out = self._fit(**kwargs)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/core/models/ensemble/stacker_ensemble_model.py\", line 169, in _fit\n", + " return super()._fit(X=X, y=y, time_limit=time_limit, **kwargs)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/core/models/ensemble/bagged_ensemble_model.py\", line 228, in _fit\n", + " self.save_model_base(self.model_base)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/core/models/ensemble/bagged_ensemble_model.py\", line 998, in save_model_base\n", + " save_pkl.save(path=os.path.join(self.path + \"utils\", \"model_template.pkl\"), object=model_base)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/common/savers/save_pkl.py\", line 27, in save\n", + " save_with_fn(validated_path, object, pickle_fn, format=format, verbose=verbose, compression_fn=compression_fn, compression_fn_kwargs=compression_fn_kwargs)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/common/savers/save_pkl.py\", line 41, in save_with_fn\n", + " os.makedirs(path_parent, exist_ok=True)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/os.py\", line 213, in makedirs\n", + " makedirs(head, exist_ok=exist_ok)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/os.py\", line 223, in makedirs\n", + " mkdir(name, mode)\n", + "OSError: [Errno 28] No space left on device: 'AutogluonModels/ag-20231116_121151/models/NeuralNetFastAI_BAG_L1_FULL'\n", + "Fitting 1 L1 models ...\n", + "Fitting model: XGBoost_BAG_L1_FULL ...\n", + "\tWarning: Exception caused XGBoost_BAG_L1_FULL to fail during training... Skipping this model.\n", + "\t\t[Errno 28] No space left on device: 'AutogluonModels/ag-20231116_121151/models/XGBoost_BAG_L1_FULL'\n", + "Detailed Traceback:\n", + "Traceback (most recent call last):\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/core/trainer/abstract_trainer.py\", line 1733, in _train_and_save\n", + " model = self._train_single(X, y, model, X_val, y_val, total_resources=total_resources, **model_fit_kwargs)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/core/trainer/abstract_trainer.py\", line 1684, in _train_single\n", + " model = model.fit(X=X, y=y, X_val=X_val, y_val=y_val, total_resources=total_resources, **model_fit_kwargs)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/core/models/abstract/abstract_model.py\", line 829, in fit\n", + " out = self._fit(**kwargs)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/core/models/ensemble/stacker_ensemble_model.py\", line 169, in _fit\n", + " return super()._fit(X=X, y=y, time_limit=time_limit, **kwargs)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/core/models/ensemble/bagged_ensemble_model.py\", line 228, in _fit\n", + " self.save_model_base(self.model_base)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/core/models/ensemble/bagged_ensemble_model.py\", line 998, in save_model_base\n", + " save_pkl.save(path=os.path.join(self.path + \"utils\", \"model_template.pkl\"), object=model_base)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/common/savers/save_pkl.py\", line 27, in save\n", + " save_with_fn(validated_path, object, pickle_fn, format=format, verbose=verbose, compression_fn=compression_fn, compression_fn_kwargs=compression_fn_kwargs)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/common/savers/save_pkl.py\", line 41, in save_with_fn\n", + " os.makedirs(path_parent, exist_ok=True)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/os.py\", line 213, in makedirs\n", + " makedirs(head, exist_ok=exist_ok)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/os.py\", line 223, in makedirs\n", + " mkdir(name, mode)\n", + "OSError: [Errno 28] No space left on device: 'AutogluonModels/ag-20231116_121151/models/XGBoost_BAG_L1_FULL'\n", + "Fitting 1 L1 models ...\n", + "Fitting model: NeuralNetTorch_BAG_L1_FULL ...\n", + "\tWarning: Exception caused NeuralNetTorch_BAG_L1_FULL to fail during training... Skipping this model.\n", + "\t\t[Errno 28] No space left on device: 'AutogluonModels/ag-20231116_121151/models/NeuralNetTorch_BAG_L1_FULL'\n", + "Detailed Traceback:\n", + "Traceback (most recent call last):\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/core/trainer/abstract_trainer.py\", line 1733, in _train_and_save\n", + " model = self._train_single(X, y, model, X_val, y_val, total_resources=total_resources, **model_fit_kwargs)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/core/trainer/abstract_trainer.py\", line 1684, in _train_single\n", + " model = model.fit(X=X, y=y, X_val=X_val, y_val=y_val, total_resources=total_resources, **model_fit_kwargs)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/core/models/abstract/abstract_model.py\", line 829, in fit\n", + " out = self._fit(**kwargs)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/core/models/ensemble/stacker_ensemble_model.py\", line 169, in _fit\n", + " return super()._fit(X=X, y=y, time_limit=time_limit, **kwargs)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/core/models/ensemble/bagged_ensemble_model.py\", line 228, in _fit\n", + " self.save_model_base(self.model_base)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/core/models/ensemble/bagged_ensemble_model.py\", line 998, in save_model_base\n", + " save_pkl.save(path=os.path.join(self.path + \"utils\", \"model_template.pkl\"), object=model_base)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/common/savers/save_pkl.py\", line 27, in save\n", + " save_with_fn(validated_path, object, pickle_fn, format=format, verbose=verbose, compression_fn=compression_fn, compression_fn_kwargs=compression_fn_kwargs)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/common/savers/save_pkl.py\", line 41, in save_with_fn\n", + " os.makedirs(path_parent, exist_ok=True)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/os.py\", line 213, in makedirs\n", + " makedirs(head, exist_ok=exist_ok)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/os.py\", line 223, in makedirs\n", + " mkdir(name, mode)\n", + "OSError: [Errno 28] No space left on device: 'AutogluonModels/ag-20231116_121151/models/NeuralNetTorch_BAG_L1_FULL'\n", + "Fitting 1 L1 models ...\n", + "Fitting model: LightGBMLarge_BAG_L1_FULL ...\n", + "\tWarning: Exception caused LightGBMLarge_BAG_L1_FULL to fail during training... Skipping this model.\n", + "\t\t[Errno 28] No space left on device\n", + "Detailed Traceback:\n", + "Traceback (most recent call last):\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/core/trainer/abstract_trainer.py\", line 1733, in _train_and_save\n", + " model = self._train_single(X, y, model, X_val, y_val, total_resources=total_resources, **model_fit_kwargs)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/core/trainer/abstract_trainer.py\", line 1684, in _train_single\n", + " model = model.fit(X=X, y=y, X_val=X_val, y_val=y_val, total_resources=total_resources, **model_fit_kwargs)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/core/models/abstract/abstract_model.py\", line 829, in fit\n", + " out = self._fit(**kwargs)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/core/models/ensemble/stacker_ensemble_model.py\", line 169, in _fit\n", + " return super()._fit(X=X, y=y, time_limit=time_limit, **kwargs)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/core/models/ensemble/bagged_ensemble_model.py\", line 250, in _fit\n", + " self._fit_single(X=X, y=y, model_base=model_base, use_child_oof=use_child_oof, skip_oof=_skip_oof, **kwargs)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/core/models/ensemble/bagged_ensemble_model.py\", line 442, in _fit_single\n", + " self.save_child(model_base)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/core/models/ensemble/bagged_ensemble_model.py\", line 792, in save_child\n", + " child.save(verbose=verbose)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/core/models/abstract/abstract_model.py\", line 1026, in save\n", + " save_pkl.save(path=file_path, object=self, verbose=verbose)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/common/savers/save_pkl.py\", line 27, in save\n", + " save_with_fn(validated_path, object, pickle_fn, format=format, verbose=verbose, compression_fn=compression_fn, compression_fn_kwargs=compression_fn_kwargs)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/common/savers/save_pkl.py\", line 47, in save_with_fn\n", + " pickle_fn(object, fout)\n", + " File \"/home/dashuo/anaconda3/lib/python3.8/site-packages/autogluon/common/savers/save_pkl.py\", line 25, in pickle_fn\n", + " return pickle.dump(o, buffer, protocol=4)\n", + "OSError: [Errno 28] No space left on device\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Fitting model: WeightedEnsemble_L2_FULL | Skipping fit via cloning parent ...\n", + "\t0.24s\t = Training runtime\n" + ] + }, + { + "ename": "OSError", + "evalue": "[Errno 28] No space left on device: 'AutogluonModels/ag-20231116_121151/models/WeightedEnsemble_L2_FULL'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mOSError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[12], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m predictor_a\u001b[38;5;241m.\u001b[39mrefit_full()\n\u001b[1;32m 2\u001b[0m predictor_b\u001b[38;5;241m.\u001b[39mrefit_full()\n\u001b[0;32m----> 3\u001b[0m \u001b[43mpredictor_c\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrefit_full\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/anaconda3/lib/python3.8/site-packages/autogluon/tabular/predictor/predictor.py:2617\u001b[0m, in \u001b[0;36mTabularPredictor.refit_full\u001b[0;34m(self, model, set_best_to_refit_full)\u001b[0m\n\u001b[1;32m 2609\u001b[0m model \u001b[38;5;241m=\u001b[39m model_best\n\u001b[1;32m 2610\u001b[0m logger\u001b[38;5;241m.\u001b[39mlog(\n\u001b[1;32m 2611\u001b[0m \u001b[38;5;241m20\u001b[39m,\n\u001b[1;32m 2612\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mRefitting models via `predictor.refit_full` using all of the data (combined train and validation)...\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 2615\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\t\u001b[39;00m\u001b[38;5;124mTo learn more, refer to the `.refit_full` method docstring which explains how \u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_FULL\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m models differ from normal models.\u001b[39m\u001b[38;5;124m'\u001b[39m,\n\u001b[1;32m 2616\u001b[0m )\n\u001b[0;32m-> 2617\u001b[0m refit_full_dict \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_learner\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrefit_ensemble_full\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2619\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m set_best_to_refit_full:\n\u001b[1;32m 2620\u001b[0m model_full_dict \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_trainer\u001b[38;5;241m.\u001b[39mget_model_full_dict()\n", + "File \u001b[0;32m~/anaconda3/lib/python3.8/site-packages/autogluon/tabular/learner/abstract_learner.py:448\u001b[0m, in \u001b[0;36mAbstractTabularLearner.refit_ensemble_full\u001b[0;34m(self, model)\u001b[0m\n\u001b[1;32m 447\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mrefit_ensemble_full\u001b[39m(\u001b[38;5;28mself\u001b[39m, model\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mall\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 448\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload_trainer\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrefit_ensemble_full\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/anaconda3/lib/python3.8/site-packages/autogluon/core/trainer/abstract_trainer.py:1334\u001b[0m, in \u001b[0;36mAbstractTrainer.refit_ensemble_full\u001b[0;34m(self, model)\u001b[0m\n\u001b[1;32m 1332\u001b[0m ensemble_set_valid\u001b[38;5;241m.\u001b[39mappend(model)\n\u001b[1;32m 1333\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m ensemble_set_valid:\n\u001b[0;32m-> 1334\u001b[0m models_trained_full \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrefit_single_full\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodels\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mensemble_set_valid\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1335\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1336\u001b[0m models_trained_full \u001b[38;5;241m=\u001b[39m []\n", + "File \u001b[0;32m~/anaconda3/lib/python3.8/site-packages/autogluon/core/trainer/abstract_trainer.py:1268\u001b[0m, in \u001b[0;36mAbstractTrainer.refit_single_full\u001b[0;34m(self, X, y, X_val, y_val, X_unlabeled, models)\u001b[0m\n\u001b[1;32m 1266\u001b[0m logger\u001b[38;5;241m.\u001b[39mlog(\u001b[38;5;241m20\u001b[39m, \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFitting model: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmodel_full\u001b[38;5;241m.\u001b[39mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m | Skipping fit via cloning parent ...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1267\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_add_model(model_full, stack_name\u001b[38;5;241m=\u001b[39mREFIT_FULL_NAME, level\u001b[38;5;241m=\u001b[39mlevel)\n\u001b[0;32m-> 1268\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel_full\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1269\u001b[0m models_trained \u001b[38;5;241m=\u001b[39m [model_full\u001b[38;5;241m.\u001b[39mname]\n\u001b[1;32m 1270\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n", + "File \u001b[0;32m~/anaconda3/lib/python3.8/site-packages/autogluon/core/trainer/abstract_trainer.py:1411\u001b[0m, in \u001b[0;36mAbstractTrainer.save_model\u001b[0;34m(self, model, reduce_memory)\u001b[0m\n\u001b[1;32m 1409\u001b[0m model\u001b[38;5;241m.\u001b[39mreduce_memory_size(remove_fit\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, remove_info\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, requires_save\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 1410\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlow_memory:\n\u001b[0;32m-> 1411\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1412\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1413\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodels[model\u001b[38;5;241m.\u001b[39mname] \u001b[38;5;241m=\u001b[39m model\n", + "File \u001b[0;32m~/anaconda3/lib/python3.8/site-packages/autogluon/core/models/ensemble/bagged_ensemble_model.py:1022\u001b[0m, in \u001b[0;36mBaggedEnsembleModel.save\u001b[0;34m(self, path, verbose, save_oof, save_children)\u001b[0m\n\u001b[1;32m 1020\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlow_memory:\n\u001b[1;32m 1021\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodels \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_child_model_names(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodels)\n\u001b[0;32m-> 1022\u001b[0m path \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpath\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverbose\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1023\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodels \u001b[38;5;241m=\u001b[39m _models\n\u001b[1;32m 1024\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m path\n", + "File \u001b[0;32m~/anaconda3/lib/python3.8/site-packages/autogluon/core/models/abstract/abstract_model.py:1026\u001b[0m, in \u001b[0;36mAbstractModel.save\u001b[0;34m(self, path, verbose)\u001b[0m\n\u001b[1;32m 1024\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiler \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiler\u001b[38;5;241m.\u001b[39msave_in_pkl:\n\u001b[1;32m 1025\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;66;03m# Don't save model in pkl\u001b[39;00m\n\u001b[0;32m-> 1026\u001b[0m \u001b[43msave_pkl\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpath\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfile_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mobject\u001b[39;49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverbose\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1027\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel \u001b[38;5;241m=\u001b[39m _model\n\u001b[1;32m 1028\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m path\n", + "File \u001b[0;32m~/anaconda3/lib/python3.8/site-packages/autogluon/common/savers/save_pkl.py:27\u001b[0m, in \u001b[0;36msave\u001b[0;34m(path, object, format, verbose, **kwargs)\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mpickle_fn\u001b[39m(o, buffer):\n\u001b[1;32m 25\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m pickle\u001b[38;5;241m.\u001b[39mdump(o, buffer, protocol\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m4\u001b[39m)\n\u001b[0;32m---> 27\u001b[0m \u001b[43msave_with_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvalidated_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mobject\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpickle_fn\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mformat\u001b[39;49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mformat\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverbose\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcompression_fn\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcompression_fn\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcompression_fn_kwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcompression_fn_kwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/anaconda3/lib/python3.8/site-packages/autogluon/common/savers/save_pkl.py:41\u001b[0m, in \u001b[0;36msave_with_fn\u001b[0;34m(path, object, pickle_fn, format, verbose, compression_fn, compression_fn_kwargs)\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m path_parent \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 40\u001b[0m path_parent \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;66;03m# Allows saving to working directory root without crashing\u001b[39;00m\n\u001b[0;32m---> 41\u001b[0m \u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmakedirs\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpath_parent\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mexist_ok\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 43\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m compression_fn_kwargs \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 44\u001b[0m compression_fn_kwargs \u001b[38;5;241m=\u001b[39m {}\n", + "File \u001b[0;32m~/anaconda3/lib/python3.8/os.py:223\u001b[0m, in \u001b[0;36mmakedirs\u001b[0;34m(name, mode, exist_ok)\u001b[0m\n\u001b[1;32m 221\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[1;32m 222\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 223\u001b[0m \u001b[43mmkdir\u001b[49m\u001b[43m(\u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 224\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mOSError\u001b[39;00m:\n\u001b[1;32m 225\u001b[0m \u001b[38;5;66;03m# Cannot rely on checking for EEXIST, since the operating system\u001b[39;00m\n\u001b[1;32m 226\u001b[0m \u001b[38;5;66;03m# could give priority to other errors like EACCES or EROFS\u001b[39;00m\n\u001b[1;32m 227\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m exist_ok \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m path\u001b[38;5;241m.\u001b[39misdir(name):\n", + "\u001b[0;31mOSError\u001b[0m: [Errno 28] No space left on device: 'AutogluonModels/ag-20231116_121151/models/WeightedEnsemble_L2_FULL'" + ] + } + ], + "source": [ + "predictor_a.refit_full()\n", + "predictor_b.refit_full()\n", + "predictor_c.refit_full()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "predictor_a.leaderboard(silent=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "predictor_b.leaderboard(silent=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "predictor_c.leaderboard(silent=True)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Predictions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_a = data_collection.X_test_estimated['a'].drop(columns=['location', 'date_forecast'])\n", + "test_b = data_collection.X_test_estimated['b'].drop(columns=['location', 'date_forecast'])\n", + "test_c = data_collection.X_test_estimated['c'].drop(columns=['location', 'date_forecast'])\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "y_pred_a = predictor_a.predict(test_a)\n", + "y_pred_b = predictor_b.predict(test_b)\n", + "y_pred_c = predictor_c.predict(test_c)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "final_pred = pd.concat([y_pred_a, y_pred_b, y_pred_c]).reset_index(drop=True)\n", + "final_pred_AutoGluon = ReLU(final_pred)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### CatBoost" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "selected_features = ['date_forecast', 'absolute_humidity_2m:gm3',\n", + " 'air_density_2m:kgm3', 'clear_sky_energy_1h:J',\n", + " 'clear_sky_rad:W', 'dew_or_rime:idx',\n", + " 'dew_point_2m:K', 'diffuse_rad:W', 'diffuse_rad_1h:J', 'direct_rad:W',\n", + " 'direct_rad_1h:J', 'effective_cloud_cover:p', 'elevation:m',\n", + " 'fresh_snow_6h:cm', 'is_day:idx',\n", + " 'is_in_shadow:idx', 'msl_pressure:hPa', 'precip_5min:mm',\n", + " 'pressure_100m:hPa', 'pressure_50m:hPa',\n", + " 'prob_rime:p', 'rain_water:kgm2', 'relative_humidity_1000hPa:p',\n", + " 'sfc_pressure:hPa', 'snow_depth:cm',\n", + " 'sun_azimuth:d', 'sun_elevation:d', 'super_cooled_liquid_water:kgm2',\n", + " 't_1000hPa:K', 'total_cloud_cover:p', 'visibility:m',\n", + " 'wind_speed_10m:ms', 'wind_speed_u_10m:ms', 'wind_speed_v_10m:ms',\n", + " 'wind_speed_w_1000hPa:ms']\n", + "\n", + "made_features = ['location', 'type', 'is_day:idx', 'is_in_shadow:idx', 'dew_or_rime:idx']\n", + "\n", + "drop_feature = 'diffuse_rad:W'\n", + "\n", + "\n", + "#Loading all data\n", + "data_collection = DataSet()\n", + "#Preprocessing\n", + "data_collection.select_features(selected_features)\n", + "data_collection.resample_to_hourly()\n", + "data_collection.remove_nans(drop_feature)\n", + "data_collection.add_location()\n", + "data_collection.add_type()\n", + "data_collection.combine_obs_est()\n", + "data_collection.drop_bad_data()\n", + "data_collection.cyclic_time_encoding()\n", + "\n", + "k_b = 5\n", + "k_c = 6\n", + "data_collection.scale_y_train(k_b = k_b, k_c = k_c)\n", + "\n", + "X_train, X_test, y_train = data_collection.train_test()\n", + "\n", + "for f in made_features:\n", + " if f not in ['location', 'type']:\n", + " X_train[f] = X_train[f].map(remap)\n", + " X_test[f] = X_test[f].map(remap)\n", + "\n", + "make_categorical(X_train,made_features)\n", + "X_train = X_train.drop('time', axis=1)\n", + "\n", + "make_categorical(X_test,made_features)\n", + "X_test = X_test.drop('date_forecast', axis=1)\n", + "\n", + "train_pool = cb.Pool(\n", + " X_train,\n", + " y_train,\n", + " cat_features = made_features\n", + ")\n", + "test_pool = cb.Pool(\n", + " X_test,\n", + " cat_features = made_features\n", + ")\n", + "\n", + "model = cb.CatBoostRegressor(\n", + " iterations = 10000,\n", + " depth = 9,\n", + " learning_rate =0.005,\n", + " loss_function ='MAE',\n", + " cat_features = made_features\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#train the model\n", + "model.fit(train_pool, silent=True)\n", + "# make the prediction using the resulting model\n", + "preds = model.predict(test_pool)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#scale back\n", + "length = int((X_test.shape[0]/3))\n", + "pred_a = preds[:length]\n", + "pred_b = preds[length:2*length] / k_b\n", + "pred_c = preds[2*length:3*length] / k_c\n", + "preds = np.concatenate([pred_a,pred_b, pred_c])\n", + "#Drop negative values\n", + "final_pred_cb = ReLU(preds)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Combining for final result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "final_pred_AutoGluon = pd.DataFrame({'predictions':final_pred_AutoGluon})\n", + "final_pred_AutoGluon['predictions'] = final_pred_AutoGluon['predictions'].apply(lambda x: 0 if x < 5 else x)\n", + "\n", + "final_pred_cb = pd.DataFrame({'predictions':final_pred_cb})\n", + "final_pred_cb['predictions'] = final_pred_cb['predictions'].apply(lambda x: 0 if x < 5 else x)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "final_pred = 0.5*(final_pred_AutoGluon + final_pred_cb)\n", + "\n", + "final_pred = final_pred.reset_index()\n", + "final_pred = final_pred.rename(columns={'index': 'id'})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "final_pred.to_csv('Short_notebook_2.csv', index=True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/models/__pycache__/SwanDNA.cpython-39.pyc b/models/__pycache__/SwanDNA.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bea57a4d902a50f5d479ad0b93167a158661b2d4 GIT binary patch literal 17419 zcmeHOS&tmob*`$eu3n~R;c&QVEy-Hg7CEG7*>MzE5f_QJK#m2A-o`fCHC@#+J?yS- z-l`@!lO_^`0!ScX0K-NC1VKFG#0duRdmi%cKM(`~>P<+HfNzE$BE^B@eBZg%>-2C@ zXa#{>gSu0FtM0A4ob#RUoO{cgo-P_VzIk&aIPguw_!s^dK5YDW184Mk(=dF)Hyg&X z{%_^T){}6Hxr8x`8IqV-n&XF|d0CFDlA4ksPY0g39 zJmEiyoF~(qL&$l`e;PSYr#Xjx^IfC-%w3M@QrSwfZ(TV3?iZ6e`FBZqVN`3WW}wPu z;#>;CsHLiWboj{N#~V1KlSl%C!@g|!=5i*m1KYR$!tyh|{hqy??f=gD4t`@`@#=h% zJ^$wU3t!#*gQAP$D?u$#L0AoRuQ}?w6Wv z)w&wELybN5tb6L@sgsNS8rI^t9ld(ugr~08uP(OK>IrWpI&tdc$y19bKZ{4@LG<3d zo;WIqJ1VSN=uGdRknuc=Gx{8ouDK3TI|5NUX!x1SR&?-f;}1>bnm&5&+jp#QTQ{s5 znXVOETbXWV+gvcv*X+{WU*WjRx)tQr*fEogUvHM}Bpa=H?I6ik*E-?l#C$8USA#|e zH|!{A)RcpV6Vpq~l^sLv$B$pUcB0zoL~)=_thJiKiMSPRbWVKj&B{Ayzw+u~JzPEU z?lmtww{-eMv*ibksP}WRy^&;BgVjc778OOT+At5A^XBH1%*3MH8yNsLnA0o+K!xlY zTNs3G$j*+rR5sNCbl=&y1l3wwEb7u~HI@H* zSg8dbBr=HHqxCQ@zv?b^nk%TRRYSRUCw6=HwYaa?wMNT}M`}5Td^FhwD6DuZ^+rA37_B9R zg{6O&8AnnT4M>0Sn$kdNkCu?c#!X|(+%`<3n`zH=O)7Q!F#c=RHoMlg)io-1*S-m@ zsANUqEvP(m9sJ65v)d-F^FAOz9YmGo0@S_J4ip6*ni@mMtuRcS3-vJYRFbJxUrsU@ zFJ7HWa%qf8lJVJ&>wsCBHj+F zg)^E#VwgoUkG~l?H%pN2#a?Uo#s&e$6dP`!*INL^ZBwqS{#B-b)yo5p=#jA%8=#(J z!ltU{nH*v=P;hB1mHrscsDy-#?fq>|v8gGllnd$zZmY+bJkI0^CeI*AYzmWd`J0s` z)9IL0gGQq_F7ncs_=vppNnG*S9-tHt7Zme$0Sd*3M|2d)FrWyCz#_hvU$!;bSk4NF zH~e>=a?&c3u97&O(j}=6fW)Q{zc=zRVt`u-t!fVtwQ}t;nsPg3?bOS(z z@8i*E7qZZ@_HS3Hz&(Juv=)d$6Jlaq@2-kvCkg@k9dWP+Y9k5(A)t~BR3NLK##4Iq zWj3bzJ%8Cxtd+2=tw}O7+D0XlA=A>A+1L0|fp7@H=6t`Rz2%h;pJk~Rki;UTzPW7; z%jrli?VPn#&WO=g1r)kV^Igs)*_g1MXllva+&5@! zDzQ7ci4yCgkT!ISypjPJ6V4O3@mZWb;dDy2V&0<@&hx6(X^U|F+XD!vJpF$mnUv13 zvdYHxjV5I}>>JMX9?YFzzgqVp;<#O1^TH532_`UZX(LnZ=`n^;OhVr@5JtvioM0xV zU~ef`pGUF8S`8se!j(?(cfQxyNy7bQJTWBP{V8~)!mU1ox1dcFA4LpTzs+PX+u2hcIZ?T3+a>_9{&l+fJ!OC)wQ|65fc5k3XOKFF@Al#QNwV?G9HYF2&4tW;x zUWhUDH<|MsazMw`t1)a$*SZM`9?l`pM&8>fX>S{_-x*SI?`DsxUGyL(6oA$GRX}Dq z0;Tarn*gWDD4R#lN2jz+Np_Yx`GUH{P+FD$)y%zg2Ur01Ce z!IxTLP}Xuti5|tlhr?=?7nvX6{0M29s2!L{SEhAThtr;ZMckwzFRf4NbT_LbgEcy+ zn)d3hCDw(d#5%LIY@J_9Z0U1W+vkFiX2OlYnAV;Am%wggkUF?Vh5q;|(WBy^i;2y9 zL*hqck>s0RyTTjG_!Z3~F-m#Ba2{Z+|L09biF;&_IeoUSs-1?X-d(HLV(+^>nL9lu zK9DaDoHbAw^k(FGZrfAvgkVsP*Hi#=%;L1i(#6QVR*%=QRp`RK3t$ z3EWN;_>^O>vr6(n`NAyJYa1fj+L7?R*rPAvxx6~=E76B0QBf_5?CNIJBY8{OTGYGV z0%SHeTuO}BXt+^ur6WGPu5E>!#`QwqMYC2swgeo%-MSX2hdX)$IgKY%0-|_537CHm`2%1+I5N?b!CCMQ zr6^(18j{#p2S4r_>kJxLcQQAuE~w!VV-;8yBWM!az%cWsbENK*DBQZAJu*EnrJmU$=3@JH-y71B^hF`go@A-6>QEyK04)TdFu`26MRAtWYN!n@KVrq1vgI^4n# z)N$UN8y9g!r!L8*b1Qnz@G1-_kgLzSHEvA}305*SBqz!A$ds!~e5t&Cf&=;^+>D4@ zMzIKpnz1Z^7XJWT|CP@Zutp?dVjP8o%#%Hi$cGtY>cv^_8%qJ(J!Igip-74}`Kzc% zaq869k@VaS!r*caWVTXmg_8jIAMijL2D1B+mAtjNFw&3v;w~M9@3E>Bcjrn+6!WS$ zM-FzJ$HT@;X43Y&8CDVFR`37(!I0PM;jWt&bUl);DE6wC$w<;#HLxIbLailc4m?Pa zjy?|Sfj@*jC25RLs z(n^Hn1rxy@;`NJZIe0&rS(oXO^n)Wcmk`&fuyLnOx)|S%tg)ox?ZG*tl)nHg8+E?c15#&h6~Y z-1n?2HiEr;63a%!f*mWd7EdN-1(7ehahh;{gI9FZ6jpsxa`M;v6~HD;^)!hAe>yx720zZPZbvIQI@-^W+PnGf0Lss?gJ@QT@4< ztL0M-0{K}7+39PLvzG?Z^`dmktcV6ZCK_a?I+TSDjakpwb0xM{jN!TG%co#j{)~tk>#+@_q+8&}TzEmV`i3&{*tK zDLRq6sgxE5F2xE{!y~H=tk5=G?<(qDg_F~2H4s-l2Bye*>F({nC`IQ5podUc5vMgk`oOW3<=(AgskVSd_n0lNW5>Sa&x=X$qv3HR#~ zh;v`moqBOl#Y51w)b(eGZczT?rfDfB<`51!P3e04gE`kjRmkXkBD`x^tcE)B;B}v< z{V{$F%nH{`V|xohzLcUlac(OQ`etui@S=_NqO7M*#U=RJb~g)OJlD;43*F*PP`Up7 z>E8W#22V^afW%FFZdPhXuu;~p=4hbwa{RsloXfs1<)fJ@U{YfuWZ)TIkx3+3ME)8Z zSV)&m4cCJD>RKEng?`{j>~@kvXqe7gl0nosajH#hMX14F*UGh164Q_r2YQ`M4XLOS zdv<^7gsoJqpjMj{x%~nmn0DNlW=v5m-hbN}` zhUp`~Qn9WSWwflNw{|irW-E*R<(ep&R&deljIFFH&+oN7D^H9yJelF*&Sp?aHKM~l zP*@DKe)Gssi$|ZU$}@Zp2Lp8E6Cz+g!9~~HGPaGb#neJf2&CTJLenWO+t|G{xK8gH zTe1%au_jReZL^=N_vzTkM=7khV*R0;J#5@?kYjftus57Iixmim8$9n=lAqK0Se*d> z-8LzNTSYuI#TDB)yH(=TfVB*s=9ZW2(qK=8dIhalWZLCvOgy)C>BZ=X`5b6TuAf>iLcA*41}D}A z@-|K>uCtF#kL_nEL( zwdZz@Htrth%Q+zMduB^7<)`zuVty6p$U&LqA%VmH1=pKCf2K0%&ifw%y?#UY#^~N} zxSIN8bG0of+P%0M_?S__$_LVXT-ZVk1y&>v<~D`@C1GyUU}RW1fvq0lZYVK3xD(G! zU~Pjs_DS8WL-V$I!18|bg#mwCco6<}8r{*6gJ*apHl2&58mR<(CB`5C+~EyDVPdbw zvFp5=uq^UtM6YwWsd<(ozBv{}hPc*9wBbX|vzRx=v!)n08iA(&_hd_TIci+!R)55V zksqBb@QPOAryRd~P@2cbM&wbj8g++n{}U!$xEXI2?bbiCS$&rB#AqYNhzi+n@8C6y z2~jzN3r$q+x3m>w`+ZmF3n&IONe8a=QrAvqX^1rB)DKXr5{cc5E$q3c?Vs7J8I(7~ zy3@+j!fzfJ>z_98>M!}G!JLWo%f$TzA(Lz&qJ(`E6az+SzjF+|y6{NuFdRBPj7wWHK)G4LJn<{B*1Kb>5brKOkaKw& zPmY6=Oz>sikXasJ6(I=$e2K7c@ct%~6(*;cNM+yP)udQE8#M{We#Y56iksswrrpI6 z2s5n@@}JScvB9^!B>{DNn1kDwMX z3@(sfpx;K}9cLV^M%|h{f$DkoLbN&sDuF4&@Gb=CeFd>MTKxzud7tv#N2})qt(gC7 zL#sKo_3Xt5M=EOZM?tFl@h!T^V|hy?`0mf@{^JtkGoEZyaYWa_Vwl;0Ryk*q}%K zHO2W5lL3Pf^1_G&KhNBaFNb$G_kK+1$Ea+PmG!aZdBGOu?~N_TaJLu0kdLKg9|FGN z=10T^KAPgk!^Y4@Q@Lb6+V&RKsMdxcW!Fc2RAJ#gbY2E%f`}frMBHYYxs7OFCEXeU z)`D0&eq$po(tMmnTpdqH9s*ihQ(W+^;qy>_%Y``m=pyeloHY%5cjSBYeCUe2{Hg9|gA5@@&IvH@s?4@!}Yn{O$^K zW+-v=WLeo7B2OYOA}Ssp`W|O@9v2Q4jz>b|zt}uEVN&||GjZp|>=3lFQ<|;`aeIb8 zpJgK4Z-rO4kYHzED`NCi4c3=q-V`RNzm!yaGQpSl!5;C#$Z;uo;IcZl zW6fB$S(2q?XJ(TIcU(jDPUGFRL!H;IZk7NZyyt>%|KZ2eNZ9%7Ul;?yjCE2z@U{{g-=w$~7 zpGjcjxp4f%5#h$hb1wU0w@%0Rw7mR;&%ez?U~ z3rzsYA9EInmfBl<7M@}+DmBwH&x#Zx}0 zPaM60Nj{U;J8k(1p?rHIU&hFGTG_WHi+Qs1Ktj@r4zBtU69T-nYG~y8z>(x%(~CH- vgLdi9@xEjit_F??fmo)AFsk2;6blyTO{s&E@MJ&|@>d6KSeLdAD#=u|- zXAVp`#+ujzTjGqgHE{-xgw52NvW1 zcVm0p!<$WDI#CyxZjIRjW+UnW)2lHTfZ2?;fZ3`s+rV6iwt?BMF&BZk80`SFQ)6}_ zc1P=f=?PA!?e&>(_ku|reKI{vWAr)?;$(DC`dQ8=)O{x?!-G9^u0~}COV|W#4O$Ym0k;N@gdM=3kc3^poq;RiF5rzpC+=dM z9^lQvhQx2((Bhst;|UvV%ClS)_RTySCZj#!-p|8e-!GC`oM_meFrsAgq>KFtC-c+b zDCR|?>FBm5!9zcdvnS|i!p;059!=t`ERuhrd9BP3)1WAX`3axKqIJmgNnXZJ{+33Y z2Usz#Ugg1pRs>uKP1xYdxG)4A-$(} zniAag{d?0ORd2YBHH495rP%cSAj|SHD3d%Ze4n@QinpmD?czjnzCq0<8etSA=NAa- zNUBca@*s~m5lWtNsd<5L7Mc!fv4h5CUHa*=m5|Qlh>+oJQwW}Oz(c2kAnrh2i3gQd z>Ze9zoUs?RGxof;qk#|R?2Qj6kb2^}JcpMLz49RTMsYgzg0M{PEA()hp_6Cc{*m{= zbTk6Vz4wAJ9;t8t@0&kSz|Y>SG1^;wQ&?rr!vkT&d??IGP#y@gxW`K<6?=H@yuzeU zzKB(74)Iq2hicVD3%Qh~_-#BIefv}!LcAjVku$&nBNNud{8TfwfrTSlqJ8t$hqvEJ z^YCs+ENTkD#4TE~_$M@_R)R_=`l)`VF|9PlX2mL_G9y+&k!j)2T9xKo@{LjHPrb2A z@buxyZ8JQk|TYeWKH_iy^MZ}-h9w-Kp z6i%4%aGD1Ey?#q{5@_ZDgb$J~y7{zR^g4HA4z$#MXlxKg;FNHn4U&(dTcb$4V`1(m zLD6SVNO1a$@6uyaSjF@(=Dd#{sRZ<^gc|5sCT-09Ymue+}Q4XXQ5tdW)JrqUKMic>_(~aeaW zMJYsQr4YHPd|_3-O)e}XL);WHwR2Uj&{j3-q07`CKpUDS@<$~x{;c4<{k1$<>zZq+%lN9+%=?!A{vdkm%j zAt|jRmDb$jq)`0#sUZ{IZ}V3PzD*60Qdm65MzJv9d?j)CHwf$AiSJEeiIX6O!b8RU zEH3+9$!cN4-;!dq=-H?SCwY~kB5dW^1xv(6IOQDcVEEy*9Q2wlW>J5u={HtkusO%n z3M`qqs$1W}mueyEF_UdUbR0aXUas{^GG@Jgk=G@L(|!wbkBxra(fIr5!uouS$76jV z*XN{Jir)`ty7b_++^SF3a?Ae|ua~ARM3N}~O`2PaWLNSc1ZJhs4yhDhNAr2!MMba! zOL`fb47KRpq*y5#CNcjY4x+GGi`FB!u`mVswN?3q(I;1pnX8|*Kn4|P)K!x)$Fm^^Q$LW zFqEXaDFoAzZlcN~RoYk^voXaOaK&dj1B-z%moZX345c2Uzux(e_y2G9=*_;x{}PJ` zTWwvKQBVd5J`r#Ssn4(D9i-!JJeF814b1?!160XXjcv2pi|4k!Op+CQPaL5tx8DQM z$Rq5%-q4-Njn$-7ww4{8-vn936t&)j%Vy7=+e9s$FzoBBDjh<~Hvcn%q|%);sKoev z`UEiv2Gw#*PPhCG1;5$F z1%cM9#EOkm3D6j^!npyL!mg>LL~wf57AWY z6juY4POYvTHz5!OLnhQj1vwVnpmhlj5Ez|eUUIRBm~GPj&vc2m6uvS}ZOS`NTW8uS z;^J`|qYW567o$zY$OcA}5w(;EJZ37DA`>H>Xul(@A#y|{dxRoY;4h;=P7wzg4+(5f zvjWZO0$W_ffN3kzx5H(nL?2$@HW08vnyy*_Z02-&Z)o%#`QN7p&|mvc0h zx_Jz|a(?=@jG!E}J=Euj(A&#|Rw36jlzmL=lf`(pvXkK7qtDgK-h#3dj!wB!jJj96 z2B2|Pbnzh5A>=Dahs?-~ETluoF`^cpcH~5DJX_LnIiR^KcHV!-SFXfgoFn)-q}laQ zK0!8EC)b`#u6Mn6l`eWo;RRk%2H{;w8=FDa-mO?_EQXBId_PFN;C{@5QS2S& zImkZxx;M+5YQtZxZX#ZAX*4 z#*_V6<~_|?G?~hh@-MHT&?`$s_1^2vsOg?p-;1}d3MCte{#w=nh)(v8XkY*yyc zZIxja5wVPAKm`ASWKU! zhkl)&d9O%is_EAtal5l0>?dhb9)U6ImXO5w;)~RBArtV$ehv|!!oz#7;98{K3F$2H zP<^nJ)SPLTsvmW=6;`5~tGD@=x6kg0%7sK~P^lSTm@{|ht-@0IBFopbJN;l+4F zqow1L7fowIx_`xpI;3;;EBfrq>vVkK#mI9tgc61XbH7^8 z-N4+HRn7*=<9rf2i>(rxX5smBFSK%vM`njmZWg)IVds| zn-7Egh}ZpTmO$huKL0Le%ZQVIkJhnfG0%O`ez&nG{QHFNG?yswAfhc(kUdO?lS+^j zV-pmJ2wUACo^vNipqGQvf5XV)IW!vXDz>Ck=^~JI*^a!sn0;kEG!N}oCe%!!zs0_yX7Zo-T@2>VYAS%eucuiauDaMUwaV(%3Rdk<@=fa4fa zhEM`=(wTmYt1{e+P{^@@g&ZPNQgS56KoV(8MVHFv-eW^$-jsB!^jqfOchT+JoT7Ba z{TA-GCR0RP^6VrsoB|(C;qAGzOK<&Oz!#JyXw1PKEZxSE^PgQ=fAl&8XrGJsMa%c2 zJjAVu)czCt*duB_KOlI3A^(Uf(`KaGy{_YIASfVwh`LgU_f%v`_UnRzdG>0H?T_f* z{MF`0ovwstFH`Hu7t$j%FTK~jkAmVZ{PyjSp1M(;n@X}~YfPo}w5gOS0^yM|z@Q>$ zq`ysIT#+1r-70Zt467v-A2}rP_{lF(z}Yn>bHOp|eKd>6&qVRfy{VdATSR0!bAV zLZ+qo--Q1i<|AdE3p4+FAe11>Q`U)pyN;IPdx9OTlj;A)w4#fqmT6oGcJd44tc?2z^!Fz!dFbx3QnuEPn}{}TTlTH+!aVZ;3+D&%Axl$JX3--F#Olg4|{ zYr4DBGD!dpye-4C=li&{=b-gA%Ev3vdl$NR^cd$g5OY~MI oRnCoxocaa>s(1;FgKMs?taPF57+YJ8wS8^-!uIC&Qgx5kG5Ckhx9&m70mK=c?DU62cUVH6%tb1ma zcr`vbl6~R8KUg3APxuXe<>X(GD2G5*&+K{)C;=%w>YDDVepY=|)j^}-5-1zLuk&Z~ zg#3Y>*t;-Y-^Dg>z-b<4bK3c9+|P_S+VWe72ok{YOL$I-w|do z^C@8_*SF^mbgvE^i`l@jD^3GA4yysDR&kn4+aPWC6!s@=Dw`XRgXC`ac9-XmQ?b)+ zYtp>Gc6;M{Ij?>XMUZ5@RK#3_IAylzRVVut8k~5P5l>^((+^3HGL7j^9k_%#qGuiw zgLoEfZOP_?l=0}xb-{BH^ph7ow*{p-NJ|M@uw`p9p!Id@L6Gfm)>?a@hW^;kw_0h! zr*n7`!jhsDB>de^2ig(B&AUQc8V|q*|$+fYb8)LK3j_ENih;J2S z+hTN)6vib2eLHu?wE`sD_OS+iH=iruzik&rVHQ?l7fw-Q+CdFubu|`cRE=#g0y@yU zrxf-)ML4vZ(#(bfE~Ml8{iL7!zBJ+>-;zei`;pX=fi$9Y^9xyzgC~CCN5LK!nQZd- z7tUBe+4Qsih{L!ZMU8u(14gLIh7hC0l&YYzlZ04&l#dO~H*_roa2GT%QTFYb53 zXqe?(bhgr%ck(pZ8+LxU=TGbmSLswJ=)?$-tOM&QjNmO5Yk9CI&~6Q=*$p6sx^#ig zQ-f;Mcxf3L{-Mr%&#cR%#q*e3o~SxR{Y(mi0qY&8*=IlsIwr?tA51-_U_iPLW*%z@ z4As)B7I@SN?Y0eR_M$Y%#YLDZE+Kgv$cZj4!%x;KUGrT@x5PCRD{>u$HUR!{AlT>7 z$?$@NI@Hzh8(sd}*{h)^HXYYc9ai5D)2nQ1mZx=ENoE2sJ9d^gzmY5`oi*mgMeFJ>fuKBp1?i+;hi;L;^LVH zdXcB7fLr~f(afe8s6l`Nvxa&dV40jAOfd&zt_+@sWM?^(*3vvZS{;KCduXacd zX-_LNVD3U@&<^Nba!_WIIm|6bpno;H#GuhVMcsuI$fO{wFc2WdlyhxD?s?SOr#U$y zN04iCI02(eurBNbPlaEt-?Uu=vWxLh;A)#tvD|``tFwMAu0TgxVKB^s$ZPBlx!Ch} z_?{0kInSRAAk+8L#P0=Ro(k_m6bzyuP@nCfuBpEu3RnwkNg0hR@`Ne3ZgvdFgOyRqHGfKlC%2uHo2R9__fdJfNzZ5M&}Xkp)k_FQ3JA%%Vbqj%R~e7(@s zOdTKtK@hPA`bB@2CIBGZl2%Y^3N;}ug%Ac02YhW9;gAJoT=)j00)tfb+m<*g=>ES9PPUm literal 0 HcmV?d00001 diff --git a/models/cdilDNA.py b/models/cdilDNA.py new file mode 100644 index 0000000..c69b6c8 --- /dev/null +++ b/models/cdilDNA.py @@ -0,0 +1,208 @@ +import math +import json +from typing import NamedTuple +import torch +import torch.nn as nn +from torch.nn.utils import weight_norm +from torch.nn import BatchNorm1d +import numpy as np + + +class Config(NamedTuple): + vocab_size: int = None # Size of Vocabulary + dim: int = 768 # Dimension of Hidden Layer in Transformer Encoder + n_layers: int = 12 # Numher of Hidden Layers + #activ_fn: str = "gelu" # Non-linear Activation Function Type in Hidden Layers + max_len: int = 512 # Maximum Length for Positional Embeddings + n_segments: int = 2 # Number of Sentence Segments + n_class: int = 919 # Number of classes + promote: str = "True" #if use special tokens at the beginning + hdim: int = 128 + + @classmethod + def from_json(cls, file): + return cls(**json.load(open(file, "r"))) + + +def gelu(x): + "Implementation of the gelu activation function by Hugging Face" + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + +class CDILBlock(nn.Module): + def __init__(self, c_in, c_out, hdim, ks, dil, dropout): + super(CDILBlock, self).__init__() + self.conv1 = nn.Conv1d(in_channels=c_in, out_channels=hdim, kernel_size=ks, padding='same', dilation=dil, padding_mode='circular', bias = False) + self.conv2 = nn.Conv1d(in_channels=hdim, out_channels=c_out, kernel_size=ks, padding='same', dilation=dil, padding_mode='circular', bias = False) + self.dropout = nn.Dropout(dropout) + self.batch_norm1 = nn.BatchNorm1d(hdim) + self.batch_norm2 = nn.BatchNorm1d(c_out) + self.res = nn.Conv1d(c_in, c_out, kernel_size=(1,)) if c_in != c_out else None + self.nonlinear = nn.ReLU() + + def forward(self, x): + out = self.conv1(x) + out = self.dropout(out) + out = self.batch_norm1(out) + out = self.nonlinear(out) + + out = self.conv2(out) + out = self.dropout(out) + out = self.batch_norm2(out) + out = self.nonlinear(out) + res = x if self.res is None else self.res(x) + return self.nonlinear(out) + res + +# class CDILBlock2(nn.Module): +# def __init__(self, c_in, c_out, hdim, ks, dil, dropout): + +# super().__init__() +# self.conv = nn.Conv1d(in_channels=c_in, out_channels=hdim, kernel_size=ks, padding='same', dilation=dil, padding_mode='circular') + +# self.layer_norm1 = nn.LayerNorm(hdim) +# self.nonlinear1 = nn.ReLU() +# self.dropout = nn.Dropout(dropout) +# self.layer_norm2 = nn.LayerNorm(hdim) +# self.conv21 = nn.Conv1d(in_channels=hdim, out_channels=hdim*2, kernel_size=1) +# self.nonlinear2 = nn.ReLU() +# self.conv22 = nn.Conv1d(in_channels=hdim*2, out_channels=c_out, kernel_size=1) +# self.dropout2 = nn.Dropout(dropout) + +# def forward(self, x): +# x = self.layer_norm1(x.permute(0, 2, 1)).permute(0, 2, 1) +# out = self.conv(x) +# out = self.dropout(self.nonlinear1(out)) +# x2 = out + x +# x2 = self.layer_norm2(x2.permute(0, 2, 1)).permute(0, 2, 1) +# out2 = self.dropout2(self.conv22(self.nonlinear2(self.conv21(x2)))) +# return out2 + x2 + + +class CDILLayer(nn.Module): + def __init__(self, dim_in, dim_out, hdim, ks, dropout): + super(CDILLayer, self).__init__() + layers = [] + for i in range(len(dim_out)): + current_input = dim_in if i == 0 else dim_out[i - 1] + current_output = dim_out[i] + hdim = hdim + current_dilation = 2 ** i + current_dropout = dropout + layers += [CDILBlock(current_input, current_output, hdim, ks, current_dilation, current_dropout)] + self.conv_net = nn.Sequential(*layers) + + def forward(self, x): + return self.conv_net(x) + + +class ClassifierHead(nn.Module): + def __init__(self, dim_hidden, out): + super(ClassifierHead, self).__init__() + self.linear = nn.Linear(dim_hidden, out) + self.init_weights() + + def init_weights(self): + self.linear.weight.data.normal_(0, 0.01) + self.linear.bias.data.normal_(0, 0.01) + + def forward(self, x): + y = self.linear(x) + return y + + +class Classifier(nn.Module): + def __init__(self, dim_in, dim_out, clf_dim, layers, ks, output_size, max_len, dropout): + super(Classifier, self).__init__() + self.encoder = CDILLayer(dim_in, [dim_out]*layers, dim_out*2, ks, dropout) + self.revoClf = CDILLayer(dim_out, [clf_dim]*layers, clf_dim*2, ks, dropout) + self.classifier = ClassifierHead(clf_dim, output_size) + # self.freeze_cdilNet()dim_in: 5 + + def freeze_cdilNet(self): + for param in self.cdilNet.parameters(): + param.requires_grad = False + + def forward(self, x1, x2, idx_linear): + # print(x1.shape, x2.shape) + x1, x2 = x1.float(), x2.float() + y1 = self.encoder(x1) + y2 = self.encoder(x2) + y = y1 - y2 + y = self.revoClf(y) + y = torch.mean(y, dim=2) + y = self.classifier(y) + idx_linear = idx_linear.unsqueeze(0).t().type(torch.int64) + y = torch.gather(y, 1, idx_linear) + return y + + +class GB_Linear_Classifier(nn.Module): + """ + The SwanDNA model. Encoder is a stack of SwanDNA blocks. Decoder a global average pooling, followed by a linear layer. + + Args: + input_size (int): The input size of the embedding layer. + output_size (int): The output size of the decoder layer. + max_len (int): The maximum sequence length in the data. + group_size (int): The size of groups to be shifted. + hidden_size (int): The hidden layer size for the MLPs. + mlp_dropout (float): The dropout probability for the MLPs. + layer_dropout (float): The dropout probability for the SwanDNABlock. + prenorm (str): The type of normalization for the pre-normalization step. + norm (str): The type of normalization for the post-normalization step. + """ + def __init__(self, dim_in, dim_out, layers, ks, output_size, dropout, max_len): + super().__init__() + + self.encoder = CDILLayer(dim_in, [dim_out]*layers, dim_out*2, ks, dropout) + + self.decoder = nn.Linear(dim_out, output_size) + # self.freeze_encoder() + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=1.0) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def freeze_encoder(self): + for param in self.encoder.parameters(): + param.requires_grad = False + + def forward(self, x): + x = x.float() + # x = self.embedding(x) + x = torch.permute(x, (0, 2, 1)) + x = self.encoder(x) + x = torch.permute(x, (0, 2, 1)) + x = torch.mean(x, dim=1) + x = self.decoder(x) + return x + +class Model4PretrainCDIL(nn.Module): + "CDIL Model for Pretrain : Masked LM" + def __init__(self, dim, hdim1, hdim2, kernel_size, n_layers, dropout): + super().__init__() + self.encoder = CDILLayer(dim, [hdim1]*n_layers, hdim1*2, kernel_size, dropout) + self.hidden_list = [hdim2]*n_layers + self.hidden_list[-1] = dim + self.decoder = CDILLayer(hdim1, self.hidden_list, hdim2*2, kernel_size, dropout) + # self.sigmoid = nn.Sigmoid() + + def forward(self, input_seq): + input_seq = input_seq.float() + # encoder + h = torch.permute(input_seq, (0, 2, 1)) + h = self.encoder(h) + # decoder + h = self.decoder(h) + h = torch.permute(h, (0, 2, 1)) + + return h \ No newline at end of file diff --git a/models/deeperdeepsea.py b/models/deeperdeepsea.py new file mode 100644 index 0000000..e1b75ec --- /dev/null +++ b/models/deeperdeepsea.py @@ -0,0 +1,55 @@ +import numpy as np +import torch.nn as nn + + +class DeeperDeepSEA(nn.Module): + def __init__(self, sequence_length, n_targets): + super(DeeperDeepSEA, self).__init__() + conv_kernel_size = 8 + pool_kernel_size = 4 + + self.conv_net = nn.Sequential( + nn.Conv1d(4, 320, kernel_size=conv_kernel_size), + nn.ReLU(inplace=True), + nn.Conv1d(320, 320, kernel_size=conv_kernel_size), + nn.ReLU(inplace=True), + nn.MaxPool1d( + kernel_size=pool_kernel_size, stride=pool_kernel_size), + nn.BatchNorm1d(320), + + nn.Conv1d(320, 480, kernel_size=conv_kernel_size), + nn.ReLU(inplace=True), + nn.Conv1d(480, 480, kernel_size=conv_kernel_size), + nn.ReLU(inplace=True), + nn.MaxPool1d( + kernel_size=pool_kernel_size, stride=pool_kernel_size), + nn.BatchNorm1d(480), + nn.Dropout(p=0.2), + + nn.Conv1d(480, 960, kernel_size=conv_kernel_size), + nn.ReLU(inplace=True), + nn.Conv1d(960, 960, kernel_size=conv_kernel_size), + nn.ReLU(inplace=True), + nn.BatchNorm1d(960), + nn.Dropout(p=0.2)) + + reduce_by = 2 * (conv_kernel_size - 1) + pool_kernel_size = float(pool_kernel_size) + self._n_channels = int( + np.floor( + (np.floor( + (sequence_length - reduce_by) / pool_kernel_size) + - reduce_by) / pool_kernel_size) + - reduce_by) + self.classifier = nn.Sequential( + nn.Linear(960 * self._n_channels, n_targets), + nn.ReLU(inplace=True), + nn.BatchNorm1d(n_targets), + nn.Linear(n_targets, n_targets)) + + def forward(self, x): + x = x.permute(0, 2, 1) + out = self.conv_net(x) + reshape_out = out.view(out.size(0), 960 * self._n_channels) + predict = self.classifier(reshape_out) + return predict diff --git a/models/pretraining_model.py b/models/pretraining_model.py new file mode 100644 index 0000000..63e99cd --- /dev/null +++ b/models/pretraining_model.py @@ -0,0 +1,134 @@ +from models.SwanDNA import SwanDNANetwork +import torch.nn as nn +import math +from flash_pytorch import FLASH, FLASHTransformer +import numpy as np +import torch + +class Model4Pretrain(nn.Module): + """ + SwanDNA Model for Pretrain : Masked LM + With one SwanDNA encoder and one SwanDNA decoder. + """ + def __init__(self, input_size, max_len, embedding_size, group_size, hidden_size, mlp_dropout, layer_dropout, prenorm, norm): + super().__init__() + self.max_n_layers = math.ceil(np.log2(max_len)) + self.embedding_size = (self.max_n_layers+1) * group_size + self.embedding = nn.Linear( + input_size, + self.embedding_size + ) + self.encoder = SwanDNANetwork( + max_len, + self.embedding_size, + group_size, + hidden_size, + mlp_dropout, + layer_dropout, + prenorm, + norm, + 4 + ) + # self.decoder = SwanDNAEncoder( + # max_len, + # self.embedding_size, + # group_size, + # hidden_size, + # mlp_dropout, + # layer_dropout, + # prenorm, + # norm + # ) + + self.linear = nn.Linear(self.embedding_size, input_size) + + def forward(self, input_seq): + input_seq = input_seq.float() + h = self.embedding(input_seq) + # encoder + h = self.encoder(h) + # decoder + h = self.linear(h) + + return h + + +class Model4TSNE(nn.Module): + """ + SwanDNA Model for Pretrain : Masked LM + With one SwanDNA encoder and one SwanDNA decoder. + """ + def __init__(self, input_size, max_len, embedding_size, track_size, hidden_size, mlp_dropout, layer_dropout, prenorm, norm): + super().__init__() + self.max_n_layers = math.ceil(np.log2(max_len)) + self.embedding_size = (self.max_n_layers+1) * track_size + self.embedding = nn.Linear( + input_size, + self.embedding_size + ) + self.encoder = SwanDNANetwork( + max_len, + self.embedding_size, + track_size, + hidden_size, + mlp_dropout, + layer_dropout, + prenorm, + norm + ) + + def forward(self, input_seq): + input_seq = input_seq.float() + h = self.embedding(input_seq) + # encoder + h = self.encoder(h) + + return h + + +class Model4PretrainFlash(nn.Module): + """ + SwanDNA Model for Pretrain : Masked LM + With one SwanDNA encoder and one SwanDNA decoder. + """ + def __init__(self, input_size, embedding_size, group_size, max_len): + super().__init__() + self.max_n_layers = 8 + self.max_len = max_len + self.embedding = nn.Linear( + input_size, + embedding_size + ) + + self.pos_enc = nn.Embedding(max_len, embedding_size) + + self.encoder = nn.ModuleList( + [ + FLASH( + dim = embedding_size, + group_size = group_size, # group size + causal = True, # autoregressive or not + query_key_dim = int(embedding_size/4), # query / key dimension + expansion_factor = 2., # hidden dimension = dim * expansion_factor + laplace_attn_fn = True # new Mega paper claims this is more stable than relu squared as attention function + ) + for _ in range(self.max_n_layers) + ] + ) + + self.linear = nn.Linear(embedding_size, input_size) + + def forward(self, input_seq): + input_seq = input_seq.float() + positions = torch.arange(0, self.max_len).expand(input_seq.size(0), self.max_len).cuda() + h = self.embedding(input_seq) + pos_enc = self.pos_enc(positions) + h = pos_enc + h + # input_seq = torch.permute(input_seq, (0, 2, 1)) + for layer in range(self.max_n_layers): + h = self.encoder[layer](h) + # h = torch.permute(h, (0, 2, 1)) + # decoder + h = self.linear(h) + + return h \ No newline at end of file diff --git a/models/x_formers.py b/models/x_formers.py new file mode 100644 index 0000000..0b134e2 --- /dev/null +++ b/models/x_formers.py @@ -0,0 +1,165 @@ +import math +import torch +import torch.nn as nn +from linformer import Linformer +from torch.nn import TransformerEncoder, TransformerEncoderLayer +from nystrom_attention import Nystromformer +from flash_attn.flash_attention import FlashMHA +from mega_pytorch import MegaLayer +from Other_models.S4_model import S4Model + +class PositionalEncoding(nn.Module): + def __init__(self, d_model, max_len, dropout: float = 0.1): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + position = torch.arange(max_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) + pe = torch.zeros(max_len, 1, d_model) + pe[:, 0, 0::2] = torch.sin(position * div_term) + pe[:, 0, 1::2] = torch.cos(position * div_term) + self.register_buffer('pe', pe) + + def forward(self, x): + x = x + self.pe[:x.size(0)] # [seq_len, batch_size, dim] + return self.dropout(x) + + +class XFormer(nn.Module): + def __init__(self, model, use_pos, input_size, dim, depth, heads, seq_len): + super(XFormer, self).__init__() + self.model = model + self.use_pos = use_pos + self.seq_len = seq_len + self.pos_enc = nn.Embedding(seq_len, dim) + # self.pos_encoding = PositionalEncoding(dim, seq_len) + self.linear = nn.Linear(input_size, dim) + + if model == 'transformer': + encoder_layers = TransformerEncoderLayer(dim, heads, dim) + self.former = TransformerEncoder(encoder_layers, depth) + elif model == 'nystromer': + self.former = Nystromformer( + dim=dim, + dim_head=int(dim/heads), + heads=heads, + depth=depth, + num_landmarks=256, # number of landmarks + pinv_iterations=6 + ) + elif model == 'linformer': + self.former = Linformer( + dim = dim, + seq_len = seq_len, + depth =depth, + heads = heads, + k = dim, + one_kv_head = True, + share_kv = True + ) + elif model == 'flash': + self.former = FlashMHA( + embed_dim=dim, # total channels (= num_heads * head_dim) + num_heads=heads, # number of heads + dtype=torch.float16, + ) + elif model == 'mega': + self.former = nn.Sequential( + *[MegaLayer( + dim = dim, # model dimensions + ema_heads = heads, # number of EMA heads + attn_dim_qk = dim, # dimension of queries / keys in attention + attn_dim_value = dim*2, # dimension of values in attention + laplacian_attn_fn = False, # whether to use softmax (false) or laplacian attention activation fn (true) + ) for _ in range(depth)] + ) + elif model == 's4': + self.former = S4Model( + d_input=input_size, + d_model=dim, + n_layers=depth, + dropout=0.2, + prenorm=True, + ) + print(self.former) + + def forward(self, x): + # x = x.float() + positions = torch.arange(0, self.seq_len).expand(x.size(0), self.seq_len).cuda() + x = self.linear(x) + # x = x.to(torch.float16) + if self.use_pos and self.model!="transformer": + pos_enc = self.pos_enc(positions) + x = pos_enc + x + + if self.use_pos and self.model=="transformer": + x = self.pos_encoding(x) + + if self.model == 'transformer': + x = x.permute(1, 0, 2) + x = self.former(x) + x = x.permute(1, 0, 2) + else: + x = self.former(x) + return x + + +class Model4Pretrain(nn.Module): + def __init__(self, model, depth, heads, input_size, hdim1, hdim2): + super().__init__() + + self.encoder = XFormer(model, False, input_size, hdim1, depth=depth, heads=heads) + self.decoder = XFormer(model, False, hdim1, hdim2, depth=depth, heads=heads) + self.linear = nn.Linear(hdim2, input_size) + self.sig = nn.Sigmoid() + + def forward(self, x): + x = x.float() + h = self.encoder(x) + h = self.decoder(h) + logits_lm = self.linear(h) + return self.sig(logits_lm) + + +class FormerClassifier(nn.Module): + def __init__(self, name, layers, heads, dim_in, dim_out, clf_dim, output_size, max_len): + super(FormerClassifier, self).__init__() + + self.encoder = XFormer(name, False, dim_in, dim_out, depth=layers, heads=heads, seq_len=max_len) + self.Net2 = XFormer(name, False, dim_out, clf_dim, depth=layers, heads=heads, seq_len=max_len) + self.classifier = nn.Linear(clf_dim, output_size) + # self.sig = nn.Sigmoid() + + def freeze_cdilNet(self): + for param in self.cdilNet.parameters(): + param.requires_grad = False + + def forward(self, x1, x2, idx_linear): + x1, x2 = x1.float(), x2.float() + idx_linear = idx_linear.to(torch.int64) + x1, x2 = x1.permute(0, 2, 1), x2.permute(0, 2, 1) + # x1, x2 = F.pad(x1, (0, 0, 0, 24, 0, 0)), F.pad(x2, (0, 0, 0, 24, 0, 0)) + y1 = self.encoder(x1) + y2 = self.encoder(x2) + y_vcf = y1 - y2 + y_vcf = self.Net2(y_vcf) + y_class = torch.mean(y_vcf, dim=1) + y = self.classifier(y_class) + idx_linear = idx_linear.unsqueeze(0).t() + y = torch.gather(y, 1, idx_linear) + return y + + +class Plant_FormerClassifier(nn.Module): + def __init__(self, name, layers, heads, dim_in, dim_out, clf_dim, max_seq_len, output_size): + super().__init__() + + self.encoder = XFormer(name, True, dim_in, dim_out, depth=layers, heads=heads, seq_len=max_seq_len) + self.Net2 = XFormer(name, False, dim_out, clf_dim, depth=layers, heads=heads, seq_len=max_seq_len) + self.linear = nn.Linear(clf_dim, output_size).half() + + def forward(self, seq): + en = self.encoder(seq) + de = self.Net2(en) + y_center = de[:, 400:600, :] + y = self.linear(torch.mean(y_center, dim=1)) + return y