source: python/trunk/Lib/test/test_functools.py

Last change on this file was 391, checked in by dmik, 11 years ago

python: Merge vendor 2.7.6 to trunk.

  • Property svn:eol-style set to native
File size: 16.6 KB
Line 
1import functools
2import sys
3import unittest
4from test import test_support
5from weakref import proxy
6import pickle
7
8@staticmethod
9def PythonPartial(func, *args, **keywords):
10 'Pure Python approximation of partial()'
11 def newfunc(*fargs, **fkeywords):
12 newkeywords = keywords.copy()
13 newkeywords.update(fkeywords)
14 return func(*(args + fargs), **newkeywords)
15 newfunc.func = func
16 newfunc.args = args
17 newfunc.keywords = keywords
18 return newfunc
19
20def capture(*args, **kw):
21 """capture all positional and keyword arguments"""
22 return args, kw
23
24def signature(part):
25 """ return the signature of a partial object """
26 return (part.func, part.args, part.keywords, part.__dict__)
27
28class TestPartial(unittest.TestCase):
29
30 thetype = functools.partial
31
32 def test_basic_examples(self):
33 p = self.thetype(capture, 1, 2, a=10, b=20)
34 self.assertEqual(p(3, 4, b=30, c=40),
35 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
36 p = self.thetype(map, lambda x: x*10)
37 self.assertEqual(p([1,2,3,4]), [10, 20, 30, 40])
38
39 def test_attributes(self):
40 p = self.thetype(capture, 1, 2, a=10, b=20)
41 # attributes should be readable
42 self.assertEqual(p.func, capture)
43 self.assertEqual(p.args, (1, 2))
44 self.assertEqual(p.keywords, dict(a=10, b=20))
45 # attributes should not be writable
46 if not isinstance(self.thetype, type):
47 return
48 self.assertRaises(TypeError, setattr, p, 'func', map)
49 self.assertRaises(TypeError, setattr, p, 'args', (1, 2))
50 self.assertRaises(TypeError, setattr, p, 'keywords', dict(a=1, b=2))
51
52 p = self.thetype(hex)
53 try:
54 del p.__dict__
55 except TypeError:
56 pass
57 else:
58 self.fail('partial object allowed __dict__ to be deleted')
59
60 def test_argument_checking(self):
61 self.assertRaises(TypeError, self.thetype) # need at least a func arg
62 try:
63 self.thetype(2)()
64 except TypeError:
65 pass
66 else:
67 self.fail('First arg not checked for callability')
68
69 def test_protection_of_callers_dict_argument(self):
70 # a caller's dictionary should not be altered by partial
71 def func(a=10, b=20):
72 return a
73 d = {'a':3}
74 p = self.thetype(func, a=5)
75 self.assertEqual(p(**d), 3)
76 self.assertEqual(d, {'a':3})
77 p(b=7)
78 self.assertEqual(d, {'a':3})
79
80 def test_arg_combinations(self):
81 # exercise special code paths for zero args in either partial
82 # object or the caller
83 p = self.thetype(capture)
84 self.assertEqual(p(), ((), {}))
85 self.assertEqual(p(1,2), ((1,2), {}))
86 p = self.thetype(capture, 1, 2)
87 self.assertEqual(p(), ((1,2), {}))
88 self.assertEqual(p(3,4), ((1,2,3,4), {}))
89
90 def test_kw_combinations(self):
91 # exercise special code paths for no keyword args in
92 # either the partial object or the caller
93 p = self.thetype(capture)
94 self.assertEqual(p(), ((), {}))
95 self.assertEqual(p(a=1), ((), {'a':1}))
96 p = self.thetype(capture, a=1)
97 self.assertEqual(p(), ((), {'a':1}))
98 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
99 # keyword args in the call override those in the partial object
100 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
101
102 def test_positional(self):
103 # make sure positional arguments are captured correctly
104 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
105 p = self.thetype(capture, *args)
106 expected = args + ('x',)
107 got, empty = p('x')
108 self.assertTrue(expected == got and empty == {})
109
110 def test_keyword(self):
111 # make sure keyword arguments are captured correctly
112 for a in ['a', 0, None, 3.5]:
113 p = self.thetype(capture, a=a)
114 expected = {'a':a,'x':None}
115 empty, got = p(x=None)
116 self.assertTrue(expected == got and empty == ())
117
118 def test_no_side_effects(self):
119 # make sure there are no side effects that affect subsequent calls
120 p = self.thetype(capture, 0, a=1)
121 args1, kw1 = p(1, b=2)
122 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
123 args2, kw2 = p()
124 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
125
126 def test_error_propagation(self):
127 def f(x, y):
128 x // y
129 self.assertRaises(ZeroDivisionError, self.thetype(f, 1, 0))
130 self.assertRaises(ZeroDivisionError, self.thetype(f, 1), 0)
131 self.assertRaises(ZeroDivisionError, self.thetype(f), 1, 0)
132 self.assertRaises(ZeroDivisionError, self.thetype(f, y=0), 1)
133
134 def test_weakref(self):
135 f = self.thetype(int, base=16)
136 p = proxy(f)
137 self.assertEqual(f.func, p.func)
138 f = None
139 self.assertRaises(ReferenceError, getattr, p, 'func')
140
141 def test_with_bound_and_unbound_methods(self):
142 data = map(str, range(10))
143 join = self.thetype(str.join, '')
144 self.assertEqual(join(data), '0123456789')
145 join = self.thetype(''.join)
146 self.assertEqual(join(data), '0123456789')
147
148 def test_pickle(self):
149 f = self.thetype(signature, 'asdf', bar=True)
150 f.add_something_to__dict__ = True
151 f_copy = pickle.loads(pickle.dumps(f))
152 self.assertEqual(signature(f), signature(f_copy))
153
154 # Issue 6083: Reference counting bug
155 def test_setstate_refcount(self):
156 class BadSequence:
157 def __len__(self):
158 return 4
159 def __getitem__(self, key):
160 if key == 0:
161 return max
162 elif key == 1:
163 return tuple(range(1000000))
164 elif key in (2, 3):
165 return {}
166 raise IndexError
167
168 f = self.thetype(object)
169 self.assertRaises(SystemError, f.__setstate__, BadSequence())
170
171class PartialSubclass(functools.partial):
172 pass
173
174class TestPartialSubclass(TestPartial):
175
176 thetype = PartialSubclass
177
178class TestPythonPartial(TestPartial):
179
180 thetype = PythonPartial
181
182 # the python version isn't picklable
183 def test_pickle(self): pass
184 def test_setstate_refcount(self): pass
185
186class TestUpdateWrapper(unittest.TestCase):
187
188 def check_wrapper(self, wrapper, wrapped,
189 assigned=functools.WRAPPER_ASSIGNMENTS,
190 updated=functools.WRAPPER_UPDATES):
191 # Check attributes were assigned
192 for name in assigned:
193 self.assertTrue(getattr(wrapper, name) is getattr(wrapped, name))
194 # Check attributes were updated
195 for name in updated:
196 wrapper_attr = getattr(wrapper, name)
197 wrapped_attr = getattr(wrapped, name)
198 for key in wrapped_attr:
199 self.assertTrue(wrapped_attr[key] is wrapper_attr[key])
200
201 def _default_update(self):
202 def f():
203 """This is a test"""
204 pass
205 f.attr = 'This is also a test'
206 def wrapper():
207 pass
208 functools.update_wrapper(wrapper, f)
209 return wrapper, f
210
211 def test_default_update(self):
212 wrapper, f = self._default_update()
213 self.check_wrapper(wrapper, f)
214 self.assertEqual(wrapper.__name__, 'f')
215 self.assertEqual(wrapper.attr, 'This is also a test')
216
217 @unittest.skipIf(sys.flags.optimize >= 2,
218 "Docstrings are omitted with -O2 and above")
219 def test_default_update_doc(self):
220 wrapper, f = self._default_update()
221 self.assertEqual(wrapper.__doc__, 'This is a test')
222
223 def test_no_update(self):
224 def f():
225 """This is a test"""
226 pass
227 f.attr = 'This is also a test'
228 def wrapper():
229 pass
230 functools.update_wrapper(wrapper, f, (), ())
231 self.check_wrapper(wrapper, f, (), ())
232 self.assertEqual(wrapper.__name__, 'wrapper')
233 self.assertEqual(wrapper.__doc__, None)
234 self.assertFalse(hasattr(wrapper, 'attr'))
235
236 def test_selective_update(self):
237 def f():
238 pass
239 f.attr = 'This is a different test'
240 f.dict_attr = dict(a=1, b=2, c=3)
241 def wrapper():
242 pass
243 wrapper.dict_attr = {}
244 assign = ('attr',)
245 update = ('dict_attr',)
246 functools.update_wrapper(wrapper, f, assign, update)
247 self.check_wrapper(wrapper, f, assign, update)
248 self.assertEqual(wrapper.__name__, 'wrapper')
249 self.assertEqual(wrapper.__doc__, None)
250 self.assertEqual(wrapper.attr, 'This is a different test')
251 self.assertEqual(wrapper.dict_attr, f.dict_attr)
252
253 @test_support.requires_docstrings
254 def test_builtin_update(self):
255 # Test for bug #1576241
256 def wrapper():
257 pass
258 functools.update_wrapper(wrapper, max)
259 self.assertEqual(wrapper.__name__, 'max')
260 self.assertTrue(wrapper.__doc__.startswith('max('))
261
262class TestWraps(TestUpdateWrapper):
263
264 def _default_update(self):
265 def f():
266 """This is a test"""
267 pass
268 f.attr = 'This is also a test'
269 @functools.wraps(f)
270 def wrapper():
271 pass
272 self.check_wrapper(wrapper, f)
273 return wrapper
274
275 def test_default_update(self):
276 wrapper = self._default_update()
277 self.assertEqual(wrapper.__name__, 'f')
278 self.assertEqual(wrapper.attr, 'This is also a test')
279
280 @unittest.skipIf(sys.flags.optimize >= 2,
281 "Docstrings are omitted with -O2 and above")
282 def test_default_update_doc(self):
283 wrapper = self._default_update()
284 self.assertEqual(wrapper.__doc__, 'This is a test')
285
286 def test_no_update(self):
287 def f():
288 """This is a test"""
289 pass
290 f.attr = 'This is also a test'
291 @functools.wraps(f, (), ())
292 def wrapper():
293 pass
294 self.check_wrapper(wrapper, f, (), ())
295 self.assertEqual(wrapper.__name__, 'wrapper')
296 self.assertEqual(wrapper.__doc__, None)
297 self.assertFalse(hasattr(wrapper, 'attr'))
298
299 def test_selective_update(self):
300 def f():
301 pass
302 f.attr = 'This is a different test'
303 f.dict_attr = dict(a=1, b=2, c=3)
304 def add_dict_attr(f):
305 f.dict_attr = {}
306 return f
307 assign = ('attr',)
308 update = ('dict_attr',)
309 @functools.wraps(f, assign, update)
310 @add_dict_attr
311 def wrapper():
312 pass
313 self.check_wrapper(wrapper, f, assign, update)
314 self.assertEqual(wrapper.__name__, 'wrapper')
315 self.assertEqual(wrapper.__doc__, None)
316 self.assertEqual(wrapper.attr, 'This is a different test')
317 self.assertEqual(wrapper.dict_attr, f.dict_attr)
318
319
320class TestReduce(unittest.TestCase):
321
322 def test_reduce(self):
323 class Squares:
324
325 def __init__(self, max):
326 self.max = max
327 self.sofar = []
328
329 def __len__(self): return len(self.sofar)
330
331 def __getitem__(self, i):
332 if not 0 <= i < self.max: raise IndexError
333 n = len(self.sofar)
334 while n <= i:
335 self.sofar.append(n*n)
336 n += 1
337 return self.sofar[i]
338
339 reduce = functools.reduce
340 self.assertEqual(reduce(lambda x, y: x+y, ['a', 'b', 'c'], ''), 'abc')
341 self.assertEqual(
342 reduce(lambda x, y: x+y, [['a', 'c'], [], ['d', 'w']], []),
343 ['a','c','d','w']
344 )
345 self.assertEqual(reduce(lambda x, y: x*y, range(2,8), 1), 5040)
346 self.assertEqual(
347 reduce(lambda x, y: x*y, range(2,21), 1L),
348 2432902008176640000L
349 )
350 self.assertEqual(reduce(lambda x, y: x+y, Squares(10)), 285)
351 self.assertEqual(reduce(lambda x, y: x+y, Squares(10), 0), 285)
352 self.assertEqual(reduce(lambda x, y: x+y, Squares(0), 0), 0)
353 self.assertRaises(TypeError, reduce)
354 self.assertRaises(TypeError, reduce, 42, 42)
355 self.assertRaises(TypeError, reduce, 42, 42, 42)
356 self.assertEqual(reduce(42, "1"), "1") # func is never called with one item
357 self.assertEqual(reduce(42, "", "1"), "1") # func is never called with one item
358 self.assertRaises(TypeError, reduce, 42, (42, 42))
359
360class TestCmpToKey(unittest.TestCase):
361 def test_cmp_to_key(self):
362 def mycmp(x, y):
363 return y - x
364 self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)),
365 [4, 3, 2, 1, 0])
366
367 def test_hash(self):
368 def mycmp(x, y):
369 return y - x
370 key = functools.cmp_to_key(mycmp)
371 k = key(10)
372 self.assertRaises(TypeError, hash(k))
373
374class TestTotalOrdering(unittest.TestCase):
375
376 def test_total_ordering_lt(self):
377 @functools.total_ordering
378 class A:
379 def __init__(self, value):
380 self.value = value
381 def __lt__(self, other):
382 return self.value < other.value
383 def __eq__(self, other):
384 return self.value == other.value
385 self.assertTrue(A(1) < A(2))
386 self.assertTrue(A(2) > A(1))
387 self.assertTrue(A(1) <= A(2))
388 self.assertTrue(A(2) >= A(1))
389 self.assertTrue(A(2) <= A(2))
390 self.assertTrue(A(2) >= A(2))
391
392 def test_total_ordering_le(self):
393 @functools.total_ordering
394 class A:
395 def __init__(self, value):
396 self.value = value
397 def __le__(self, other):
398 return self.value <= other.value
399 def __eq__(self, other):
400 return self.value == other.value
401 self.assertTrue(A(1) < A(2))
402 self.assertTrue(A(2) > A(1))
403 self.assertTrue(A(1) <= A(2))
404 self.assertTrue(A(2) >= A(1))
405 self.assertTrue(A(2) <= A(2))
406 self.assertTrue(A(2) >= A(2))
407
408 def test_total_ordering_gt(self):
409 @functools.total_ordering
410 class A:
411 def __init__(self, value):
412 self.value = value
413 def __gt__(self, other):
414 return self.value > other.value
415 def __eq__(self, other):
416 return self.value == other.value
417 self.assertTrue(A(1) < A(2))
418 self.assertTrue(A(2) > A(1))
419 self.assertTrue(A(1) <= A(2))
420 self.assertTrue(A(2) >= A(1))
421 self.assertTrue(A(2) <= A(2))
422 self.assertTrue(A(2) >= A(2))
423
424 def test_total_ordering_ge(self):
425 @functools.total_ordering
426 class A:
427 def __init__(self, value):
428 self.value = value
429 def __ge__(self, other):
430 return self.value >= other.value
431 def __eq__(self, other):
432 return self.value == other.value
433 self.assertTrue(A(1) < A(2))
434 self.assertTrue(A(2) > A(1))
435 self.assertTrue(A(1) <= A(2))
436 self.assertTrue(A(2) >= A(1))
437 self.assertTrue(A(2) <= A(2))
438 self.assertTrue(A(2) >= A(2))
439
440 def test_total_ordering_no_overwrite(self):
441 # new methods should not overwrite existing
442 @functools.total_ordering
443 class A(str):
444 pass
445 self.assertTrue(A("a") < A("b"))
446 self.assertTrue(A("b") > A("a"))
447 self.assertTrue(A("a") <= A("b"))
448 self.assertTrue(A("b") >= A("a"))
449 self.assertTrue(A("b") <= A("b"))
450 self.assertTrue(A("b") >= A("b"))
451
452 def test_no_operations_defined(self):
453 with self.assertRaises(ValueError):
454 @functools.total_ordering
455 class A:
456 pass
457
458 def test_bug_10042(self):
459 @functools.total_ordering
460 class TestTO:
461 def __init__(self, value):
462 self.value = value
463 def __eq__(self, other):
464 if isinstance(other, TestTO):
465 return self.value == other.value
466 return False
467 def __lt__(self, other):
468 if isinstance(other, TestTO):
469 return self.value < other.value
470 raise TypeError
471 with self.assertRaises(TypeError):
472 TestTO(8) <= ()
473
474def test_main(verbose=None):
475 test_classes = (
476 TestPartial,
477 TestPartialSubclass,
478 TestPythonPartial,
479 TestUpdateWrapper,
480 TestTotalOrdering,
481 TestWraps,
482 TestReduce,
483 )
484 test_support.run_unittest(*test_classes)
485
486 # verify reference counting
487 if verbose and hasattr(sys, "gettotalrefcount"):
488 import gc
489 counts = [None] * 5
490 for i in xrange(len(counts)):
491 test_support.run_unittest(*test_classes)
492 gc.collect()
493 counts[i] = sys.gettotalrefcount()
494 print counts
495
496if __name__ == '__main__':
497 test_main(verbose=True)
Note: See TracBrowser for help on using the repository browser.