!git clone https://github.com/felixkreuk/SegFeat
%cd /kaggle/working
%%writefile segfeat.patch
diff --git a/dataloader.py b/dataloader.py
index 44d6fba..f141162 100644
--- a/dataloader.py
+++ b/dataloader.py
@@ -87,14 +87,14 @@ def extract_features(wav_file, hparams):
 
     # extract mel-spectrogram
     if hparams.feats == 'mel':
-        spect = librosa.feature.melspectrogram(wav,
+        spect = librosa.feature.melspectrogram(y=wav,
                                                sr=sr,
                                                n_fft=hparams.n_fft,
                                                hop_length=hparams.hop_length,
                                                n_mels=hparams.rnn_input_size)
     # extract mfcc
     elif hparams.feats == 'mfcc':
-        spect = librosa.feature.mfcc(wav,
+        spect = librosa.feature.mfcc(y=wav,
                                      sr=sr,
                                      n_fft=hparams.n_fft,
                                      hop_length=hparams.hop_length,
@@ -208,7 +208,7 @@ class WavPhnDataset(Dataset):
         raise NotImplementedError
 
     def process_file(self, wav_path):
-        phn_path = wav_path.replace("wav", "phn")
+        phn_path = wav_path.replace("WAV", "PHN")
 
         # load audio
         spect = extract_features(wav_path, self.hparams)
@@ -235,7 +235,7 @@ class WavPhnDataset(Dataset):
 
     def _make_dataset(self):
         files = []
-        wavs = list(iter_find_files(self.wav_path, "*.wav"))
+        wavs = list(iter_find_files(self.wav_path, "*.WAV"))
         if self.hparams.devrun:
             wavs = wavs[:self.hparams.devrun_size]
 
@@ -266,9 +266,9 @@ class TimitDataset(WavPhnDataset):
 
     @staticmethod
     def get_datasets(hparams):
-        train_dataset = TimitDataset(join(hparams.wav_path, 'train'),
+        train_dataset = TimitDataset(join(hparams.wav_path, 'TRAIN'),
                                      hparams)
-        test_dataset  = TimitDataset(join(hparams.wav_path, 'test'),
+        test_dataset  = TimitDataset(join(hparams.wav_path, 'TEST'),
                                      hparams)
 
         train_len   = len(train_dataset)
%%writefile lightning.patch
diff --git a/main.py b/main.py
index 62cbb2c..43845e4 100644
--- a/main.py
+++ b/main.py
@@ -11,7 +11,7 @@ import torch.nn.functional as F
 from loguru import logger
 from pytorch_lightning import Trainer
 from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
-from pytorch_lightning.logging import TestTubeLogger
+from pytorch_lightning.loggers import TensorBoardLogger
 from torch.backends import cudnn
 from torch.utils.data import DataLoader, Dataset
 
@@ -44,15 +44,13 @@ def main(hparams):
         mode='min'
     )
 
-    tt_logger = TestTubeLogger(
+    tb_logger = TensorBoardLogger(
         save_dir=hparams.run_dir,
         name="lightning_logs",
-        debug=False,
-        create_git_tag=False
     )
 
     checkpoint = ModelCheckpoint(
-        filepath=model_save_path,
+        dirpath=model_save_path,
         save_top_k=1,
         verbose=True,
         monitor='val_f1_at_2',
@@ -60,19 +58,17 @@ def main(hparams):
     )
 
     trainer = Trainer(
-            logger=tt_logger,
-            overfit_pct=hparams.overfit,
+            logger=tb_logger,
             check_val_every_n_epoch=1,
             min_epochs=1,
             max_epochs=hparams.epochs,
-            nb_sanity_val_steps=4,
-            checkpoint_callback=None,
-            val_percent_check=hparams.val_percent_check,
+            num_sanity_val_steps=4,
+            callbacks=[early_stop, checkpoint],
+            limit_val_batches=hparams.val_percent_check,
             val_check_interval=hparams.val_check_interval,
-            early_stop_callback=None,
-            gpus=hparams.gpus,
-            show_progress_bar=False,
-            distributed_backend=None,
+            devices="auto",
+            accelerator="auto",
+            enable_progress_bar=True,
             )
 
     if not hparams.test:
diff --git a/model.py b/model.py
index 12c3542..b0a8956 100644
--- a/model.py
+++ b/model.py
@@ -138,7 +138,7 @@ class Segmentor(nn.Module):
         results = {}
 
         # feed through rnn
-        x = nn.utils.rnn.pack_padded_sequence(x, length, batch_first=True, enforce_sorted=False)
+        x = nn.utils.rnn.pack_padded_sequence(x, length.cpu(), batch_first=True, enforce_sorted=False)
         rnn_out, _ = self.rnn(x)
         rnn_out, _ = nn.utils.rnn.pad_packed_sequence(rnn_out, batch_first=True)
         rnn_cum = torch.cumsum(rnn_out, dim=1)
diff --git a/solver.py b/solver.py
index 46672db..5cf3c28 100644
--- a/solver.py
+++ b/solver.py
@@ -19,7 +19,7 @@ from utils import PrecisionRecallMetricMultiple, StatsMeter
 class Solver(LightningModule):
     def __init__(self, config):
         super(Solver, self).__init__()
-        self.hparams = config
+        self.save_hyperparameters(config)
 
         if config.dataset == "timit":
             self.datasetClass = TimitDataset
@@ -46,23 +46,23 @@ class Solver(LightningModule):
                         'test':  StatsMeter()}
         self._device = 'cuda' if config.cuda else 'cpu'
 
+        self.validation_step_outputs = []
+        self.test_step_outputs = []
+
         self.build_model()
         logger.info(f"running on {self._device}")
         logger.info(f"rnn input size: {config.rnn_input_size}")
         logger.info(f"{self.segmentor}")
 
-    @pl.data_loader
     def train_dataloader(self):
         self.train_loader = DataLoader(self.train_dataset,
                                        batch_size=self.config.batch_size,
                                        shuffle=True,
                                        collate_fn=collate_fn_padd,
                                        num_workers=6)
-        logger.info(f"input shape: {self.train_dataset[0][0].shape}")
         logger.info(f"training set length {len(self.train_dataset)}")
         return self.train_loader
 
-    @pl.data_loader
     def val_dataloader(self):
         self.valid_loader = DataLoader(self.valid_dataset,
                                        batch_size=self.config.batch_size,
@@ -72,7 +72,6 @@ class Solver(LightningModule):
         logger.info(f"validation set length {len(self.valid_dataset)}")
         return self.valid_loader
 
-    @pl.data_loader
     def test_dataloader(self):
         self.test_loader  = DataLoader(self.test_dataset,
                                        batch_size=self.config.batch_size,
@@ -200,8 +199,6 @@ class Solver(LightningModule):
 
         for output in outputs:
             loss = output[f'{prefix}_loss']
-            if self.trainer.use_dp:
-                loss = torch.mean(loss)
             loss_mean += loss
 
         loss_mean /= len(outputs)
@@ -243,19 +240,28 @@ class Solver(LightningModule):
 
         logger.info(f"\nEVAL {prefix} STATS:\n{json.dumps(metrics, sort_keys=True, indent=4)}\n")
 
-        return metrics
+        for k, v in metrics.items():
+            self.log(k, v, prog_bar=(k == f'{prefix}_f1_at_2'))
 
     def validation_step(self, data_batch, batch_i):
-        return self.generic_eval_step(data_batch, batch_i, 'val')
+        out = self.generic_eval_step(data_batch, batch_i, 'val')
+        self.validation_step_outputs.append(out)
+        return out
 
-    def validation_epoch_end(self, outputs):
-        return self.generic_eval_end(outputs, 'val')
+    def on_validation_epoch_end(self):
+        outputs = self.validation_step_outputs
+        self.generic_eval_end(outputs, 'val')
+        self.validation_step_outputs.clear()
 
     def test_step(self, data_batch, batch_i):
-        return self.generic_eval_step(data_batch, batch_i, 'test')
-
-    def test_epoch_end(self, outputs):
-        return self.generic_eval_end(outputs, 'test')
+        out = self.generic_eval_step(data_batch, batch_i, 'test')
+        self.test_step_outputs.append(out)
+        return out
+
+    def on_test_epoch_end(self):
+        outputs = self.test_step_outputs
+        self.generic_eval_end(outputs, 'test')
+        self.test_step_outputs.clear()
 
     def configure_optimizers(self):
         optimizer = {'adam':     torch.optim.Adam(self.segmentor.parameters(), lr=self.config.lr),
diff --git a/utils.py b/utils.py
index 599c482..c13444f 100644
--- a/utils.py
+++ b/utils.py
@@ -37,12 +37,13 @@ class PrecisionRecallMetric:
         for (y, yhat) in zip(batch_y, batch_yhat):
             y, yhat = np.array(y), np.array(yhat)
             y, yhat = y[1:-1], yhat[1:-1]
-            for yhat_i in yhat:
-                min_dist = np.abs(y - yhat_i).min()
-                precision_counter += (min_dist <= self.tolerance)
-            for y_i in y:
-                min_dist = np.abs(yhat - y_i).min()
-                recall_counter += (min_dist <= self.tolerance)
+            if len(yhat) > 0 and len(y) > 0:
+                for yhat_i in yhat:
+                    min_dist = np.abs(y - yhat_i).min()
+                    precision_counter += (min_dist <= self.tolerance)
+                for y_i in y:
+                    min_dist = np.abs(yhat - y_i).min()
+                    recall_counter += (min_dist <= self.tolerance)
             pred_counter += len(yhat)
             gt_counter += len(y)
 
%cd SegFeat
!git apply ../segfeat.patch
!git apply ../lightning.patch
%%writefile requirements.txt
torch
torchaudio
torchvision
pytorch-lightning
boltons
loguru
librosa
numpy
pandas
soundfile
tqdm
!pip install -r requirements.txt
!python main.py --wav_path /kaggle/input/darpa-timit-acousticphonetic-continuous-speech/data --dataset timit --delta_feats --dist_feats