Weak Supervision with Snorkel: Image Classification Example

13 minute read

Published:

Introduction:

In the world of machine learning, data is often hailed as the crown jewel that powers models and drives innovation. Yet, obtaining high-quality, labeled data remains a significant challenge, often demanding painstaking manual efforts from human annotators. This is where the concept of weak supervision emerges as a beacon of hope for machine learning engineers and practitioners.

Weak supervision is the art of leveraging various sources of noisy or imprecise supervision to label a large amount of data efficiently. It takes the burden off exhaustive manual labeling and opens the door to scaling up projects that might have been otherwise resource-intensive. In this post, we embark on a journey to explore the Snorkel, a powerful tool that empowers us to automate the labeling process, saving time and effort without compromising on results.

In this tutorial, tailored for machine learning engineers and enthusiasts alike, we’ll unveil the advantages of weak supervision using a practical example: Image Classification. By the end of this guide, you’ll have a basic understanding of how to harness the potential of Snorkel to streamline your image classification pipelines and achieve impressive results with reduced labeling efforts.

Whether you’re a seasoned practitioner seeking to optimize your workflow or a newcomer eager to unlock the potential of weak supervision, this tutorial will equip you with the knowledge and skills needed to elevate your machine learning projects. So, let’s dive into the world of weak supervision and see how Snorkel can revolutionize the way we approach labeling and ultimately, supercharging our machine learning models.

Are you ready to embark on this exciting journey? Let’s begin!

Data Download: Exploring the Open Images Dataset V7

Before we dive into the exciting world of weak supervision and Snorkel for image classification, we need to set the stage by obtaining the necessary data. In this tutorial, we’ll be using the Open Images Dataset V7, a rich collection of images spanning a wide array of categories. This dataset is a treasure trove for machine learning tasks, providing a diverse range of visuals that will help us showcase the power of weak supervision.

To get started, we’ll perform a series of commands to download the essential files from the Open Images Dataset V7. These files contain crucial information about class labels, class descriptions, and annotations. Below is the code snippet you’ll need to execute to gather these files:

mkdir oiv7
wget -P oiv7/ https://storage.googleapis.com/openimages/v7/oidv7-classes-trainable.txt
wget -P oiv7/ https://storage.googleapis.com/openimages/v7/oidv7-class-descriptions.csv
wget -P oiv7/ https://storage.googleapis.com/openimages/v7/oidv7-train-annotations-human-imagelabels.csv
wget -P oiv7/ https://storage.googleapis.com/openimages/v7/oidv7-val-annotations-human-imagelabels.csv
wget -P oiv7/ https://storage.googleapis.com/openimages/v7/oidv7-test-annotations-human-imagelabels.csv

In this set of commands, we create a directory named oiv7 to neatly organize the downloaded files. The downloaded files include:

  1. oidv7-classes-trainable.txt: A list of trainable (verified) class labels.
  2. oidv7-class-descriptions.csv: A CSV file containing class descriptions.
  3. oidv7-train-annotations-human-imagelabels.csv: Annotations for training images.
  4. oidv7-val-annotations-human-imagelabels.csv: Annotations for validation images.
  5. oidv7-test-annotations-human-imagelabels.csv: Annotations for test images.

Now that we have the essential data files in place, it’s time to turn our attention to the actual image files. In this section, we’ll walk through the process of downloading labeled images that are trainable according to the Open Images Dataset V7.

The provided Python code streamlines this image download process, ensuring that we only retrieve images that are relevant and fit for training.

