1 | /*
|
---|
2 | * A type which wraps a socket
|
---|
3 | *
|
---|
4 | * socket_connection.c
|
---|
5 | *
|
---|
6 | * Copyright (c) 2006-2008, R Oudkerk --- see COPYING.txt
|
---|
7 | */
|
---|
8 |
|
---|
9 | #include "multiprocessing.h"
|
---|
10 |
|
---|
11 | #if defined(HAVE_POLL) && !defined(HAVE_BROKEN_POLL)
|
---|
12 | # include "poll.h"
|
---|
13 | #endif
|
---|
14 |
|
---|
15 | #ifdef MS_WINDOWS
|
---|
16 | # define WRITE(h, buffer, length) send((SOCKET)h, buffer, length, 0)
|
---|
17 | # define READ(h, buffer, length) recv((SOCKET)h, buffer, length, 0)
|
---|
18 | # define CLOSE(h) closesocket((SOCKET)h)
|
---|
19 | #else
|
---|
20 | # define WRITE(h, buffer, length) write(h, buffer, length)
|
---|
21 | # define READ(h, buffer, length) read(h, buffer, length)
|
---|
22 | # define CLOSE(h) close(h)
|
---|
23 | #endif
|
---|
24 |
|
---|
25 | /*
|
---|
26 | * Wrapper for PyErr_CheckSignals() which can be called without the GIL
|
---|
27 | */
|
---|
28 |
|
---|
29 | static int
|
---|
30 | check_signals(void)
|
---|
31 | {
|
---|
32 | PyGILState_STATE state;
|
---|
33 | int res;
|
---|
34 | state = PyGILState_Ensure();
|
---|
35 | res = PyErr_CheckSignals();
|
---|
36 | PyGILState_Release(state);
|
---|
37 | return res;
|
---|
38 | }
|
---|
39 |
|
---|
40 | /*
|
---|
41 | * Send string to file descriptor
|
---|
42 | */
|
---|
43 |
|
---|
44 | static Py_ssize_t
|
---|
45 | _conn_sendall(HANDLE h, char *string, size_t length)
|
---|
46 | {
|
---|
47 | char *p = string;
|
---|
48 | Py_ssize_t res;
|
---|
49 |
|
---|
50 | while (length > 0) {
|
---|
51 | res = WRITE(h, p, length);
|
---|
52 | if (res < 0) {
|
---|
53 | if (errno == EINTR) {
|
---|
54 | if (check_signals() < 0)
|
---|
55 | return MP_EXCEPTION_HAS_BEEN_SET;
|
---|
56 | continue;
|
---|
57 | }
|
---|
58 | return MP_SOCKET_ERROR;
|
---|
59 | }
|
---|
60 | length -= res;
|
---|
61 | p += res;
|
---|
62 | }
|
---|
63 |
|
---|
64 | return MP_SUCCESS;
|
---|
65 | }
|
---|
66 |
|
---|
67 | /*
|
---|
68 | * Receive string of exact length from file descriptor
|
---|
69 | */
|
---|
70 |
|
---|
71 | static Py_ssize_t
|
---|
72 | _conn_recvall(HANDLE h, char *buffer, size_t length)
|
---|
73 | {
|
---|
74 | size_t remaining = length;
|
---|
75 | Py_ssize_t temp;
|
---|
76 | char *p = buffer;
|
---|
77 |
|
---|
78 | while (remaining > 0) {
|
---|
79 | temp = READ(h, p, remaining);
|
---|
80 | if (temp < 0) {
|
---|
81 | if (errno == EINTR) {
|
---|
82 | if (check_signals() < 0)
|
---|
83 | return MP_EXCEPTION_HAS_BEEN_SET;
|
---|
84 | continue;
|
---|
85 | }
|
---|
86 | return temp;
|
---|
87 | }
|
---|
88 | else if (temp == 0) {
|
---|
89 | return remaining == length ? MP_END_OF_FILE : MP_EARLY_END_OF_FILE;
|
---|
90 | }
|
---|
91 | remaining -= temp;
|
---|
92 | p += temp;
|
---|
93 | }
|
---|
94 |
|
---|
95 | return MP_SUCCESS;
|
---|
96 | }
|
---|
97 |
|
---|
98 | /*
|
---|
99 | * Send a string prepended by the string length in network byte order
|
---|
100 | */
|
---|
101 |
|
---|
102 | static Py_ssize_t
|
---|
103 | conn_send_string(ConnectionObject *conn, char *string, size_t length)
|
---|
104 | {
|
---|
105 | Py_ssize_t res;
|
---|
106 | /* The "header" of the message is a 32 bit unsigned number (in
|
---|
107 | network order) which specifies the length of the "body". If
|
---|
108 | the message is shorter than about 16kb then it is quicker to
|
---|
109 | combine the "header" and the "body" of the message and send
|
---|
110 | them at once. */
|
---|
111 | if (length < (16*1024)) {
|
---|
112 | char *message;
|
---|
113 |
|
---|
114 | message = PyMem_Malloc(length+4);
|
---|
115 | if (message == NULL)
|
---|
116 | return MP_MEMORY_ERROR;
|
---|
117 |
|
---|
118 | *(UINT32*)message = htonl((UINT32)length);
|
---|
119 | memcpy(message+4, string, length);
|
---|
120 | Py_BEGIN_ALLOW_THREADS
|
---|
121 | res = _conn_sendall(conn->handle, message, length+4);
|
---|
122 | Py_END_ALLOW_THREADS
|
---|
123 | PyMem_Free(message);
|
---|
124 | } else {
|
---|
125 | UINT32 lenbuff;
|
---|
126 |
|
---|
127 | if (length > MAX_MESSAGE_LENGTH)
|
---|
128 | return MP_BAD_MESSAGE_LENGTH;
|
---|
129 |
|
---|
130 | lenbuff = htonl((UINT32)length);
|
---|
131 | Py_BEGIN_ALLOW_THREADS
|
---|
132 | res = _conn_sendall(conn->handle, (char*)&lenbuff, 4) ||
|
---|
133 | _conn_sendall(conn->handle, string, length);
|
---|
134 | Py_END_ALLOW_THREADS
|
---|
135 | }
|
---|
136 | return res;
|
---|
137 | }
|
---|
138 |
|
---|
139 | /*
|
---|
140 | * Attempts to read into buffer, or failing that into *newbuffer
|
---|
141 | *
|
---|
142 | * Returns number of bytes read.
|
---|
143 | */
|
---|
144 |
|
---|
145 | static Py_ssize_t
|
---|
146 | conn_recv_string(ConnectionObject *conn, char *buffer,
|
---|
147 | size_t buflength, char **newbuffer, size_t maxlength)
|
---|
148 | {
|
---|
149 | Py_ssize_t res;
|
---|
150 | UINT32 ulength;
|
---|
151 |
|
---|
152 | *newbuffer = NULL;
|
---|
153 |
|
---|
154 | Py_BEGIN_ALLOW_THREADS
|
---|
155 | res = _conn_recvall(conn->handle, (char*)&ulength, 4);
|
---|
156 | Py_END_ALLOW_THREADS
|
---|
157 | if (res < 0)
|
---|
158 | return res;
|
---|
159 |
|
---|
160 | ulength = ntohl(ulength);
|
---|
161 | if (ulength > maxlength)
|
---|
162 | return MP_BAD_MESSAGE_LENGTH;
|
---|
163 |
|
---|
164 | if (ulength > buflength) {
|
---|
165 | *newbuffer = buffer = PyMem_Malloc((size_t)ulength);
|
---|
166 | if (buffer == NULL)
|
---|
167 | return MP_MEMORY_ERROR;
|
---|
168 | }
|
---|
169 |
|
---|
170 | Py_BEGIN_ALLOW_THREADS
|
---|
171 | res = _conn_recvall(conn->handle, buffer, (size_t)ulength);
|
---|
172 | Py_END_ALLOW_THREADS
|
---|
173 |
|
---|
174 | if (res >= 0) {
|
---|
175 | res = (Py_ssize_t)ulength;
|
---|
176 | } else if (*newbuffer != NULL) {
|
---|
177 | PyMem_Free(*newbuffer);
|
---|
178 | *newbuffer = NULL;
|
---|
179 | }
|
---|
180 | return res;
|
---|
181 | }
|
---|
182 |
|
---|
183 | /*
|
---|
184 | * Check whether any data is available for reading -- neg timeout blocks
|
---|
185 | */
|
---|
186 |
|
---|
187 | static int
|
---|
188 | conn_poll(ConnectionObject *conn, double timeout, PyThreadState *_save)
|
---|
189 | {
|
---|
190 | #if defined(HAVE_POLL) && !defined(HAVE_BROKEN_POLL)
|
---|
191 | int res;
|
---|
192 | struct pollfd p;
|
---|
193 |
|
---|
194 | p.fd = (int)conn->handle;
|
---|
195 | p.events = POLLIN | POLLPRI;
|
---|
196 | p.revents = 0;
|
---|
197 |
|
---|
198 | if (timeout < 0) {
|
---|
199 | do {
|
---|
200 | res = poll(&p, 1, -1);
|
---|
201 | } while (res < 0 && errno == EINTR);
|
---|
202 | } else {
|
---|
203 | res = poll(&p, 1, (int)(timeout * 1000 + 0.5));
|
---|
204 | if (res < 0 && errno == EINTR) {
|
---|
205 | /* We were interrupted by a signal. Just indicate a
|
---|
206 | timeout even though we are early. */
|
---|
207 | return FALSE;
|
---|
208 | }
|
---|
209 | }
|
---|
210 |
|
---|
211 | if (res < 0) {
|
---|
212 | return MP_SOCKET_ERROR;
|
---|
213 | } else if (p.revents & (POLLNVAL|POLLERR)) {
|
---|
214 | Py_BLOCK_THREADS
|
---|
215 | PyErr_SetString(PyExc_IOError, "poll() gave POLLNVAL or POLLERR");
|
---|
216 | Py_UNBLOCK_THREADS
|
---|
217 | return MP_EXCEPTION_HAS_BEEN_SET;
|
---|
218 | } else if (p.revents != 0) {
|
---|
219 | return TRUE;
|
---|
220 | } else {
|
---|
221 | assert(res == 0);
|
---|
222 | return FALSE;
|
---|
223 | }
|
---|
224 | #else
|
---|
225 | int res;
|
---|
226 | fd_set rfds;
|
---|
227 |
|
---|
228 | /*
|
---|
229 | * Verify the handle, issue 3321. Not required for windows.
|
---|
230 | */
|
---|
231 | #ifndef MS_WINDOWS
|
---|
232 | if (((int)conn->handle) < 0 || ((int)conn->handle) >= FD_SETSIZE) {
|
---|
233 | Py_BLOCK_THREADS
|
---|
234 | PyErr_SetString(PyExc_IOError, "handle out of range in select()");
|
---|
235 | Py_UNBLOCK_THREADS
|
---|
236 | return MP_EXCEPTION_HAS_BEEN_SET;
|
---|
237 | }
|
---|
238 | #endif
|
---|
239 |
|
---|
240 | FD_ZERO(&rfds);
|
---|
241 | FD_SET((SOCKET)conn->handle, &rfds);
|
---|
242 |
|
---|
243 | if (timeout < 0.0) {
|
---|
244 | do {
|
---|
245 | res = select((int)conn->handle+1, &rfds, NULL, NULL, NULL);
|
---|
246 | } while (res < 0 && errno == EINTR);
|
---|
247 | } else {
|
---|
248 | struct timeval tv;
|
---|
249 | tv.tv_sec = (long)timeout;
|
---|
250 | tv.tv_usec = (long)((timeout - tv.tv_sec) * 1e6 + 0.5);
|
---|
251 | res = select((int)conn->handle+1, &rfds, NULL, NULL, &tv);
|
---|
252 | if (res < 0 && errno == EINTR) {
|
---|
253 | /* We were interrupted by a signal. Just indicate a
|
---|
254 | timeout even though we are early. */
|
---|
255 | return FALSE;
|
---|
256 | }
|
---|
257 | }
|
---|
258 |
|
---|
259 | if (res < 0) {
|
---|
260 | return MP_SOCKET_ERROR;
|
---|
261 | } else if (FD_ISSET(conn->handle, &rfds)) {
|
---|
262 | return TRUE;
|
---|
263 | } else {
|
---|
264 | assert(res == 0);
|
---|
265 | return FALSE;
|
---|
266 | }
|
---|
267 | #endif
|
---|
268 | }
|
---|
269 |
|
---|
270 | /*
|
---|
271 | * "connection.h" defines the Connection type using defs above
|
---|
272 | */
|
---|
273 |
|
---|
274 | #define CONNECTION_NAME "Connection"
|
---|
275 | #define CONNECTION_TYPE ConnectionType
|
---|
276 |
|
---|
277 | #include "connection.h"
|
---|