How to Optimize AI Models with Metric Learning

December 21, 2023 | 7 min read

Shai Kazaz

Senior Data Scientist

Shai 2x

At BlueVoyant, we use advanced artificial intelligence (AI) and machine learning (ML) models to augment our technology and human expertise. One such method is called “metric learning,” which is an approach to train a distance function that aims to establish similarity or dissimilarity between data points. Metric learning aims to give certain aspects of our data a sense of distance.

Metric learning models, if trained well, can solve various problems that other training methods are not designed to solve. One of the most useful features is the ability to detect a new kind of object without any additional training.

This blog serves as an introduction to metric learning and explains why it can be useful in coding.

What is Metric Learning?

Metric learning is a method that uses mostly neural networks (NNs) to give data samples a “geographical” meaning. In other words, it optimizes the way to measure distance between data points in a predefined manner. For those who are familiar with classification tasks, the difference is that classifiers are trying to answer, “what is this data point?” or “what is in this data point?” while metric learners are trying to answer, “where is this data point?”

From now on, assume that we are dealing with Computer Vision (CV) for the purposes of this article, though we can use metric learning beyond the CV domain. In CV, the data point is an image or series of images. We use Convolutional Neural Networks (CNNs or ConvNets) as a function that processes the image and extracts features from that image.

What is a feature in an image? Well, a feature in an image is something like “this image contains a half circle”, “this image doesn’t contain any black squares,” or “there is a line connecting two circles in the left upper corner of the image.” CNNs processes this data and uses it to learn. The methods used by CNNs are the main contributors to the AI revolution we are seeing today.

The way humans process images is much in the same fashion— some neurons “fire” when they see the feature which they need to detect. CNN models learn much more complicated features, but they still serve the same goal.

How Does Metric Learning Work?

Siamese networks often form the foundation of metric learning. They were first introduced to compare signatures to understand if a signature belongs to the same person or not. To train Siamese networks, we show two identical networks two images that are either from the same class or not and the network will output what it recognizes. As more research on Siamese networks has been done, more methods have been introduced to improve the process, like triplet loss networks, but we’ll get to that later.

Let's talk about a use case to illustrate how to use metric learning. Assume we have a data set of images, each containing a logo of a company. Our goal is to identify which logo is in each image. Although we have labeled training data which says which logo is in the image, we are going to train our model in an unsupervised manner.

When we train a network for our task, we define a backbone network that acts as a feature extractor. This network outputs an embedding vector in some predefined space for each image. Then we pass two (or more) to a distance function that calculates the distance between the vectors. The network is optimized on a loss function based on the distance function, which aims to group logos from the same class in the embedding space and separates logos from different classes.

How to Train a Metric Learner

While Siamese networks are often used in metric learning, they can sometimes lead to ineffective optimization processes. Since optimizing NNs is a local process, we calculate the loss function for each iteration based on the distance function. By calculating the gradient on each iteration, we update the network’s trained parameters. Due to this optimization step, we can succeed in making an image of a logo close to other images of different variations of this logo, and now we can measure and optimize the distance from that logo and another logo. Occasionally, this leads to mistakenly moving the logo far away from the other variations of that logo.

Triplet loss networks try to solve this issue by comparing three images instead of two. We look at a base image, another one from the same class, and another one from different class. We then make the base image closer to the one from the same class and farther from the one from a different class. This is another significant difference between classifiers and metric learners: classifiers over each data point once in each step of the training process go, but we see that a data point in our case is not the same, as a “data point” in our case is two images in Siamese and three images in triplets.

Before we train our metric learner model, there are a few things we need to understand to configure for the training process, like the distance function, samplers, and more. Building a clustering model on top of the trained network by using the embedding vectors and their labels for training a cluster is a common approach. We can then treat it like a classifier and ask the model “to which class is this image most similar?”

Why Use Metric Learning?

Here at Bluevoyant, we use metric learning in our web domain investigations. Our models can determine whether a website constitutes phishing or not, or whether a malicious actor is using brand imagery in an impersonation attempt or spreading disinformation about one of our clients.

You can also think of our logos example — once we have a model that compares different logos, we can ask different questions by using this model without retraining it. For example: do these two logos share some common elements, even if they aren’t identical?

Let’s play with it a little bit with some code examples. First, download and prepare a toy dataset of logos. We’ve included a Flickr dataset that contains 27 different brands.

Now let’s extract the files:

tar -zxvf flickr_logos_27_dataset.tar.gz tar -zxvf flickr_logos_27_dataset/flickr_logos_27_dataset_images.tar.gz

Let’s create a folder to store all the images:

mkdir Logos

Here is a python script to prepare the data:

You can see that in the end of this script we print to see how many files we have for each brand/logo.

Next, download the powerful bench marker:

git clone -b metric-learning

https://github.com/KevinMusgrave/powerful-benchmarker.git

To install the metric learning library that we will use:

Now let’s create a python file to read the dataset:

cd powerful-benchmarker/src/powerful_benchmarker/datasets

Now create a file with name logos_dataset.py and paste this inside:

In this folder there is a file __init__.py. Open it and add the next line inside: 

from .logos_dataset import Logos 

That’s it! We can start the training. Go to the running file: 

cd ../../../examples 

And start running, for example with Triplets Margin loss, and resnet18 backbone: