ZeroShotClassification
¶
Zero-shot Classification on topic keywords with candidate labels.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
candidate_topics |
List[str] |
A list of labels to assign to the topics if they
exceed |
required |
model |
str |
A transformers pipeline that should be initialized as
"zero-shot-classification". For example,
|
'facebook/bart-large-mnli' |
pipeline_kwargs |
Mapping[str, Any] |
Kwargs that you can pass to the transformers.pipeline
when it is called. NOTE: Use |
{} |
min_prob |
float |
The minimum probability to assign a candidate label to a topic |
0.8 |
Usage:
from bertopic.representation import ZeroShotClassification
from bertopic import BERTopic
# Create your representation model
candidate_topics = ["space and nasa", "bicycles", "sports"]
representation_model = ZeroShotClassification(candidate_topics, model="facebook/bart-large-mnli")
# Use the representation model in BERTopic on top of the default pipeline
topic_model = BERTopic(representation_model=representation_model)
Source code in bertopic\representation\_zeroshot.py
class ZeroShotClassification(BaseRepresentation):
"""Zero-shot Classification on topic keywords with candidate labels.
Arguments:
candidate_topics: A list of labels to assign to the topics if they
exceed `min_prob`
model: A transformers pipeline that should be initialized as
"zero-shot-classification". For example,
`pipeline("zero-shot-classification", model="facebook/bart-large-mnli")`
pipeline_kwargs: Kwargs that you can pass to the transformers.pipeline
when it is called. NOTE: Use `{"multi_label": True}`
to extract multiple labels for each topic.
min_prob: The minimum probability to assign a candidate label to a topic
Usage:
```python
from bertopic.representation import ZeroShotClassification
from bertopic import BERTopic
# Create your representation model
candidate_topics = ["space and nasa", "bicycles", "sports"]
representation_model = ZeroShotClassification(candidate_topics, model="facebook/bart-large-mnli")
# Use the representation model in BERTopic on top of the default pipeline
topic_model = BERTopic(representation_model=representation_model)
```
"""
def __init__(
self,
candidate_topics: List[str],
model: str = "facebook/bart-large-mnli",
pipeline_kwargs: Mapping[str, Any] = {},
min_prob: float = 0.8,
):
self.candidate_topics = candidate_topics
if isinstance(model, str):
self.model = pipeline("zero-shot-classification", model=model)
elif isinstance(model, Pipeline):
self.model = model
else:
raise ValueError(
"Make sure that the HF model that you"
"pass is either a string referring to a"
"HF model or a `transformers.pipeline` object."
)
self.pipeline_kwargs = pipeline_kwargs
self.min_prob = min_prob
def extract_topics(
self,
topic_model,
documents: pd.DataFrame,
c_tf_idf: csr_matrix,
topics: Mapping[str, List[Tuple[str, float]]],
) -> Mapping[str, List[Tuple[str, float]]]:
"""Extract topics.
Arguments:
topic_model: Not used
documents: Not used
c_tf_idf: Not used
topics: The candidate topics as calculated with c-TF-IDF
Returns:
updated_topics: Updated topic representations
"""
# Classify topics
topic_descriptions = [" ".join(list(zip(*topics[topic]))[0]) for topic in topics.keys()]
classifications = self.model(topic_descriptions, self.candidate_topics, **self.pipeline_kwargs)
# Extract labels
updated_topics = {}
for topic, classification in zip(topics.keys(), classifications):
topic_description = topics[topic]
# Multi-label assignment
if self.pipeline_kwargs.get("multi_label"):
topic_description = []
for label, score in zip(classification["labels"], classification["scores"]):
if score > self.min_prob:
topic_description.append((label, score))
# Single label assignment
elif classification["scores"][0] > self.min_prob:
topic_description = [(classification["labels"][0], classification["scores"][0])]
# Make sure that 10 items are returned
if len(topic_description) == 0:
topic_description = topics[topic]
elif len(topic_description) < 10:
topic_description += [("", 0) for _ in range(10 - len(topic_description))]
updated_topics[topic] = topic_description
return updated_topics
extract_topics(self, topic_model, documents, c_tf_idf, topics)
¶
Extract topics.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
topic_model |
Not used |
required | |
documents |
DataFrame |
Not used |
required |
c_tf_idf |
csr_matrix |
Not used |
required |
topics |
Mapping[str, List[Tuple[str, float]]] |
The candidate topics as calculated with c-TF-IDF |
required |
Returns:
Type | Description |
---|---|
updated_topics |
Updated topic representations |
Source code in bertopic\representation\_zeroshot.py
def extract_topics(
self,
topic_model,
documents: pd.DataFrame,
c_tf_idf: csr_matrix,
topics: Mapping[str, List[Tuple[str, float]]],
) -> Mapping[str, List[Tuple[str, float]]]:
"""Extract topics.
Arguments:
topic_model: Not used
documents: Not used
c_tf_idf: Not used
topics: The candidate topics as calculated with c-TF-IDF
Returns:
updated_topics: Updated topic representations
"""
# Classify topics
topic_descriptions = [" ".join(list(zip(*topics[topic]))[0]) for topic in topics.keys()]
classifications = self.model(topic_descriptions, self.candidate_topics, **self.pipeline_kwargs)
# Extract labels
updated_topics = {}
for topic, classification in zip(topics.keys(), classifications):
topic_description = topics[topic]
# Multi-label assignment
if self.pipeline_kwargs.get("multi_label"):
topic_description = []
for label, score in zip(classification["labels"], classification["scores"]):
if score > self.min_prob:
topic_description.append((label, score))
# Single label assignment
elif classification["scores"][0] > self.min_prob:
topic_description = [(classification["labels"][0], classification["scores"][0])]
# Make sure that 10 items are returned
if len(topic_description) == 0:
topic_description = topics[topic]
elif len(topic_description) < 10:
topic_description += [("", 0) for _ in range(10 - len(topic_description))]
updated_topics[topic] = topic_description
return updated_topics