Jae-Kyung Cho Being unique is better than being perfect

MLOps study - Raviraja Week 1: W&B

Week 1 is about how to use Weights & Biases. These are tools I learned to use when I attended the ICRA workshop and when collaborating with friends from Google Brain, and they are definitely very convenient to use. Studying the theory of deep learning and writing papers is important, but I came to think that quickly learning good tools and making them your own is also important. There are three main features you can use in WandB. Before using these three features, let’s start with configuration.




WandB configuring

WandB seems to have really good compatibility with PyTorch Lightning. Of course, the integration with PyTorch was also good, but if you had to fetch the API and log manually, with PyTorch Lightning you could use many features just by declaring the Logger once. The arguments you can use when declaring are as follows.

< Arguments >

  • name: Display name for the run.
  • save_dir: Path where data is saved.
  • offline: Run offline (data can be streamed later to wandb servers).
  • id: Sets the version, mainly used to resume a previous run.
  • version: Same as id.
  • anonymous: Enables or explicitly disables anonymous logging.
  • project: The name of the project to which this run will belong.
  • log_model: Log checkpoints created by ~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint as W&B artifacts. latest and best aliases are automatically set.
    • log_model=’all’: save all checkpoints
    • log_model=True: save last checkpoint
from pytorch_lightning.loggers import WandbLogger
wandb_logger = WandbLogger(project="MLOps Basics")

trainer = pl.Trainer(
        max_epochs=3,
        logger=wandb_logger,
        callbacks=[],
)




WandB logging

If you use torchmetrics, you can conveniently compute various metrics and save them to the wandb logger. The fact that you can simply access self.log seems to be an advantage of PyTorch Lightning!!

When logging, you can use various options.

  • prog_bar=True : shows the progress bar
  • on_epoch=True : takes the average value over the batches within an epoch
  • on_step=True : logs the value for every batch -> it’s common to set this to False during validation.

The logging frequency differs by LightningModule method.

  • training_step: logs for each batch
  • training_epoch_end: logs per epoch
  • validation_step: logs for each batch
  • validation_epoch_end: logs per epoch

In epoch_end, it performs the work of aggregating the results of each step at the epoch level, so in the step you must pass the results in the form of a dictionary.

class ColaModel(pl.LightningModule):
    def __init__(self, model_name="google/bert_uncased_L-2_H-128_A-2", lr=3e-5):
        self.train_accuracy_metric = torchmetrics.Accuracy()
        self.val_accuracy_metric = torchmetrics.Accuracy()

    def training_step(self, batch, batch_idx):
        outputs = self.forward(
            batch["input_ids"], batch["attention_mask"], labels=batch["label"]
        )
        preds = torch.argmax(outputs.logits, 1)
        train_acc = self.train_accuracy_metric(preds, batch["label"])

        # You can access the wandb logger via self.log. No need to declare it separately.
        self.log("train/loss", outputs.loss, prog_bar=True, on_epoch=True)
        self.log("train/acc", train_acc, prog_bar=True, on_epoch=True)
        return outputs.loss
    
    def validation_step(self, batch, batch_idx):
        labels = batch["label"]
        outputs = self.forward(
            batch["input_ids"], batch["attention_mask"], labels=batch["label"]
        )
        preds = torch.argmax(outputs.logits, 1)
        valid_acc = self.val_accuracy_metric(preds, labels)

        self.log("valid/loss", outputs.loss, prog_bar=True, on_step=True)
        return {"labels": labels, "logits": outputs.logits}

    def validation_epoch_end(self, outputs):
        # Recomputing per epoch using the results of validation_step
        labels = torch.cat([x["labels"] for x in outputs])
        logits = torch.cat([x["logits"] for x in outputs])
        preds = torch.argmax(logits, 1)

        data = confusion_matrix(labels.numpy(), preds.numpy())
        return data




WandB plotting

Plotting is possible in several ways, but the way I like the most is integrating with matplotlib. Just by putting a plot made with plt as you’d normally use it into wandb, visualization is possible.

def validation_epoch_end(self, outputs):
    # Recomputing per epoch using the results of validation_step
    labels = torch.cat([x["labels"] for x in outputs])
    logits = torch.cat([x["logits"] for x in outputs])
    preds = torch.argmax(logits, 1)

    data = confusion_matrix(labels.numpy(), preds.numpy())
    df_cm = pd.DataFrame(data, columns=np.unique(labels), index=np.unique(labels))
    df_cm.index.name = "Actual"
    df_cm.columns.name = "Predicted"
    plt.figure(figsize=(7, 4))
    plot = sns.heatmap(
        df_cm, cmap="Blues", annot=True, annot_kws={"size": 16}
    )  # font size
    self.logger.experiment.log({"Confusion Matrix": wandb.Image(plot)})

You just need to convert the figure drawn with plt into a wandb.Image and put it inside self.logger.experiment.log!! What an incredible advance in technology~~!~!~!




Keep watching data samples

The thing I’m most curious about during the process of training a model is which samples it works successfully on, and which samples it doesn't work properly on. This is because you need to know where it fails and where it works in order to determine whether there’s overfitting and to identify ways to improve the model. There’s a callback logger for this too!! (Truly amazing…)

If you log in Table format in the on_validation_end method, you can grasp it more easily. Of course, you do have to do a bit of coding!!

class SamplesVisualisationLogger(pl.Callback):
    def __init__(self, datamodule):
        super().__init__()
        self.datamodule = datamodule

    def on_validation_end(self, trainer, pl_module):
        # can be done on complete dataset also
        val_batch = next(iter(self.datamodule.val_dataloader()))
        sentences = val_batch["sentence"]

        # get the predictions
        outputs = pl_module(val_batch["input_ids"], val_batch["attention_mask"])
        preds = torch.argmax(outputs.logits, 1)
        labels = val_batch["label"]

        # predicted and labelled data
        df = pd.DataFrame(
            {"Sentence": sentences, "Label": labels.numpy(), "Predicted": preds.numpy()}
        )

        # wrongly predicted data
        wrong_df = df[df["Label"] != df["Predicted"]]

        # Logging wrongly predicted dataframe as a table
        trainer.logger.experiment.log(
            {
                "examples": wandb.Table(dataframe=wrong_df, allow_mixed_types=True),
                "global_step": trainer.global_step,
            }
        )

Download ipynb file




references:

Comments