1 | # Tests for rich comparisons
|
---|
2 |
|
---|
3 | import unittest
|
---|
4 | from test import test_support
|
---|
5 |
|
---|
6 | import operator
|
---|
7 |
|
---|
8 | class Number:
|
---|
9 |
|
---|
10 | def __init__(self, x):
|
---|
11 | self.x = x
|
---|
12 |
|
---|
13 | def __lt__(self, other):
|
---|
14 | return self.x < other
|
---|
15 |
|
---|
16 | def __le__(self, other):
|
---|
17 | return self.x <= other
|
---|
18 |
|
---|
19 | def __eq__(self, other):
|
---|
20 | return self.x == other
|
---|
21 |
|
---|
22 | def __ne__(self, other):
|
---|
23 | return self.x != other
|
---|
24 |
|
---|
25 | def __gt__(self, other):
|
---|
26 | return self.x > other
|
---|
27 |
|
---|
28 | def __ge__(self, other):
|
---|
29 | return self.x >= other
|
---|
30 |
|
---|
31 | def __cmp__(self, other):
|
---|
32 | raise test_support.TestFailed, "Number.__cmp__() should not be called"
|
---|
33 |
|
---|
34 | def __repr__(self):
|
---|
35 | return "Number(%r)" % (self.x, )
|
---|
36 |
|
---|
37 | class Vector:
|
---|
38 |
|
---|
39 | def __init__(self, data):
|
---|
40 | self.data = data
|
---|
41 |
|
---|
42 | def __len__(self):
|
---|
43 | return len(self.data)
|
---|
44 |
|
---|
45 | def __getitem__(self, i):
|
---|
46 | return self.data[i]
|
---|
47 |
|
---|
48 | def __setitem__(self, i, v):
|
---|
49 | self.data[i] = v
|
---|
50 |
|
---|
51 | __hash__ = None # Vectors cannot be hashed
|
---|
52 |
|
---|
53 | def __nonzero__(self):
|
---|
54 | raise TypeError, "Vectors cannot be used in Boolean contexts"
|
---|
55 |
|
---|
56 | def __cmp__(self, other):
|
---|
57 | raise test_support.TestFailed, "Vector.__cmp__() should not be called"
|
---|
58 |
|
---|
59 | def __repr__(self):
|
---|
60 | return "Vector(%r)" % (self.data, )
|
---|
61 |
|
---|
62 | def __lt__(self, other):
|
---|
63 | return Vector([a < b for a, b in zip(self.data, self.__cast(other))])
|
---|
64 |
|
---|
65 | def __le__(self, other):
|
---|
66 | return Vector([a <= b for a, b in zip(self.data, self.__cast(other))])
|
---|
67 |
|
---|
68 | def __eq__(self, other):
|
---|
69 | return Vector([a == b for a, b in zip(self.data, self.__cast(other))])
|
---|
70 |
|
---|
71 | def __ne__(self, other):
|
---|
72 | return Vector([a != b for a, b in zip(self.data, self.__cast(other))])
|
---|
73 |
|
---|
74 | def __gt__(self, other):
|
---|
75 | return Vector([a > b for a, b in zip(self.data, self.__cast(other))])
|
---|
76 |
|
---|
77 | def __ge__(self, other):
|
---|
78 | return Vector([a >= b for a, b in zip(self.data, self.__cast(other))])
|
---|
79 |
|
---|
80 | def __cast(self, other):
|
---|
81 | if isinstance(other, Vector):
|
---|
82 | other = other.data
|
---|
83 | if len(self.data) != len(other):
|
---|
84 | raise ValueError, "Cannot compare vectors of different length"
|
---|
85 | return other
|
---|
86 |
|
---|
87 | opmap = {
|
---|
88 | "lt": (lambda a,b: a< b, operator.lt, operator.__lt__),
|
---|
89 | "le": (lambda a,b: a<=b, operator.le, operator.__le__),
|
---|
90 | "eq": (lambda a,b: a==b, operator.eq, operator.__eq__),
|
---|
91 | "ne": (lambda a,b: a!=b, operator.ne, operator.__ne__),
|
---|
92 | "gt": (lambda a,b: a> b, operator.gt, operator.__gt__),
|
---|
93 | "ge": (lambda a,b: a>=b, operator.ge, operator.__ge__)
|
---|
94 | }
|
---|
95 |
|
---|
96 | class VectorTest(unittest.TestCase):
|
---|
97 |
|
---|
98 | def checkfail(self, error, opname, *args):
|
---|
99 | for op in opmap[opname]:
|
---|
100 | self.assertRaises(error, op, *args)
|
---|
101 |
|
---|
102 | def checkequal(self, opname, a, b, expres):
|
---|
103 | for op in opmap[opname]:
|
---|
104 | realres = op(a, b)
|
---|
105 | # can't use assertEqual(realres, expres) here
|
---|
106 | self.assertEqual(len(realres), len(expres))
|
---|
107 | for i in xrange(len(realres)):
|
---|
108 | # results are bool, so we can use "is" here
|
---|
109 | self.assertTrue(realres[i] is expres[i])
|
---|
110 |
|
---|
111 | def test_mixed(self):
|
---|
112 | # check that comparisons involving Vector objects
|
---|
113 | # which return rich results (i.e. Vectors with itemwise
|
---|
114 | # comparison results) work
|
---|
115 | a = Vector(range(2))
|
---|
116 | b = Vector(range(3))
|
---|
117 | # all comparisons should fail for different length
|
---|
118 | for opname in opmap:
|
---|
119 | self.checkfail(ValueError, opname, a, b)
|
---|
120 |
|
---|
121 | a = range(5)
|
---|
122 | b = 5 * [2]
|
---|
123 | # try mixed arguments (but not (a, b) as that won't return a bool vector)
|
---|
124 | args = [(a, Vector(b)), (Vector(a), b), (Vector(a), Vector(b))]
|
---|
125 | for (a, b) in args:
|
---|
126 | self.checkequal("lt", a, b, [True, True, False, False, False])
|
---|
127 | self.checkequal("le", a, b, [True, True, True, False, False])
|
---|
128 | self.checkequal("eq", a, b, [False, False, True, False, False])
|
---|
129 | self.checkequal("ne", a, b, [True, True, False, True, True ])
|
---|
130 | self.checkequal("gt", a, b, [False, False, False, True, True ])
|
---|
131 | self.checkequal("ge", a, b, [False, False, True, True, True ])
|
---|
132 |
|
---|
133 | for ops in opmap.itervalues():
|
---|
134 | for op in ops:
|
---|
135 | # calls __nonzero__, which should fail
|
---|
136 | self.assertRaises(TypeError, bool, op(a, b))
|
---|
137 |
|
---|
138 | class NumberTest(unittest.TestCase):
|
---|
139 |
|
---|
140 | def test_basic(self):
|
---|
141 | # Check that comparisons involving Number objects
|
---|
142 | # give the same results give as comparing the
|
---|
143 | # corresponding ints
|
---|
144 | for a in xrange(3):
|
---|
145 | for b in xrange(3):
|
---|
146 | for typea in (int, Number):
|
---|
147 | for typeb in (int, Number):
|
---|
148 | if typea==typeb==int:
|
---|
149 | continue # the combination int, int is useless
|
---|
150 | ta = typea(a)
|
---|
151 | tb = typeb(b)
|
---|
152 | for ops in opmap.itervalues():
|
---|
153 | for op in ops:
|
---|
154 | realoutcome = op(a, b)
|
---|
155 | testoutcome = op(ta, tb)
|
---|
156 | self.assertEqual(realoutcome, testoutcome)
|
---|
157 |
|
---|
158 | def checkvalue(self, opname, a, b, expres):
|
---|
159 | for typea in (int, Number):
|
---|
160 | for typeb in (int, Number):
|
---|
161 | ta = typea(a)
|
---|
162 | tb = typeb(b)
|
---|
163 | for op in opmap[opname]:
|
---|
164 | realres = op(ta, tb)
|
---|
165 | realres = getattr(realres, "x", realres)
|
---|
166 | self.assertTrue(realres is expres)
|
---|
167 |
|
---|
168 | def test_values(self):
|
---|
169 | # check all operators and all comparison results
|
---|
170 | self.checkvalue("lt", 0, 0, False)
|
---|
171 | self.checkvalue("le", 0, 0, True )
|
---|
172 | self.checkvalue("eq", 0, 0, True )
|
---|
173 | self.checkvalue("ne", 0, 0, False)
|
---|
174 | self.checkvalue("gt", 0, 0, False)
|
---|
175 | self.checkvalue("ge", 0, 0, True )
|
---|
176 |
|
---|
177 | self.checkvalue("lt", 0, 1, True )
|
---|
178 | self.checkvalue("le", 0, 1, True )
|
---|
179 | self.checkvalue("eq", 0, 1, False)
|
---|
180 | self.checkvalue("ne", 0, 1, True )
|
---|
181 | self.checkvalue("gt", 0, 1, False)
|
---|
182 | self.checkvalue("ge", 0, 1, False)
|
---|
183 |
|
---|
184 | self.checkvalue("lt", 1, 0, False)
|
---|
185 | self.checkvalue("le", 1, 0, False)
|
---|
186 | self.checkvalue("eq", 1, 0, False)
|
---|
187 | self.checkvalue("ne", 1, 0, True )
|
---|
188 | self.checkvalue("gt", 1, 0, True )
|
---|
189 | self.checkvalue("ge", 1, 0, True )
|
---|
190 |
|
---|
191 | class MiscTest(unittest.TestCase):
|
---|
192 |
|
---|
193 | def test_misbehavin(self):
|
---|
194 | class Misb:
|
---|
195 | def __lt__(self_, other): return 0
|
---|
196 | def __gt__(self_, other): return 0
|
---|
197 | def __eq__(self_, other): return 0
|
---|
198 | def __le__(self_, other): self.fail("This shouldn't happen")
|
---|
199 | def __ge__(self_, other): self.fail("This shouldn't happen")
|
---|
200 | def __ne__(self_, other): self.fail("This shouldn't happen")
|
---|
201 | def __cmp__(self_, other): raise RuntimeError, "expected"
|
---|
202 | a = Misb()
|
---|
203 | b = Misb()
|
---|
204 | self.assertEqual(a<b, 0)
|
---|
205 | self.assertEqual(a==b, 0)
|
---|
206 | self.assertEqual(a>b, 0)
|
---|
207 | self.assertRaises(RuntimeError, cmp, a, b)
|
---|
208 |
|
---|
209 | def test_not(self):
|
---|
210 | # Check that exceptions in __nonzero__ are properly
|
---|
211 | # propagated by the not operator
|
---|
212 | import operator
|
---|
213 | class Exc(Exception):
|
---|
214 | pass
|
---|
215 | class Bad:
|
---|
216 | def __nonzero__(self):
|
---|
217 | raise Exc
|
---|
218 |
|
---|
219 | def do(bad):
|
---|
220 | not bad
|
---|
221 |
|
---|
222 | for func in (do, operator.not_):
|
---|
223 | self.assertRaises(Exc, func, Bad())
|
---|
224 |
|
---|
225 | def test_recursion(self):
|
---|
226 | # Check that comparison for recursive objects fails gracefully
|
---|
227 | from UserList import UserList
|
---|
228 | a = UserList()
|
---|
229 | b = UserList()
|
---|
230 | a.append(b)
|
---|
231 | b.append(a)
|
---|
232 | self.assertRaises(RuntimeError, operator.eq, a, b)
|
---|
233 | self.assertRaises(RuntimeError, operator.ne, a, b)
|
---|
234 | self.assertRaises(RuntimeError, operator.lt, a, b)
|
---|
235 | self.assertRaises(RuntimeError, operator.le, a, b)
|
---|
236 | self.assertRaises(RuntimeError, operator.gt, a, b)
|
---|
237 | self.assertRaises(RuntimeError, operator.ge, a, b)
|
---|
238 |
|
---|
239 | b.append(17)
|
---|
240 | # Even recursive lists of different lengths are different,
|
---|
241 | # but they cannot be ordered
|
---|
242 | self.assertTrue(not (a == b))
|
---|
243 | self.assertTrue(a != b)
|
---|
244 | self.assertRaises(RuntimeError, operator.lt, a, b)
|
---|
245 | self.assertRaises(RuntimeError, operator.le, a, b)
|
---|
246 | self.assertRaises(RuntimeError, operator.gt, a, b)
|
---|
247 | self.assertRaises(RuntimeError, operator.ge, a, b)
|
---|
248 | a.append(17)
|
---|
249 | self.assertRaises(RuntimeError, operator.eq, a, b)
|
---|
250 | self.assertRaises(RuntimeError, operator.ne, a, b)
|
---|
251 | a.insert(0, 11)
|
---|
252 | b.insert(0, 12)
|
---|
253 | self.assertTrue(not (a == b))
|
---|
254 | self.assertTrue(a != b)
|
---|
255 | self.assertTrue(a < b)
|
---|
256 |
|
---|
257 | class DictTest(unittest.TestCase):
|
---|
258 |
|
---|
259 | def test_dicts(self):
|
---|
260 | # Verify that __eq__ and __ne__ work for dicts even if the keys and
|
---|
261 | # values don't support anything other than __eq__ and __ne__ (and
|
---|
262 | # __hash__). Complex numbers are a fine example of that.
|
---|
263 | import random
|
---|
264 | imag1a = {}
|
---|
265 | for i in range(50):
|
---|
266 | imag1a[random.randrange(100)*1j] = random.randrange(100)*1j
|
---|
267 | items = imag1a.items()
|
---|
268 | random.shuffle(items)
|
---|
269 | imag1b = {}
|
---|
270 | for k, v in items:
|
---|
271 | imag1b[k] = v
|
---|
272 | imag2 = imag1b.copy()
|
---|
273 | imag2[k] = v + 1.0
|
---|
274 | self.assertTrue(imag1a == imag1a)
|
---|
275 | self.assertTrue(imag1a == imag1b)
|
---|
276 | self.assertTrue(imag2 == imag2)
|
---|
277 | self.assertTrue(imag1a != imag2)
|
---|
278 | for opname in ("lt", "le", "gt", "ge"):
|
---|
279 | for op in opmap[opname]:
|
---|
280 | self.assertRaises(TypeError, op, imag1a, imag2)
|
---|
281 |
|
---|
282 | class ListTest(unittest.TestCase):
|
---|
283 |
|
---|
284 | def test_coverage(self):
|
---|
285 | # exercise all comparisons for lists
|
---|
286 | x = [42]
|
---|
287 | self.assertIs(x<x, False)
|
---|
288 | self.assertIs(x<=x, True)
|
---|
289 | self.assertIs(x==x, True)
|
---|
290 | self.assertIs(x!=x, False)
|
---|
291 | self.assertIs(x>x, False)
|
---|
292 | self.assertIs(x>=x, True)
|
---|
293 | y = [42, 42]
|
---|
294 | self.assertIs(x<y, True)
|
---|
295 | self.assertIs(x<=y, True)
|
---|
296 | self.assertIs(x==y, False)
|
---|
297 | self.assertIs(x!=y, True)
|
---|
298 | self.assertIs(x>y, False)
|
---|
299 | self.assertIs(x>=y, False)
|
---|
300 |
|
---|
301 | def test_badentry(self):
|
---|
302 | # make sure that exceptions for item comparison are properly
|
---|
303 | # propagated in list comparisons
|
---|
304 | class Exc(Exception):
|
---|
305 | pass
|
---|
306 | class Bad:
|
---|
307 | def __eq__(self, other):
|
---|
308 | raise Exc
|
---|
309 |
|
---|
310 | x = [Bad()]
|
---|
311 | y = [Bad()]
|
---|
312 |
|
---|
313 | for op in opmap["eq"]:
|
---|
314 | self.assertRaises(Exc, op, x, y)
|
---|
315 |
|
---|
316 | def test_goodentry(self):
|
---|
317 | # This test exercises the final call to PyObject_RichCompare()
|
---|
318 | # in Objects/listobject.c::list_richcompare()
|
---|
319 | class Good:
|
---|
320 | def __lt__(self, other):
|
---|
321 | return True
|
---|
322 |
|
---|
323 | x = [Good()]
|
---|
324 | y = [Good()]
|
---|
325 |
|
---|
326 | for op in opmap["lt"]:
|
---|
327 | self.assertIs(op(x, y), True)
|
---|
328 |
|
---|
329 | def test_main():
|
---|
330 | test_support.run_unittest(VectorTest, NumberTest, MiscTest, ListTest)
|
---|
331 | with test_support.check_py3k_warnings(("dict inequality comparisons "
|
---|
332 | "not supported in 3.x",
|
---|
333 | DeprecationWarning)):
|
---|
334 | test_support.run_unittest(DictTest)
|
---|
335 |
|
---|
336 |
|
---|
337 | if __name__ == "__main__":
|
---|
338 | test_main()
|
---|