import os
from concurrent import futures
import boto3
import botocore
import numpy as np
import pandas as pd
from tqdm import tqdm
BUCKET_NAME = 'open-images-dataset'
REGEX = r'(test|train|validation|challenge2018)/([a-fA-F0-9]*)'
BUCKET = boto3.resource(
's3', config=botocore.config.Config(
signature_version=botocore.UNSIGNED)).Bucket(BUCKET_NAME)
def download_one_image(bucket, split: str, path: str, image_id: str, progress_bar: tqdm):
"""
Download a single image from the specified split.
:param bucket: S3 bucket resource.
:param split: Dataset split ('train', 'val', 'test').
:param path: Path to store the downloaded image.
:param image_id: Image ID.
:param progress_bar: TQDM progress bar.
"""
try:
filename = os.path.join(path, f'{image_id}.jpg')
if not os.path.isfile(filename):
bucket.download_file(f'{split}/{image_id}.jpg', filename)
except botocore.exceptions.ClientError as exception:
# TODO: log
# print(f'\nERROR when downloading image splitimaid: {str(exception)}\n')
pass
progress_bar.update(1)
def get_class_label(requested_classnames: list[str]) -> dict[str, str]:
"""
Retrieve class labels for requested class names.
:param requested_classnames: List of requested class names.
:returns: Dictionary mapping class labels to class names.
"""
classnames: pd.DataFrame = pd.read_csv('oiv7/oidv7-class-descriptions.csv')
classes_trainable = set(line.strip() for line in open('oiv7/oidv7-classes-trainable.txt'))
requested_labels = dict()
for cname in requested_classnames:
ind = classnames.index[classnames['DisplayName'].str.lower() == cname.lower()].tolist()
label = classnames.values[ind][0][0]
# check if the class is trainable
if label not in classes_trainable:
raise TypeError(f"Class {cname} is not trainable!")
requested_labels[label] = cname.lower()
return requested_labels
def get_image_ids(splits: list[str], requested_labels: dict[str, str], num_images: int) -> dict[str, list]:
"""
Retrieve image IDs for requested splits and labels.
:param splits: List of dataset splits.
:param requested_labels: Dictionary mapping class labels to class names.
:param num_images: Number of images to retrieve.
:returns: Dictionary mapping splits to lists of image ids.
"""
image_ids_per_split = dict()
for split in splits:
filename = f'oiv7/oidv7-{split}-annotations-human-imagelabels.csv'
example_per_label = np.zeros(2)
image_ids = []
chunk: pd.DataFrame
with pd.read_csv(filename, chunksize=10 ** 6) as pd_reader:
for chunk in pd_reader:
if np.all(example_per_label > num_images):
break
for i, label in enumerate(requested_labels):
if example_per_label[i] > num_images:
continue
verified_chunk = chunk[(chunk['LabelName'] == label) & chunk['Confidence'] == 1.0]
example_per_label[i] += len(verified_chunk)
image_ids.extend(verified_chunk.values.tolist())
image_ids_per_split[split] = image_ids
return image_ids_per_split
def download_images(requested_labels: dict[str, str], image_ids_dict: dict[str, list], path: str):
"""
Download images based on requested labels and splits.
:param requested_labels: Dictionary mapping class labels to class names.
:param image_ids_dict: Dictionary mapping splits to lists of image ids.
:param path: Path to store the downloaded images.
"""
os.makedirs(path, exist_ok=True)
for split, image_ids in image_ids_dict.items():
if split == 'val':
split = 'validation'
os.makedirs(f'{path}/{split}', exist_ok=True)
for _, classname in requested_labels.items():
os.makedirs(f'{path}/{split}/{classname}', exist_ok=True)
progress_bar = tqdm(total=len(image_ids), desc=f'Downloading {split} images', leave=True)
with futures.ThreadPoolExecutor(5) as executor:
all_futures = [
executor.submit(download_one_image,
BUCKET, split,
f'{path}/{split}/{requested_labels[image_property[2]]}',
image_property[0], progress_bar)
for image_property in image_ids
]
for future in futures.as_completed(all_futures):
future.result()
progress_bar.close()
if __name__ == '__main__':
req_lab = get_class_label(['sea', 'Jungle'])
img_ids = get_image_ids(['train', 'val', 'test'], req_lab, 10000)
download_images(requested_labels=req_lab, image_ids_dict=img_ids, path='dataset')

Let’s break down the key components of the code:

  1. download_one_image: This function downloads a single image from the specified split (train, val, or test) and saves it to the specified path. The function uses the BUCKET resource from the boto3 library to interact with the S3 bucket.

  2. get_class_label: This function retrieves class labels for the requested class names. It ensures that the requested classes are trainable, as per the dataset specifications.

  3. get_image_ids: This function retrieves image IDs for requested splits and labels. It identifies images that match the requested class labels and have a confidence level of 1.0 (verified by human).

  4. download_images: This function orchestrates the image download process based on requested labels, splits, and paths. It uses concurrent futures to speed up the download process by using multiple threads.

