from datetime import datetime
import numpy as np
from loguru import logger
from sklearn.cluster import KMeans
from sklearn.preprocessing import OneHotEncoder
from ..preprocessor import c_tf_idf, extract_tfidf_topics
from ..utils.dataset import TMDataset
from .abstract_helper_models.base import BaseModel, TrainingStatus
from .abstract_helper_models.mixins import SentenceEncodingMixin
time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
MODEL_NAME = "KmeansTM"
EMBEDDING_MODEL_NAME = "paraphrase-MiniLM-L3-v2"
# logger.add(f"{MODEL_NAME}_{time}.log", backtrace=True, diagnose=True)
[docs]class KmeansTM(BaseModel, SentenceEncodingMixin):
"""
A topic modeling class that uses K-Means clustering on text data.
This class inherits from the BaseModel class and utilizes sentence embeddings,
UMAP for dimensionality reduction, and K-Means for clustering text data into topics.
Parameters
----------
embedding_model_name : str
Name of the sentence embedding model to use.
umap_args : dict
Arguments for UMAP dimensionality reduction.
kmeans_args : dict
Arguments for K-Means clustering.
embeddings_path : str
Path to the folder containing embeddings.
embeddings_file_path : str
Path to the file containing embeddings.
trained : bool
Flag indicating whether the model has been trained.
save_embeddings : bool
Whether to save generated embeddings.
n_topics : int or None
Number of topics to extract.
"""
def __init__(
self,
embedding_model_name: str = EMBEDDING_MODEL_NAME,
umap_args: dict = None,
kmeans_args: dict = None,
random_state: int = None,
embeddings_folder_path: str = None,
embeddings_file_path: str = None,
save_embeddings: bool = False,
**kwargs,
):
"""
Initialize the KmeansTM model.
Parameters
----------
num_topics : int, optional
Number of topics to extract, by default 20
embedding_model_name : str, optional
Name of the sentence embedding model to use, by default "paraphrase-MiniLM-L3-v2"
umap_args : dict, optional
Arguments for UMAP dimensionality reduction, by default {}
kmeans_args : dict, optional
Arguments for K-Means clustering, by default {}
random_state : int, optional
Random state for UMAP, by default None
embeddings_folder_path : str, optional
Path to folder to save embeddings, by default None
embeddings_file_path : str, optional
Path to specific embeddings file, by default None
**kwargs
Additional keyword arguments passed to the superclass.
"""
super().__init__(use_pretrained_embeddings=True, **kwargs)
self.save_hyperparameters(
ignore=[
"embeddings_file_path",
"embeddings_folder_path",
"random_state",
"save_embeddings",
]
)
self.embedding_model_name = self.hparams.get(
"embedding_model_name", embedding_model_name
)
self.umap_args = self.hparams.get(
"umap_args",
umap_args
or {
"n_neighbors": 15,
"n_components": 15,
"metric": "cosine",
},
)
self.kmeans_args = self.hparams.get("kmeans_args", kmeans_args or {})
if random_state is not None:
self.umap_args["random_state"] = random_state
self.embeddings_path = embeddings_folder_path
self.embeddings_file_path = embeddings_file_path
self.save_embeddings = save_embeddings
self.n_topics = None
self._status = TrainingStatus.NOT_STARTED
[docs] def get_info(self):
"""
Get information about the model.
Returns
-------
dict
Dictionary containing model information including model name,
number of topics, embedding model name, UMAP arguments,
K-Means arguments, and training status.
"""
info = {
"model_name": MODEL_NAME,
"num_topics": self.n_topics,
"embedding_model": self.embedding_model_name,
"umap_args": self.umap_args,
"kmeans_args": self.kmeans_args,
"trained": self._status.name,
}
return info
def _prepare_embeddings(self, dataset):
"""
Prepares the dataset for clustering.
Parameters
----------
dataset : Dataset
The dataset to be used for clustering.
"""
if dataset.has_embeddings(self.embedding_model_name):
logger.info(
f"--- Loading precomputed {EMBEDDING_MODEL_NAME} embeddings ---"
)
self.embeddings = dataset.get_embeddings(
self.embedding_model_name,
self.embeddings_path,
self.embeddings_file_path,
)
self.dataframe = dataset.dataframe
else:
logger.info(
f"--- Creating {EMBEDDING_MODEL_NAME} document embeddings ---")
self.embeddings = self.encode_documents(
dataset.texts, encoder_model=self.embedding_model_name, use_average=True
)
if self.save_embeddings:
dataset.save_embeddings(
self.embeddings,
self.embedding_model_name,
self.embeddings_path,
self.embeddings_file_path,
)
self.dataframe = dataset.dataframe
def _clustering(self):
"""
Applies K-Means clustering to the reduced embeddings.
Raises
------
ValueError
If an error occurs during clustering.
"""
assert (
hasattr(
self, "reduced_embeddings") and self.reduced_embeddings is not None
), "Reduced embeddings must be generated before clustering."
try:
logger.info("--- Creating document cluster ---")
self.clustering_model = KMeans(
n_clusters=self.n_topics, **self.kmeans_args)
self.clustering_model.fit(self.reduced_embeddings)
self.labels = self.clustering_model.labels_
labels = np.array(self.labels)
self.topic_centroids = []
for label in np.unique(labels):
label_embeddings = self.embeddings[labels == label]
mean_embedding = np.mean(label_embeddings, axis=0)
self.topic_centroids.append(mean_embedding)
except Exception as e:
raise RuntimeError(f"Error in clustering: {e}") from e
[docs] def fit(
self,
dataset: TMDataset = None,
n_topics: int = 20,
):
"""
Trains the K-Means topic model on the provided dataset.
Parameters
----------
dataset : Dataset
The dataset to train the model on.
n_topics : int, optional
Number of topics to extract, by default 20
Raises
------
AssertionError
If the dataset is not an instance of TMDataset.
"""
assert isinstance(
dataset, TMDataset
), "The dataset must be an instance of TMDataset."
self.n_topics = n_topics
if self.n_topics <= 0:
raise ValueError("Number of topics must be greater than 0.")
self._status = TrainingStatus.INITIALIZED
try:
logger.info(f"--- Training {MODEL_NAME} topic model ---")
self._status = TrainingStatus.RUNNING
self.dataframe, self.embeddings = self.prepare_embeddings(
dataset, logger)
self.reduced_embeddings = self.dim_reduction(logger)
self._clustering()
self.dataframe["predictions"] = self.labels
docs_per_topic = self.dataframe.groupby(
["predictions"], as_index=False
).agg({"text": " ".join})
tfidf, count = c_tf_idf(
docs_per_topic["text"].values, m=len(self.dataframe)
)
self.topic_dict = extract_tfidf_topics(
tfidf, count, docs_per_topic, n=100)
one_hot_encoder = OneHotEncoder(sparse=False)
predictions_one_hot = one_hot_encoder.fit_transform(
self.dataframe[["predictions"]]
)
self.beta = tfidf
self.theta = predictions_one_hot
except Exception as e:
logger.error(f"Error in training: {e}")
self._status = TrainingStatus.FAILED
raise
except KeyboardInterrupt:
logger.error("Training interrupted.")
self._status = TrainingStatus.INTERRUPTED
raise
logger.info("--- Training completed successfully. ---")
self._status = TrainingStatus.SUCCEEDED
[docs] def predict(self, texts):
"""
Predict topics for new documents.
Parameters
----------
texts : list of str
List of texts to predict topics for.
Returns
-------
list of int
List of predicted topic labels.
Raises
------
ValueError
If the model has not been trained yet.
"""
if self._status != TrainingStatus.SUCCEEDED:
raise RuntimeError("Model has not been trained yet or failed.")
embeddings = self.encode_documents(
texts, encoder_model=self.embedding_model_name, use_average=True
)
reduced_embeddings = self.reducer.transform(embeddings)
labels = self.clustering_model.predict(reduced_embeddings)
return labels