Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

train with pytorch lightning #262

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -139,4 +139,4 @@ pix2tex/model/checkpoints/**
.vscode
.DS_Store
test/*

*.prf
112 changes: 112 additions & 0 deletions pix2tex/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,118 @@ def detokenize(tokens, tokenizer):
return toks


def evaluate_step(model: Model, dataset_tokenizer, data_batch, args: Munch, name: str = 'test'):
"""One step to evaluate the model. Returns bleu score on the data batch

Args:
model (torch.nn.Module): the model
data_batch : one test data batch (seq, im)
args (Munch): arguments

Returns:
Tuple[float, float, float]: BLEU score of batch, normed edit distance, token accuracy
"""
(seq, im) = data_batch
edit_dists = []
log = {}
bleu_score, edit_distance, token_accuracy = 0, 1, 0

dec = model.generate(im, temperature=args.get('temperature', .2))
pred = detokenize(dec, dataset_tokenizer)
truth = detokenize(seq['input_ids'], dataset_tokenizer)

# blue score
bleu_score = metrics.bleu_score(pred, [alternatives(x) for x in truth])

# edit distance
for predi, truthi in zip(token2str(dec, dataset_tokenizer), token2str(seq['input_ids'], dataset_tokenizer)):
ts = post_process(truthi)
if len(ts) > 0:
edit_dists.append(distance(post_process(predi), ts)/len(ts))
edit_distance = np.mean(edit_dists) if len(edit_dists) > 0 else 1

# token accuracy
tgt_seq = seq['input_ids'][:, 1:]
shape_diff = dec.shape[1]-tgt_seq.shape[1]
if shape_diff < 0:
dec = torch.nn.functional.pad(dec, (0, -shape_diff), "constant", args.pad_token)
elif shape_diff > 0:
tgt_seq = torch.nn.functional.pad(tgt_seq, (0, shape_diff), "constant", args.pad_token)
mask = torch.logical_or(tgt_seq != args.pad_token, dec != args.pad_token)
tok_acc = (dec == tgt_seq)[mask].float().mean().item()
token_accuracy = np.mean(tok_acc)

log[name+'/bleu'] = bleu_score
log[name+'/token_acc'] = token_accuracy
log[name+'/edit_distance'] = edit_distance

if args.wandb:
pred = token2str(dec, dataset_tokenizer)
truth = token2str(seq['input_ids'], dataset_tokenizer)
table = wandb.Table(columns=["Truth", "Prediction"])
for k in range(min([len(pred), args.test_samples])):
table.add_data(post_process(truth[k]), post_process(pred[k]))
log[name+'/examples'] = table
wandb.log(log)
return bleu_score, edit_distance, token_accuracy


def evaluate_step__(model: Model, dataset_tokenizer, data_batch, args: Munch, name: str = 'test'):
"""One step to evaluate the model. Returns bleu score on the data batch

Args:
model (torch.nn.Module): the model
data_batch : test data batch
args (Munch): arguments

Returns:
Tuple[float, float, float]: BLEU score of batch, normed edit distance, token accuracy
"""
(seq, im) = data_batch
bleus, edit_dists, token_acc = [], [], []
bleu_score, edit_distance, token_accuracy = 0, 1, 0
log = {}

# loss = decoder(tgt_seq, mask=tgt_mask, context=encoded)
dec = model.generate(im, temperature=args.get('temperature', .2))
pred = detokenize(dec, dataset_tokenizer)
truth = detokenize(seq['input_ids'], dataset_tokenizer)
bleus.append(metrics.bleu_score(pred, [alternatives(x) for x in truth]))
for predi, truthi in zip(token2str(dec, dataset_tokenizer), token2str(seq['input_ids'], dataset_tokenizer)):
ts = post_process(truthi)
if len(ts) > 0:
edit_dists.append(distance(post_process(predi), ts)/len(ts))
# dec = dec.cpu()
tgt_seq = seq['input_ids'][:, 1:]
shape_diff = dec.shape[1]-tgt_seq.shape[1]
if shape_diff < 0:
dec = torch.nn.functional.pad(dec, (0, -shape_diff), "constant", args.pad_token)
elif shape_diff > 0:
tgt_seq = torch.nn.functional.pad(tgt_seq, (0, shape_diff), "constant", args.pad_token)
mask = torch.logical_or(tgt_seq != args.pad_token, dec != args.pad_token)
tok_acc = (dec == tgt_seq)[mask].float().mean().item()
token_acc.append(tok_acc)

if len(bleus) > 0:
bleu_score = np.mean(bleus)
log[name+'/bleu'] = bleu_score
if len(edit_dists) > 0:
edit_distance = np.mean(edit_dists)
log[name+'/edit_distance'] = edit_distance
if len(token_acc) > 0:
token_accuracy = np.mean(token_acc)
log[name+'/token_acc'] = token_accuracy
if args.wandb:
pred = token2str(dec, dataset_tokenizer)
truth = token2str(seq['input_ids'], dataset_tokenizer)
table = wandb.Table(columns=["Truth", "Prediction"])
for k in range(min([len(pred), args.test_samples])):
table.add_data(post_process(truth[k]), post_process(pred[k]))
log[name+'/examples'] = table
wandb.log(log)
return bleu_score, edit_distance, token_accuracy


@torch.no_grad()
def evaluate(model: Model, dataset: Im2LatexDataset, args: Munch, num_batches: int = None, name: str = 'test'):
"""evaluates the model. Returns bleu score on the dataset
Expand Down
132 changes: 130 additions & 2 deletions pix2tex/train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pix2tex.dataset.dataset import Im2LatexDataset
import os
import sys
import argparse
import logging
import yaml
Expand All @@ -9,10 +10,14 @@
from tqdm.auto import tqdm
import wandb
import torch.nn as nn
from pix2tex.eval import evaluate
from pix2tex.eval import evaluate, evaluate_step
from pix2tex.models import get_model
# from pix2tex.utils import *
from pix2tex.utils import in_model_path, parse_args, seed_everything, get_optimizer, get_scheduler, gpu_memory_check
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar, OnExceptionCheckpoint
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme
from pytorch_lightning.loggers import CSVLogger


def train(args):
Expand Down Expand Up @@ -79,6 +84,126 @@ def save_models(e, step=0):
save_models(e, step=len(dataloader))


class DataModule(pl.LightningDataModule):
def __init__(self, args, **kwargs):
super().__init__()
self.args = args

train_dataloader = Im2LatexDataset().load(args.data)
train_dataloader.update(**args, test=False)
val_dataloader = Im2LatexDataset().load(args.valdata)
val_args = args.copy()
val_args.update(batchsize=args.testbatchsize, keep_smaller_batches=True, test=True)
val_dataloader.update(**val_args)
dataset_tokenizer = val_dataloader.tokenizer

self.dataset_tokenizer = dataset_tokenizer
self.train_data = train_dataloader
self.valid_data = val_dataloader

def train_dataloader(self):
return self.train_data

def val_dataloader(self):
return self.valid_data


class OCR_Model(pl.LightningModule):
def __init__(self, args, dataset_tokenizer, **kwargs):
super().__init__()
self.args = args
self.dataset_tokenizer = dataset_tokenizer

model = get_model(args)
if args.load_chkpt is not None:
model.load_state_dict(torch.load(args.load_chkpt))
self.model = model
if torch.cuda.is_available() and not args.no_cuda:
gpu_memory_check(model, args)

microbatch = args.get('micro_batchsize', -1)
if microbatch == -1:
microbatch = args.batchsize
self.microbatch = microbatch

def forward(self, x):
return self.model(x)

def configure_optimizers(self):
args = self.args
opt = get_optimizer(args.optimizer)(self.model.parameters(), args.lr, betas=args.betas)
scheduler = get_scheduler(args.scheduler)(opt, step_size=args.lr_step, gamma=args.gamma)
return {
"optimizer": opt,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "train_loss",
}
}

def training_step(self, train_batch, batch_idx):
args = self.args
(seq, im) = train_batch
if seq is not None and im is not None:
total_loss = 0
for j in range(0, len(im), self.microbatch):
tgt_seq, tgt_mask = seq['input_ids'][j:j+self.microbatch], seq['attention_mask'][j:j+self.microbatch].bool()
loss = self.model.data_parallel(im[j:j+self.microbatch], device_ids=args.gpu_devices, tgt_seq=tgt_seq, mask=tgt_mask)*self.microbatch/args.batchsize
total_loss += loss
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)
if args.wandb:
wandb.log({'train/loss': total_loss})

self.log('train_loss', total_loss, on_epoch=True, on_step=False, prog_bar=True)
return total_loss

def validation_step(self, val_batch, batch_idx):
bleu_score, edit_distance, token_accuracy = evaluate_step(self.model, self.dataset_tokenizer, val_batch, self.args, name='val')
metric_dict = {'bleu_score': bleu_score, 'edit_distance': edit_distance, 'token_accuracy': token_accuracy}
self.log_dict(metric_dict, on_epoch=True, on_step=False, prog_bar=True)
return metric_dict

def on_train_epoch_end(self):
if self.args.wandb:
wandb.log({'train/epoch': self.current_epoch+1})


class OCR():
def __init__(self, args):
self.args = args
self.logger = CSVLogger(save_dir='pl_logs', name='')
self.out_path = os.path.join(args.model_path, args.name)
os.makedirs(self.out_path, exist_ok=True)
self.data_model_setup()
self.callbacks_setup()

def data_model_setup(self):
self.Data = DataModule(self.args)
dataset_tokenizer = self.Data.dataset_tokenizer
self.Model = OCR_Model(self.args, dataset_tokenizer)

def callbacks_setup(self):
save_name = f'pl_{args.name}' + '_{epoch}_{step}'

# NOTE: currently lightning doesn't support multiple monitor metrics
save_ckpt = ModelCheckpoint(monitor='bleu_score', mode='max', filename=save_name, dirpath=self.out_path,
every_n_epochs=self.args.save_freq, save_top_k=10, save_last=True)

# BUG: exp_save_name was alaways like pl_pix2tex_0_0.ckpt. possibly a bug in lightning
exp_save_name = f'pl_pix2tex_{self.Model.current_epoch}_{self.Model.global_step}'
excpt = OnExceptionCheckpoint(dirpath=self.out_path, filename=exp_save_name)
bar = RichProgressBar(leave=True, theme=RichProgressBarTheme(
description='green_yellow', progress_bar='green1', progress_bar_finished='green1'))
self.callbacks = [save_ckpt, excpt, bar]

def fit(self):
args = self.args
accelerator = 'gpu' if torch.cuda.is_available() and not args.no_cuda else 'cpu'
trainer = pl.Trainer(accelerator=accelerator, callbacks=self.callbacks, logger=self.logger,
max_epochs=args.epochs, val_check_interval=args.sample_freq)
trainer.fit(self.Model, self.Data)


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Train model')
parser.add_argument('--config', default=None, help='path to yaml config file', type=str)
Expand All @@ -99,4 +224,7 @@ def save_models(e, step=0):
args.id = wandb.util.generate_id()
wandb.init(config=dict(args), resume='allow', name=args.name, id=args.id)
args = Munch(wandb.config)
train(args)
# train(args)

ocr = OCR(args)
ocr.fit()
28 changes: 28 additions & 0 deletions run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#!/bin/bash

if [ $1 == "setup" ]; then
echo "Setting up python virtual environment"
echo "Entering virtual environment"
source ./venv/bin/activate

pip3 install 'pix2tex[train]'
pip3 install pytorch-lightning rich

# install and login wandb
pip3 install wandb
wandb login

elif [ $1 == "generate" ]; then
echo "Generate images dataset"
# eg. python3 -m pix2tex.dataset.dataset --equations path_to_textfile --images path_to_images --out dataset.pkl
python3 -m pix2tex.dataset.dataset --equations pix2tex/dataset/data/math.txt --images pix2tex/dataset/data/train --out pix2tex/dataset/data/train.pkl
python3 -m pix2tex.dataset.dataset --equations pix2tex/dataset/data/math.txt --images pix2tex/dataset/data/val --out pix2tex/dataset/data/val.pkl

elif [ $1 == "train" ]; then
echo "Training model"
python3 -m pix2tex.train --config pix2tex/model/settings/config.yaml

else
echo "Invalid argument"
echo "Usage: ./run.sh [setup|generate|train|test]"
fi