#!/usr/bin/env python # -*- coding: utf-8 -*- # # Copyright (C) 2011 Radim Rehurek # Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html """ Automated tests for probability estimation algorithms in the probability_estimation module. """ import logging import unittest from gensim.corpora.dictionary import Dictionary from gensim.corpora.hashdictionary import HashDictionary from gensim.topic_coherence import probability_estimation class BaseTestCases: class ProbabilityEstimationBase(unittest.TestCase): texts = [ ['human', 'interface', 'computer'], ['eps', 'user', 'interface', 'system'], ['system', 'human', 'system', 'eps'], ['user', 'response', 'time'], ['trees'], ['graph', 'trees'] ] dictionary = None def build_segmented_topics(self): # Suppose the segmented topics from s_one_pre are: token2id = self.dictionary.token2id computer_id = token2id['computer'] system_id = token2id['system'] user_id = token2id['user'] graph_id = token2id['graph'] self.segmented_topics = [ [ (system_id, graph_id), (computer_id, graph_id), (computer_id, system_id) ], [ (computer_id, graph_id), (user_id, graph_id), (user_id, computer_id) ] ] self.computer_id = computer_id self.system_id = system_id self.user_id = user_id self.graph_id = graph_id def setup_dictionary(self): raise NotImplementedError def setUp(self): self.setup_dictionary() self.corpus = [self.dictionary.doc2bow(text) for text in self.texts] self.build_segmented_topics() def test_p_boolean_document(self): """Test p_boolean_document()""" accumulator = probability_estimation.p_boolean_document( self.corpus, self.segmented_topics) obtained = accumulator.index_to_dict() expected = { self.graph_id: {5}, self.user_id: {1, 3}, self.system_id: {1, 2}, self.computer_id: {0} } self.assertEqual(expected, obtained) def test_p_boolean_sliding_window(self): """Test p_boolean_sliding_window()""" # Test with window size as 2. window_id is zero indexed. accumulator = probability_estimation.p_boolean_sliding_window( self.texts, self.segmented_topics, self.dictionary, 2) self.assertEqual(1, accumulator[self.computer_id]) self.assertEqual(3, accumulator[self.user_id]) self.assertEqual(1, accumulator[self.graph_id]) self.assertEqual(4, accumulator[self.system_id]) class TestProbabilityEstimation(BaseTestCases.ProbabilityEstimationBase): def setup_dictionary(self): self.dictionary = HashDictionary(self.texts) class TestProbabilityEstimationWithNormalDictionary(BaseTestCases.ProbabilityEstimationBase): def setup_dictionary(self): self.dictionary = Dictionary(self.texts) self.dictionary.id2token = {v: k for k, v in self.dictionary.token2id.items()} if __name__ == '__main__': logging.root.setLevel(logging.WARNING) unittest.main()