Skip to content

Heatmap

Visualize a heatmap of the topic's similarity matrix

Based on the cosine similarity matrix between topic embeddings, a heatmap is created showing the similarity between topics.

Parameters:

Name Type Description Default
topic_model

A fitted BERTopic instance.

required
topics List[int]

A selection of topics to visualize.

None
top_n_topics int

Only select the top n most frequent topics.

None
n_clusters int

Create n clusters and order the similarity matrix by those clusters.

None
custom_labels Union[bool, str]

If bool, whether to use custom topic labels that were defined using topic_model.set_topic_labels. If str, it uses labels from other aspects, e.g., "Aspect1".

False
title str

Title of the plot.

'<b>Similarity Matrix</b>'
width int

The width of the figure.

800
height int

The height of the figure.

800

Returns:

Name Type Description
fig Figure

A plotly figure

Examples:

To visualize the similarity matrix of topics simply run:

topic_model.visualize_heatmap()

Or if you want to save the resulting figure:

fig = topic_model.visualize_heatmap()
fig.write_html("path/to/file.html")
Source code in bertopic\plotting\_heatmap.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def visualize_heatmap(topic_model,
                      topics: List[int] = None,
                      top_n_topics: int = None,
                      n_clusters: int = None,
                      custom_labels: Union[bool, str] = False,
                      title: str = "<b>Similarity Matrix</b>",
                      width: int = 800,
                      height: int = 800) -> go.Figure:
    """ Visualize a heatmap of the topic's similarity matrix

    Based on the cosine similarity matrix between topic embeddings,
    a heatmap is created showing the similarity between topics.

    Arguments:
        topic_model: A fitted BERTopic instance.
        topics: A selection of topics to visualize.
        top_n_topics: Only select the top n most frequent topics.
        n_clusters: Create n clusters and order the similarity
                    matrix by those clusters.
        custom_labels: If bool, whether to use custom topic labels that were defined using 
                       `topic_model.set_topic_labels`.
                       If `str`, it uses labels from other aspects, e.g., "Aspect1".
        title: Title of the plot.
        width: The width of the figure.
        height: The height of the figure.

    Returns:
        fig: A plotly figure

    Examples:

    To visualize the similarity matrix of
    topics simply run:

    ```python
    topic_model.visualize_heatmap()
    ```

    Or if you want to save the resulting figure:

    ```python
    fig = topic_model.visualize_heatmap()
    fig.write_html("path/to/file.html")
    ```
    <iframe src="../../getting_started/visualization/heatmap.html"
    style="width:1000px; height: 720px; border: 0px;""></iframe>
    """

    # Select topic embeddings
    if topic_model.topic_embeddings_ is not None:
        embeddings = np.array(topic_model.topic_embeddings_)[topic_model._outliers:]
    else:
        embeddings = topic_model.c_tf_idf_[topic_model._outliers:]

    # Select topics based on top_n and topics args
    freq_df = topic_model.get_topic_freq()
    freq_df = freq_df.loc[freq_df.Topic != -1, :]
    if topics is not None:
        topics = list(topics)
    elif top_n_topics is not None:
        topics = sorted(freq_df.Topic.to_list()[:top_n_topics])
    else:
        topics = sorted(freq_df.Topic.to_list())

    # Order heatmap by similar clusters of topics
    sorted_topics = topics
    if n_clusters:
        if n_clusters >= len(set(topics)):
            raise ValueError("Make sure to set `n_clusters` lower than "
                             "the total number of unique topics.")

        distance_matrix = cosine_similarity(embeddings[topics])
        Z = linkage(distance_matrix, 'ward')
        clusters = fcluster(Z, t=n_clusters, criterion='maxclust')

        # Extract new order of topics
        mapping = {cluster: [] for cluster in clusters}
        for topic, cluster in zip(topics, clusters):
            mapping[cluster].append(topic)
        mapping = [cluster for cluster in mapping.values()]
        sorted_topics = [topic for cluster in mapping for topic in cluster]

    # Select embeddings
    indices = np.array([topics.index(topic) for topic in sorted_topics])
    embeddings = embeddings[indices]
    distance_matrix = cosine_similarity(embeddings)

    # Create labels
    if isinstance(custom_labels, str):
        new_labels = [[[str(topic), None]] + topic_model.topic_aspects_[custom_labels][topic] for topic in sorted_topics]
        new_labels = ["_".join([label[0] for label in labels[:4]]) for labels in new_labels]
        new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels]
    elif topic_model.custom_labels_ is not None and custom_labels:
        new_labels = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in sorted_topics]
    else:
        new_labels = [[[str(topic), None]] + topic_model.get_topic(topic) for topic in sorted_topics]
        new_labels = ["_".join([label[0] for label in labels[:4]]) for labels in new_labels]
        new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels]

    fig = px.imshow(distance_matrix,
                    labels=dict(color="Similarity Score"),
                    x=new_labels,
                    y=new_labels,
                    color_continuous_scale='GnBu'
                    )

    fig.update_layout(
        title={
            'text': f"{title}",
            'y': .95,
            'x': 0.55,
            'xanchor': 'center',
            'yanchor': 'top',
            'font': dict(
                size=22,
                color="Black")
        },
        width=width,
        height=height,
        hoverlabel=dict(
            bgcolor="white",
            font_size=16,
            font_family="Rockwell"
        ),
    )
    fig.update_layout(showlegend=True)
    fig.update_layout(legend_title_text='Trend')

    return fig