Using Attention Analysis To Explain Model Decisions
The real businees value of data analytics lies in an ability to explain why and how a model made a decision. Most classic models work as a black box and it is usually a data scientist job to find the answers. For example, a task of predicting customer's churn is usually framed as a binary classification problem and resulting models predict whether customer is a churner or not, without answering questions why and how a customer became a churner. Finding insigths in ML models is the core idea behing Explainable AI (XAI).
In this notebook we'll build a customer churn predictor with the Attention Mechanism which will help us analyze customer's journey and possible help us to find the inshights on when and why the custormers churned.
The Dataset
We'll use the Customer Churn Dataset from Kaggle:
import warnings; warnings.simplefilter('ignore')
import os
import kagglehub
import pandas as pd
import numpy as np
from IPython.display import display, Markdown
print_df = lambda df: display(Markdown(df.to_markdown()))
path = kagglehub.dataset_download("sahideseker/customer-churn-prediction-dataset")
df = pd.read_csv(os.path.join(path, 'customer_churn_dataset.csv'))
print_df( df.head(3) )| customer_id | age | tenure | service_type | monthly_fee | churn | |
|---|---|---|---|---|---|---|
| 0 | C1000 | 56 | 14 | bundle | 47.54 | 0 |
| 1 | C1001 | 69 | 63 | tv | 31.89 | 0 |
| 2 | C1002 | 46 | 27 | tv | 129.43 | 0 |
Normalize Data
First we need to build the sequential input for LSTM layers.
We need to simulate historical engagement sequences based on the static features in the DataFrame:
import tensorflow as tf
from tensorflow.keras.preprocessing.sequence import pad_sequences
def create_sequences_from_df(df, max_seq_length):
sequences = []
labels = []
for _, row in df.iterrows():
tenure = int(row['tenure'])
is_churner = row['churn'] == 1
base_engagement = row['monthly_fee'] / 150.0 # Normalize fee to get a base engagement level
# Create a sequence based on churn status
if is_churner:
# Churners have declining engagement over their tenure
sequence = np.linspace(base_engagement, 0.1, tenure)
# Add some noise
sequence += np.random.normal(0, 0.05, size=sequence.shape)
else:
# Non-churners have stable (but noisy) engagement
sequence = np.full(tenure, base_engagement)
sequence += np.random.normal(0, 0.05, size=sequence.shape)
sequences.append(np.clip(sequence, 0, 1)) # Ensure engagement is between 0 and 1
labels.append(row['churn'])
# Pad sequences to ensure they all have the same length for the model
padded_sequences = pad_sequences(sequences, maxlen=max_seq_length, dtype='float32', padding='pre', truncating='pre')
return padded_sequences, np.array(labels)The Attention Layer
We'll create a simple attention layer as it's shown on keras tutorials:
class Attention(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super(Attention, self).__init__(**kwargs)
def build(self, input_shape):
self.W = self.add_weight(name='attention_weight', shape=(input_shape[-1], 1),
initializer='random_normal', trainable=True)
self.b = self.add_weight(name='attention_bias', shape=(input_shape[1], 1),
initializer='zeros', trainable=True)
super(Attention, self).build(input_shape)
def call(self, x):
e = tf.tanh(tf.tensordot(x, self.W, axes=1) + self.b)
a = tf.nn.softmax(e, axis=1)
output = x * a
return tf.reduce_sum(output, axis=1), aBuild The Model
We'll use the Attention Layer after LSTM:
MAX_TIMESTEPS = 60
X, y = create_sequences_from_df(df, MAX_TIMESTEPS)
X = X.reshape((len(X), MAX_TIMESTEPS, 1))You may need to give the layers explicit IDs. For this example we use autogenerated 'dense' and 'attention':
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, LSTM, Dense, Dropout
inputs = Input(shape=(MAX_TIMESTEPS, 1))
lstm_out = LSTM(32, return_sequences=True)(inputs)
context_vector, attention_weights = Attention()(lstm_out)
dropout_out = Dropout(0.2)(context_vector)
outputs = Dense(1, activation='sigmoid')(dropout_out)
model = Model(inputs=inputs, outputs=[outputs, attention_weights])
model.compile(optimizer='adam',
loss={'dense': 'binary_crossentropy', 'attention': 'mse'}, # Dummy loss for attention
loss_weights={'dense': 1., 'attention': 0.}, # We only care about the prediction loss
metrics={'dense': 'accuracy'})
# model.summary()Training
Notice that we provide dummy labels for the attention layer output:
dummy_attention_labels = np.zeros((X.shape[0], X.shape[1], 1))
model.fit(X, [y, dummy_attention_labels], epochs=20, batch_size=32, validation_split=0.2)Visualizations
We will select some customers and visualize churn and model attention on one plot:
import matplotlib.pyplot as plt
num_churners_to_plot = 3
num_non_churners_to_plot = 2
churner_indices = df[df['churn'] == 1].index[:num_churners_to_plot]
non_churner_indices = df[df['churn'] == 0].index[:num_non_churners_to_plot]
indices_to_plot = list(churner_indices) + list(non_churner_indices)
num_plots = len(indices_to_plot)
fig, axes = plt.subplots(num_plots, 1, figsize=(16, 6 * num_plots))
fig.suptitle('Attention Analysis for Multiple Customers', fontsize=20, y=0.99)
for i, customer_index in enumerate(indices_to_plot):
ax1 = axes[i]
# Get data for this specific customer
customer_info = df.loc[customer_index]
customer_sequence = X[customer_index:customer_index+1]
# Make prediction and get attention weights
prediction, attention = model.predict(customer_sequence)
churn_prob = prediction[0][0]
attention_scores = attention[0].flatten()
# Plotting logic for this subplot
actual_status = 'Churner' if customer_info['churn'] == 1 else 'Non-Churner'
plot_title = (f"Customer: {customer_info['customer_id']} | Actual Status: {actual_status}\n"
f"Predicted Churn Probability: {churn_prob:.2f} | Tenure: {customer_info['tenure']} months")
ax1.set_title(plot_title, fontsize=14)
# Plot the simulated engagement data
color = 'tab:blue'
ax1.set_xlabel('Time (Months Before Analysis)')
ax1.set_ylabel('Simulated Engagement', color=color)
ax1.plot(range(MAX_TIMESTEPS), customer_sequence.flatten(), color=color, label='Engagement')
ax1.tick_params(axis='y', labelcolor=color)
ax1.set_ylim(0, 1)
ax1.legend(loc='upper left')
# Plot attention weights on a second y-axis
ax2 = ax1.twinx()
color = 'tab:red'
ax2.set_ylabel('Attention Weight', color=color)
ax2.bar(range(MAX_TIMESTEPS), attention_scores, color=color, alpha=0.6, label='Attention')
ax2.tick_params(axis='y', labelcolor=color)
ax2.set_ylim(0, max(attention_scores) * 1.15 if max(attention_scores) > 0 else 0.1)
ax2.legend(loc='upper right')
fig.tight_layout(rect=[0, 0, 1, 0.98]) # Adjust layout to make room for suptitle
plt.show()[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 93ms/step [1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step [1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step [1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step [1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step
Explaining The Visualization
This kind of plots help to understand when the "turning point" happend for a partucular customer. This way we can identify when the customer become a churner and proceed with fundamental analysis to find out why.
In the charts above, you can clearly see:- For the first ~40 days, the model pays little attention.
- Around day 45, there's a dip in engagement, and the model's attention spikes dramatically. This is the turning point.
- The model continues to pay high attention after this point, confirming that the events around day 45 were critical to its final decision.
Now, a business analyst can go back and look at the raw data for that specific customer around day 45. What happened? Did they have a bad support interaction? Did a new competitor launch an ad campaign? You've turned a black box prediction into a targeted, actionable insight.
Conclusion
Customer Churn demostates well how the attentions analysys may be applied to a DNN, however its applications go beyond one use case.