Cohere
¶
Use the Cohere API to generate topic labels based on their generative model.
Find more about their models here: https://docs.cohere.ai/docs
Parameters:
Name | Type | Description | Default |
---|---|---|---|
client |
A |
required | |
model |
str |
Model to use within Cohere, defaults to |
'xlarge' |
prompt |
str |
The prompt to be used in the model. If no prompt is given,
|
None |
delay_in_seconds |
float |
The delay in seconds between consecutive prompts in order to prevent RateLimitErrors. |
None |
nr_docs |
int |
The number of documents to pass to OpenAI if a prompt
with the |
4 |
diversity |
float |
The diversity of documents to pass to OpenAI. Accepts values between 0 and 1. A higher values results in passing more diverse documents whereas lower values passes more similar documents. |
None |
doc_length |
int |
The maximum length of each document. If a document is longer, it will be truncated. If None, the entire document is passed. |
None |
tokenizer |
Union[str, Callable] |
The tokenizer used to calculate to split the document into segments
used to count the length of a document.
* If tokenizer is 'char', then the document is split up
into characters which are counted to adhere to |
None |
Usage:
To use this, you will need to install cohere first:
pip install cohere
Then, get yourself an API key and use Cohere's API as follows:
import cohere
from bertopic.representation import Cohere
from bertopic import BERTopic
# Create your representation model
co = cohere.Client(my_api_key)
representation_model = Cohere(co)
# Use the representation model in BERTopic on top of the default pipeline
topic_model = BERTopic(representation_model=representation_model)
You can also use a custom prompt:
prompt = "I have the following documents: [DOCUMENTS]. What topic do they contain?"
representation_model = Cohere(co, prompt=prompt)
Source code in bertopic\representation\_cohere.py
class Cohere(BaseRepresentation):
"""Use the Cohere API to generate topic labels based on their
generative model.
Find more about their models here:
https://docs.cohere.ai/docs
Arguments:
client: A `cohere.Client`
model: Model to use within Cohere, defaults to `"xlarge"`.
prompt: The prompt to be used in the model. If no prompt is given,
`self.default_prompt_` is used instead.
NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt
to decide where the keywords and documents need to be
inserted.
delay_in_seconds: The delay in seconds between consecutive prompts
in order to prevent RateLimitErrors.
nr_docs: The number of documents to pass to OpenAI if a prompt
with the `["DOCUMENTS"]` tag is used.
diversity: The diversity of documents to pass to OpenAI.
Accepts values between 0 and 1. A higher
values results in passing more diverse documents
whereas lower values passes more similar documents.
doc_length: The maximum length of each document. If a document is longer,
it will be truncated. If None, the entire document is passed.
tokenizer: The tokenizer used to calculate to split the document into segments
used to count the length of a document.
* If tokenizer is 'char', then the document is split up
into characters which are counted to adhere to `doc_length`
* If tokenizer is 'whitespace', the document is split up
into words separated by whitespaces. These words are counted
and truncated depending on `doc_length`
* If tokenizer is 'vectorizer', then the internal CountVectorizer
is used to tokenize the document. These tokens are counted
and truncated depending on `doc_length`
* If tokenizer is a callable, then that callable is used to tokenize
the document. These tokens are counted and truncated depending
on `doc_length`
Usage:
To use this, you will need to install cohere first:
`pip install cohere`
Then, get yourself an API key and use Cohere's API as follows:
```python
import cohere
from bertopic.representation import Cohere
from bertopic import BERTopic
# Create your representation model
co = cohere.Client(my_api_key)
representation_model = Cohere(co)
# Use the representation model in BERTopic on top of the default pipeline
topic_model = BERTopic(representation_model=representation_model)
```
You can also use a custom prompt:
```python
prompt = "I have the following documents: [DOCUMENTS]. What topic do they contain?"
representation_model = Cohere(co, prompt=prompt)
```
"""
def __init__(
self,
client,
model: str = "xlarge",
prompt: str = None,
delay_in_seconds: float = None,
nr_docs: int = 4,
diversity: float = None,
doc_length: int = None,
tokenizer: Union[str, Callable] = None,
):
self.client = client
self.model = model
self.prompt = prompt if prompt is not None else DEFAULT_PROMPT
self.default_prompt_ = DEFAULT_PROMPT
self.delay_in_seconds = delay_in_seconds
self.nr_docs = nr_docs
self.diversity = diversity
self.doc_length = doc_length
self.tokenizer = tokenizer
self.prompts_ = []
def extract_topics(
self,
topic_model,
documents: pd.DataFrame,
c_tf_idf: csr_matrix,
topics: Mapping[str, List[Tuple[str, float]]],
) -> Mapping[str, List[Tuple[str, float]]]:
"""Extract topics.
Arguments:
topic_model: Not used
documents: Not used
c_tf_idf: Not used
topics: The candidate topics as calculated with c-TF-IDF
Returns:
updated_topics: Updated topic representations
"""
# Extract the top 4 representative documents per topic
repr_docs_mappings, _, _, _ = topic_model._extract_representative_docs(
c_tf_idf, documents, topics, 500, self.nr_docs, self.diversity
)
# Generate using Cohere's Language Model
updated_topics = {}
for topic, docs in tqdm(repr_docs_mappings.items(), disable=not topic_model.verbose):
truncated_docs = [truncate_document(topic_model, self.doc_length, self.tokenizer, doc) for doc in docs]
prompt = self._create_prompt(truncated_docs, topic, topics)
self.prompts_.append(prompt)
# Delay
if self.delay_in_seconds:
time.sleep(self.delay_in_seconds)
request = self.client.generate(
model=self.model,
prompt=prompt,
max_tokens=50,
num_generations=1,
stop_sequences=["\n"],
)
label = request.generations[0].text.strip()
updated_topics[topic] = [(label, 1)] + [("", 0) for _ in range(9)]
return updated_topics
def _create_prompt(self, docs, topic, topics):
keywords = list(zip(*topics[topic]))[0]
# Use the Default Chat Prompt
if self.prompt == DEFAULT_PROMPT:
prompt = self.prompt.replace("[KEYWORDS]", ", ".join(keywords))
prompt = self._replace_documents(prompt, docs)
# Use a custom prompt that leverages keywords, documents or both using
# custom tags, namely [KEYWORDS] and [DOCUMENTS] respectively
else:
prompt = self.prompt
if "[KEYWORDS]" in prompt:
prompt = prompt.replace("[KEYWORDS]", ", ".join(keywords))
if "[DOCUMENTS]" in prompt:
prompt = self._replace_documents(prompt, docs)
return prompt
@staticmethod
def _replace_documents(prompt, docs):
to_replace = ""
for doc in docs:
to_replace += f"- {doc}\n"
prompt = prompt.replace("[DOCUMENTS]", to_replace)
return prompt
extract_topics(self, topic_model, documents, c_tf_idf, topics)
¶
Extract topics.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
topic_model |
Not used |
required | |
documents |
DataFrame |
Not used |
required |
c_tf_idf |
csr_matrix |
Not used |
required |
topics |
Mapping[str, List[Tuple[str, float]]] |
The candidate topics as calculated with c-TF-IDF |
required |
Returns:
Type | Description |
---|---|
updated_topics |
Updated topic representations |
Source code in bertopic\representation\_cohere.py
def extract_topics(
self,
topic_model,
documents: pd.DataFrame,
c_tf_idf: csr_matrix,
topics: Mapping[str, List[Tuple[str, float]]],
) -> Mapping[str, List[Tuple[str, float]]]:
"""Extract topics.
Arguments:
topic_model: Not used
documents: Not used
c_tf_idf: Not used
topics: The candidate topics as calculated with c-TF-IDF
Returns:
updated_topics: Updated topic representations
"""
# Extract the top 4 representative documents per topic
repr_docs_mappings, _, _, _ = topic_model._extract_representative_docs(
c_tf_idf, documents, topics, 500, self.nr_docs, self.diversity
)
# Generate using Cohere's Language Model
updated_topics = {}
for topic, docs in tqdm(repr_docs_mappings.items(), disable=not topic_model.verbose):
truncated_docs = [truncate_document(topic_model, self.doc_length, self.tokenizer, doc) for doc in docs]
prompt = self._create_prompt(truncated_docs, topic, topics)
self.prompts_.append(prompt)
# Delay
if self.delay_in_seconds:
time.sleep(self.delay_in_seconds)
request = self.client.generate(
model=self.model,
prompt=prompt,
max_tokens=50,
num_generations=1,
stop_sequences=["\n"],
)
label = request.generations[0].text.strip()
updated_topics[topic] = [(label, 1)] + [("", 0) for _ in range(9)]
return updated_topics