Skip to content

Customer Segmentation with Ant Colony

DOI

This notebook explains the concepts behind using an ant-colony-based algorithm for customer segmentation.

The Business Problem

Modern businesses collect vast amounts of customer data—purchase history, browsing behavior, demographics, and more. The goal of customer segmentation is to partition this diverse customer base into distinct groups or "personas."

This is useful for almost all buisiness processes - Targeted Marketing, Product Developmemt, Customer Support, etc.

Traditional methods often rely on pre-defined rules (e.g., "high-spenders," "new customers"). An ant-based approach is different: it allows meaningful segments to emerge organically from the data itself, revealing patterns you might not have known to look for.

How Ants Organize

This algorithm is inspired by the remarkable ability of some ant species to sort items in their nests, such as larvae, eggs, or food particles, into neat piles. They achieve this without any central coordination or blueprint. The "intelligence" is a collective, emergent property of the swarm following simple rules.

The core behavior is:

  • An ant wanders randomly.
  • If it stumbles upon an item that is out of place (i.e., not surrounded by similar items), it is likely to pick it up.
  • If it is carrying an item and stumbles upon a cluster of similar items, it is likely to drop it.

Over thousands of these simple pick-up and drop-off actions, a globally organized state (sorted piles) emerges from local interactions. We can apply this exact logic to data points.

The Data Clustering Algorithm

In our model, customer data points replace the ant's larvae, and a virtual 2D grid represents the nest. "Virtual ants" move around this grid, picking up and dropping data points until distinct clusters form.

  1. Each customer is a vector of features (e.g., [age, average_spend, monthly_visits]).
  2. We define a distance metric (like Euclidean distance) to measure how similar two customers are. The smaller the distance, the more similar they are.
  3. All customer data points are randomly scattered on a 2D grid.
  4. A number of "virtual ants" are created. Initially, they are not carrying any data points.
The Simulation Loop
  • An ant is chosen at random.
  • The ant observes the data point at its current location on the grid.
  • If the ant is NOT carrying anything: It decides whether to pick up the data point based on how dissimilar it is to its neighbors.
  • If the ant IS carrying a data point: It decides whether to drop the data point it's holding based on how similar the point is to the items in the new location.
  • If the ant did not pick up or drop an item, it moves to an adjacent cell on the grid.

This process is repeated, and over time, the data points self-organize into clusters of similar customers.

The Core Mechanics

The decision to "pick up" or "drop" is not deterministic; it's probabilistic. These probabilities are calculated based on the local density of similar items in an ant's neighborhood.

f(di)=max(0,11N2djNeighborhooddistance(di,dj))f(d_i) = \max \left( 0, 1 - \frac{1}{N^2} \sum_{d_j \in \text{Neighborhood}} \text{distance}(d_i, d_j) \right)

Where f(di)f(d_i) is a measure of the similarity of a data point did_i to other data points in a small, local region of the grid. N2N^2 is the number of cells in the neighborhood.

Pick Up Probability

An ant is more likely to pick up an item if it's in a neighborhood of dissimilar items (i.e., f(di)f(d_i) is low).

Ppick(di)=(k1k1+f(di))2P_{\text{pick}}(d_i) = \left( \frac{k_1}{k_1 + f(d_i)} \right)^2

Drop Probability

An ant carrying an item did_i is more likely to drop it in a neighborhood of similar items (i.e., f(di)f(d_i) is high).

Pdrop(di)=(f(di)k2+f(di))2P_{\text{drop}}(d_i) = \left( \frac{f(d_i)}{k_2 + f(d_i)} \right)^2

**k2k_2 is a constant that controls the sensitivity of the drop action. A small k2k_2 makes it easier to form clusters.

Python Implemetation

! pip install -q numpy matplotlib scikit-learn seaborn 
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import make_blobs
from sklearn.preprocessing import MinMaxScaler
import random

Parameters

Beware: the algorithm is very sensitive to hyperparameters. Including number of ants, grid size and radius. A grid size should be estimated relative to the data size, there must be enough iterations to increase probability of an ant visiting a cell. Radius, k1, k2 should always be open for adjustments

