1 | #-*- coding: ISO-8859-1 -*-
|
---|
2 | # pysqlite2/test/userfunctions.py: tests for user-defined functions and
|
---|
3 | # aggregates.
|
---|
4 | #
|
---|
5 | # Copyright (C) 2005-2007 Gerhard Häring <gh@ghaering.de>
|
---|
6 | #
|
---|
7 | # This file is part of pysqlite.
|
---|
8 | #
|
---|
9 | # This software is provided 'as-is', without any express or implied
|
---|
10 | # warranty. In no event will the authors be held liable for any damages
|
---|
11 | # arising from the use of this software.
|
---|
12 | #
|
---|
13 | # Permission is granted to anyone to use this software for any purpose,
|
---|
14 | # including commercial applications, and to alter it and redistribute it
|
---|
15 | # freely, subject to the following restrictions:
|
---|
16 | #
|
---|
17 | # 1. The origin of this software must not be misrepresented; you must not
|
---|
18 | # claim that you wrote the original software. If you use this software
|
---|
19 | # in a product, an acknowledgment in the product documentation would be
|
---|
20 | # appreciated but is not required.
|
---|
21 | # 2. Altered source versions must be plainly marked as such, and must not be
|
---|
22 | # misrepresented as being the original software.
|
---|
23 | # 3. This notice may not be removed or altered from any source distribution.
|
---|
24 |
|
---|
25 | import unittest
|
---|
26 | import sqlite3 as sqlite
|
---|
27 |
|
---|
28 | def func_returntext():
|
---|
29 | return "foo"
|
---|
30 | def func_returnunicode():
|
---|
31 | return u"bar"
|
---|
32 | def func_returnint():
|
---|
33 | return 42
|
---|
34 | def func_returnfloat():
|
---|
35 | return 3.14
|
---|
36 | def func_returnnull():
|
---|
37 | return None
|
---|
38 | def func_returnblob():
|
---|
39 | return buffer("blob")
|
---|
40 | def func_returnlonglong():
|
---|
41 | return 1<<31
|
---|
42 | def func_raiseexception():
|
---|
43 | 5 // 0
|
---|
44 |
|
---|
45 | def func_isstring(v):
|
---|
46 | return type(v) is unicode
|
---|
47 | def func_isint(v):
|
---|
48 | return type(v) is int
|
---|
49 | def func_isfloat(v):
|
---|
50 | return type(v) is float
|
---|
51 | def func_isnone(v):
|
---|
52 | return type(v) is type(None)
|
---|
53 | def func_isblob(v):
|
---|
54 | return type(v) is buffer
|
---|
55 | def func_islonglong(v):
|
---|
56 | return isinstance(v, (int, long)) and v >= 1<<31
|
---|
57 |
|
---|
58 | class AggrNoStep:
|
---|
59 | def __init__(self):
|
---|
60 | pass
|
---|
61 |
|
---|
62 | def finalize(self):
|
---|
63 | return 1
|
---|
64 |
|
---|
65 | class AggrNoFinalize:
|
---|
66 | def __init__(self):
|
---|
67 | pass
|
---|
68 |
|
---|
69 | def step(self, x):
|
---|
70 | pass
|
---|
71 |
|
---|
72 | class AggrExceptionInInit:
|
---|
73 | def __init__(self):
|
---|
74 | 5 // 0
|
---|
75 |
|
---|
76 | def step(self, x):
|
---|
77 | pass
|
---|
78 |
|
---|
79 | def finalize(self):
|
---|
80 | pass
|
---|
81 |
|
---|
82 | class AggrExceptionInStep:
|
---|
83 | def __init__(self):
|
---|
84 | pass
|
---|
85 |
|
---|
86 | def step(self, x):
|
---|
87 | 5 // 0
|
---|
88 |
|
---|
89 | def finalize(self):
|
---|
90 | return 42
|
---|
91 |
|
---|
92 | class AggrExceptionInFinalize:
|
---|
93 | def __init__(self):
|
---|
94 | pass
|
---|
95 |
|
---|
96 | def step(self, x):
|
---|
97 | pass
|
---|
98 |
|
---|
99 | def finalize(self):
|
---|
100 | 5 // 0
|
---|
101 |
|
---|
102 | class AggrCheckType:
|
---|
103 | def __init__(self):
|
---|
104 | self.val = None
|
---|
105 |
|
---|
106 | def step(self, whichType, val):
|
---|
107 | theType = {"str": unicode, "int": int, "float": float, "None": type(None), "blob": buffer}
|
---|
108 | self.val = int(theType[whichType] is type(val))
|
---|
109 |
|
---|
110 | def finalize(self):
|
---|
111 | return self.val
|
---|
112 |
|
---|
113 | class AggrSum:
|
---|
114 | def __init__(self):
|
---|
115 | self.val = 0.0
|
---|
116 |
|
---|
117 | def step(self, val):
|
---|
118 | self.val += val
|
---|
119 |
|
---|
120 | def finalize(self):
|
---|
121 | return self.val
|
---|
122 |
|
---|
123 | class FunctionTests(unittest.TestCase):
|
---|
124 | def setUp(self):
|
---|
125 | self.con = sqlite.connect(":memory:")
|
---|
126 |
|
---|
127 | self.con.create_function("returntext", 0, func_returntext)
|
---|
128 | self.con.create_function("returnunicode", 0, func_returnunicode)
|
---|
129 | self.con.create_function("returnint", 0, func_returnint)
|
---|
130 | self.con.create_function("returnfloat", 0, func_returnfloat)
|
---|
131 | self.con.create_function("returnnull", 0, func_returnnull)
|
---|
132 | self.con.create_function("returnblob", 0, func_returnblob)
|
---|
133 | self.con.create_function("returnlonglong", 0, func_returnlonglong)
|
---|
134 | self.con.create_function("raiseexception", 0, func_raiseexception)
|
---|
135 |
|
---|
136 | self.con.create_function("isstring", 1, func_isstring)
|
---|
137 | self.con.create_function("isint", 1, func_isint)
|
---|
138 | self.con.create_function("isfloat", 1, func_isfloat)
|
---|
139 | self.con.create_function("isnone", 1, func_isnone)
|
---|
140 | self.con.create_function("isblob", 1, func_isblob)
|
---|
141 | self.con.create_function("islonglong", 1, func_islonglong)
|
---|
142 |
|
---|
143 | def tearDown(self):
|
---|
144 | self.con.close()
|
---|
145 |
|
---|
146 | def CheckFuncErrorOnCreate(self):
|
---|
147 | try:
|
---|
148 | self.con.create_function("bla", -100, lambda x: 2*x)
|
---|
149 | self.fail("should have raised an OperationalError")
|
---|
150 | except sqlite.OperationalError:
|
---|
151 | pass
|
---|
152 |
|
---|
153 | def CheckFuncRefCount(self):
|
---|
154 | def getfunc():
|
---|
155 | def f():
|
---|
156 | return 1
|
---|
157 | return f
|
---|
158 | f = getfunc()
|
---|
159 | globals()["foo"] = f
|
---|
160 | # self.con.create_function("reftest", 0, getfunc())
|
---|
161 | self.con.create_function("reftest", 0, f)
|
---|
162 | cur = self.con.cursor()
|
---|
163 | cur.execute("select reftest()")
|
---|
164 |
|
---|
165 | def CheckFuncReturnText(self):
|
---|
166 | cur = self.con.cursor()
|
---|
167 | cur.execute("select returntext()")
|
---|
168 | val = cur.fetchone()[0]
|
---|
169 | self.assertEqual(type(val), unicode)
|
---|
170 | self.assertEqual(val, "foo")
|
---|
171 |
|
---|
172 | def CheckFuncReturnUnicode(self):
|
---|
173 | cur = self.con.cursor()
|
---|
174 | cur.execute("select returnunicode()")
|
---|
175 | val = cur.fetchone()[0]
|
---|
176 | self.assertEqual(type(val), unicode)
|
---|
177 | self.assertEqual(val, u"bar")
|
---|
178 |
|
---|
179 | def CheckFuncReturnInt(self):
|
---|
180 | cur = self.con.cursor()
|
---|
181 | cur.execute("select returnint()")
|
---|
182 | val = cur.fetchone()[0]
|
---|
183 | self.assertEqual(type(val), int)
|
---|
184 | self.assertEqual(val, 42)
|
---|
185 |
|
---|
186 | def CheckFuncReturnFloat(self):
|
---|
187 | cur = self.con.cursor()
|
---|
188 | cur.execute("select returnfloat()")
|
---|
189 | val = cur.fetchone()[0]
|
---|
190 | self.assertEqual(type(val), float)
|
---|
191 | if val < 3.139 or val > 3.141:
|
---|
192 | self.fail("wrong value")
|
---|
193 |
|
---|
194 | def CheckFuncReturnNull(self):
|
---|
195 | cur = self.con.cursor()
|
---|
196 | cur.execute("select returnnull()")
|
---|
197 | val = cur.fetchone()[0]
|
---|
198 | self.assertEqual(type(val), type(None))
|
---|
199 | self.assertEqual(val, None)
|
---|
200 |
|
---|
201 | def CheckFuncReturnBlob(self):
|
---|
202 | cur = self.con.cursor()
|
---|
203 | cur.execute("select returnblob()")
|
---|
204 | val = cur.fetchone()[0]
|
---|
205 | self.assertEqual(type(val), buffer)
|
---|
206 | self.assertEqual(val, buffer("blob"))
|
---|
207 |
|
---|
208 | def CheckFuncReturnLongLong(self):
|
---|
209 | cur = self.con.cursor()
|
---|
210 | cur.execute("select returnlonglong()")
|
---|
211 | val = cur.fetchone()[0]
|
---|
212 | self.assertEqual(val, 1<<31)
|
---|
213 |
|
---|
214 | def CheckFuncException(self):
|
---|
215 | cur = self.con.cursor()
|
---|
216 | try:
|
---|
217 | cur.execute("select raiseexception()")
|
---|
218 | cur.fetchone()
|
---|
219 | self.fail("should have raised OperationalError")
|
---|
220 | except sqlite.OperationalError, e:
|
---|
221 | self.assertEqual(e.args[0], 'user-defined function raised exception')
|
---|
222 |
|
---|
223 | def CheckParamString(self):
|
---|
224 | cur = self.con.cursor()
|
---|
225 | cur.execute("select isstring(?)", ("foo",))
|
---|
226 | val = cur.fetchone()[0]
|
---|
227 | self.assertEqual(val, 1)
|
---|
228 |
|
---|
229 | def CheckParamInt(self):
|
---|
230 | cur = self.con.cursor()
|
---|
231 | cur.execute("select isint(?)", (42,))
|
---|
232 | val = cur.fetchone()[0]
|
---|
233 | self.assertEqual(val, 1)
|
---|
234 |
|
---|
235 | def CheckParamFloat(self):
|
---|
236 | cur = self.con.cursor()
|
---|
237 | cur.execute("select isfloat(?)", (3.14,))
|
---|
238 | val = cur.fetchone()[0]
|
---|
239 | self.assertEqual(val, 1)
|
---|
240 |
|
---|
241 | def CheckParamNone(self):
|
---|
242 | cur = self.con.cursor()
|
---|
243 | cur.execute("select isnone(?)", (None,))
|
---|
244 | val = cur.fetchone()[0]
|
---|
245 | self.assertEqual(val, 1)
|
---|
246 |
|
---|
247 | def CheckParamBlob(self):
|
---|
248 | cur = self.con.cursor()
|
---|
249 | cur.execute("select isblob(?)", (buffer("blob"),))
|
---|
250 | val = cur.fetchone()[0]
|
---|
251 | self.assertEqual(val, 1)
|
---|
252 |
|
---|
253 | def CheckParamLongLong(self):
|
---|
254 | cur = self.con.cursor()
|
---|
255 | cur.execute("select islonglong(?)", (1<<42,))
|
---|
256 | val = cur.fetchone()[0]
|
---|
257 | self.assertEqual(val, 1)
|
---|
258 |
|
---|
259 | class AggregateTests(unittest.TestCase):
|
---|
260 | def setUp(self):
|
---|
261 | self.con = sqlite.connect(":memory:")
|
---|
262 | cur = self.con.cursor()
|
---|
263 | cur.execute("""
|
---|
264 | create table test(
|
---|
265 | t text,
|
---|
266 | i integer,
|
---|
267 | f float,
|
---|
268 | n,
|
---|
269 | b blob
|
---|
270 | )
|
---|
271 | """)
|
---|
272 | cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)",
|
---|
273 | ("foo", 5, 3.14, None, buffer("blob"),))
|
---|
274 |
|
---|
275 | self.con.create_aggregate("nostep", 1, AggrNoStep)
|
---|
276 | self.con.create_aggregate("nofinalize", 1, AggrNoFinalize)
|
---|
277 | self.con.create_aggregate("excInit", 1, AggrExceptionInInit)
|
---|
278 | self.con.create_aggregate("excStep", 1, AggrExceptionInStep)
|
---|
279 | self.con.create_aggregate("excFinalize", 1, AggrExceptionInFinalize)
|
---|
280 | self.con.create_aggregate("checkType", 2, AggrCheckType)
|
---|
281 | self.con.create_aggregate("mysum", 1, AggrSum)
|
---|
282 |
|
---|
283 | def tearDown(self):
|
---|
284 | #self.cur.close()
|
---|
285 | #self.con.close()
|
---|
286 | pass
|
---|
287 |
|
---|
288 | def CheckAggrErrorOnCreate(self):
|
---|
289 | try:
|
---|
290 | self.con.create_function("bla", -100, AggrSum)
|
---|
291 | self.fail("should have raised an OperationalError")
|
---|
292 | except sqlite.OperationalError:
|
---|
293 | pass
|
---|
294 |
|
---|
295 | def CheckAggrNoStep(self):
|
---|
296 | cur = self.con.cursor()
|
---|
297 | try:
|
---|
298 | cur.execute("select nostep(t) from test")
|
---|
299 | self.fail("should have raised an AttributeError")
|
---|
300 | except AttributeError, e:
|
---|
301 | self.assertEqual(e.args[0], "AggrNoStep instance has no attribute 'step'")
|
---|
302 |
|
---|
303 | def CheckAggrNoFinalize(self):
|
---|
304 | cur = self.con.cursor()
|
---|
305 | try:
|
---|
306 | cur.execute("select nofinalize(t) from test")
|
---|
307 | val = cur.fetchone()[0]
|
---|
308 | self.fail("should have raised an OperationalError")
|
---|
309 | except sqlite.OperationalError, e:
|
---|
310 | self.assertEqual(e.args[0], "user-defined aggregate's 'finalize' method raised error")
|
---|
311 |
|
---|
312 | def CheckAggrExceptionInInit(self):
|
---|
313 | cur = self.con.cursor()
|
---|
314 | try:
|
---|
315 | cur.execute("select excInit(t) from test")
|
---|
316 | val = cur.fetchone()[0]
|
---|
317 | self.fail("should have raised an OperationalError")
|
---|
318 | except sqlite.OperationalError, e:
|
---|
319 | self.assertEqual(e.args[0], "user-defined aggregate's '__init__' method raised error")
|
---|
320 |
|
---|
321 | def CheckAggrExceptionInStep(self):
|
---|
322 | cur = self.con.cursor()
|
---|
323 | try:
|
---|
324 | cur.execute("select excStep(t) from test")
|
---|
325 | val = cur.fetchone()[0]
|
---|
326 | self.fail("should have raised an OperationalError")
|
---|
327 | except sqlite.OperationalError, e:
|
---|
328 | self.assertEqual(e.args[0], "user-defined aggregate's 'step' method raised error")
|
---|
329 |
|
---|
330 | def CheckAggrExceptionInFinalize(self):
|
---|
331 | cur = self.con.cursor()
|
---|
332 | try:
|
---|
333 | cur.execute("select excFinalize(t) from test")
|
---|
334 | val = cur.fetchone()[0]
|
---|
335 | self.fail("should have raised an OperationalError")
|
---|
336 | except sqlite.OperationalError, e:
|
---|
337 | self.assertEqual(e.args[0], "user-defined aggregate's 'finalize' method raised error")
|
---|
338 |
|
---|
339 | def CheckAggrCheckParamStr(self):
|
---|
340 | cur = self.con.cursor()
|
---|
341 | cur.execute("select checkType('str', ?)", ("foo",))
|
---|
342 | val = cur.fetchone()[0]
|
---|
343 | self.assertEqual(val, 1)
|
---|
344 |
|
---|
345 | def CheckAggrCheckParamInt(self):
|
---|
346 | cur = self.con.cursor()
|
---|
347 | cur.execute("select checkType('int', ?)", (42,))
|
---|
348 | val = cur.fetchone()[0]
|
---|
349 | self.assertEqual(val, 1)
|
---|
350 |
|
---|
351 | def CheckAggrCheckParamFloat(self):
|
---|
352 | cur = self.con.cursor()
|
---|
353 | cur.execute("select checkType('float', ?)", (3.14,))
|
---|
354 | val = cur.fetchone()[0]
|
---|
355 | self.assertEqual(val, 1)
|
---|
356 |
|
---|
357 | def CheckAggrCheckParamNone(self):
|
---|
358 | cur = self.con.cursor()
|
---|
359 | cur.execute("select checkType('None', ?)", (None,))
|
---|
360 | val = cur.fetchone()[0]
|
---|
361 | self.assertEqual(val, 1)
|
---|
362 |
|
---|
363 | def CheckAggrCheckParamBlob(self):
|
---|
364 | cur = self.con.cursor()
|
---|
365 | cur.execute("select checkType('blob', ?)", (buffer("blob"),))
|
---|
366 | val = cur.fetchone()[0]
|
---|
367 | self.assertEqual(val, 1)
|
---|
368 |
|
---|
369 | def CheckAggrCheckAggrSum(self):
|
---|
370 | cur = self.con.cursor()
|
---|
371 | cur.execute("delete from test")
|
---|
372 | cur.executemany("insert into test(i) values (?)", [(10,), (20,), (30,)])
|
---|
373 | cur.execute("select mysum(i) from test")
|
---|
374 | val = cur.fetchone()[0]
|
---|
375 | self.assertEqual(val, 60)
|
---|
376 |
|
---|
377 | class AuthorizerTests(unittest.TestCase):
|
---|
378 | @staticmethod
|
---|
379 | def authorizer_cb(action, arg1, arg2, dbname, source):
|
---|
380 | if action != sqlite.SQLITE_SELECT:
|
---|
381 | return sqlite.SQLITE_DENY
|
---|
382 | if arg2 == 'c2' or arg1 == 't2':
|
---|
383 | return sqlite.SQLITE_DENY
|
---|
384 | return sqlite.SQLITE_OK
|
---|
385 |
|
---|
386 | def setUp(self):
|
---|
387 | self.con = sqlite.connect(":memory:")
|
---|
388 | self.con.executescript("""
|
---|
389 | create table t1 (c1, c2);
|
---|
390 | create table t2 (c1, c2);
|
---|
391 | insert into t1 (c1, c2) values (1, 2);
|
---|
392 | insert into t2 (c1, c2) values (4, 5);
|
---|
393 | """)
|
---|
394 |
|
---|
395 | # For our security test:
|
---|
396 | self.con.execute("select c2 from t2")
|
---|
397 |
|
---|
398 | self.con.set_authorizer(self.authorizer_cb)
|
---|
399 |
|
---|
400 | def tearDown(self):
|
---|
401 | pass
|
---|
402 |
|
---|
403 | def test_table_access(self):
|
---|
404 | try:
|
---|
405 | self.con.execute("select * from t2")
|
---|
406 | except sqlite.DatabaseError, e:
|
---|
407 | if not e.args[0].endswith("prohibited"):
|
---|
408 | self.fail("wrong exception text: %s" % e.args[0])
|
---|
409 | return
|
---|
410 | self.fail("should have raised an exception due to missing privileges")
|
---|
411 |
|
---|
412 | def test_column_access(self):
|
---|
413 | try:
|
---|
414 | self.con.execute("select c2 from t1")
|
---|
415 | except sqlite.DatabaseError, e:
|
---|
416 | if not e.args[0].endswith("prohibited"):
|
---|
417 | self.fail("wrong exception text: %s" % e.args[0])
|
---|
418 | return
|
---|
419 | self.fail("should have raised an exception due to missing privileges")
|
---|
420 |
|
---|
421 | class AuthorizerRaiseExceptionTests(AuthorizerTests):
|
---|
422 | @staticmethod
|
---|
423 | def authorizer_cb(action, arg1, arg2, dbname, source):
|
---|
424 | if action != sqlite.SQLITE_SELECT:
|
---|
425 | raise ValueError
|
---|
426 | if arg2 == 'c2' or arg1 == 't2':
|
---|
427 | raise ValueError
|
---|
428 | return sqlite.SQLITE_OK
|
---|
429 |
|
---|
430 | class AuthorizerIllegalTypeTests(AuthorizerTests):
|
---|
431 | @staticmethod
|
---|
432 | def authorizer_cb(action, arg1, arg2, dbname, source):
|
---|
433 | if action != sqlite.SQLITE_SELECT:
|
---|
434 | return 0.0
|
---|
435 | if arg2 == 'c2' or arg1 == 't2':
|
---|
436 | return 0.0
|
---|
437 | return sqlite.SQLITE_OK
|
---|
438 |
|
---|
439 | class AuthorizerLargeIntegerTests(AuthorizerTests):
|
---|
440 | @staticmethod
|
---|
441 | def authorizer_cb(action, arg1, arg2, dbname, source):
|
---|
442 | if action != sqlite.SQLITE_SELECT:
|
---|
443 | return 2**32
|
---|
444 | if arg2 == 'c2' or arg1 == 't2':
|
---|
445 | return 2**32
|
---|
446 | return sqlite.SQLITE_OK
|
---|
447 |
|
---|
448 |
|
---|
449 | def suite():
|
---|
450 | function_suite = unittest.makeSuite(FunctionTests, "Check")
|
---|
451 | aggregate_suite = unittest.makeSuite(AggregateTests, "Check")
|
---|
452 | authorizer_suite = unittest.makeSuite(AuthorizerTests)
|
---|
453 | return unittest.TestSuite((
|
---|
454 | function_suite,
|
---|
455 | aggregate_suite,
|
---|
456 | authorizer_suite,
|
---|
457 | unittest.makeSuite(AuthorizerRaiseExceptionTests),
|
---|
458 | unittest.makeSuite(AuthorizerIllegalTypeTests),
|
---|
459 | unittest.makeSuite(AuthorizerLargeIntegerTests),
|
---|
460 | ))
|
---|
461 |
|
---|
462 | def test():
|
---|
463 | runner = unittest.TextTestRunner()
|
---|
464 | runner.run(suite())
|
---|
465 |
|
---|
466 | if __name__ == "__main__":
|
---|
467 | test()
|
---|