MLOps study - Raviraja Week 1: W&B
13 Oct 2022Week 1 is about how to use Weights & Biases. These are tools I learned to use when I attended the ICRA workshop and when collaborating with friends from Google Brain, and they are definitely very convenient to use. Studying the theory of deep learning and writing papers is important, but I came to think that quickly learning good tools and making them your own is also important. There are three main features you can use in WandB. Before using these three features, let’s start with configuration.
WandB configuring
WandB seems to have really good compatibility with PyTorch Lightning. Of course, the integration with PyTorch was also good, but if you had to fetch the API and log manually, with PyTorch Lightning you could use many features just by declaring the Logger once. The arguments you can use when declaring are as follows.
< Arguments >
- name: Display name for the run.
- save_dir: Path where data is saved.
- offline: Run offline (data can be streamed later to wandb servers).
- id: Sets the version, mainly used to resume a previous run.
- version: Same as id.
- anonymous: Enables or explicitly disables anonymous logging.
- project: The name of the project to which this run will belong.
- log_model: Log checkpoints created by ~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint as W&B artifacts. latest and best aliases are automatically set.
- log_model=’all’: save all checkpoints
- log_model=True: save last checkpoint
from pytorch_lightning.loggers import WandbLogger
wandb_logger = WandbLogger(project="MLOps Basics")
trainer = pl.Trainer(
max_epochs=3,
logger=wandb_logger,
callbacks=[],
)
WandB logging
If you use torchmetrics, you can conveniently compute various metrics and save them to the wandb logger. The fact that you can simply access self.log seems to be an advantage of PyTorch Lightning!!
When logging, you can use various options.
- prog_bar=True : shows the progress bar
- on_epoch=True : takes the average value over the batches within an epoch
- on_step=True : logs the value for every batch -> it’s common to set this to False during validation.
The logging frequency differs by LightningModule method.
- training_step: logs for each batch
- training_epoch_end: logs per epoch
- validation_step: logs for each batch
- validation_epoch_end: logs per epoch
In epoch_end, it performs the work of aggregating the results of each step at the epoch level, so in the step you must pass the results in the form of a dictionary.
class ColaModel(pl.LightningModule):
def __init__(self, model_name="google/bert_uncased_L-2_H-128_A-2", lr=3e-5):
self.train_accuracy_metric = torchmetrics.Accuracy()
self.val_accuracy_metric = torchmetrics.Accuracy()
def training_step(self, batch, batch_idx):
outputs = self.forward(
batch["input_ids"], batch["attention_mask"], labels=batch["label"]
)
preds = torch.argmax(outputs.logits, 1)
train_acc = self.train_accuracy_metric(preds, batch["label"])
# You can access the wandb logger via self.log. No need to declare it separately.
self.log("train/loss", outputs.loss, prog_bar=True, on_epoch=True)
self.log("train/acc", train_acc, prog_bar=True, on_epoch=True)
return outputs.loss
def validation_step(self, batch, batch_idx):
labels = batch["label"]
outputs = self.forward(
batch["input_ids"], batch["attention_mask"], labels=batch["label"]
)
preds = torch.argmax(outputs.logits, 1)
valid_acc = self.val_accuracy_metric(preds, labels)
self.log("valid/loss", outputs.loss, prog_bar=True, on_step=True)
return {"labels": labels, "logits": outputs.logits}
def validation_epoch_end(self, outputs):
# Recomputing per epoch using the results of validation_step
labels = torch.cat([x["labels"] for x in outputs])
logits = torch.cat([x["logits"] for x in outputs])
preds = torch.argmax(logits, 1)
data = confusion_matrix(labels.numpy(), preds.numpy())
return data
WandB plotting
Plotting is possible in several ways, but the way I like the most is integrating with matplotlib. Just by putting a plot made with plt as you’d normally use it into wandb, visualization is possible.
def validation_epoch_end(self, outputs):
# Recomputing per epoch using the results of validation_step
labels = torch.cat([x["labels"] for x in outputs])
logits = torch.cat([x["logits"] for x in outputs])
preds = torch.argmax(logits, 1)
data = confusion_matrix(labels.numpy(), preds.numpy())
df_cm = pd.DataFrame(data, columns=np.unique(labels), index=np.unique(labels))
df_cm.index.name = "Actual"
df_cm.columns.name = "Predicted"
plt.figure(figsize=(7, 4))
plot = sns.heatmap(
df_cm, cmap="Blues", annot=True, annot_kws={"size": 16}
) # font size
self.logger.experiment.log({"Confusion Matrix": wandb.Image(plot)})
You just need to convert the figure drawn with plt into a wandb.Image and put it inside self.logger.experiment.log!! What an incredible advance in technology~~!~!~!
Keep watching data samples
The thing I’m most curious about during the process of training a model is which samples it works successfully on, and which samples it doesn't work properly on. This is because you need to know where it fails and where it works in order to determine whether there’s overfitting and to identify ways to improve the model. There’s a callback logger for this too!! (Truly amazing…)
If you log in Table format in the on_validation_end method, you can grasp it more easily. Of course, you do have to do a bit of coding!!
class SamplesVisualisationLogger(pl.Callback):
def __init__(self, datamodule):
super().__init__()
self.datamodule = datamodule
def on_validation_end(self, trainer, pl_module):
# can be done on complete dataset also
val_batch = next(iter(self.datamodule.val_dataloader()))
sentences = val_batch["sentence"]
# get the predictions
outputs = pl_module(val_batch["input_ids"], val_batch["attention_mask"])
preds = torch.argmax(outputs.logits, 1)
labels = val_batch["label"]
# predicted and labelled data
df = pd.DataFrame(
{"Sentence": sentences, "Label": labels.numpy(), "Predicted": preds.numpy()}
)
# wrongly predicted data
wrong_df = df[df["Label"] != df["Predicted"]]
# Logging wrongly predicted dataframe as a table
trainer.logger.experiment.log(
{
"examples": wandb.Table(dataframe=wrong_df, allow_mixed_types=True),
"global_step": trainer.global_step,
}
)