LSTM for Anomaly Detection
Anton A. Nesterov | an (at) vski.sh |
Version | 1.0 |
In this notebook we'll use LSTM Autoencoder to detect unusual patterns in sensory data.
How an LSTM Autoencoder Works
- Train on Normalcy - Train an LSTM model only on data that we know is "normal."
- Reconstruct - The model's job isn't to forecast, but to learn the underlying patterns of the normal data so well that it can reconstruct it. It takes a sequence as input and tries to predict that same sequence as output.
- Calculate Reconstruction Error - When we present the trained model with new, unseen data, we measure how well it was able to reconstruct it.
- Set a Threshold - If the model sees normal data, the reconstruction error will be very low because it looks familiar. If the model sees an anomaly, the error will be high because the pattern is alien to what it learned.
Python Implementation
We'll create a perfect, predictable sine wave as our "normal" data. Then, we will manually inject some sharp spikes into it to act as our anomalies.
import warnings; warnings.simplefilter('ignore')
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from sklearn.preprocessing import MinMaxScaler
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, RepeatVector, TimeDistributed
Sine wave:
time = np.arange(0, 500, 0.5)
normal_data = np.sin(time / 20) * 5
normal_data.shape
Generate anomalies:
anomaly_data = normal_data.copy()
anomaly_data[100] += 10
anomaly_data[300] -= 8
anomaly_data[700] += 12
anomaly_data[900] -= 10
Scale data:
scaler = MinMaxScaler(feature_range=(0, 1))
scaler.fit(normal_data.reshape(-1, 1))
scaled_normal_data = scaler.transform(normal_data.reshape(-1, 1))
scaled_anomaly_data = scaler.transform(anomaly_data.reshape(-1, 1))
def create_sequences(data, time_steps):
X = []
for i in range(len(data) - time_steps):
X.append(data[i:(i + time_steps)])
return np.array(X)
Train and test set. We train on normal data and test on anomalies:
TIME_STEPS = 20
# Create sequences from the data
X_train = create_sequences(scaled_normal_data, TIME_STEPS)
X_test = create_sequences(scaled_anomaly_data, TIME_STEPS)
# Reshape for LSTM [samples, time_steps, features]
X_train = X_train.reshape(X_train.shape[0], TIME_STEPS, 1)
X_test = X_test.reshape(X_test.shape[0], TIME_STEPS, 1)
Create LSTM Autoencoder:
model = Sequential([
# Encoder
LSTM(64, activation='relu', input_shape=(TIME_STEPS, 1), return_sequences=True),
LSTM(32, activation='relu', return_sequences=False),
RepeatVector(TIME_STEPS),
# Decoder
LSTM(32, activation='relu', return_sequences=True),
LSTM(64, activation='relu', return_sequences=True),
TimeDistributed(Dense(1))
])
model.compile(optimizer='adam', loss='mae')
history = model.fit(X_train, X_train, epochs=20, batch_size=32, validation_split=0.1, verbose=1)
Estimate MSA with train data and define treshold:
X_train_pred = model.predict(X_train)
# Calculate the Mean Absolute Error for each sequence in the training data
train_mae_loss = np.mean(np.abs(X_train_pred - X_train), axis=1)
# Set the anomaly threshold (e.g., mean + 2 * standard deviation)
threshold = np.mean(train_mae_loss) + 2 * np.std(train_mae_loss)
Estimate MSA on test data and find anomalies:
X_test_pred = model.predict(X_test)
# Calculate the reconstruction loss for the test data
test_mae_loss = np.mean(np.abs(X_test_pred - X_test), axis=1)
# Identify the anomalies
anomalies = test_mae_loss > threshold
anomaly_indices = np.where(anomalies)[0]
Visualise results:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(18, 12), sharex=True)
fig.suptitle('LSTM Anomaly Detection', fontsize=16)
# Plot 1: Original Data with Detected Anomalies
ax1.plot(time, anomaly_data, label='Original Data', color='blue')
ax1.scatter(time[anomaly_indices + TIME_STEPS], anomaly_data[anomaly_indices + TIME_STEPS],
color='red', marker='o', s=80, label='Detected Anomaly')
ax1.set_ylabel('Value')
ax1.set_title('Data with Detected Anomalies')
ax1.legend()
ax1.grid(True)
# Plot 2: Reconstruction Error
ax2.plot(time[TIME_STEPS:], test_mae_loss.flatten(), label='Reconstruction Error', color='orange')
ax2.axhline(y=threshold, color='r', linestyle='--', label='Anomaly Threshold')
ax2.set_ylabel('Mean Absolute Error')
ax2.set_xlabel('Time')
ax2.set_title('Reconstruction Error vs. Threshold')
ax2.legend()
ax2.grid(True)
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()