Weak Supervision with Snorkel: Image Classification Example
Published:
Introduction:
In the world of machine learning, data is often hailed as the crown jewel that powers models and drives innovation. Yet, obtaining high-quality, labeled data remains a significant challenge, often demanding painstaking manual efforts from human annotators. This is where the concept of weak supervision emerges as a beacon of hope for machine learning engineers and practitioners.
Weak supervision is the art of leveraging various sources of noisy or imprecise supervision to label a large amount of data efficiently. It takes the burden off exhaustive manual labeling and opens the door to scaling up projects that might have been otherwise resource-intensive. In this post, we embark on a journey to explore the Snorkel, a powerful tool that empowers us to automate the labeling process, saving time and effort without compromising on results.
In this tutorial, tailored for machine learning engineers and enthusiasts alike, we’ll unveil the advantages of weak supervision using a practical example: Image Classification. By the end of this guide, you’ll have a basic understanding of how to harness the potential of Snorkel to streamline your image classification pipelines and achieve impressive results with reduced labeling efforts.
Whether you’re a seasoned practitioner seeking to optimize your workflow or a newcomer eager to unlock the potential of weak supervision, this tutorial will equip you with the knowledge and skills needed to elevate your machine learning projects. So, let’s dive into the world of weak supervision and see how Snorkel can revolutionize the way we approach labeling and ultimately, supercharging our machine learning models.
Are you ready to embark on this exciting journey? Let’s begin!
Data Download: Exploring the Open Images Dataset V7
Before we dive into the exciting world of weak supervision and Snorkel for image classification, we need to set the stage by obtaining the necessary data. In this tutorial, we’ll be using the Open Images Dataset V7, a rich collection of images spanning a wide array of categories. This dataset is a treasure trove for machine learning tasks, providing a diverse range of visuals that will help us showcase the power of weak supervision.
To get started, we’ll perform a series of commands to download the essential files from the Open Images Dataset V7. These files contain crucial information about class labels, class descriptions, and annotations. Below is the code snippet you’ll need to execute to gather these files:
mkdir oiv7 | |
wget -P oiv7/ https://storage.googleapis.com/openimages/v7/oidv7-classes-trainable.txt | |
wget -P oiv7/ https://storage.googleapis.com/openimages/v7/oidv7-class-descriptions.csv | |
wget -P oiv7/ https://storage.googleapis.com/openimages/v7/oidv7-train-annotations-human-imagelabels.csv | |
wget -P oiv7/ https://storage.googleapis.com/openimages/v7/oidv7-val-annotations-human-imagelabels.csv | |
wget -P oiv7/ https://storage.googleapis.com/openimages/v7/oidv7-test-annotations-human-imagelabels.csv |
In this set of commands, we create a directory named oiv7 to neatly organize the downloaded files. The downloaded files include:
- oidv7-classes-trainable.txt: A list of trainable (verified) class labels.
- oidv7-class-descriptions.csv: A CSV file containing class descriptions.
- oidv7-train-annotations-human-imagelabels.csv: Annotations for training images.
- oidv7-val-annotations-human-imagelabels.csv: Annotations for validation images.
- oidv7-test-annotations-human-imagelabels.csv: Annotations for test images.
Now that we have the essential data files in place, it’s time to turn our attention to the actual image files. In this section, we’ll walk through the process of downloading labeled images that are trainable according to the Open Images Dataset V7.
The provided Python code streamlines this image download process, ensuring that we only retrieve images that are relevant and fit for training.
import os | |
from concurrent import futures | |
import boto3 | |
import botocore | |
import numpy as np | |
import pandas as pd | |
from tqdm import tqdm | |
BUCKET_NAME = 'open-images-dataset' | |
REGEX = r'(test|train|validation|challenge2018)/([a-fA-F0-9]*)' | |
BUCKET = boto3.resource( | |
's3', config=botocore.config.Config( | |
signature_version=botocore.UNSIGNED)).Bucket(BUCKET_NAME) | |
def download_one_image(bucket, split: str, path: str, image_id: str, progress_bar: tqdm): | |
""" | |
Download a single image from the specified split. | |
:param bucket: S3 bucket resource. | |
:param split: Dataset split ('train', 'val', 'test'). | |
:param path: Path to store the downloaded image. | |
:param image_id: Image ID. | |
:param progress_bar: TQDM progress bar. | |
""" | |
try: | |
filename = os.path.join(path, f'{image_id}.jpg') | |
if not os.path.isfile(filename): | |
bucket.download_file(f'{split}/{image_id}.jpg', filename) | |
except botocore.exceptions.ClientError as exception: | |
# TODO: log | |
# print(f'\nERROR when downloading image splitima≥id: {str(exception)}\n') | |
pass | |
progress_bar.update(1) | |
def get_class_label(requested_classnames: list[str]) -> dict[str, str]: | |
""" | |
Retrieve class labels for requested class names. | |
:param requested_classnames: List of requested class names. | |
:returns: Dictionary mapping class labels to class names. | |
""" | |
classnames: pd.DataFrame = pd.read_csv('oiv7/oidv7-class-descriptions.csv') | |
classes_trainable = set(line.strip() for line in open('oiv7/oidv7-classes-trainable.txt')) | |
requested_labels = dict() | |
for cname in requested_classnames: | |
ind = classnames.index[classnames['DisplayName'].str.lower() == cname.lower()].tolist() | |
label = classnames.values[ind][0][0] | |
# check if the class is trainable | |
if label not in classes_trainable: | |
raise TypeError(f"Class {cname} is not trainable!") | |
requested_labels[label] = cname.lower() | |
return requested_labels | |
def get_image_ids(splits: list[str], requested_labels: dict[str, str], num_images: int) -> dict[str, list]: | |
""" | |
Retrieve image IDs for requested splits and labels. | |
:param splits: List of dataset splits. | |
:param requested_labels: Dictionary mapping class labels to class names. | |
:param num_images: Number of images to retrieve. | |
:returns: Dictionary mapping splits to lists of image ids. | |
""" | |
image_ids_per_split = dict() | |
for split in splits: | |
filename = f'oiv7/oidv7-{split}-annotations-human-imagelabels.csv' | |
example_per_label = np.zeros(2) | |
image_ids = [] | |
chunk: pd.DataFrame | |
with pd.read_csv(filename, chunksize=10 ** 6) as pd_reader: | |
for chunk in pd_reader: | |
if np.all(example_per_label > num_images): | |
break | |
for i, label in enumerate(requested_labels): | |
if example_per_label[i] > num_images: | |
continue | |
verified_chunk = chunk[(chunk['LabelName'] == label) & chunk['Confidence'] == 1.0] | |
example_per_label[i] += len(verified_chunk) | |
image_ids.extend(verified_chunk.values.tolist()) | |
image_ids_per_split[split] = image_ids | |
return image_ids_per_split | |
def download_images(requested_labels: dict[str, str], image_ids_dict: dict[str, list], path: str): | |
""" | |
Download images based on requested labels and splits. | |
:param requested_labels: Dictionary mapping class labels to class names. | |
:param image_ids_dict: Dictionary mapping splits to lists of image ids. | |
:param path: Path to store the downloaded images. | |
""" | |
os.makedirs(path, exist_ok=True) | |
for split, image_ids in image_ids_dict.items(): | |
if split == 'val': | |
split = 'validation' | |
os.makedirs(f'{path}/{split}', exist_ok=True) | |
for _, classname in requested_labels.items(): | |
os.makedirs(f'{path}/{split}/{classname}', exist_ok=True) | |
progress_bar = tqdm(total=len(image_ids), desc=f'Downloading {split} images', leave=True) | |
with futures.ThreadPoolExecutor(5) as executor: | |
all_futures = [ | |
executor.submit(download_one_image, | |
BUCKET, split, | |
f'{path}/{split}/{requested_labels[image_property[2]]}', | |
image_property[0], progress_bar) | |
for image_property in image_ids | |
] | |
for future in futures.as_completed(all_futures): | |
future.result() | |
progress_bar.close() | |
if __name__ == '__main__': | |
req_lab = get_class_label(['sea', 'Jungle']) | |
img_ids = get_image_ids(['train', 'val', 'test'], req_lab, 10000) | |
download_images(requested_labels=req_lab, image_ids_dict=img_ids, path='dataset') |
Let’s break down the key components of the code:
download_one_image: This function downloads a single image from the specified split (train, val, or test) and saves it to the specified path. The function uses the BUCKET resource from the boto3 library to interact with the S3 bucket.
get_class_label: This function retrieves class labels for the requested class names. It ensures that the requested classes are trainable, as per the dataset specifications.
get_image_ids: This function retrieves image IDs for requested splits and labels. It identifies images that match the requested class labels and have a confidence level of 1.0 (verified by human).
download_images: This function orchestrates the image download process based on requested labels, splits, and paths. It uses concurrent futures to speed up the download process by using multiple threads.
By combining these components, the code provides a way to download labeled images from the Open Images Dataset V7. Next, we’ll use the code to download labeled images, specifically focusing on the classes “sea” and “Jungle.” These images will serve as our starting point for weak supervision, demonstrating how we can leverage Snorkel to automatically label and train an image classifier.
Organizing and Splitting the Dataset
In this section, we’ll walk through a code snippet that performs data splitting and organization. This step is pivotal in setting the stage for robust model training and evaluation.
""" | |
This module performs data splitting and organizes images into different splits based on classes. | |
""" | |
import os | |
import shutil | |
from collections import defaultdict | |
from pathlib import Path | |
import numpy as np | |
from sklearn.model_selection import train_test_split | |
from tqdm import tqdm | |
def stat(root_path: str) -> dict: | |
""" | |
Analyze and compute statistics on the data by counting images for each class. | |
:param root_path: Root directory containing the images. | |
:returns: Dictionary mapping class names to lists of image paths. | |
""" | |
root = Path(root_path) | |
all_images = list(root.glob('**/*.jpg')) | |
data = defaultdict(lambda: []) | |
for image_path in all_images: | |
_, _, class_name, _ = str(image_path).split('/') | |
data[class_name].append(image_path) | |
print('Data Stat:') | |
for class_name in data: | |
print(f'{class_name}: {len(data[class_name])}') | |
return data | |
def split_manual(data: dict[str, list], save_dir: str, splits: list[float] | None = None, | |
random_seed: int = 56) -> None: | |
""" | |
Split the data into different subsets (train, val, test) and organize images accordingly. | |
:param data: Dictionary mapping class names to lists of image paths. | |
:param save_dir: Directory to save the split data. | |
:param splits: List of ratios for train, val, and test splits. | |
:param random_seed: Random seed for reproducibility. | |
""" | |
x = [] | |
y = [] | |
X = dict() | |
Y = dict() | |
for class_name, image_list in data.items(): | |
x.extend(image_list) | |
y.extend([class_name] * len(image_list)) | |
if splits is None: | |
splits = [0.7, 0.2, 0.1] | |
splits = np.divide(splits, np.sum(splits)) | |
x_train_val, X['test'], y_train_val, Y['test'] = train_test_split(x, y, | |
test_size=splits[2], | |
random_state=random_seed, | |
stratify=y) | |
X['train'], X['val'], Y['train'], Y['val'] = train_test_split(x_train_val, y_train_val, | |
test_size=splits[1] / (splits[1] + splits[0]), | |
random_state=random_seed, | |
stratify=y_train_val) | |
os.makedirs(save_dir, exist_ok=True) | |
save_path = Path(save_dir) | |
for split in X: | |
os.makedirs(save_path / split, exist_ok=True) | |
for class_name in data: | |
os.makedirs(save_path / split / class_name, exist_ok=True) | |
for img_path, class_name in tqdm(zip(X[split], Y[split]), total=len(X[split]), desc=split): | |
shutil.copyfile(img_path, save_path / split / class_name / Path(img_path).name) | |
if __name__ == '__main__': | |
data = stat('dataset') | |
split_manual(data, save_dir='data_split_manual') |
Let’s delve into the key components of the code:
stat: This function computes statistics on the data by counting the number of images for each class. It returns a dictionary mapping class names to lists of image paths.
split_manual: This function splits the data into different subsets (train, val, test) based on specified ratios. It ensures that each subset maintains a proportional representation of different classes. We split the data into train, val and test proportional to 0.7, 0.2, and 0.1.
Labeling Functions
The heart of the Snorkel lies in creating labeling functions that generate noisy labels for our data. In this section, we’ll go though a code script that defines a set of labeling functions, each contributing to the creation of our weakly labeled dataset.
""" | |
This script defines labeling functions using various techniques for weak supervision with Snorkel. | |
""" | |
import os | |
import cv2 | |
import fasttext.util | |
import numpy as np | |
import torch | |
from scipy.stats import mode | |
from snorkel.labeling import labeling_function | |
from torchvision.io import read_image | |
from torchvision.io.image import ImageReadMode | |
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights | |
from collections.abc import Callable | |
JUNGLE = 0 | |
SEA = 1 | |
ABSTAIN = -1 | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model = efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT) | |
model.eval() | |
preprocess = EfficientNet_B0_Weights.DEFAULT.transforms(antialias=True) | |
fasttext.util.download_model('en', if_exists='ignore') | |
ft = fasttext.load_model('cc.en.300.bin') | |
v_sea = [] | |
for word in ['sea', 'ocean', 'boat', 'fish', 'beach', 'blue']: | |
v_sea.append(ft.get_word_vector(word)) | |
v_sea = np.array(v_sea) | |
v_jungle = [] | |
for word in ['forest', 'jungle', 'wood', 'bush', 'green']: | |
v_jungle.append(ft.get_word_vector(word)) | |
v_jungle = np.array(v_jungle) | |
LABELING_FUNCS = [] | |
def add_func(f: Callable): | |
""" | |
Decorator to add a labeling function to the list. | |
:param f: The labeling function to add. | |
:returns: The input function. | |
""" | |
LABELING_FUNCS.append(f) | |
return f | |
@add_func | |
@labeling_function() | |
def check_color(x: str | os.PathLike) -> int: | |
""" | |
Labeling function to classify images based on dominant color. | |
:param x: Path to the image file. | |
:returns: Label indicating SEA, JUNGLE, or ABSTAIN. | |
""" | |
# read image into BGR format | |
img = cv2.imread(str(x)) | |
if len(img.shape) < 3: | |
return -1 | |
color_max = np.argmax(img.mean(axis=(0, 1))) | |
match color_max: | |
case 0: | |
return SEA | |
case 1: | |
return JUNGLE | |
case 2: | |
return ABSTAIN | |
@add_func | |
@labeling_function() | |
def check_pixel_color(x: str | os.PathLike) -> int: | |
""" | |
Labeling function to classify images based on mode of pixel colors. | |
:param x: Path to the image file. | |
:returns: Label indicating SEA, JUNGLE, or ABSTAIN. | |
""" | |
img = cv2.imread(str(x)) | |
if len(img.shape) < 3: | |
return -1 | |
pixel_max_color = mode(np.argmax(img, axis=2), axis=None).mode | |
match pixel_max_color: | |
case 0: | |
return SEA | |
case 1: | |
return JUNGLE | |
case 2: | |
return ABSTAIN | |
@add_func | |
@labeling_function() | |
def check_with_efficientNet(x: str | os.PathLike) -> int: | |
""" | |
Labeling function to classify images using EfficientNet predictions and FastText embeddings. | |
:param x: Path to the image file. | |
:returns: Label indicating SEA or JUNGLE. | |
""" | |
img = read_image(str(x), mode=ImageReadMode.RGB) | |
batch = preprocess(img).unsqueeze(0) | |
prediction = model(batch).squeeze(0).softmax(0) | |
class_id = prediction.argmax().item() | |
category_name = EfficientNet_B0_Weights.DEFAULT.meta["categories"][class_id] | |
v = ft.get_sentence_vector(category_name) | |
dist_sea = min(np.sum(np.power(v_sea - v, 2), axis=1)) | |
dist_jungle = min(np.sum(np.power(v_jungle - v, 2), axis=1)) | |
return SEA if dist_sea < dist_jungle else JUNGLE |
Here’s an overview of the key components of the provided code:
check_color: This function classifies images based on their dominant color on average.
check_pixel_color: This function classifies images based on the mode of max color per pixel.
check_with_efficientNet: This function leverages EfficientNet predictions and FastText embeddings to classify images as “SEA” or “JUNGLE.” Here, first we classify the image using EfficientNet. Then, the closeness of the output label is examined against several words related to sea or jungle with FastText. Thus, if the meaning of the output label is closer to sea, we label it as SEA.
Adding labeling functions: The script uses the add_func decorator to add each labeling function to the LABELING_FUNCS list, which will be used later.
By combining these labeling functions, we will generate a set of noisy labels for our images. These labels are the cornerstone of our weak supervision approach, allowing us to utilize Snorkel’s capabilities.
Weak Supervision Labeling with Snorkel
The power of weak supervision comes to life when we leverage labeling functions to create noisy labels for our dataset. Now, we’ll explore a script that performs weak supervision labeling with Snorkel.
""" | |
This script performs weak supervision labeling using labeling functions with Snorkel. | |
""" | |
import os | |
import pickle | |
from pathlib import Path | |
from snorkel.labeling import LFApplier, LFAnalysis | |
from snorkel.labeling.model import LabelModel | |
from labeling_funcs import LABELING_FUNCS | |
def weak_supervision_labeling(root: str = 'data_split_manual', splits=None, root_save: str = 'data_snorkel'): | |
""" | |
Perform weak supervision labeling using labeling functions with snorkel and save label data to pickle files. | |
:param root: Root directory containing the input images (root/split/class_name). | |
:param splits: List of data splits to process (e.g. ['train', 'val']). | |
:param root_save: Root directory to save the labeled data. | |
""" | |
if splits is None: | |
splits = ['train', 'val'] | |
root_new = Path(root_save) | |
os.makedirs(root_new, exist_ok=True) | |
for split in splits: | |
root_split = Path(root / split) | |
all_images = list(root_split.glob('**/*.jpg')) | |
print(f'{len(all_images)} images found in {split}') | |
applier = LFApplier(LABELING_FUNCS) | |
L_train = applier.apply(all_images) | |
print(LFAnalysis(L=L_train, lfs=LABELING_FUNCS).lf_summary()) | |
label_model = LabelModel(cardinality=2, verbose=True) | |
label_model.fit(L_train=L_train, n_epochs=500, log_freq=100, seed=123) | |
label_split = label_model.predict_proba(L_train) | |
os.makedirs(root_new / split, exist_ok=True) | |
with open(root_new / split / 'data.pkl', 'wb') as f: | |
pickle.dump({"labels": label_split, "images": all_images}, f) | |
if __name__ == '__main__': | |
weak_supervision_labeling() |
Here’s a breakdown of the key elements in the provided code:
DATA PREPARATION: The script starts by specifying the root directory containing the input images and the splits (e.g., ‘train’, ‘val’) to process. Additionally, it defines the root directory to save the labeled data.
LFApplier: The labeling functions (LABELING_FUNCS) defined in the previous code script are applied to all images in the specified split. The result is a label matrix (L_train) where each row corresponds to an image and each column corresponds to a labeling function.
LFAnalysis: This step provides an analysis of the labeling functions’ performance on the data. It generates a summary that indicates how well the labeling functions agree or disagree on assigning labels to images.
label_model: A LabelModel is trained using the label matrix (L_train). This model learns to estimate the true underlying labels by accounting for the noise introduced by the labeling functions.
Label Prediction: The label model predicts probabilities of labels for each image based on the noisy labels from the labeling functions.
Saving Labeled Data: The labeled data, including the predicted labels and image paths, is saved to pickle files. This data will serve as the input for our model training process.
By executing this script, we perform the crucial step of labeling our data using weak supervision techniques. Snorkel helps us manage the uncertainty introduced by the labeling functions, creating a labeled dataset that reflects the inherent noise in the weakly supervised data.
Dataloaders
Next, we implement our dataloaders for the supervised and the weakly supervised procedures.
""" | |
Data module | |
This module contains utility functions and classes for training a Convolutional Neural Network (CNN) using weak | |
supervision techniques. It includes functions to create data loaders for different dataset splits, as well as a custom | |
dataset class for weakly supervised learning with Snorkel labels. | |
Classes: | |
- SnorkelDataset: Custom dataset class for weakly supervised learning using Snorkel labels. | |
Functions: | |
- get_transforms: Get data transformation pipelines for different dataset splits. | |
- get_data_loader: Get data loaders for the specified dataset splits. | |
- get_data_loader_snorkel: Get data loaders for weakly supervised learning using Snorkel labels. | |
Note: | |
This module assumes that the data is organized into different splits (e.g., 'train', 'val', 'test') | |
within the root directory. | |
""" | |
import os | |
import pickle | |
import numpy as np | |
from PIL import Image | |
from torch.utils.data import DataLoader, Dataset | |
from torchvision import transforms | |
from torchvision.datasets import ImageFolder | |
def get_data_loader(root: str, splits: list['str']) -> dict[str, DataLoader]: | |
""" | |
Creates and returns data loaders for different dataset splits. | |
:param root: Root directory containing the dataset. | |
:param splits: List of dataset splits (e.g., ['train', 'val', 'test']). | |
:return: Dataloaders, a dictionary containing data loaders for each split. | |
""" | |
dataloaders = dict() | |
transforms = get_transforms() | |
for split in splits: | |
dataset = ImageFolder(os.path.join(root, split), transform=transforms[split]) | |
dataloaders[split] = DataLoader(dataset=dataset, | |
shuffle=True if split == 'train' else 0, | |
num_workers=6, | |
batch_size=128, | |
pin_memory=True) | |
return dataloaders | |
def get_data_loader_snorkel(root: str, splits: list['str'], label_type: str) -> dict[str, DataLoader]: | |
""" | |
Get data loaders for the specified dataset splits for snorkel generated dataset. | |
:param root: Root directory of the dataset that contains "data.pkl" files in each split. | |
:param splits: List of dataset splits (e.g., ['train', 'val', 'test']). | |
:param label_type: Type of labels, either 'hard' or 'soft'. | |
:return: Dictionary of data loaders for each split. | |
""" | |
dataloaders = dict() | |
transforms = get_transforms() | |
for split in splits: | |
dataset = SnorkelDataset(os.path.join(root, split, 'data.pkl'), | |
transforms=transforms[split], | |
label_type=label_type) | |
dataloaders[split] = DataLoader(dataset=dataset, | |
shuffle=True if split == 'train' else 0, | |
num_workers=6, | |
batch_size=64, | |
pin_memory=True) | |
return dataloaders | |
class SnorkelDataset(Dataset): | |
""" | |
Custom dataset for weakly supervised learning using Snorkel labels. | |
""" | |
def __init__(self, data_path: str, label_type: str, transforms: transforms.Compose | None = None): | |
""" | |
:param data_path: Path to the pickled data file. | |
:param label_type: Type of labels, either 'hard' or 'soft'. | |
:param transforms: Data transformations to apply to images. | |
""" | |
with open(data_path, 'rb') as f: | |
data = pickle.load(f) | |
self.labels = data['labels'] | |
if label_type.lower() == 'hard': | |
self.labels = np.argmax(self.labels, axis=1) | |
self.images = data["images"] | |
self.transforms = transforms | |
def __len__(self) -> int: | |
""" | |
Get the length of the dataset. | |
:return: Number of images in the dataset. | |
""" | |
return len(self.images) | |
def __getitem__(self, idx: int): | |
""" | |
Get an item from the dataset. | |
:param idx: Index of the item to retrieve. | |
:return: Tuple containing the image and its corresponding label. | |
""" | |
image = Image.open(str(self.images[idx])).convert("RGB") | |
if transforms: | |
image = self.transforms(image) | |
return image, self.labels[idx] | |
def get_transforms() -> dict[str, transforms.Compose]: | |
""" | |
Get data transformation pipelines for different dataset splits. | |
:return: Dictionary of data transformations for 'train', 'val', and 'test' splits. | |
""" | |
data_transforms = { | |
'train': transforms.Compose([ | |
transforms.Resize(256), | |
transforms.RandomCrop(224), | |
transforms.RandomHorizontalFlip(), | |
transforms.RandomRotation(5), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]), | |
'val': transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]), | |
'test': transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]), | |
} | |
return data_transforms |
Here’s a breakdown of the key elements in the provided code:
get_transforms(): This function provides data transformation pipelines tailored for different dataset splits: ‘train’, ‘val’, and ‘test’. These transformations include resizing, cropping, flipping, rotation, normalization, and tensor conversion.
SnorkelDataset: This custom dataset class is designed for weakly supervised learning using Snorkel labels. It takes a path to a pickled data file and a label type (‘hard’ or ‘soft’) as inputs. In weakly supervised learning, “hard labels” refer to discrete, definite labels assigned to data points, indicating clear categorization (e.g., ‘SEA’ or ‘JUNGLE’). On the other hand, “soft labels” represent probabilistic or continuous assignments, reflecting the uncertainty or ambiguity in classification. The class loads images and corresponding labels from the data file and applies the specified transformations.
get_data_loader(): This function creates and returns data loaders for different dataset splits. It utilizes the ImageFolder dataset from PyTorch, which organizes data into class folders. The dataloaders are configured with appropriate transformations and batch sizes for training, validation, and testing.
get_data_loader_snorkel(): This function generates data loaders for the specified dataset splits using Snorkel-generated labels. It utilizes the SnorkelDataset class to load images and labels from pickled data files, enabling weakly supervised learning. The dataloaders are configured similarly to those in get_data_loader(), tailored for Snorkel-labeled data.
By leveraging these utility functions and classes, we ensure that our data is well-prepared and ready to be fed into our CNN model.
Training function
Here is the training function. I skip the description of this part, since it follows a common pattern of model training with pytorch.
""" | |
This script defines the training function for our CNN. | |
""" | |
import copy | |
import os | |
import time | |
import torch | |
from torch import nn | |
from torch.utils.tensorboard import SummaryWriter | |
from torcheval.metrics import MulticlassPrecision | |
from tqdm import tqdm | |
def train( | |
model: nn.Module, | |
dataloaders: dict, | |
optimizer, | |
criterion, | |
scheduler, | |
device, | |
run_name: str, | |
writer: SummaryWriter | None = None, | |
num_epochs: int = 10): | |
""" | |
Train the CNN model. | |
:param model: The neural network model to be trained. | |
:param dataloaders: Dictionary containing dataloaders for training and validation. | |
:param optimizer: The optimizer for updating model parameters. | |
:param criterion: Loss criterion for training. | |
:param scheduler: Learning rate scheduler. | |
:param device: Device for computation ('cpu' or 'cuda'). | |
:param run_name: Name of the run for saving model checkpoints. | |
:param writer: SummaryWriter for logging training progress (optional). | |
:param num_epochs: Number of training epochs. | |
:return: Best trained model. | |
""" | |
os.makedirs('./weights', exist_ok=True) | |
metric = MulticlassPrecision(num_classes=2, average=None) | |
since = time.time() | |
dataset_sizes = {phase: len(dl.dataset) for phase, dl in dataloaders.items()} | |
best_model_wts = copy.deepcopy(model.state_dict()) | |
best_loss = 1_000_000. | |
best_acc = 0 | |
model.to(device) | |
for epoch in range(num_epochs): | |
for phase in ['train', 'val']: | |
if phase == 'train': | |
model.train() # Set model to training mode | |
else: | |
model.eval() # Set model to evaluate mode | |
running_loss = 0.0 | |
running_corrects = 0 | |
running_samples = 0 | |
running_pred_vec = [] | |
running_target_vec = [] | |
# Iterate over data. | |
pbar = tqdm(dataloaders[phase], desc=f'Epoch {epoch:3}/{num_epochs:3} - {phase:6}', unit=' batch') | |
for inputs, labels in pbar: | |
inputs = inputs.to(device) | |
# labels = labels.to(device) | |
running_samples += inputs.size(0) | |
# zero the parameter gradients | |
optimizer.zero_grad() | |
# forward | |
# track history if only in train | |
with torch.set_grad_enabled(phase == 'train'): | |
outputs = model(inputs) | |
outputs = outputs.cpu() | |
_, preds = torch.max(outputs, 1) | |
loss = criterion(outputs, labels) | |
# backward + optimize only if in training phase | |
if phase == 'train': | |
loss.backward() | |
optimizer.step() | |
# statistics | |
if len(labels.shape) > 1: | |
_, labels = torch.max(labels, 1) | |
running_loss += loss.item() * inputs.size(0) | |
running_corrects += torch.sum(preds == labels) | |
running_pred_vec.extend(preds.numpy().tolist()) | |
running_target_vec.extend(labels.numpy().tolist()) | |
metric.update(torch.tensor(running_pred_vec), torch.tensor(running_target_vec)) | |
m = metric.compute() | |
pbar.set_postfix(loss=running_loss / running_samples, | |
accuracy=running_corrects.item() / running_samples * 100, | |
Precision=[round(m[0].item(), 3), round(m[1].item(), 3)]) | |
if phase == 'val': | |
scheduler.step(running_loss) | |
epoch_loss = running_loss / dataset_sizes[phase] | |
epoch_acc = running_corrects.double() / dataset_sizes[phase] | |
if writer is not None: | |
writer.add_scalar('LOSS/{}'.format(phase), epoch_loss, epoch) | |
writer.add_scalar('ACC/{}'.format(phase), epoch_acc, epoch) | |
writer.add_scalar('OPTIM/LR', scheduler.get_lr()[-1], epoch) | |
# deep copy the model | |
if phase == 'val' and epoch_loss < best_loss: | |
best_acc = epoch_acc | |
best_model_wts = copy.deepcopy(model.state_dict()) | |
torch.save({ | |
'epoch': epoch, | |
'model_state_dict': best_model_wts, | |
'optimizer_state_dict': optimizer.state_dict(), | |
'scheduler_state_dict': scheduler.state_dict(), | |
'accuracy': best_acc | |
}, f'weights/{run_name}_best_model.pt') | |
pbar.close() | |
time_elapsed = time.time() - since | |
print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s') | |
print(f'Best val accuracy: {best_acc * 100}') | |
model.load_state_dict(best_model_wts) | |
return model |
NOTE: The condition if len(labels.shape) > 1:
deals with soft labels.
Inference
With the following code, we can evaluate all weights saved in the weight_folder
on the test data.
""" | |
This script tests trained CNN models using various metrics on the test dataset. | |
We assume that there is a test split in the data root | |
""" | |
import os | |
from pathlib import Path | |
import torch | |
from torcheval.metrics import MulticlassPrecision, BinaryAccuracy | |
from tqdm import tqdm | |
from data import get_data_loader | |
from model import CNNModel | |
metrics = [MulticlassPrecision(num_classes=2, average=None), | |
BinaryAccuracy()] | |
def test_model(weight_folder: str, data_root: str, device='cpu') -> None: | |
""" | |
Test the trained CNN models using specified metrics. | |
:param weight_folder: Path to the folder containing model weight files (eg. original_best_model.pt). | |
:param data_root: Root directory of the dataset. | |
:param device: Device to run the evaluation on (default is 'cpu'). | |
""" | |
model_wights = [os.path.join(weight_folder, x) for x in os.listdir(weight_folder) if x.endswith('_best_model.pt')] | |
results = [] | |
print(f'Found {len(model_wights)} Models.') | |
for model_weight in model_wights: | |
model = CNNModel() | |
model.load_state_dict(torch.load(model_weight)['model_state_dict']) | |
model.eval() | |
model.to(device) | |
dataloaders = get_data_loader(root=data_root, splits=['test']) | |
predictions = [] | |
labels = [] | |
pbar = tqdm(dataloaders['test'], desc=f'{" ".join(Path(model_weight).name.split("_")[:-2]):15}', unit=' batch') | |
for inputs, label in pbar: | |
inputs = inputs.to(device) | |
# labels = labels.to(device) | |
with torch.no_grad(): | |
outputs = model(inputs) | |
outputs = outputs.cpu() | |
_, preds = torch.max(outputs, 1) | |
predictions.extend(preds.numpy().tolist()) | |
labels.extend(label.numpy().tolist()) | |
for metric in metrics: | |
metric.update(torch.tensor(predictions), torch.tensor(labels)) | |
results.append(f'{" ".join(Path(model_weight).name.split("_")[:-2]):15} ' | |
f'Precision = {metrics[0].compute().numpy()} ' | |
f'ACC = {metrics[1].compute():.4f}') | |
for res in results: | |
print(res) |
Results
Now let’s see the results. With the following script, we can train the model with original supervised procedure or two semi-supervised ones (soft and hard labels).
""" | |
This script provides functions to train and test a CNN model with/without the Snorkel framework. | |
""" | |
import torch | |
from data import get_data_loader_snorkel, get_data_loader | |
from inference import test_model | |
from model import CNNModel | |
from train import train | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
def Train(): | |
""" | |
Train a CNN model with/without the Snorkel framework. | |
""" | |
model = CNNModel() | |
dataloaders = { | |
'original': get_data_loader('data_split_manual', splits=['train', 'val']), | |
'snorkel_hard': get_data_loader_snorkel('data_snorkel', splits=['train', 'val'], label_type='hard'), | |
'snorkel_soft': get_data_loader_snorkel('data_snorkel', splits=['train', 'val'], label_type='soft') | |
} | |
criterion = torch.nn.CrossEntropyLoss() | |
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9, weight_decay=1e-4) | |
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.9, patience=5) | |
for run_name, dataloader in dataloaders.items(): | |
train( | |
model=model, dataloaders=dataloader, optimizer=optimizer, criterion=criterion, scheduler=lr_scheduler, | |
device=device, | |
run_name=run_name, | |
num_epochs=20 | |
) | |
def Test(): | |
""" | |
Test a trained CNN model using the test dataset. | |
""" | |
test_model(weight_folder='weights', data_root='data_split_manual', device=device) | |
if __name__ == '__main__': | |
Train() | |
Test() |
The result is
original Precision = [0.94285715 0.921875 ] ACC = 0.9259 | |
snorkel soft Precision = [0.93 0.90949225] ACC = 0.9132 | |
snorkel hard Precision = [0.88235295 0.9226687 ] ACC = 0.9144 |
Original Labels:
- Precision: For the original labels, the model achieves high precision scores for both classes (‘SEA’ and ‘JUNGLE’), indicating that when it predicts a class, it’s usually correct. Specifically, the precision values of approximately 0.94 and 0.92 for ‘SEA’ and ‘JUNGLE’ respectively demonstrate the model’s accuracy in its predictions.
Accuracy (ACC): The overall accuracy of 0.926 suggests that the model is successful in correctly classifying approximately 92.6% of the images in the dataset.
Snorkel Soft Labels:
- Precision: With soft labels, where the labeling functions provide probabilistic or continuous assignments, the precision scores remain relatively high but show a slight decrease compared to the original labels. The values of around 0.93 for ‘SEA’ and 0.909 for ‘JUNGLE’ indicate a minor decrease in precision.
- Accuracy (ACC): The accuracy of 0.9132, while slightly lower than the original labels, still demonstrates a strong performance, capturing approximately 91.3% of the dataset correctly.
Snorkel Hard Labels:
- Precision: When using hard labels (discrete, definite labels) provided by Snorkel, there is a more noticeable decrease in precision for the ‘SEA’ class, dropping to approximately 0.882. However, the precision for ‘JUNGLE’ remains high at around 0.923.
- Accuracy (ACC): The overall accuracy of 0.9144, although slightly lower than the original labels, showcases the model’s ability to maintain a strong classification performance with Snorkel hard labels.
Conclusion
In the pursuit of harnessing the power of weak supervision, our journey has traversed a landscape where precision meets ambiguity and accuracy coexists with uncertainty. The application of labeling functions in the Snorkel framework has enabled us to embrace our prior knowledge in our data, offering a nuanced perspective on image classification. We also observed the advantage of using soft labels: their ability to mitigate the impact of class imbalances. In our dataset, where ‘SEA’ and ‘JUNGLE’ classes exhibited varying instances, soft labels allowed for a more nuanced representation of uncertainty. This nuanced understanding ensured that the precision for both classes stayed relatively close, compared to the more discrete hard labels. In an imbalanced dataset, the imprecision introduced by labeling functions might disproportionately affect the minority class. Soft labels, by representing class assignments probabilistically, provided a flexibility that allowed the model to balance the precision between the classes more effectively. This balancing act is crucial, especially in applications where misclassifying the minority class carries significant consequences. It’s important to note that while our labeling functions in this example were relatively straightforward to implement, many real-world scenarios pose complex challenges. Designing accurate labeling functions can be intricate, requiring domain expertise and careful consideration. Despite these challenges, our study demonstrates that noisy labels, when harnessed intelligently through weak supervision techniques, can still offer valuable insights and contribute to robust model training. Our exploration underscores the resilience of machine learning models in the face of noisy or uncertain labels. Even when labeling functions are not perfect, the intelligent integration of these noisy annotations can lead to significant advancements in model performance. Embracing the inherent noise in weakly supervised data and leveraging techniques like Snorkel not only expands the scope of feasible applications but also highlights the adaptability and learning potential of modern machine learning systems.
GitHub
You can find the code at the project GitHub repository.
Reference
[1] Snorkel
[2] ChatGPT