By combining these components, the code provides a way to download labeled images from the Open Images Dataset V7. Next, we’ll use the code to download labeled images, specifically focusing on the classes “sea” and “Jungle.” These images will serve as our starting point for weak supervision, demonstrating how we can leverage Snorkel to automatically label and train an image classifier.

Organizing and Splitting the Dataset

In this section, we’ll walk through a code snippet that performs data splitting and organization. This step is pivotal in setting the stage for robust model training and evaluation.

"""
This module performs data splitting and organizes images into different splits based on classes.
"""
import os
import shutil
from collections import defaultdict
from pathlib import Path
import numpy as np
from sklearn.model_selection import train_test_split
from tqdm import tqdm
def stat(root_path: str) -> dict:
"""
Analyze and compute statistics on the data by counting images for each class.
:param root_path: Root directory containing the images.
:returns: Dictionary mapping class names to lists of image paths.
"""
root = Path(root_path)
all_images = list(root.glob('**/*.jpg'))
data = defaultdict(lambda: [])
for image_path in all_images:
_, _, class_name, _ = str(image_path).split('/')
data[class_name].append(image_path)
print('Data Stat:')
for class_name in data:
print(f'{class_name}: {len(data[class_name])}')
return data
def split_manual(data: dict[str, list], save_dir: str, splits: list[float] | None = None,
random_seed: int = 56) -> None:
"""
Split the data into different subsets (train, val, test) and organize images accordingly.
:param data: Dictionary mapping class names to lists of image paths.
:param save_dir: Directory to save the split data.
:param splits: List of ratios for train, val, and test splits.
:param random_seed: Random seed for reproducibility.
"""
x = []
y = []
X = dict()
Y = dict()
for class_name, image_list in data.items():
x.extend(image_list)
y.extend([class_name] * len(image_list))
if splits is None:
splits = [0.7, 0.2, 0.1]
splits = np.divide(splits, np.sum(splits))
x_train_val, X['test'], y_train_val, Y['test'] = train_test_split(x, y,
test_size=splits[2],
random_state=random_seed,
stratify=y)
X['train'], X['val'], Y['train'], Y['val'] = train_test_split(x_train_val, y_train_val,
test_size=splits[1] / (splits[1] + splits[0]),
random_state=random_seed,
stratify=y_train_val)
os.makedirs(save_dir, exist_ok=True)
save_path = Path(save_dir)
for split in X:
os.makedirs(save_path / split, exist_ok=True)
for class_name in data:
os.makedirs(save_path / split / class_name, exist_ok=True)
for img_path, class_name in tqdm(zip(X[split], Y[split]), total=len(X[split]), desc=split):
shutil.copyfile(img_path, save_path / split / class_name / Path(img_path).name)
if __name__ == '__main__':
data = stat('dataset')
split_manual(data, save_dir='data_split_manual')

Let’s delve into the key components of the code:

  1. stat: This function computes statistics on the data by counting the number of images for each class. It returns a dictionary mapping class names to lists of image paths.

  2. split_manual: This function splits the data into different subsets (train, val, test) based on specified ratios. It ensures that each subset maintains a proportional representation of different classes. We split the data into train, val and test proportional to 0.7, 0.2, and 0.1.

Labeling Functions

The heart of the Snorkel lies in creating labeling functions that generate noisy labels for our data. In this section, we’ll go though a code script that defines a set of labeling functions, each contributing to the creation of our weakly labeled dataset.

