Customer Segmentation with Ant Colony
This notebook explains the concepts behind using an ant-colony-based algorithm for customer segmentation.
The Business Problem
Modern businesses collect vast amounts of customer data—purchase history, browsing behavior, demographics, and more. The goal of customer segmentation is to partition this diverse customer base into distinct groups or "personas."
This is useful for almost all buisiness processes - Targeted Marketing, Product Developmemt, Customer Support, etc.
Traditional methods often rely on pre-defined rules (e.g., "high-spenders," "new customers"). An ant-based approach is different: it allows meaningful segments to emerge organically from the data itself, revealing patterns you might not have known to look for.
How Ants Organize
This algorithm is inspired by the remarkable ability of some ant species to sort items in their nests, such as larvae, eggs, or food particles, into neat piles. They achieve this without any central coordination or blueprint. The "intelligence" is a collective, emergent property of the swarm following simple rules.
The core behavior is:
- An ant wanders randomly.
- If it stumbles upon an item that is out of place (i.e., not surrounded by similar items), it is likely to pick it up.
- If it is carrying an item and stumbles upon a cluster of similar items, it is likely to drop it.
Over thousands of these simple pick-up and drop-off actions, a globally organized state (sorted piles) emerges from local interactions. We can apply this exact logic to data points.
The Data Clustering Algorithm
In our model, customer data points replace the ant's larvae, and a virtual 2D grid represents the nest. "Virtual ants" move around this grid, picking up and dropping data points until distinct clusters form.
- Each customer is a vector of features (e.g., [age, average_spend, monthly_visits]).
- We define a distance metric (like Euclidean distance) to measure how similar two customers are. The smaller the distance, the more similar they are.
- All customer data points are randomly scattered on a 2D grid.
- A number of "virtual ants" are created. Initially, they are not carrying any data points.
- An ant is chosen at random.
- The ant observes the data point at its current location on the grid.
- If the ant is NOT carrying anything: It decides whether to pick up the data point based on how dissimilar it is to its neighbors.
- If the ant IS carrying a data point: It decides whether to drop the data point it's holding based on how similar the point is to the items in the new location.
- If the ant did not pick up or drop an item, it moves to an adjacent cell on the grid.
This process is repeated, and over time, the data points self-organize into clusters of similar customers.
The Core Mechanics
The decision to "pick up" or "drop" is not deterministic; it's probabilistic. These probabilities are calculated based on the local density of similar items in an ant's neighborhood.
Where is a measure of the similarity of a data point to other data points in a small, local region of the grid. is the number of cells in the neighborhood.
Pick Up Probability
An ant is more likely to pick up an item if it's in a neighborhood of dissimilar items (i.e., is low).
Drop Probability
An ant carrying an item is more likely to drop it in a neighborhood of similar items (i.e., is high).
** is a constant that controls the sensitivity of the drop action. A small makes it easier to form clusters.
Python Implemetation
! pip install -q numpy matplotlib scikit-learn seaborn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import make_blobs
from sklearn.preprocessing import MinMaxScaler
import random
Parameters
Beware: the algorithm is very sensitive to hyperparameters. Including number of ants, grid size and radius. A grid size should be estimated relative to the data size, there must be enough iterations to increase probability of an ant visiting a cell. Radius, k1, k2 should always be open for adjustmentsConsider Synthetic Data vs Hyperparameters in this example:
## Synthetic Data
N_SAMPLES = 800
N_FEATURES = 2
N_CLUSTERS = 5
CLUSTER_STD = 0.8
# Hyperparamenters
GRID_SIZE = 35
N_ANTS = 100
ITERATIONS = 300000
NEIGHBORHOOD_RADIUS = 7 # The radius around a point to check for similarity
K1 = 0.345 # Pick up sensitivity
K2 = 0.15 # Drop sensitivity
Core Functions
Get neighbours :
def get_neighborhood(pos, grid, radius, grid_size):
"""Gets all non-empty grid cells within the radius of a position."""
x, y = pos
x_min, x_max = max(0, x - radius), min(grid_size, x + radius + 1)
y_min, y_max = max(0, y - radius), min(grid_size, y + radius + 1)
neighbors = []
for i in range(x_min, x_max):
for j in range(y_min, y_max):
if grid[i, j] != -1:
neighbors.append(grid[i, j])
return neighbors
Calculate similarity :
def calculate_local_density(data_idx, neighbors, data):
"""Calculates the similarity of a data point to its neighbors."""
true_neighbors = [n_idx for n_idx in neighbors if n_idx != data_idx]
if not true_neighbors:
return 0.0
distances = [np.linalg.norm(data[data_idx] - data[n_idx]) for n_idx in true_neighbors]
avg_dist = np.mean(distances)
# Since data is scaled 0-1, avg_dist is also ~0-1.
similarity = max(0, 1 - avg_dist)
return similarity
Probability functions (sensitive to hyperparameters K1, K2):
def pick_up_prob(density):
"""Calculates the probability of picking up an item."""
return (K1 / (K1 + density)) ** 2
def drop_prob(density):
"""Calculates the probability of dropping an item."""
return (density / (K1 + density)) ** 2 if density > 0 else 0
Initilaizing Simulation
We initialize grid of GRID_SIZE
with random ant positions:
def initialize_simulation(data, grid_size, n_ants):
"""
Creates and initializes all the state variables for the simulation.
Returns a dictionary containing the simulation state.
"""
# Normalize data
scaler = MinMaxScaler()
normalized_data = scaler.fit_transform(data)
n_samples = normalized_data.shape[0]
# Create the grid and scatter data
grid = np.full((grid_size, grid_size), -1, dtype=int)
for i in range(n_samples):
while True:
x, y = np.random.randint(0, grid_size, size=2)
if grid[x, y] == -1:
grid[x, y] = i
break
# Create ants
ant_positions = np.random.randint(0, grid_size, size=(n_ants, 2))
ant_payload = np.full(n_ants, -1, dtype=int)
print(f"Data scattered on a {grid_size}x{grid_size} grid. {n_ants} ants ready.\n")
return {
"grid": grid,
"data": normalized_data,
"ant_positions": ant_positions,
"ant_payload": ant_payload,
}
Implement Simulation Loop
On each iteration we move ant ro a random position and decide whether an ant should pick up or drop a data point:
def run_simulation(state, iterations, grid_size, radius):
"""
Executes the main simulation loop.
Note: This function MODIFIES the 'state' dictionary in place.
"""
print(f"Running simulation for {iterations} iterations... \n")
# Unpack state for easier access
grid = state["grid"]
data = state["data"]
ant_positions = state["ant_positions"]
ant_payload = state["ant_payload"]
n_ants = len(ant_positions)
for i in range(iterations):
ant_idx = random.randint(0, n_ants - 1)
ant_pos = ant_positions[ant_idx]
if ant_payload[ant_idx] == -1: # If ant is empty-handed
if grid[ant_pos[0], ant_pos[1]] != -1:
data_idx = grid[ant_pos[0], ant_pos[1]]
neighbors = get_neighborhood(ant_pos, grid, radius, grid_size)
density = calculate_local_density(data_idx, neighbors, data)
if random.random() < pick_up_prob(density):
ant_payload[ant_idx] = data_idx
grid[ant_pos[0], ant_pos[1]] = -1
else: # If ant is carrying an item
if grid[ant_pos[0], ant_pos[1]] == -1:
carried_data_idx = ant_payload[ant_idx]
neighbors = get_neighborhood(ant_pos, grid, radius, grid_size)
density = calculate_local_density(carried_data_idx, neighbors, data)
if random.random() < drop_prob(density):
grid[ant_pos[0], ant_pos[1]] = carried_data_idx
ant_payload[ant_idx] = -1
# Move ant
move = np.random.randint(-1, 2, size=2)
ant_positions[ant_idx] = (ant_pos + move) % grid_size
if (i + 1) % 10000 == 0:
print(f"Iteration {i+1}/{iterations} complete. \n")
print("Simulation finished.")
Label Found Clusters
from scipy.ndimage import label
def extract_clusters(grid, data):
"""
Extracts final clusters from the grid using scipy's label function.
Returns the data points and their corresponding cluster labels.
"""
# Create a boolean grid where True means a cell is occupied.
occupied_grid = grid != -1
# The 'label' function finds all contiguous regions and assigns
# a unique integer ID to each region.
labeled_grid, num_clusters = label(occupied_grid)
print(f"Found {num_clusters} potential clusters using scipy.ndimage.label.")
clustered_data = []
labels = []
# Iterate through the original grid to link data points to their new cluster ID
rows, cols = np.where(grid != -1) # Find coordinates of all data points
for r, c in zip(rows, cols):
data_idx = grid[r, c]
cluster_id = labeled_grid[r, c]
clustered_data.append(data[data_idx])
labels.append(cluster_id)
return np.array(clustered_data), np.array(labels)
Generate Synthetic Data
X, y_true = make_blobs(n_samples=N_SAMPLES, centers=N_CLUSTERS,
n_features=N_FEATURES, random_state=42, cluster_std=CLUSTER_STD)
# 2. Visualize Initial Data
plt.figure(figsize=(10, 8))
sns.scatterplot(x=X[:, 0], y=X[:, 1], s=50, alpha=0.7, edgecolor='k')
plt.title('Initial Unclustered Customer Data', fontsize=16)
plt.xlabel('Customer Feature 1 (e.g., Average Spend)', fontsize=12)
plt.ylabel('Customer Feature 2 (e.g., Website Visits)', fontsize=12)
plt.grid(True, linestyle='--', alpha=0.6)
plt.tight_layout()
plt.show()
Run Simulation
simulation_state = initialize_simulation(X, GRID_SIZE, N_ANTS)
run_simulation(simulation_state, ITERATIONS, GRID_SIZE, NEIGHBORHOOD_RADIUS)
# 4. Extract and Visualize Final Clusters
final_grid = simulation_state["grid"]
final_data = simulation_state["data"]
clustered_X, labels = extract_clusters(final_grid, final_data)
Plot Results
plt.figure(figsize=(12, 10))
unique_labels = np.unique(labels)
palette = sns.color_palette("deep", len(unique_labels))
sns.scatterplot(x=clustered_X[:, 0], y=clustered_X[:, 1], hue=labels,
palette=palette, s=60, alpha=0.8, edgecolor='k', legend='full')
plt.title('Customer Segments Discovered by Ant Clustering', fontsize=16)
plt.xlabel('Customer Feature 1 (e.g., Average Spend)', fontsize=12)
plt.ylabel('Customer Feature 2 (e.g., Website Visits)', fontsize=12)
plt.legend(title='Discovered Segment')
plt.grid(True, linestyle='--', alpha=0.6)
plt.tight_layout()
plt.show()