1 | """Unit tests for collections.defaultdict."""
|
---|
2 |
|
---|
3 | import os
|
---|
4 | import copy
|
---|
5 | import tempfile
|
---|
6 | import unittest
|
---|
7 | from test import test_support
|
---|
8 |
|
---|
9 | from collections import defaultdict
|
---|
10 |
|
---|
11 | def foobar():
|
---|
12 | return list
|
---|
13 |
|
---|
14 | class TestDefaultDict(unittest.TestCase):
|
---|
15 |
|
---|
16 | def test_basic(self):
|
---|
17 | d1 = defaultdict()
|
---|
18 | self.assertEqual(d1.default_factory, None)
|
---|
19 | d1.default_factory = list
|
---|
20 | d1[12].append(42)
|
---|
21 | self.assertEqual(d1, {12: [42]})
|
---|
22 | d1[12].append(24)
|
---|
23 | self.assertEqual(d1, {12: [42, 24]})
|
---|
24 | d1[13]
|
---|
25 | d1[14]
|
---|
26 | self.assertEqual(d1, {12: [42, 24], 13: [], 14: []})
|
---|
27 | self.assertTrue(d1[12] is not d1[13] is not d1[14])
|
---|
28 | d2 = defaultdict(list, foo=1, bar=2)
|
---|
29 | self.assertEqual(d2.default_factory, list)
|
---|
30 | self.assertEqual(d2, {"foo": 1, "bar": 2})
|
---|
31 | self.assertEqual(d2["foo"], 1)
|
---|
32 | self.assertEqual(d2["bar"], 2)
|
---|
33 | self.assertEqual(d2[42], [])
|
---|
34 | self.assertIn("foo", d2)
|
---|
35 | self.assertIn("foo", d2.keys())
|
---|
36 | self.assertIn("bar", d2)
|
---|
37 | self.assertIn("bar", d2.keys())
|
---|
38 | self.assertIn(42, d2)
|
---|
39 | self.assertIn(42, d2.keys())
|
---|
40 | self.assertNotIn(12, d2)
|
---|
41 | self.assertNotIn(12, d2.keys())
|
---|
42 | d2.default_factory = None
|
---|
43 | self.assertEqual(d2.default_factory, None)
|
---|
44 | try:
|
---|
45 | d2[15]
|
---|
46 | except KeyError, err:
|
---|
47 | self.assertEqual(err.args, (15,))
|
---|
48 | else:
|
---|
49 | self.fail("d2[15] didn't raise KeyError")
|
---|
50 | self.assertRaises(TypeError, defaultdict, 1)
|
---|
51 |
|
---|
52 | def test_missing(self):
|
---|
53 | d1 = defaultdict()
|
---|
54 | self.assertRaises(KeyError, d1.__missing__, 42)
|
---|
55 | d1.default_factory = list
|
---|
56 | self.assertEqual(d1.__missing__(42), [])
|
---|
57 |
|
---|
58 | def test_repr(self):
|
---|
59 | d1 = defaultdict()
|
---|
60 | self.assertEqual(d1.default_factory, None)
|
---|
61 | self.assertEqual(repr(d1), "defaultdict(None, {})")
|
---|
62 | self.assertEqual(eval(repr(d1)), d1)
|
---|
63 | d1[11] = 41
|
---|
64 | self.assertEqual(repr(d1), "defaultdict(None, {11: 41})")
|
---|
65 | d2 = defaultdict(int)
|
---|
66 | self.assertEqual(d2.default_factory, int)
|
---|
67 | d2[12] = 42
|
---|
68 | self.assertEqual(repr(d2), "defaultdict(<type 'int'>, {12: 42})")
|
---|
69 | def foo(): return 43
|
---|
70 | d3 = defaultdict(foo)
|
---|
71 | self.assertTrue(d3.default_factory is foo)
|
---|
72 | d3[13]
|
---|
73 | self.assertEqual(repr(d3), "defaultdict(%s, {13: 43})" % repr(foo))
|
---|
74 |
|
---|
75 | def test_print(self):
|
---|
76 | d1 = defaultdict()
|
---|
77 | def foo(): return 42
|
---|
78 | d2 = defaultdict(foo, {1: 2})
|
---|
79 | # NOTE: We can't use tempfile.[Named]TemporaryFile since this
|
---|
80 | # code must exercise the tp_print C code, which only gets
|
---|
81 | # invoked for *real* files.
|
---|
82 | tfn = tempfile.mktemp()
|
---|
83 | try:
|
---|
84 | f = open(tfn, "w+")
|
---|
85 | try:
|
---|
86 | print >>f, d1
|
---|
87 | print >>f, d2
|
---|
88 | f.seek(0)
|
---|
89 | self.assertEqual(f.readline(), repr(d1) + "\n")
|
---|
90 | self.assertEqual(f.readline(), repr(d2) + "\n")
|
---|
91 | finally:
|
---|
92 | f.close()
|
---|
93 | finally:
|
---|
94 | os.remove(tfn)
|
---|
95 |
|
---|
96 | def test_copy(self):
|
---|
97 | d1 = defaultdict()
|
---|
98 | d2 = d1.copy()
|
---|
99 | self.assertEqual(type(d2), defaultdict)
|
---|
100 | self.assertEqual(d2.default_factory, None)
|
---|
101 | self.assertEqual(d2, {})
|
---|
102 | d1.default_factory = list
|
---|
103 | d3 = d1.copy()
|
---|
104 | self.assertEqual(type(d3), defaultdict)
|
---|
105 | self.assertEqual(d3.default_factory, list)
|
---|
106 | self.assertEqual(d3, {})
|
---|
107 | d1[42]
|
---|
108 | d4 = d1.copy()
|
---|
109 | self.assertEqual(type(d4), defaultdict)
|
---|
110 | self.assertEqual(d4.default_factory, list)
|
---|
111 | self.assertEqual(d4, {42: []})
|
---|
112 | d4[12]
|
---|
113 | self.assertEqual(d4, {42: [], 12: []})
|
---|
114 |
|
---|
115 | # Issue 6637: Copy fails for empty default dict
|
---|
116 | d = defaultdict()
|
---|
117 | d['a'] = 42
|
---|
118 | e = d.copy()
|
---|
119 | self.assertEqual(e['a'], 42)
|
---|
120 |
|
---|
121 | def test_shallow_copy(self):
|
---|
122 | d1 = defaultdict(foobar, {1: 1})
|
---|
123 | d2 = copy.copy(d1)
|
---|
124 | self.assertEqual(d2.default_factory, foobar)
|
---|
125 | self.assertEqual(d2, d1)
|
---|
126 | d1.default_factory = list
|
---|
127 | d2 = copy.copy(d1)
|
---|
128 | self.assertEqual(d2.default_factory, list)
|
---|
129 | self.assertEqual(d2, d1)
|
---|
130 |
|
---|
131 | def test_deep_copy(self):
|
---|
132 | d1 = defaultdict(foobar, {1: [1]})
|
---|
133 | d2 = copy.deepcopy(d1)
|
---|
134 | self.assertEqual(d2.default_factory, foobar)
|
---|
135 | self.assertEqual(d2, d1)
|
---|
136 | self.assertTrue(d1[1] is not d2[1])
|
---|
137 | d1.default_factory = list
|
---|
138 | d2 = copy.deepcopy(d1)
|
---|
139 | self.assertEqual(d2.default_factory, list)
|
---|
140 | self.assertEqual(d2, d1)
|
---|
141 |
|
---|
142 | def test_keyerror_without_factory(self):
|
---|
143 | d1 = defaultdict()
|
---|
144 | try:
|
---|
145 | d1[(1,)]
|
---|
146 | except KeyError, err:
|
---|
147 | self.assertEqual(err.args[0], (1,))
|
---|
148 | else:
|
---|
149 | self.fail("expected KeyError")
|
---|
150 |
|
---|
151 | def test_recursive_repr(self):
|
---|
152 | # Issue2045: stack overflow when default_factory is a bound method
|
---|
153 | class sub(defaultdict):
|
---|
154 | def __init__(self):
|
---|
155 | self.default_factory = self._factory
|
---|
156 | def _factory(self):
|
---|
157 | return []
|
---|
158 | d = sub()
|
---|
159 | self.assertTrue(repr(d).startswith(
|
---|
160 | "defaultdict(<bound method sub._factory of defaultdict(..."))
|
---|
161 |
|
---|
162 | # NOTE: printing a subclass of a builtin type does not call its
|
---|
163 | # tp_print slot. So this part is essentially the same test as above.
|
---|
164 | tfn = tempfile.mktemp()
|
---|
165 | try:
|
---|
166 | f = open(tfn, "w+")
|
---|
167 | try:
|
---|
168 | print >>f, d
|
---|
169 | finally:
|
---|
170 | f.close()
|
---|
171 | finally:
|
---|
172 | os.remove(tfn)
|
---|
173 |
|
---|
174 | def test_callable_arg(self):
|
---|
175 | self.assertRaises(TypeError, defaultdict, {})
|
---|
176 |
|
---|
177 | def test_main():
|
---|
178 | test_support.run_unittest(TestDefaultDict)
|
---|
179 |
|
---|
180 | if __name__ == "__main__":
|
---|
181 | test_main()
|
---|