"""
This script defines labeling functions using various techniques for weak supervision with Snorkel.
"""
import os
import cv2
import fasttext.util
import numpy as np
import torch
from scipy.stats import mode
from snorkel.labeling import labeling_function
from torchvision.io import read_image
from torchvision.io.image import ImageReadMode
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
from collections.abc import Callable
JUNGLE = 0
SEA = 1
ABSTAIN = -1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT)
model.eval()
preprocess = EfficientNet_B0_Weights.DEFAULT.transforms(antialias=True)
fasttext.util.download_model('en', if_exists='ignore')
ft = fasttext.load_model('cc.en.300.bin')
v_sea = []
for word in ['sea', 'ocean', 'boat', 'fish', 'beach', 'blue']:
v_sea.append(ft.get_word_vector(word))
v_sea = np.array(v_sea)
v_jungle = []
for word in ['forest', 'jungle', 'wood', 'bush', 'green']:
v_jungle.append(ft.get_word_vector(word))
v_jungle = np.array(v_jungle)
LABELING_FUNCS = []
def add_func(f: Callable):
"""
Decorator to add a labeling function to the list.
:param f: The labeling function to add.
:returns: The input function.
"""
LABELING_FUNCS.append(f)
return f
@add_func
@labeling_function()
def check_color(x: str | os.PathLike) -> int:
"""
Labeling function to classify images based on dominant color.
:param x: Path to the image file.
:returns: Label indicating SEA, JUNGLE, or ABSTAIN.
"""
# read image into BGR format
img = cv2.imread(str(x))
if len(img.shape) < 3:
return -1
color_max = np.argmax(img.mean(axis=(0, 1)))
match color_max:
case 0:
return SEA
case 1:
return JUNGLE
case 2:
return ABSTAIN
@add_func
@labeling_function()
def check_pixel_color(x: str | os.PathLike) -> int:
"""
Labeling function to classify images based on mode of pixel colors.
:param x: Path to the image file.
:returns: Label indicating SEA, JUNGLE, or ABSTAIN.
"""
img = cv2.imread(str(x))
if len(img.shape) < 3:
return -1
pixel_max_color = mode(np.argmax(img, axis=2), axis=None).mode
match pixel_max_color:
case 0:
return SEA
case 1:
return JUNGLE
case 2:
return ABSTAIN
@add_func
@labeling_function()
def check_with_efficientNet(x: str | os.PathLike) -> int:
"""
Labeling function to classify images using EfficientNet predictions and FastText embeddings.
:param x: Path to the image file.
:returns: Label indicating SEA or JUNGLE.
"""
img = read_image(str(x), mode=ImageReadMode.RGB)
batch = preprocess(img).unsqueeze(0)
prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
category_name = EfficientNet_B0_Weights.DEFAULT.meta["categories"][class_id]
v = ft.get_sentence_vector(category_name)
dist_sea = min(np.sum(np.power(v_sea - v, 2), axis=1))
dist_jungle = min(np.sum(np.power(v_jungle - v, 2), axis=1))
return SEA if dist_sea < dist_jungle else JUNGLE

Here’s an overview of the key components of the provided code:

  1. check_color: This function classifies images based on their dominant color on average.

  2. check_pixel_color: This function classifies images based on the mode of max color per pixel.

  3. check_with_efficientNet: This function leverages EfficientNet predictions and FastText embeddings to classify images as “SEA” or “JUNGLE.” Here, first we classify the image using EfficientNet. Then, the closeness of the output label is examined against several words related to sea or jungle with FastText. Thus, if the meaning of the output label is closer to sea, we label it as SEA.

  4. Adding labeling functions: The script uses the add_func decorator to add each labeling function to the LABELING_FUNCS list, which will be used later.

By combining these labeling functions, we will generate a set of noisy labels for our images. These labels are the cornerstone of our weak supervision approach, allowing us to utilize Snorkel’s capabilities.

Weak Supervision Labeling with Snorkel

The power of weak supervision comes to life when we leverage labeling functions to create noisy labels for our dataset. Now, we’ll explore a script that performs weak supervision labeling with Snorkel.

"""
This script performs weak supervision labeling using labeling functions with Snorkel.
"""
import os
import pickle
from pathlib import Path
from snorkel.labeling import LFApplier, LFAnalysis
from snorkel.labeling.model import LabelModel
from labeling_funcs import LABELING_FUNCS
def weak_supervision_labeling(root: str = 'data_split_manual', splits=None, root_save: str = 'data_snorkel'):
"""
Perform weak supervision labeling using labeling functions with snorkel and save label data to pickle files.
:param root: Root directory containing the input images (root/split/class_name).
:param splits: List of data splits to process (e.g. ['train', 'val']).
:param root_save: Root directory to save the labeled data.
"""
if splits is None:
splits = ['train', 'val']
root_new = Path(root_save)
os.makedirs(root_new, exist_ok=True)
for split in splits:
root_split = Path(root / split)
all_images = list(root_split.glob('**/*.jpg'))
print(f'{len(all_images)} images found in {split}')
applier = LFApplier(LABELING_FUNCS)
L_train = applier.apply(all_images)
print(LFAnalysis(L=L_train, lfs=LABELING_FUNCS).lf_summary())
label_model = LabelModel(cardinality=2, verbose=True)
label_model.fit(L_train=L_train, n_epochs=500, log_freq=100, seed=123)
label_split = label_model.predict_proba(L_train)
os.makedirs(root_new / split, exist_ok=True)
with open(root_new / split / 'data.pkl', 'wb') as f:
pickle.dump({"labels": label_split, "images": all_images}, f)
if __name__ == '__main__':
weak_supervision_labeling()

