Skip to content

Using Attention Analysis To Explain Model Decisions

DOI

The real businees value of data analytics lies in an ability to explain why and how a model made a decision. Most classic models work as a black box and it is usually a data scientist job to find the answers. For example, a task of predicting customer's churn is usually framed as a binary classification problem and resulting models predict whether customer is a churner or not, without answering questions why and how a customer became a churner. Finding insigths in ML models is the core idea behing Explainable AI (XAI).

In this notebook we'll build a customer churn predictor with the Attention Mechanism which will help us analyze customer's journey and possible help us to find the inshights on when and why the custormers churned.

The Dataset

We'll use the Customer Churn Dataset from Kaggle:

import warnings; warnings.simplefilter('ignore')
import os
import kagglehub
import pandas as pd
import numpy as np
 
from IPython.display import display, Markdown
print_df = lambda df: display(Markdown(df.to_markdown()))
 
 
path = kagglehub.dataset_download("sahideseker/customer-churn-prediction-dataset")
df = pd.read_csv(os.path.join(path, 'customer_churn_dataset.csv'))
 
print_df( df.head(3) )
customer_idagetenureservice_typemonthly_feechurn
0C10005614bundle47.540
1C10016963tv31.890
2C10024627tv129.430

Normalize Data

First we need to build the sequential input for LSTM layers.
We need to simulate historical engagement sequences based on the static features in the DataFrame:

import tensorflow as tf
from tensorflow.keras.preprocessing.sequence import pad_sequences
 
def create_sequences_from_df(df, max_seq_length):
    sequences = []
    labels = []
    
    for _, row in df.iterrows():
        tenure = int(row['tenure'])
        is_churner = row['churn'] == 1
        base_engagement = row['monthly_fee'] / 150.0 # Normalize fee to get a base engagement level
        
        # Create a sequence based on churn status
        if is_churner:
            # Churners have declining engagement over their tenure
            sequence = np.linspace(base_engagement, 0.1, tenure)
            # Add some noise
            sequence += np.random.normal(0, 0.05, size=sequence.shape)
        else:
            # Non-churners have stable (but noisy) engagement
            sequence = np.full(tenure, base_engagement)
            sequence += np.random.normal(0, 0.05, size=sequence.shape)
        
        sequences.append(np.clip(sequence, 0, 1)) # Ensure engagement is between 0 and 1
        labels.append(row['churn'])
        
    # Pad sequences to ensure they all have the same length for the model
    padded_sequences = pad_sequences(sequences, maxlen=max_seq_length, dtype='float32', padding='pre', truncating='pre')
    return padded_sequences, np.array(labels)

The Attention Layer

We'll create a simple attention layer as it's shown on keras tutorials:

class Attention(tf.keras.layers.Layer):
  def __init__(self, **kwargs):
    super(Attention, self).__init__(**kwargs)
 
  def build(self, input_shape):
    self.W = self.add_weight(name='attention_weight', shape=(input_shape[-1], 1),
                              initializer='random_normal', trainable=True)
    self.b = self.add_weight(name='attention_bias', shape=(input_shape[1], 1),
                              initializer='zeros', trainable=True)
    super(Attention, self).build(input_shape)
 
  def call(self, x):
    e = tf.tanh(tf.tensordot(x, self.W, axes=1) + self.b)
    a = tf.nn.softmax(e, axis=1)
    output = x * a
    return tf.reduce_sum(output, axis=1), a

Build The Model

We'll use the Attention Layer after LSTM:

MAX_TIMESTEPS = 60
 
X, y = create_sequences_from_df(df, MAX_TIMESTEPS)
X = X.reshape((len(X), MAX_TIMESTEPS, 1))

You may need to give the layers explicit IDs. For this example we use autogenerated 'dense' and 'attention':

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, LSTM, Dense, Dropout
 
inputs = Input(shape=(MAX_TIMESTEPS, 1))
lstm_out = LSTM(32, return_sequences=True)(inputs)
context_vector, attention_weights = Attention()(lstm_out)
dropout_out = Dropout(0.2)(context_vector)
outputs = Dense(1, activation='sigmoid')(dropout_out)
 
