Detecting Anomalies with ML
Anton A. Nesterov | an (at) vski.sh |
Version | 1.0 |
Understanding normal user behavior is key to identifying potential security threats or user experience issues.
This notebook will demonstrate how to use an Autoencoder — a type of neural network—to automatically detect anomalies. An autoencoder is an ideal choice for this unsupervised task because it learns to identify the normal, underlying patterns in the data. When an unusual pattern (an anomaly) appears, the model struggles to reconstruct it accurately, and this high reconstruction error signals an anomaly.
This is a common unsupervised anomaly detection technique. There is also more interpretable, but similar approach — Isolation Forests that is better suited for human-supervised analytics, and performs we'll with smaller datasets.
Anomaly detection techniques are applied in Fraud Detection, Industrial Sensors Maintance, DevOps, Trading Algorithms, User Behavior Analysis, and more...
We'll focus on a DDoS detection use case, our example will try to recognize a sudden, unexpected spike in server requests that could indicate a distributed denial-of-service attack.
Generate Synthetic Data
For this example, we'll create a synthetic time series representing server requests over a few days. We'll add a clear trend, some daily seasonality, and then inject a few obvious anomalies to test our model.
!pip install tensorflow
import os; os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import warnings
warnings.filterwarnings('ignore')
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
# Generate a time series with a trend and seasonality
np.random.seed(42)
n_points = 24 * 7 # One week of hourly data
time_index = pd.date_range('2024-01-01', periods=n_points, freq='H')
# Base trend and seasonality
requests = np.linspace(500, 1500, n_points) # Upward trend
daily_seasonality = 200 * np.sin(np.arange(n_points) * 2 * np.pi / 24)
noise = np.random.normal(0, 50, n_points)
requests = requests + daily_seasonality + noise
# Introduce anomalies
requests[50] += 5000
requests[100] += 3000
requests[130:135] = requests[130:135] / 10
# Create a DataFrame
server_data = pd.DataFrame({
'timestamp': time_index,
'requests': requests
}).set_index('timestamp')
# Visualize the data
plt.figure(figsize=(15, 6))
plt.plot(server_data.index, server_data['requests'], color='skyblue')
plt.title('Simulated Server Request Rates with Anomalies')
plt.xlabel('Time')
plt.ylabel('Requests per Hour')
plt.grid(True)
plt.show()
Preprocess and Train the Autoencoder
Before training, we must scale the data to a standard range. The Autoencoder will be trained on the "normal" data, so we need to be careful to filter out any obvious anomalies from our training set to prevent the network from learning them.
# Scale the data
scaler = StandardScaler()
scaled_data = scaler.fit_transform(server_data[['requests']])
# Reshape data for the Autoencoder
X = scaled_data
input_dim = X.shape[1]
encoding_dim = 1 # A compressed representation of the data
input_layer = Input(shape=(input_dim,))
# Encoder (compresses the data)
encoder = Dense(2, activation="relu")(input_layer)
encoder = Dense(encoding_dim, activation="relu")(encoder)
# Decoder (reconstructs the data)
decoder = Dense(2, activation="relu")(encoder)
decoder = Dense(input_dim, activation="sigmoid")(decoder)
# Autoencoder model
autoencoder = Model(inputs=input_layer, outputs=decoder)
autoencoder.compile(optimizer=Adam(learning_rate=0.001), loss='mse')
# Train the Autoencoder
# We use a small batch size for this simple example and early stopping to prevent overfitting
history = autoencoder.fit(
X, X,
epochs=50,
batch_size=32,
validation_split=0.1,
callbacks=[EarlyStopping(monitor='val_loss', patience=5)],
verbose=0)
Detect Anomalies with Reconstruction Error
After training, the autoencoder will have learned to reconstruct normal data points with very low error. To detect anomalies, we'll feed the entire dataset back through the trained autoencoder. We will then calculate the reconstruction error for each data point and set a threshold. Any data point with a reconstruction error above this threshold will be flagged as an anomaly.
# Get the reconstruction of the original data
reconstructions = autoencoder.predict(X)
# Calculate the reconstruction error (Mean Squared Error) for each data point
mse_loss = np.mean(np.power(X - reconstructions, 2), axis=1)
# Find a reasonable threshold for anomalies
# We'll use the 95th percentile of the reconstruction loss as our threshold
threshold = np.percentile(mse_loss, 95)
# print(f"\nAnomaly Threshold: {threshold:.4f}")
# Predict anomalies by comparing reconstruction loss to the threshold
anomalies = server_data.iloc[mse_loss > threshold]
print(f"\nNumber of anomalies detected: {len(anomalies)}")
# Visualize the results
plt.figure(figsize=(15, 6))
plt.plot(server_data.index, server_data['requests'], label='Requests', color='skyblue')
plt.scatter(anomalies.index, anomalies['requests'], color='red', marker='o', s=50, label='Anomalies')
plt.title('Server Request Rates with Detected Anomalies (Autoencoder)')
plt.xlabel('Time')
plt.ylabel('Requests per Hour')
plt.legend()
plt.grid(True)
plt.show()
[1m6/6[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step
Number of anomalies detected: 9
The plot above shows that the Autoencoder, successfully identified the large spikes and the sharp drop as anomalous events. This demonstrates how a simple neural network can be an invaluable tool for automatically monitoring and alerting on critical server behavior without the need for labeled anomaly data.
Isolation Forest
Isolation Forest can be used for similar purposes it is easier to interpret and well suited for analyzing static datasets or looking for outliers in general.
The Isolation Forest AlgorithmThe Isolation Forest algorithm works on a simple principle: anomalies are "few and different." It builds a series of random decision trees to isolate observations. Normal data points require more splits to be isolated, while anomalies, being different, are isolated in fewer steps. This makes it highly effective for identifying outliers in a dataset.
from sklearn.ensemble import IsolationForest
# Create an Isolation Forest model
# contamination: The amount of contamination of the data set, i.e.,
# the proportion of outliers in the data set. Used when fitting to define the threshold.
model = IsolationForest(n_estimators=100, contamination=0.02, random_state=42)
# Train the model (it's unsupervised, so we just pass the data)
model.fit(server_data[['requests']])
# Predict the anomalies. -1 means anomaly, 1 means normal.
server_data['anomaly'] = model.predict(server_data[['requests']])
# Extract the anomalies
anomalies = server_data[server_data['anomaly'] == -1]
print(f"Number of anomalies detected: {len(anomalies)}")
print("\nDetected Anomalies:")
print(anomalies)
# Plot the original data and the detected anomalies
plt.figure(figsize=(15, 6))
plt.plot(server_data.index, server_data['requests'], label='Requests', color='skyblue')
plt.scatter(anomalies.index, anomalies['requests'], color='red', marker='o', s=50, label='Anomalies')
plt.title('Server Request Rates with Detected Anomalies')
plt.xlabel('Time')
plt.ylabel('Requests per Hour')
plt.legend()
plt.grid(True)
plt.show()
Number of anomalies detected: 4
Detected Anomalies:
requests anomaly
timestamp
2024-01-03 02:00:00 5915.605396 -1
2024-01-05 04:00:00 4201.238939 -1
2024-01-06 11:00:00 133.962310 -1
2024-01-06 14:00:00 115.642400 -1
As you can see from the plot, the model successfully identified the large spikes and the sharp drop as anomalous events.
Other Use Cases
These approaches can be easily extended to more complex systems and data streams, providing a powerful layer of defense and insight for any operation.
Common Use Cases:
- Fraud Detection in Financial Transactions 💸
- Predictive Maintenance for Industrial Sensors ⚙️
- User Behavior Analysis on a Website 💻