Implement BERT for Sentiment Analysis in PyTorch
This implementation demonstrates how to build a BERT-based model from scratch using PyTorch and the Hugging Face Transformers library for sentiment classification.
Required Libraries and Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer
from datasets import load_dataset
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")The Self-Attention Mechanism
The core of the Transformer architecture is the Multi-Head Self-Attention mechanism, which allows the model to focus on different parts of the input sequence.
class SelfAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.qkv = nn.Linear(embed_dim, 3 * embed_dim)
self.out = nn.Linear(embed_dim, embed_dim)
def forward(self, x):
B, T, C = x.shape
qkv = self.qkv(x)
qkv = qkv.view(B, T, 3, self.num_heads, self.head_dim)
q = qkv[:, :, 0].transpose(1, 2)
k = qkv[:, :, 1].transpose(1, 2)
v = qkv[:, :, 2].transpose(1, 2)
attn = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
attn = F.softmax(attn, dim=-1)
out = attn @ v
out = out.transpose(1, 2).contiguous().view(B, T, C)
return self.out(out)Transformer Encoder Block
The Encoder Block consists of layer normalization, self-attention, and a feed-forward neural network with residual connections.
class EncoderBlock(nn.Module):
def __init__(self, embed_dim, num_heads, ff_dim):
super().__init__()
self.ln1 = nn.LayerNorm(embed_dim)
self.attn = SelfAttention(embed_dim, num_heads)
self.ln2 = nn.LayerNorm(embed_dim)
self.ff = nn.Sequential(
nn.Linear(embed_dim, ff_dim),
nn.ReLU(),
nn.Linear(ff_dim, embed_dim)
)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.ff(self.ln2(x))
return xBERT Model Architecture
The BERT model combines token embeddings, positional embeddings, and multiple encoder layers to process text data.
class BERT(nn.Module):
def __init__(self, vocab_size, embed_dim=256, num_heads=8, num_layers=4, ff_dim=1024, max_len=128, num_classes=2):
super().__init__()
self.token_emb = nn.Embedding(vocab_size, embed_dim)
self.pos_emb = nn.Embedding(max_len, embed_dim)
self.layers = nn.Sequential(*[
EncoderBlock(embed_dim, num_heads, ff_dim)
for _ in range(num_layers)
])
self.ln = nn.LayerNorm(embed_dim)
self.classifier = nn.Linear(embed_dim, num_classes)
def forward(self, x):
B, T = x.shape
pos = torch.arange(0, T, device=x.device).unsqueeze(0)
x = self.token_emb(x) + self.pos_emb(pos)
x = self.layers(x)
x = self.ln(x)
cls_token = x[:, 0, :]
logits = self.classifier(cls_token)
return logitsData Preparation and Tokenization
We utilize the IMDB dataset and tokenize the text to prepare it for the training process.
dataset = load_dataset("imdb")
def tokenize(example):
return tokenizer(
example["text"],
truncation=True,
padding="max_length",
max_length=128
)
dataset = dataset.map(tokenize, batched=True)
dataset.set_format(type="torch", columns=["input_ids", "label"])Training and Evaluation
The model is trained using the AdamW optimizer and evaluated on sample text for sentiment prediction.
device = "cuda" if torch.cuda.is_available() else "cpu"
model = BERT(vocab_size=tokenizer.vocab_size).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
train_loader = torch.utils.data.DataLoader(
dataset["train"],
batch_size=16,
shuffle=True
)
for epoch in range(2):
model.train()
total_loss = 0
for batch in train_loader:
x = batch["input_ids"].to(device)
y = batch["label"].to(device)
logits = model(x)
loss = F.cross_entropy(logits, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}")
model.eval()
text = "This movie was amazing and very interesting!"
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
inputs = inputs["input_ids"].to(device)
with torch.no_grad():
logits = model(inputs)
pred = torch.argmax(logits, dim=-1)
print("Prediction:", "Positive" if pred.item() == 1 else "Negative")