model = Model(inputs=inputs, outputs=[outputs, attention_weights])
 
model.compile(optimizer='adam',
              loss={'dense': 'binary_crossentropy', 'attention': 'mse'}, # Dummy loss for attention
              loss_weights={'dense': 1., 'attention': 0.}, # We only care about the prediction loss
              metrics={'dense': 'accuracy'})
# model.summary()

Training

Notice that we provide dummy labels for the attention layer output:

 
dummy_attention_labels = np.zeros((X.shape[0], X.shape[1], 1))
model.fit(X, [y, dummy_attention_labels], epochs=20, batch_size=32, validation_split=0.2)

Visualizations

We will select some customers and visualize churn and model attention on one plot:

import matplotlib.pyplot as plt
 
num_churners_to_plot = 3
num_non_churners_to_plot = 2
 
churner_indices = df[df['churn'] == 1].index[:num_churners_to_plot]
non_churner_indices = df[df['churn'] == 0].index[:num_non_churners_to_plot]
 
indices_to_plot = list(churner_indices) + list(non_churner_indices)
num_plots = len(indices_to_plot)
 
fig, axes = plt.subplots(num_plots, 1, figsize=(16, 6 * num_plots))
fig.suptitle('Attention Analysis for Multiple Customers', fontsize=20, y=0.99)
 
for i, customer_index in enumerate(indices_to_plot):
    ax1 = axes[i]
    
    # Get data for this specific customer
    customer_info = df.loc[customer_index]
    customer_sequence = X[customer_index:customer_index+1]
    
    # Make prediction and get attention weights
    prediction, attention = model.predict(customer_sequence)
    churn_prob = prediction[0][0]
    attention_scores = attention[0].flatten()
    
    # Plotting logic for this subplot
    actual_status = 'Churner' if customer_info['churn'] == 1 else 'Non-Churner'
    plot_title = (f"Customer: {customer_info['customer_id']} | Actual Status: {actual_status}\n"
                  f"Predicted Churn Probability: {churn_prob:.2f} | Tenure: {customer_info['tenure']} months")
    ax1.set_title(plot_title, fontsize=14)
 
    # Plot the simulated engagement data
    color = 'tab:blue'
    ax1.set_xlabel('Time (Months Before Analysis)')
    ax1.set_ylabel('Simulated Engagement', color=color)
    ax1.plot(range(MAX_TIMESTEPS), customer_sequence.flatten(), color=color, label='Engagement')
    ax1.tick_params(axis='y', labelcolor=color)
    ax1.set_ylim(0, 1)
    ax1.legend(loc='upper left')
 
    # Plot attention weights on a second y-axis
    ax2 = ax1.twinx()
    color = 'tab:red'
    ax2.set_ylabel('Attention Weight', color=color)
    ax2.bar(range(MAX_TIMESTEPS), attention_scores, color=color, alpha=0.6, label='Attention')
    ax2.tick_params(axis='y', labelcolor=color)
    ax2.set_ylim(0, max(attention_scores) * 1.15 if max(attention_scores) > 0 else 0.1)
    ax2.legend(loc='upper right')
 
fig.tight_layout(rect=[0, 0, 1, 0.98]) # Adjust layout to make room for suptitle
plt.show()

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 93ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 23ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step

png

Explaining The Visualization

This kind of plots help to understand when the "turning point" happend for a partucular customer. This way we can identify when the customer become a churner and proceed with fundamental analysis to find out why.

In the charts above, you can clearly see:
  • For the first ~40 days, the model pays little attention.
  • Around day 45, there's a dip in engagement, and the model's attention spikes dramatically. This is the turning point.
  • The model continues to pay high attention after this point, confirming that the events around day 45 were critical to its final decision.

Now, a business analyst can go back and look at the raw data for that specific customer around day 45. What happened? Did they have a bad support interaction? Did a new competitor launch an ad campaign? You've turned a black box prediction into a targeted, actionable insight.

Conclusion

Customer Churn demostates well how the attentions analysys may be applied to a DNN, however its applications go beyond one use case.

Anton Nesterov © 2025 | vski·science