Here’s a breakdown of the key elements in the provided code:

  1. DATA PREPARATION: The script starts by specifying the root directory containing the input images and the splits (e.g., ‘train’, ‘val’) to process. Additionally, it defines the root directory to save the labeled data.

  2. LFApplier: The labeling functions (LABELING_FUNCS) defined in the previous code script are applied to all images in the specified split. The result is a label matrix (L_train) where each row corresponds to an image and each column corresponds to a labeling function.

  3. LFAnalysis: This step provides an analysis of the labeling functions’ performance on the data. It generates a summary that indicates how well the labeling functions agree or disagree on assigning labels to images.

  4. label_model: A LabelModel is trained using the label matrix (L_train). This model learns to estimate the true underlying labels by accounting for the noise introduced by the labeling functions.

  5. Label Prediction: The label model predicts probabilities of labels for each image based on the noisy labels from the labeling functions.

  6. Saving Labeled Data: The labeled data, including the predicted labels and image paths, is saved to pickle files. This data will serve as the input for our model training process.

By executing this script, we perform the crucial step of labeling our data using weak supervision techniques. Snorkel helps us manage the uncertainty introduced by the labeling functions, creating a labeled dataset that reflects the inherent noise in the weakly supervised data.

Dataloaders

Next, we implement our dataloaders for the supervised and the weakly supervised procedures.