Consider Synthetic Data vs Hyperparameters in this example:

## Synthetic Data
N_SAMPLES   = 800
N_FEATURES  = 2
N_CLUSTERS  = 5
CLUSTER_STD = 0.8
 
# Hyperparamenters
GRID_SIZE   = 35
N_ANTS      = 100
ITERATIONS  = 300000
NEIGHBORHOOD_RADIUS = 7    # The radius around a point to check for similarity
K1 = 0.345                 # Pick up sensitivity
K2 = 0.15                  # Drop sensitivity

Core Functions

Get neighbours did_i:

def get_neighborhood(pos, grid, radius, grid_size):
    """Gets all non-empty grid cells within the radius of a position."""
    x, y = pos
    x_min, x_max = max(0, x - radius), min(grid_size, x + radius + 1)
    y_min, y_max = max(0, y - radius), min(grid_size, y + radius + 1)
    
    neighbors = []
    for i in range(x_min, x_max):
        for j in range(y_min, y_max):
            if grid[i, j] != -1:
                neighbors.append(grid[i, j])
    return neighbors

Calculate similarity f(di)f(d_i):

def calculate_local_density(data_idx, neighbors, data):
    """Calculates the similarity of a data point to its neighbors."""
    true_neighbors = [n_idx for n_idx in neighbors if n_idx != data_idx]
 
    if not true_neighbors:
        return 0.0
 
    distances = [np.linalg.norm(data[data_idx] - data[n_idx]) for n_idx in true_neighbors]
    avg_dist = np.mean(distances)
    
    # Since data is scaled 0-1, avg_dist is also ~0-1.
    similarity = max(0, 1 - avg_dist)
    return similarity

Probability functions (sensitive to hyperparameters K1, K2):

def pick_up_prob(density):
    """Calculates the probability of picking up an item."""
    return (K1 / (K1 + density)) ** 2
 
def drop_prob(density):
    """Calculates the probability of dropping an item."""
    return (density / (K1 + density)) ** 2 if density > 0 else 0

Initilaizing Simulation

We initialize grid of GRID_SIZE with random ant positions:

 
def initialize_simulation(data, grid_size, n_ants):
    """
    Creates and initializes all the state variables for the simulation.
    Returns a dictionary containing the simulation state.
    """
    # Normalize data
    scaler = MinMaxScaler()
    normalized_data = scaler.fit_transform(data)
    n_samples = normalized_data.shape[0]
 
    # Create the grid and scatter data
    grid = np.full((grid_size, grid_size), -1, dtype=int)
    for i in range(n_samples):
        while True:
            x, y = np.random.randint(0, grid_size, size=2)
            if grid[x, y] == -1:
                grid[x, y] = i
                break
    
    # Create ants
    ant_positions = np.random.randint(0, grid_size, size=(n_ants, 2))
    ant_payload = np.full(n_ants, -1, dtype=int)
    
    print(f"Data scattered on a {grid_size}x{grid_size} grid. {n_ants} ants ready.\n")
    
    return {
        "grid": grid,
        "data": normalized_data,
        "ant_positions": ant_positions,
        "ant_payload": ant_payload,
    }

Implement Simulation Loop

On each iteration we move ant ro a random position and decide whether an ant should pick up or drop a data point:

