Source code for stream.models.CEDC

from datetime import datetime

import numpy as np
import pandas as pd
from loguru import logger
from sentence_transformers import SentenceTransformer
from sklearn.mixture import GaussianMixture

from ..preprocessor import clean_topics
from ..preprocessor.topic_extraction import TopicExtractor
from ..utils.check_dataset_steps import check_dataset_steps
from ..utils.dataset import TMDataset
from .abstract_helper_models.base import BaseModel, TrainingStatus
from .abstract_helper_models.mixins import SentenceEncodingMixin

DATADIR = "../datasets/preprocessed_datasets"
MODEL_NAME = "CEDC"
EMBEDDING_MODEL_NAME = "paraphrase-MiniLM-L3-v2"
time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
# logger.add(f"{MODEL_NAME}_{time}.log", backtrace=True, diagnose=True)


[docs]class CEDC(BaseModel, SentenceEncodingMixin): """ Class for Clustering-based Embedding-driven Document Clustering (CEDC). Inherits from BaseModel and SentenceEncodingMixin. Parameters ---------- n_topics : int or None Number of topics to extract. embedding_model_name : str Name of the embedding model to use. umap_args : dict Arguments for UMAP dimensionality reduction. gmm_args : dict Arguments for Gaussian Mixture Model (GMM) clustering. embeddings_path : str Path to the folder containing embeddings. embeddings_file_path : str Path to the file containing embeddings. trained : bool Flag indicating whether the model has been trained. save_embeddings : bool Whether to save generated embeddings. """ def __init__( self, embedding_model_name: str = EMBEDDING_MODEL_NAME, umap_args: dict = None, random_state: int = None, gmm_args: dict = None, embeddings_folder_path: str = None, embeddings_file_path: str = None, save_embeddings: bool = False, **kwargs, ): """ Initializes the CEDC model. Parameters ---------- embedding_model_name : str, optional Name of the embedding model (default is "paraphrase-MiniLM-L3-v2"). umap_args : dict, optional Arguments for UMAP dimensionality reduction. random_state : int, optional Random state for reproducibility. gmm_args : dict, optional Arguments for Gaussian Mixture Model (GMM) clustering. embeddings_folder_path : str, optional Path to the folder containing embeddings. embeddings_file_path : str, optional Path to the file containing embeddings. save_embeddings : bool, optional Whether to save generated embeddings. **kwargs Additional keyword arguments passed to super().__init__(). """ super().__init__(use_pretrained_embeddings=True, **kwargs) self.save_hyperparameters( ignore=[ "embeddings_file_path", "embeddings_folder_path", "random_state", "save_embeddings", ] ) # Initialize hyperparameters from self.hparams self.n_topics = None self.embedding_model_name = self.hparams.get( "embedding_model_name", embedding_model_name ) self.umap_args = self.hparams.get( "umap_args", umap_args or { "n_neighbors": 15, "n_components": 15, "metric": "cosine", }, ) self.gmm_args = self.hparams.get( "gmm_args", gmm_args or { "n_components": None, "covariance_type": "full", "tol": 0.001, "reg_covar": 0.000001, "max_iter": 100, "n_init": 1, "init_params": "kmeans", }, ) if random_state is not None: self.umap_args["random_state"] = random_state self.embeddings_path = embeddings_folder_path self.embeddings_file_path = embeddings_file_path self.save_embeddings = save_embeddings self._status = TrainingStatus.NOT_STARTED
[docs] def get_info(self): """ Get information about the model. Returns ------- dict Dictionary containing model information including model name, number of topics, embedding model name, UMAP arguments, K-Means arguments, and training status. """ info = { "model_name": MODEL_NAME, "num_topics": self.n_topics, "embedding_model": self.embedding_model_name, "umap_args": self.umap_args, "kmeans_args": self.gmm_args, "trained": self._status.name, } return info
def _clustering(self): """ Applies GMM clustering to the reduced embeddings. Raises ------ ValueError If an error occurs during clustering. """ assert ( hasattr( self, "reduced_embeddings") and self.reduced_embeddings is not None ), "Reduced embeddings must be generated before clustering." self.gmm_args["n_components"] = self.n_topics try: logger.info("--- Creating document cluster ---") self.GMM = GaussianMixture( **self.gmm_args, ).fit(self.reduced_embeddings) gmm_predictions = self.GMM.predict_proba(self.reduced_embeddings) self.soft_labels = pd.DataFrame(gmm_predictions) self.labels = self.GMM.predict(self.reduced_embeddings) except Exception as e: raise RuntimeError(f"Error in clustering: {e}") from e
[docs] def fit( self, dataset: TMDataset = None, n_topics: int = 20, only_nouns: bool = False, clean: bool = False, clean_threshold: float = 0.85, expansion_corpus: str = "octis", n_words: int = 20, ): """ Trains the CEDC model on the provided dataset. Parameters ---------- dataset : Dataset Dataset containing texts to cluster. n_topics : int, optional Number of topics to extract (default is 20). only_nouns : bool, optional Whether to consider only nouns during topic extraction (default is False). clean : bool, optional Whether to clean topics based on similarity (default is False). clean_threshold : float, optional Threshold for cleaning topics based on similarity (default is 0.85). expansion_corpus : str, optional Corpus for expanding topics (default is 'octis'). n_words : int, optional Number of top words to include in each topic (default is 20). Returns ------- None """ assert isinstance( dataset, TMDataset ), "The dataset must be an instance of TMDataset." check_dataset_steps(dataset, logger, MODEL_NAME) self.n_topics = n_topics if self.n_topics <= 0: raise ValueError("Number of topics must be greater than 0.") self._status = TrainingStatus.INITIALIZED try: logger.info(f"--- Training {MODEL_NAME} topic model ---") self._status = TrainingStatus.RUNNING self.dataframe, self.embeddings = self.prepare_embeddings( dataset, logger) self.reduced_embeddings = self.dim_reduction(logger) self._clustering() assert ( hasattr(self, "soft_labels") and self.soft_labels is not None ), "Clustering must generate labels." TE = TopicExtractor( dataset=dataset, topic_assignments=self.soft_labels, n_topics=self.n_topics, embedding_model=SentenceTransformer(self.embedding_model_name), ) logger.info("--- Extract topics ---") topics, self.topic_centroids = TE._noun_extractor_haystack( self.embeddings, n=n_words + 20, corpus=expansion_corpus, only_nouns=only_nouns, ) except Exception as e: logger.error(f"Error in training: {e}") self._status = TrainingStatus.FAILED raise except KeyboardInterrupt: logger.error("Training interrupted.") self._status = TrainingStatus.INTERRUPTED raise if clean: logger.info("--- Cleaning topics ---") cleaned_topics, cleaned_centroids = clean_topics( topics, similarity=clean_threshold, embedding_model=self.embedding_model ) topics = cleaned_topics self.topic_centroids = cleaned_centroids logger.info("--- Training completed successfully. ---") self._status = TrainingStatus.SUCCEEDED self.topic_dict = topics self.theta = np.array(self.soft_labels) self.beta = self.get_beta()
[docs] def predict(self, texts, proba=True): """ Predict topics for new documents. Parameters ---------- texts : list of str List of texts to predict topics for. Returns ------- list of int List of predicted topic labels. Raises ------ ValueError If the model has not been trained yet. """ if self._status != TrainingStatus.SUCCEEDED: raise RuntimeError("Model has not been trained yet or failed.") embeddings = self.encode_documents( texts, encoder_model=self.embedding_model_name, use_average=True ) reduced_embeddings = self.reducer.transform(embeddings) if proba: labels = self.GMM.predict_proba(reduced_embeddings) else: labels = self.GMM.predict(reduced_embeddings) return labels
[docs] def get_beta(self): """ Constructs a topic-word matrix from the given topic dictionary. Parameters ---------- topic_dict : dict Dictionary where keys are topic indices and values are lists of (word, prevalence) tuples. Returns ------- ndarray Topic-word matrix where rows represent topics and columns represent words. Notes ----- The topic-word matrix is constructed by assigning prevalences of words in topics. Words are sorted alphabetically across all topics. Raises ------ ValueError If the model has not been trained yet. """ if self._status != TrainingStatus.SUCCEEDED: raise RuntimeError("Model has not been trained yet or failed.") assert hasattr(self, "topic_dict"), "Model has no topic_dict." # Extract all unique words and sort them all_words = set(word for topic in self.topic_dict.values() for word, _ in topic) sorted_words = sorted(all_words) # Create an empty DataFrame with sorted words as rows and topics as columns topic_word_matrix = pd.DataFrame( index=sorted_words, columns=sorted(self.topic_dict.keys()), data=0.0 ) # Populate the DataFrame with prevalences for topic, words in self.topic_dict.items(): for word, prevalence in words: if word in topic_word_matrix.index: topic_word_matrix.at[word, topic] = prevalence self.beta = np.array(topic_word_matrix) return self.beta