"""
Data module
This module contains utility functions and classes for training a Convolutional Neural Network (CNN) using weak
supervision techniques. It includes functions to create data loaders for different dataset splits, as well as a custom
dataset class for weakly supervised learning with Snorkel labels.
Classes:
- SnorkelDataset: Custom dataset class for weakly supervised learning using Snorkel labels.
Functions:
- get_transforms: Get data transformation pipelines for different dataset splits.
- get_data_loader: Get data loaders for the specified dataset splits.
- get_data_loader_snorkel: Get data loaders for weakly supervised learning using Snorkel labels.
Note:
This module assumes that the data is organized into different splits (e.g., 'train', 'val', 'test')
within the root directory.
"""
import os
import pickle
import numpy as np
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.datasets import ImageFolder
def get_data_loader(root: str, splits: list['str']) -> dict[str, DataLoader]:
"""
Creates and returns data loaders for different dataset splits.
:param root: Root directory containing the dataset.
:param splits: List of dataset splits (e.g., ['train', 'val', 'test']).
:return: Dataloaders, a dictionary containing data loaders for each split.
"""
dataloaders = dict()
transforms = get_transforms()
for split in splits:
dataset = ImageFolder(os.path.join(root, split), transform=transforms[split])
dataloaders[split] = DataLoader(dataset=dataset,
shuffle=True if split == 'train' else 0,
num_workers=6,
batch_size=128,
pin_memory=True)
return dataloaders
def get_data_loader_snorkel(root: str, splits: list['str'], label_type: str) -> dict[str, DataLoader]:
"""
Get data loaders for the specified dataset splits for snorkel generated dataset.
:param root: Root directory of the dataset that contains "data.pkl" files in each split.
:param splits: List of dataset splits (e.g., ['train', 'val', 'test']).
:param label_type: Type of labels, either 'hard' or 'soft'.
:return: Dictionary of data loaders for each split.
"""
dataloaders = dict()
transforms = get_transforms()
for split in splits:
dataset = SnorkelDataset(os.path.join(root, split, 'data.pkl'),
transforms=transforms[split],
label_type=label_type)
dataloaders[split] = DataLoader(dataset=dataset,
shuffle=True if split == 'train' else 0,
num_workers=6,
batch_size=64,
pin_memory=True)
return dataloaders
class SnorkelDataset(Dataset):
"""
Custom dataset for weakly supervised learning using Snorkel labels.
"""
def __init__(self, data_path: str, label_type: str, transforms: transforms.Compose | None = None):
"""
:param data_path: Path to the pickled data file.
:param label_type: Type of labels, either 'hard' or 'soft'.
:param transforms: Data transformations to apply to images.
"""
with open(data_path, 'rb') as f:
data = pickle.load(f)
self.labels = data['labels']
if label_type.lower() == 'hard':
self.labels = np.argmax(self.labels, axis=1)
self.images = data["images"]
self.transforms = transforms
def __len__(self) -> int:
"""
Get the length of the dataset.
:return: Number of images in the dataset.
"""
return len(self.images)
def __getitem__(self, idx: int):
"""
Get an item from the dataset.
:param idx: Index of the item to retrieve.
:return: Tuple containing the image and its corresponding label.
"""
image = Image.open(str(self.images[idx])).convert("RGB")
if transforms:
image = self.transforms(image)
return image, self.labels[idx]
def get_transforms() -> dict[str, transforms.Compose]:
"""
Get data transformation pipelines for different dataset splits.
:return: Dictionary of data transformations for 'train', 'val', and 'test' splits.
"""
data_transforms = {
'train': transforms.Compose([
transforms.Resize(256),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(5),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'test': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
return data_transforms
view raw data.py hosted with ❤ by GitHub

Here’s a breakdown of the key elements in the provided code:

  1. get_transforms(): This function provides data transformation pipelines tailored for different dataset splits: ‘train’, ‘val’, and ‘test’. These transformations include resizing, cropping, flipping, rotation, normalization, and tensor conversion.

  2. SnorkelDataset: This custom dataset class is designed for weakly supervised learning using Snorkel labels. It takes a path to a pickled data file and a label type (‘hard’ or ‘soft’) as inputs. In weakly supervised learning, “hard labels” refer to discrete, definite labels assigned to data points, indicating clear categorization (e.g., ‘SEA’ or ‘JUNGLE’). On the other hand, “soft labels” represent probabilistic or continuous assignments, reflecting the uncertainty or ambiguity in classification. The class loads images and corresponding labels from the data file and applies the specified transformations.

  3. get_data_loader(): This function creates and returns data loaders for different dataset splits. It utilizes the ImageFolder dataset from PyTorch, which organizes data into class folders. The dataloaders are configured with appropriate transformations and batch sizes for training, validation, and testing.

  4. get_data_loader_snorkel(): This function generates data loaders for the specified dataset splits using Snorkel-generated labels. It utilizes the SnorkelDataset class to load images and labels from pickled data files, enabling weakly supervised learning. The dataloaders are configured similarly to those in get_data_loader(), tailored for Snorkel-labeled data.

By leveraging these utility functions and classes, we ensure that our data is well-prepared and ready to be fed into our CNN model.

Training function

Here is the training function. I skip the description of this part, since it follows a common pattern of model training with pytorch.

"""
This script defines the training function for our CNN.
"""
import copy
import os
import time
import torch
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from torcheval.metrics import MulticlassPrecision
from tqdm import tqdm
def train(
model: nn.Module,
dataloaders: dict,
optimizer,
criterion,
scheduler,
device,
run_name: str,
writer: SummaryWriter | None = None,
num_epochs: int = 10):
"""
Train the CNN model.
:param model: The neural network model to be trained.
:param dataloaders: Dictionary containing dataloaders for training and validation.
:param optimizer: The optimizer for updating model parameters.
:param criterion: Loss criterion for training.
:param scheduler: Learning rate scheduler.
:param device: Device for computation ('cpu' or 'cuda').
:param run_name: Name of the run for saving model checkpoints.
:param writer: SummaryWriter for logging training progress (optional).
:param num_epochs: Number of training epochs.
:return: Best trained model.
"""
os.makedirs('./weights', exist_ok=True)
metric = MulticlassPrecision(num_classes=2, average=None)
since = time.time()
dataset_sizes = {phase: len(dl.dataset) for phase, dl in dataloaders.items()}
best_model_wts = copy.deepcopy(model.state_dict())
best_loss = 1_000_000.
best_acc = 0
model.to(device)
for epoch in range(num_epochs):
for phase in ['train', 'val']:
if phase == 'train':
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode
running_loss = 0.0
running_corrects = 0
running_samples = 0
running_pred_vec = []
running_target_vec = []
# Iterate over data.
pbar = tqdm(dataloaders[phase], desc=f'Epoch {epoch:3}/{num_epochs:3} - {phase:6}', unit=' batch')
for inputs, labels in pbar:
inputs = inputs.to(device)
# labels = labels.to(device)
running_samples += inputs.size(0)
# zero the parameter gradients
optimizer.zero_grad()
# forward
# track history if only in train
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
outputs = outputs.cpu()
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
# backward + optimize only if in training phase
if phase == 'train':
loss.backward()
optimizer.step()
# statistics
if len(labels.shape) > 1:
_, labels = torch.max(labels, 1)
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels)
running_pred_vec.extend(preds.numpy().tolist())
running_target_vec.extend(labels.numpy().tolist())
metric.update(torch.tensor(running_pred_vec), torch.tensor(running_target_vec))
m = metric.compute()
pbar.set_postfix(loss=running_loss / running_samples,
accuracy=running_corrects.item() / running_samples * 100,
Precision=[round(m[0].item(), 3), round(m[1].item(), 3)])
if phase == 'val':
scheduler.step(running_loss)
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects.double() / dataset_sizes[phase]
if writer is not None:
writer.add_scalar('LOSS/{}'.format(phase), epoch_loss, epoch)
writer.add_scalar('ACC/{}'.format(phase), epoch_acc, epoch)
writer.add_scalar('OPTIM/LR', scheduler.get_lr()[-1], epoch)
# deep copy the model
if phase == 'val' and epoch_loss < best_loss:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())
torch.save({
'epoch': epoch,
'model_state_dict': best_model_wts,
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'accuracy': best_acc
}, f'weights/{run_name}_best_model.pt')
pbar.close()
time_elapsed = time.time() - since
print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
print(f'Best val accuracy: {best_acc * 100}')
model.load_state_dict(best_model_wts)
return model
view raw train.py hosted with ❤ by GitHub

NOTE: The condition if len(labels.shape) > 1: deals with soft labels.

Inference

With the following code, we can evaluate all weights saved in the weight_folder on the test data.

"""
This script tests trained CNN models using various metrics on the test dataset.
We assume that there is a test split in the data root
"""
import os
from pathlib import Path
import torch
from torcheval.metrics import MulticlassPrecision, BinaryAccuracy
from tqdm import tqdm
from data import get_data_loader
from model import CNNModel
metrics = [MulticlassPrecision(num_classes=2, average=None),
BinaryAccuracy()]
def test_model(weight_folder: str, data_root: str, device='cpu') -> None:
"""
Test the trained CNN models using specified metrics.
:param weight_folder: Path to the folder containing model weight files (eg. original_best_model.pt).
:param data_root: Root directory of the dataset.
:param device: Device to run the evaluation on (default is 'cpu').
"""
model_wights = [os.path.join(weight_folder, x) for x in os.listdir(weight_folder) if x.endswith('_best_model.pt')]
results = []
print(f'Found {len(model_wights)} Models.')
for model_weight in model_wights:
model = CNNModel()
model.load_state_dict(torch.load(model_weight)['model_state_dict'])
model.eval()
model.to(device)
dataloaders = get_data_loader(root=data_root, splits=['test'])
predictions = []
labels = []
pbar = tqdm(dataloaders['test'], desc=f'{" ".join(Path(model_weight).name.split("_")[:-2]):15}', unit=' batch')
for inputs, label in pbar:
inputs = inputs.to(device)
# labels = labels.to(device)
with torch.no_grad():
outputs = model(inputs)
outputs = outputs.cpu()
_, preds = torch.max(outputs, 1)
predictions.extend(preds.numpy().tolist())
labels.extend(label.numpy().tolist())
for metric in metrics:
metric.update(torch.tensor(predictions), torch.tensor(labels))
results.append(f'{" ".join(Path(model_weight).name.split("_")[:-2]):15} '
f'Precision = {metrics[0].compute().numpy()} '
f'ACC = {metrics[1].compute():.4f}')
for res in results:
print(res)
view raw inference.py hosted with ❤ by GitHub

Results

Now let’s see the results. With the following script, we can train the model with original supervised procedure or two semi-supervised ones (soft and hard labels).

"""
This script provides functions to train and test a CNN model with/without the Snorkel framework.
"""
import torch
from data import get_data_loader_snorkel, get_data_loader
from inference import test_model
from model import CNNModel
from train import train
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def Train():
"""
Train a CNN model with/without the Snorkel framework.
"""
model = CNNModel()
dataloaders = {
'original': get_data_loader('data_split_manual', splits=['train', 'val']),
'snorkel_hard': get_data_loader_snorkel('data_snorkel', splits=['train', 'val'], label_type='hard'),
'snorkel_soft': get_data_loader_snorkel('data_snorkel', splits=['train', 'val'], label_type='soft')
}
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.9, patience=5)
for run_name, dataloader in dataloaders.items():
train(
model=model, dataloaders=dataloader, optimizer=optimizer, criterion=criterion, scheduler=lr_scheduler,
device=device,
run_name=run_name,
num_epochs=20
)
def Test():
"""
Test a trained CNN model using the test dataset.
"""
test_model(weight_folder='weights', data_root='data_split_manual', device=device)
if __name__ == '__main__':
Train()
Test()
view raw main.py hosted with ❤ by GitHub

