gnutls-cli: added --starttls-proto option
[gnutls:gnutls.git] / src / socket.c
1 /*
2  * Copyright (C) 2000-2012 Free Software Foundation, Inc.
3  *
4  * This file is part of GnuTLS.
5  *
6  * GnuTLS is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * GnuTLS is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  * GNU General Public License for more details.
15  *
16  * You should have received a copy of the GNU General Public License
17  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
18  */
19
20 #include <config.h>
21
22 #if HAVE_SYS_SOCKET_H
23 #include <sys/socket.h>
24 #elif HAVE_WS2TCPIP_H
25 #include <ws2tcpip.h>
26 #endif
27 #include <netdb.h>
28 #include <string.h>
29 #include <errno.h>
30 #include <sys/select.h>
31 #include <sys/types.h>
32 #include <stdio.h>
33 #include <stdlib.h>
34 #include <unistd.h>
35 #ifndef _WIN32
36 #include <arpa/inet.h>
37 #include <signal.h>
38 #endif
39 #include <socket.h>
40 #include "sockets.h"
41
42 #ifdef HAVE_LIBIDN
43 #include <idna.h>
44 #include <idn-free.h>
45 #endif
46
47 #define MAX_BUF 4096
48
49 /* Functions to manipulate sockets
50  */
51
52 ssize_t
53 socket_recv(const socket_st * socket, void *buffer, int buffer_size)
54 {
55         int ret;
56
57         if (socket->secure) {
58                 do {
59                         ret =
60                             gnutls_record_recv(socket->session, buffer,
61                                                buffer_size);
62                         if (ret == GNUTLS_E_HEARTBEAT_PING_RECEIVED)
63                                 gnutls_heartbeat_pong(socket->session, 0);
64                 }
65                 while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN
66                        || ret == GNUTLS_E_HEARTBEAT_PING_RECEIVED);
67
68         } else
69                 do {
70                         ret = recv(socket->fd, buffer, buffer_size, 0);
71                 }
72                 while (ret == -1 && errno == EINTR);
73
74         return ret;
75 }
76
77 ssize_t
78 socket_send(const socket_st * socket, const void *buffer, int buffer_size)
79 {
80         return socket_send_range(socket, buffer, buffer_size, NULL);
81 }
82
83
84 ssize_t
85 socket_send_range(const socket_st * socket, const void *buffer,
86                   int buffer_size, gnutls_range_st * range)
87 {
88         int ret;
89
90         if (socket->secure)
91                 do {
92                         if (range == NULL)
93                                 ret =
94                                     gnutls_record_send(socket->session,
95                                                        buffer,
96                                                        buffer_size);
97                         else
98                                 ret =
99                                     gnutls_record_send_range(socket->
100                                                              session,
101                                                              buffer,
102                                                              buffer_size,
103                                                              range);
104                 }
105                 while (ret == GNUTLS_E_AGAIN
106                        || ret == GNUTLS_E_INTERRUPTED);
107         else
108                 do {
109                         ret = send(socket->fd, buffer, buffer_size, 0);
110                 }
111                 while (ret == -1 && errno == EINTR);
112
113         if (ret > 0 && ret != buffer_size && socket->verbose)
114                 fprintf(stderr,
115                         "*** Only sent %d bytes instead of %d.\n", ret,
116                         buffer_size);
117
118         return ret;
119 }
120
121 static
122 ssize_t send_line(int fd, const char *txt)
123 {
124         int len = strlen(txt);
125         int ret;
126
127         ret = send(fd, txt, len, 0);
128
129         if (ret == -1) {
130                 fprintf(stderr, "error sending %s\n", txt);
131                 exit(1);
132         }
133
134         return ret;
135 }
136
137 static
138 ssize_t wait_for_text(int fd, const char *txt, unsigned txt_size)
139 {
140         char buf[512];
141         char *p;
142         int ret;
143         fd_set read_fds;
144         struct timeval tv;
145
146         do {
147                 FD_ZERO(&read_fds);
148                 FD_SET(fd, &read_fds);
149                 tv.tv_sec = 10;
150                 tv.tv_usec = 0;
151                 ret = select(fd + 1, &read_fds, NULL, NULL, &tv);
152                 if (ret <= 0)
153                         ret = -1;
154                 else
155                         ret = recv(fd, buf, sizeof(buf)-1, 0);
156                 if (ret == -1) {
157                         fprintf(stderr, "error receiving %s\n", txt);
158                         exit(1);
159                 }
160                 buf[ret] = 0;
161
162                 p = memmem(buf, ret, txt, txt_size);
163                 if (p != NULL && p != buf) {
164                         p--;
165                         if (*p == '\n')
166                                 break;
167                 }
168         } while(ret < (int)txt_size || strncmp(buf, txt, txt_size) != 0);
169
170         return ret;
171 }
172
173 void
174 socket_starttls(socket_st * socket, const char *app_proto)
175 {
176         if (socket->secure)
177                 return;
178
179         if (app_proto == NULL || strcasecmp(app_proto, "https") == 0)
180                 return;
181
182         if (strcasecmp(app_proto, "smtp") == 0 || strcasecmp(app_proto, "submission") == 0) {
183                 if (socket->verbose)
184                         printf("Negotiating SMTP STARTTLS\n");
185
186                 wait_for_text(socket->fd, "220 ", 4);
187                 send_line(socket->fd, "EHLO mail.example.com\n");
188                 wait_for_text(socket->fd, "250 ", 4);
189                 send_line(socket->fd, "STARTTLS\n");
190                 wait_for_text(socket->fd, "220 ", 4);
191         } else if (strcasecmp(app_proto, "imap") == 0 || strcasecmp(app_proto, "imap2") == 0) {
192                 if (socket->verbose)
193                         printf("Negotiating IMAP STARTTLS\n");
194
195                 send_line(socket->fd, "a CAPABILITY\r\n");
196                 wait_for_text(socket->fd, "a OK", 4);
197                 send_line(socket->fd, "a STARTTLS\r\n");
198                 wait_for_text(socket->fd, "a OK", 4);
199         } else {
200                 if (socket->verbose)
201                         fprintf(stderr, "unknown protocol %s\n", app_proto);
202         }
203
204         return;
205 }
206
207 void socket_bye(socket_st * socket)
208 {
209         int ret;
210         if (socket->secure) {
211                 do
212                         ret = gnutls_bye(socket->session, GNUTLS_SHUT_WR);
213                 while (ret == GNUTLS_E_INTERRUPTED
214                        || ret == GNUTLS_E_AGAIN);
215                 if (ret < 0)
216                         fprintf(stderr, "*** gnutls_bye() error: %s\n",
217                                 gnutls_strerror(ret));
218                 gnutls_deinit(socket->session);
219                 socket->session = NULL;
220         }
221
222         freeaddrinfo(socket->addr_info);
223         socket->addr_info = socket->ptr = NULL;
224
225         free(socket->ip);
226         free(socket->hostname);
227         free(socket->service);
228
229         shutdown(socket->fd, SHUT_RDWR);        /* no more receptions */
230         close(socket->fd);
231
232         socket->fd = -1;
233         socket->secure = 0;
234 }
235
236 void
237 socket_open(socket_st * hd, const char *hostname, const char *service,
238             int udp, const char *msg)
239 {
240         struct addrinfo hints, *res, *ptr;
241         int sd, err = 0;
242         char buffer[MAX_BUF + 1];
243         char portname[16] = { 0 };
244         char *a_hostname = (char*)hostname;
245
246 #ifdef HAVE_LIBIDN
247         err = idna_to_ascii_8z(hostname, &a_hostname, IDNA_ALLOW_UNASSIGNED);
248         if (err != IDNA_SUCCESS) {
249                 fprintf(stderr, "Cannot convert %s to IDNA: %s\n", hostname,
250                         idna_strerror(err));
251                 exit(1);
252         }
253 #endif
254
255         if (msg != NULL)
256                 printf("Resolving '%s'...\n", a_hostname);
257
258         /* get server name */
259         memset(&hints, 0, sizeof(hints));
260         hints.ai_socktype = udp ? SOCK_DGRAM : SOCK_STREAM;
261         if ((err = getaddrinfo(a_hostname, service, &hints, &res))) {
262                 fprintf(stderr, "Cannot resolve %s:%s: %s\n", hostname,
263                         service, gai_strerror(err));
264                 exit(1);
265         }
266
267         sd = -1;
268         for (ptr = res; ptr != NULL; ptr = ptr->ai_next) {
269                 sd = socket(ptr->ai_family, ptr->ai_socktype,
270                             ptr->ai_protocol);
271                 if (sd == -1)
272                         continue;
273
274                 if ((err =
275                      getnameinfo(ptr->ai_addr, ptr->ai_addrlen, buffer,
276                                  MAX_BUF, portname, sizeof(portname),
277                                  NI_NUMERICHOST | NI_NUMERICSERV)) != 0) {
278                         fprintf(stderr, "getnameinfo(): %s\n",
279                                 gai_strerror(err));
280                         continue;
281                 }
282
283                 if (hints.ai_socktype == SOCK_DGRAM) {
284 #if defined(IP_DONTFRAG)
285                         int yes = 1;
286                         if (setsockopt(sd, IPPROTO_IP, IP_DONTFRAG,
287                                        (const void *) &yes,
288                                        sizeof(yes)) < 0)
289                                 perror("setsockopt(IP_DF) failed");
290 #elif defined(IP_MTU_DISCOVER)
291                         int yes = IP_PMTUDISC_DO;
292                         if (setsockopt(sd, IPPROTO_IP, IP_MTU_DISCOVER,
293                                        (const void *) &yes,
294                                        sizeof(yes)) < 0)
295                                 perror("setsockopt(IP_DF) failed");
296 #endif
297                 }
298
299
300                 if (msg)
301                         printf("%s '%s:%s'...\n", msg, buffer, portname);
302
303                 err = connect(sd, ptr->ai_addr, ptr->ai_addrlen);
304                 if (err < 0) {
305                         continue;
306                 }
307                 break;
308         }
309
310         if (err != 0) {
311                 int e = errno;
312                 fprintf(stderr, "Could not connect to %s:%s: %s\n",
313                                 buffer, portname, strerror(e));
314                 exit(1);
315         }
316
317         if (sd == -1) {
318                 fprintf(stderr, "Could not find a supported socket\n");
319                 exit(1);
320         }
321
322         hd->secure = 0;
323         hd->fd = sd;
324         hd->hostname = strdup(hostname);
325         hd->ip = strdup(buffer);
326         hd->service = strdup(portname);
327         hd->ptr = ptr;
328         hd->addr_info = res;
329 #ifdef HAVE_LIBIDN
330         idn_free(a_hostname);
331 #endif
332         return;
333 }
334
335 void sockets_init(void)
336 {
337 #ifdef _WIN32
338         WORD wVersionRequested;
339         WSADATA wsaData;
340
341         wVersionRequested = MAKEWORD(1, 1);
342         if (WSAStartup(wVersionRequested, &wsaData) != 0) {
343                 perror("WSA_STARTUP_ERROR");
344         }
345 #else
346         signal(SIGPIPE, SIG_IGN);
347 #endif
348
349 }
350
351 /* converts a textual service or port to
352  * a service.
353  */
354 const char *port_to_service(const char *sport, const char *proto)
355 {
356         unsigned int port;
357         struct servent *sr;
358
359         port = atoi(sport);
360         if (port == 0)
361                 return sport;
362
363         port = htons(port);
364
365         sr = getservbyport(port, proto);
366         if (sr == NULL) {
367                 fprintf(stderr,
368                         "Warning: getservbyport() failed. Using port number as service.\n");
369                 return sport;
370         }
371
372         return sr->s_name;
373 }
374
375 int service_to_port(const char *service, const char *proto)
376 {
377         unsigned int port;
378         struct servent *sr;
379
380         port = atoi(service);
381         if (port != 0)
382                 return port;
383
384         sr = getservbyname(service, proto);
385         if (sr == NULL) {
386                 fprintf(stderr, "Warning: getservbyname() failed.\n");
387                 exit(1);
388         }
389
390         return ntohs(sr->s_port);
391 }