import warnings import unittest import sys from nose.tools import assert_raises from sklearn.utils.testing import ( _assert_less, _assert_greater, assert_less_equal, assert_greater_equal, assert_warns, assert_no_warnings, assert_equal, set_random_state, assert_raise_message) from sklearn.tree import DecisionTreeClassifier from sklearn.discriminant_analysis import LinearDiscriminantAnalysis try: from nose.tools import assert_less def test_assert_less(): # Check that the nose implementation of assert_less gives the # same thing as the scikit's assert_less(0, 1) _assert_less(0, 1) assert_raises(AssertionError, assert_less, 1, 0) assert_raises(AssertionError, _assert_less, 1, 0) except ImportError: pass try: from nose.tools import assert_greater def test_assert_greater(): # Check that the nose implementation of assert_less gives the # same thing as the scikit's assert_greater(1, 0) _assert_greater(1, 0) assert_raises(AssertionError, assert_greater, 0, 1) assert_raises(AssertionError, _assert_greater, 0, 1) except ImportError: pass def test_assert_less_equal(): assert_less_equal(0, 1) assert_less_equal(1, 1) assert_raises(AssertionError, assert_less_equal, 1, 0) def test_assert_greater_equal(): assert_greater_equal(1, 0) assert_greater_equal(1, 1) assert_raises(AssertionError, assert_greater_equal, 0, 1) def test_set_random_state(): lda = LinearDiscriminantAnalysis() tree = DecisionTreeClassifier() # Linear Discriminant Analysis doesn't have random state: smoke test set_random_state(lda, 3) set_random_state(tree, 3) assert_equal(tree.random_state, 3) def test_assert_raise_message(): def _raise_ValueError(message): raise ValueError(message) def _no_raise(): pass assert_raise_message(ValueError, "test", _raise_ValueError, "test") assert_raises(AssertionError, assert_raise_message, ValueError, "something else", _raise_ValueError, "test") assert_raises(ValueError, assert_raise_message, TypeError, "something else", _raise_ValueError, "test") assert_raises(AssertionError, assert_raise_message, ValueError, "test", _no_raise) # multiple exceptions in a tuple assert_raises(AssertionError, assert_raise_message, (ValueError, AttributeError), "test", _no_raise) # This class is inspired from numpy 1.7 with an alteration to check # the reset warning filters after calls to assert_warns. # This assert_warns behavior is specific to scikit-learn because #`clean_warning_registry()` is called internally by assert_warns # and clears all previous filters. class TestWarns(unittest.TestCase): def test_warn(self): def f(): warnings.warn("yo") return 3 # Test that assert_warns is not impacted by externally set # filters and is reset internally. # This is because `clean_warning_registry()` is called internally by # assert_warns and clears all previous filters. warnings.simplefilter("ignore", UserWarning) assert_equal(assert_warns(UserWarning, f), 3) # Test that the warning registry is empty after assert_warns assert_equal(sys.modules['warnings'].filters, []) assert_raises(AssertionError, assert_no_warnings, f) assert_equal(assert_no_warnings(lambda x: x, 1), 1) def test_warn_wrong_warning(self): def f(): warnings.warn("yo", DeprecationWarning) failed = False filters = sys.modules['warnings'].filters[:] try: try: # Should raise an AssertionError assert_warns(UserWarning, f) failed = True except AssertionError: pass finally: sys.modules['warnings'].filters = filters if failed: raise AssertionError("wrong warning caught by assert_warn")