The result is

original Precision = [0.94285715 0.921875 ] ACC = 0.9259
snorkel soft Precision = [0.93 0.90949225] ACC = 0.9132
snorkel hard Precision = [0.88235295 0.9226687 ] ACC = 0.9144
view raw result hosted with ❤ by GitHub

Original Labels:

  • Precision: For the original labels, the model achieves high precision scores for both classes (‘SEA’ and ‘JUNGLE’), indicating that when it predicts a class, it’s usually correct. Specifically, the precision values of approximately 0.94 and 0.92 for ‘SEA’ and ‘JUNGLE’ respectively demonstrate the model’s accuracy in its predictions.
  • Accuracy (ACC): The overall accuracy of 0.926 suggests that the model is successful in correctly classifying approximately 92.6% of the images in the dataset.

  • Snorkel Soft Labels:

  • Precision: With soft labels, where the labeling functions provide probabilistic or continuous assignments, the precision scores remain relatively high but show a slight decrease compared to the original labels. The values of around 0.93 for ‘SEA’ and 0.909 for ‘JUNGLE’ indicate a minor decrease in precision.
  • Accuracy (ACC): The accuracy of 0.9132, while slightly lower than the original labels, still demonstrates a strong performance, capturing approximately 91.3% of the dataset correctly.

