from copy import deepcopy
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import (
DataLoader,
RandomSampler,
SequentialSampler,
TensorDataset,
random_split,
)
"""
DATALOADER UTILITIES
"""
[docs]def get_bert_dataloader(df, tokenizer, max_seq_length=64, batch_size=16, mode="train"):
"""
Helper function for generating dataloaders for BERT
"""
dataset = df_to_bert_dataset(df, max_seq_length, tokenizer)
if mode == "validate":
val_sampler = SequentialSampler(dataset)
val_loader = DataLoader(dataset, sampler=val_sampler, batch_size=batch_size)
return val_loader
if mode == "distill":
distill_sampler = SequentialSampler(dataset)
distill_loader = DataLoader(
dataset, sampler=distill_sampler, batch_size=batch_size
)
return distill_loader
elif mode == "train":
train_sampler = RandomSampler(dataset)
train_loader = DataLoader(dataset, sampler=train_sampler, batch_size=batch_size)
return train_loader
[docs]def df_to_bert_dataset(df, max_length, tokenizer):
input_ids, attention_masks, labels = df_to_bert_format(df, max_length, tokenizer)
dataset = TensorDataset(input_ids, attention_masks, labels)
return dataset