Skip to content

ZeroShotClassification

Bases: BaseRepresentation

Zero-shot Classification on topic keywords with candidate labels

Parameters:

Name Type Description Default
candidate_topics List[str]

A list of labels to assign to the topics if they exceed min_prob

required
model str

A transformers pipeline that should be initialized as "zero-shot-classification". For example, pipeline("zero-shot-classification", model="facebook/bart-large-mnli")

'facebook/bart-large-mnli'
pipeline_kwargs Mapping[str, Any]

Kwargs that you can pass to the transformers.pipeline when it is called. NOTE: Use {"multi_label": True} to extract multiple labels for each topic.

{}
min_prob float

The minimum probability to assign a candidate label to a topic

0.8

Usage:

from bertopic.representation import ZeroShotClassification
from bertopic import BERTopic

# Create your representation model
candidate_topics = ["space and nasa", "bicycles", "sports"]
representation_model = ZeroShotClassification(candidate_topics, model="facebook/bart-large-mnli")

# Use the representation model in BERTopic on top of the default pipeline
topic_model = BERTopic(representation_model=representation_model)
Source code in bertopic\representation\_zeroshot.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
class ZeroShotClassification(BaseRepresentation):
    """ Zero-shot Classification on topic keywords with candidate labels

    Arguments:
        candidate_topics: A list of labels to assign to the topics if they
                          exceed `min_prob`
        model: A transformers pipeline that should be initialized as
               "zero-shot-classification". For example,
               `pipeline("zero-shot-classification", model="facebook/bart-large-mnli")`
        pipeline_kwargs: Kwargs that you can pass to the transformers.pipeline
                         when it is called. NOTE: Use `{"multi_label": True}`
                         to extract multiple labels for each topic.
        min_prob: The minimum probability to assign a candidate label to a topic

    Usage:

    ```python
    from bertopic.representation import ZeroShotClassification
    from bertopic import BERTopic

    # Create your representation model
    candidate_topics = ["space and nasa", "bicycles", "sports"]
    representation_model = ZeroShotClassification(candidate_topics, model="facebook/bart-large-mnli")

    # Use the representation model in BERTopic on top of the default pipeline
    topic_model = BERTopic(representation_model=representation_model)
    ```
    """
    def __init__(self,
                 candidate_topics: List[str],
                 model: str = "facebook/bart-large-mnli",
                 pipeline_kwargs: Mapping[str, Any] = {},
                 min_prob: float = 0.8
                 ):
        self.candidate_topics = candidate_topics
        if isinstance(model, str):
            self.model = pipeline("zero-shot-classification", model=model)
        elif isinstance(model, Pipeline):
            self.model = model
        else:
            raise ValueError("Make sure that the HF model that you"
                             "pass is either a string referring to a"
                             "HF model or a `transformers.pipeline` object.")
        self.pipeline_kwargs = pipeline_kwargs
        self.min_prob = min_prob

    def extract_topics(self,
                       topic_model,
                       documents: pd.DataFrame,
                       c_tf_idf: csr_matrix,
                       topics: Mapping[str, List[Tuple[str, float]]]
                       ) -> Mapping[str, List[Tuple[str, float]]]:
        """ Extract topics

        Arguments:
            topic_model: Not used
            documents: Not used
            c_tf_idf: Not used
            topics: The candidate topics as calculated with c-TF-IDF

        Returns:
            updated_topics: Updated topic representations
        """
        # Classify topics
        topic_descriptions = [" ".join(list(zip(*topics[topic]))[0]) for topic in topics.keys()]
        classifications = self.model(topic_descriptions, self.candidate_topics, **self.pipeline_kwargs)

        # Extract labels
        updated_topics = {}
        for topic, classification in zip(topics.keys(), classifications):
            topic_description = topics[topic]

            # Multi-label assignment
            if self.pipeline_kwargs.get("multi_label"):
                topic_description = []
                for label, score in zip(classification["labels"], classification["scores"]):
                    if score > self.min_prob:
                        topic_description.append((label, score))

            # Single label assignment
            elif classification["scores"][0] > self.min_prob:
                topic_description = [(classification["labels"][0], classification["scores"][0])]

            # Make sure that 10 items are returned
            if len(topic_description) == 0:
                topic_description = topics[topic]
            elif len(topic_description) < 10:
                topic_description += [("", 0) for _ in range(10-len(topic_description))]
            updated_topics[topic] = topic_description

        return updated_topics

extract_topics(topic_model, documents, c_tf_idf, topics)

Extract topics

Parameters:

Name Type Description Default
topic_model

Not used

required
documents DataFrame

Not used

required
c_tf_idf csr_matrix

Not used

required
topics Mapping[str, List[Tuple[str, float]]]

The candidate topics as calculated with c-TF-IDF

required

Returns:

Name Type Description
updated_topics Mapping[str, List[Tuple[str, float]]]

Updated topic representations

Source code in bertopic\representation\_zeroshot.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def extract_topics(self,
                   topic_model,
                   documents: pd.DataFrame,
                   c_tf_idf: csr_matrix,
                   topics: Mapping[str, List[Tuple[str, float]]]
                   ) -> Mapping[str, List[Tuple[str, float]]]:
    """ Extract topics

    Arguments:
        topic_model: Not used
        documents: Not used
        c_tf_idf: Not used
        topics: The candidate topics as calculated with c-TF-IDF

    Returns:
        updated_topics: Updated topic representations
    """
    # Classify topics
    topic_descriptions = [" ".join(list(zip(*topics[topic]))[0]) for topic in topics.keys()]
    classifications = self.model(topic_descriptions, self.candidate_topics, **self.pipeline_kwargs)

    # Extract labels
    updated_topics = {}
    for topic, classification in zip(topics.keys(), classifications):
        topic_description = topics[topic]

        # Multi-label assignment
        if self.pipeline_kwargs.get("multi_label"):
            topic_description = []
            for label, score in zip(classification["labels"], classification["scores"]):
                if score > self.min_prob:
                    topic_description.append((label, score))

        # Single label assignment
        elif classification["scores"][0] > self.min_prob:
            topic_description = [(classification["labels"][0], classification["scores"][0])]

        # Make sure that 10 items are returned
        if len(topic_description) == 0:
            topic_description = topics[topic]
        elif len(topic_description) < 10:
            topic_description += [("", 0) for _ in range(10-len(topic_description))]
        updated_topics[topic] = topic_description

    return updated_topics