def run_simulation(state, iterations, grid_size, radius):
    """
    Executes the main simulation loop.
    Note: This function MODIFIES the 'state' dictionary in place.
    """
    print(f"Running simulation for {iterations} iterations... \n")
    
    # Unpack state for easier access
    grid = state["grid"]
    data = state["data"]
    ant_positions = state["ant_positions"]
    ant_payload = state["ant_payload"]
    n_ants = len(ant_positions)
 
    for i in range(iterations):
        ant_idx = random.randint(0, n_ants - 1)
        ant_pos = ant_positions[ant_idx]
        
        if ant_payload[ant_idx] == -1: # If ant is empty-handed
            if grid[ant_pos[0], ant_pos[1]] != -1:
                data_idx = grid[ant_pos[0], ant_pos[1]]
                neighbors = get_neighborhood(ant_pos, grid, radius, grid_size)
                density = calculate_local_density(data_idx, neighbors, data)
                
                if random.random() < pick_up_prob(density):
                    ant_payload[ant_idx] = data_idx
                    grid[ant_pos[0], ant_pos[1]] = -1
        else: # If ant is carrying an item
            if grid[ant_pos[0], ant_pos[1]] == -1:
                carried_data_idx = ant_payload[ant_idx]
                neighbors = get_neighborhood(ant_pos, grid, radius, grid_size)
                density = calculate_local_density(carried_data_idx, neighbors, data)
                
                if random.random() < drop_prob(density):
                    grid[ant_pos[0], ant_pos[1]] = carried_data_idx
                    ant_payload[ant_idx] = -1
 
        # Move ant
        move = np.random.randint(-1, 2, size=2)
        ant_positions[ant_idx] = (ant_pos + move) % grid_size
        
        if (i + 1) % 10000 == 0:
            print(f"Iteration {i+1}/{iterations} complete. \n")
            
    print("Simulation finished.")

Label Found Clusters

from scipy.ndimage import label
 
def extract_clusters(grid, data):
    """
    Extracts final clusters from the grid using scipy's label function.
    Returns the data points and their corresponding cluster labels.
    """
    # Create a boolean grid where True means a cell is occupied.
    occupied_grid = grid != -1
    
    # The 'label' function finds all contiguous regions and assigns
    # a unique integer ID to each region.
    labeled_grid, num_clusters = label(occupied_grid)
    
    print(f"Found {num_clusters} potential clusters using scipy.ndimage.label.")
    
    clustered_data = []
    labels = []
    
    # Iterate through the original grid to link data points to their new cluster ID
    rows, cols = np.where(grid != -1) # Find coordinates of all data points
    for r, c in zip(rows, cols):
        data_idx = grid[r, c]
        cluster_id = labeled_grid[r, c]
        
        clustered_data.append(data[data_idx])
        labels.append(cluster_id)
        
    return np.array(clustered_data), np.array(labels)

Generate Synthetic Data

X, y_true = make_blobs(n_samples=N_SAMPLES, centers=N_CLUSTERS, 
                           n_features=N_FEATURES, random_state=42, cluster_std=CLUSTER_STD)
 
# 2. Visualize Initial Data
plt.figure(figsize=(10, 8))
sns.scatterplot(x=X[:, 0], y=X[:, 1], s=50, alpha=0.7, edgecolor='k')
plt.title('Initial Unclustered Customer Data', fontsize=16)
plt.xlabel('Customer Feature 1 (e.g., Average Spend)', fontsize=12)
plt.ylabel('Customer Feature 2 (e.g., Website Visits)', fontsize=12)
plt.grid(True, linestyle='--', alpha=0.6)
plt.tight_layout()
plt.show()

png

Run Simulation

simulation_state = initialize_simulation(X, GRID_SIZE, N_ANTS)
run_simulation(simulation_state, ITERATIONS, GRID_SIZE, NEIGHBORHOOD_RADIUS)
 
# 4. Extract and Visualize Final Clusters
final_grid = simulation_state["grid"]
final_data = simulation_state["data"]
clustered_X, labels = extract_clusters(final_grid, final_data)

Plot Results

plt.figure(figsize=(12, 10))
unique_labels = np.unique(labels)
palette = sns.color_palette("deep", len(unique_labels))
 
sns.scatterplot(x=clustered_X[:, 0], y=clustered_X[:, 1], hue=labels, 
                palette=palette, s=60, alpha=0.8, edgecolor='k', legend='full')
 
plt.title('Customer Segments Discovered by Ant Clustering', fontsize=16)
plt.xlabel('Customer Feature 1 (e.g., Average Spend)', fontsize=12)
plt.ylabel('Customer Feature 2 (e.g., Website Visits)', fontsize=12)
plt.legend(title='Discovered Segment')
plt.grid(True, linestyle='--', alpha=0.6)
plt.tight_layout()
plt.show()

png

Anton Nesterov © 2025 | vski·science