Train SegFeat on TIMIT
Exceeded Kaggle runtime limit
!git clone https://github.com/felixkreuk/SegFeat
%cd /kaggle/working
%%writefile segfeat.patch
diff --git a/dataloader.py b/dataloader.py
index 44d6fba..f141162 100644
--- a/dataloader.py
+++ b/dataloader.py
@@ -87,14 +87,14 @@ def extract_features(wav_file, hparams):
# extract mel-spectrogram
if hparams.feats == 'mel':
- spect = librosa.feature.melspectrogram(wav,
+ spect = librosa.feature.melspectrogram(y=wav,
sr=sr,
n_fft=hparams.n_fft,
hop_length=hparams.hop_length,
n_mels=hparams.rnn_input_size)
# extract mfcc
elif hparams.feats == 'mfcc':
- spect = librosa.feature.mfcc(wav,
+ spect = librosa.feature.mfcc(y=wav,
sr=sr,
n_fft=hparams.n_fft,
hop_length=hparams.hop_length,
@@ -208,7 +208,7 @@ class WavPhnDataset(Dataset):
raise NotImplementedError
def process_file(self, wav_path):
- phn_path = wav_path.replace("wav", "phn")
+ phn_path = wav_path.replace("WAV", "PHN")
# load audio
spect = extract_features(wav_path, self.hparams)
@@ -235,7 +235,7 @@ class WavPhnDataset(Dataset):
def _make_dataset(self):
files = []
- wavs = list(iter_find_files(self.wav_path, "*.wav"))
+ wavs = list(iter_find_files(self.wav_path, "*.WAV"))
if self.hparams.devrun:
wavs = wavs[:self.hparams.devrun_size]
@@ -266,9 +266,9 @@ class TimitDataset(WavPhnDataset):
@staticmethod
def get_datasets(hparams):
- train_dataset = TimitDataset(join(hparams.wav_path, 'train'),
+ train_dataset = TimitDataset(join(hparams.wav_path, 'TRAIN'),
hparams)
- test_dataset = TimitDataset(join(hparams.wav_path, 'test'),
+ test_dataset = TimitDataset(join(hparams.wav_path, 'TEST'),
hparams)
train_len = len(train_dataset)
%%writefile lightning.patch
diff --git a/main.py b/main.py
index 62cbb2c..43845e4 100644
--- a/main.py
+++ b/main.py
@@ -11,7 +11,7 @@ import torch.nn.functional as F
from loguru import logger
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
-from pytorch_lightning.logging import TestTubeLogger
+from pytorch_lightning.loggers import TensorBoardLogger
from torch.backends import cudnn
from torch.utils.data import DataLoader, Dataset
@@ -44,15 +44,13 @@ def main(hparams):
mode='min'
)
- tt_logger = TestTubeLogger(
+ tb_logger = TensorBoardLogger(
save_dir=hparams.run_dir,
name="lightning_logs",
- debug=False,
- create_git_tag=False
)
checkpoint = ModelCheckpoint(
- filepath=model_save_path,
+ dirpath=model_save_path,
save_top_k=1,
verbose=True,
monitor='val_f1_at_2',
@@ -60,19 +58,17 @@ def main(hparams):
)
trainer = Trainer(
- logger=tt_logger,
- overfit_pct=hparams.overfit,
+ logger=tb_logger,
check_val_every_n_epoch=1,
min_epochs=1,
max_epochs=hparams.epochs,
- nb_sanity_val_steps=4,
- checkpoint_callback=None,
- val_percent_check=hparams.val_percent_check,
+ num_sanity_val_steps=4,
+ callbacks=[early_stop, checkpoint],
+ limit_val_batches=hparams.val_percent_check,
val_check_interval=hparams.val_check_interval,
- early_stop_callback=None,
- gpus=hparams.gpus,
- show_progress_bar=False,
- distributed_backend=None,
+ devices="auto",
+ accelerator="auto",
+ enable_progress_bar=True,
)
if not hparams.test:
diff --git a/model.py b/model.py
index 12c3542..b0a8956 100644
--- a/model.py
+++ b/model.py
@@ -138,7 +138,7 @@ class Segmentor(nn.Module):
results = {}
# feed through rnn
- x = nn.utils.rnn.pack_padded_sequence(x, length, batch_first=True, enforce_sorted=False)
+ x = nn.utils.rnn.pack_padded_sequence(x, length.cpu(), batch_first=True, enforce_sorted=False)
rnn_out, _ = self.rnn(x)
rnn_out, _ = nn.utils.rnn.pad_packed_sequence(rnn_out, batch_first=True)
rnn_cum = torch.cumsum(rnn_out, dim=1)
diff --git a/solver.py b/solver.py
index 46672db..5cf3c28 100644
--- a/solver.py
+++ b/solver.py
@@ -19,7 +19,7 @@ from utils import PrecisionRecallMetricMultiple, StatsMeter
class Solver(LightningModule):
def __init__(self, config):
super(Solver, self).__init__()
- self.hparams = config
+ self.save_hyperparameters(config)
if config.dataset == "timit":
self.datasetClass = TimitDataset
@@ -46,23 +46,23 @@ class Solver(LightningModule):
'test': StatsMeter()}
self._device = 'cuda' if config.cuda else 'cpu'
+ self.validation_step_outputs = []
+ self.test_step_outputs = []
+
self.build_model()
logger.info(f"running on {self._device}")
logger.info(f"rnn input size: {config.rnn_input_size}")
logger.info(f"{self.segmentor}")
- @pl.data_loader
def train_dataloader(self):
self.train_loader = DataLoader(self.train_dataset,
batch_size=self.config.batch_size,
shuffle=True,
collate_fn=collate_fn_padd,
num_workers=6)
- logger.info(f"input shape: {self.train_dataset[0][0].shape}")
logger.info(f"training set length {len(self.train_dataset)}")
return self.train_loader
- @pl.data_loader
def val_dataloader(self):
self.valid_loader = DataLoader(self.valid_dataset,
batch_size=self.config.batch_size,
@@ -72,7 +72,6 @@ class Solver(LightningModule):
logger.info(f"validation set length {len(self.valid_dataset)}")
return self.valid_loader
- @pl.data_loader
def test_dataloader(self):
self.test_loader = DataLoader(self.test_dataset,
batch_size=self.config.batch_size,
@@ -200,8 +199,6 @@ class Solver(LightningModule):
for output in outputs:
loss = output[f'{prefix}_loss']
- if self.trainer.use_dp:
- loss = torch.mean(loss)
loss_mean += loss
loss_mean /= len(outputs)
@@ -243,19 +240,28 @@ class Solver(LightningModule):
logger.info(f"\nEVAL {prefix} STATS:\n{json.dumps(metrics, sort_keys=True, indent=4)}\n")
- return metrics
+ for k, v in metrics.items():
+ self.log(k, v, prog_bar=(k == f'{prefix}_f1_at_2'))
def validation_step(self, data_batch, batch_i):
- return self.generic_eval_step(data_batch, batch_i, 'val')
+ out = self.generic_eval_step(data_batch, batch_i, 'val')
+ self.validation_step_outputs.append(out)
+ return out
- def validation_epoch_end(self, outputs):
- return self.generic_eval_end(outputs, 'val')
+ def on_validation_epoch_end(self):
+ outputs = self.validation_step_outputs
+ self.generic_eval_end(outputs, 'val')
+ self.validation_step_outputs.clear()
def test_step(self, data_batch, batch_i):
- return self.generic_eval_step(data_batch, batch_i, 'test')
-
- def test_epoch_end(self, outputs):
- return self.generic_eval_end(outputs, 'test')
+ out = self.generic_eval_step(data_batch, batch_i, 'test')
+ self.test_step_outputs.append(out)
+ return out
+
+ def on_test_epoch_end(self):
+ outputs = self.test_step_outputs
+ self.generic_eval_end(outputs, 'test')
+ self.test_step_outputs.clear()
def configure_optimizers(self):
optimizer = {'adam': torch.optim.Adam(self.segmentor.parameters(), lr=self.config.lr),
diff --git a/utils.py b/utils.py
index 599c482..c13444f 100644
--- a/utils.py
+++ b/utils.py
@@ -37,12 +37,13 @@ class PrecisionRecallMetric:
for (y, yhat) in zip(batch_y, batch_yhat):
y, yhat = np.array(y), np.array(yhat)
y, yhat = y[1:-1], yhat[1:-1]
- for yhat_i in yhat:
- min_dist = np.abs(y - yhat_i).min()
- precision_counter += (min_dist <= self.tolerance)
- for y_i in y:
- min_dist = np.abs(yhat - y_i).min()
- recall_counter += (min_dist <= self.tolerance)
+ if len(yhat) > 0 and len(y) > 0:
+ for yhat_i in yhat:
+ min_dist = np.abs(y - yhat_i).min()
+ precision_counter += (min_dist <= self.tolerance)
+ for y_i in y:
+ min_dist = np.abs(yhat - y_i).min()
+ recall_counter += (min_dist <= self.tolerance)
pred_counter += len(yhat)
gt_counter += len(y)
%cd SegFeat
!git apply ../segfeat.patch
!git apply ../lightning.patch
%%writefile requirements.txt
torch
torchaudio
torchvision
pytorch-lightning
boltons
loguru
librosa
numpy
pandas
soundfile
tqdm
!pip install -r requirements.txt
!python main.py --wav_path /kaggle/input/darpa-timit-acousticphonetic-continuous-speech/data --dataset timit --delta_feats --dist_feats