Snorkel Hard Labels:

  • Precision: When using hard labels (discrete, definite labels) provided by Snorkel, there is a more noticeable decrease in precision for the ‘SEA’ class, dropping to approximately 0.882. However, the precision for ‘JUNGLE’ remains high at around 0.923.
  • Accuracy (ACC): The overall accuracy of 0.9144, although slightly lower than the original labels, showcases the model’s ability to maintain a strong classification performance with Snorkel hard labels.

Conclusion

In the pursuit of harnessing the power of weak supervision, our journey has traversed a landscape where precision meets ambiguity and accuracy coexists with uncertainty. The application of labeling functions in the Snorkel framework has enabled us to embrace our prior knowledge in our data, offering a nuanced perspective on image classification. We also observed the advantage of using soft labels: their ability to mitigate the impact of class imbalances. In our dataset, where ‘SEA’ and ‘JUNGLE’ classes exhibited varying instances, soft labels allowed for a more nuanced representation of uncertainty. This nuanced understanding ensured that the precision for both classes stayed relatively close, compared to the more discrete hard labels. In an imbalanced dataset, the imprecision introduced by labeling functions might disproportionately affect the minority class. Soft labels, by representing class assignments probabilistically, provided a flexibility that allowed the model to balance the precision between the classes more effectively. This balancing act is crucial, especially in applications where misclassifying the minority class carries significant consequences. It’s important to note that while our labeling functions in this example were relatively straightforward to implement, many real-world scenarios pose complex challenges. Designing accurate labeling functions can be intricate, requiring domain expertise and careful consideration. Despite these challenges, our study demonstrates that noisy labels, when harnessed intelligently through weak supervision techniques, can still offer valuable insights and contribute to robust model training. Our exploration underscores the resilience of machine learning models in the face of noisy or uncertain labels. Even when labeling functions are not perfect, the intelligent integration of these noisy annotations can lead to significant advancements in model performance. Embracing the inherent noise in weakly supervised data and leveraging techniques like Snorkel not only expands the scope of feasible applications but also highlights the adaptability and learning potential of modern machine learning systems.

GitHub

You can find the code at the project GitHub repository.

Reference

[1] Snorkel

[2] ChatGPT

[3] Open Images Downloader