/***************************************************************************
 *   Copyright (C) 2004 by Trevor "beltorak" Torrez                        *
 *   beltorak@phreaker.net                                                 *
 *   networking.c part of cidentd version 0.2                               *
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 *   This program is distributed in the hope that it will be useful,       *
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of        *
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the         *
 *   GNU General Public License for more details.                          *
 *                                                                         *
 *   You should have received a copy of the GNU General Public License     *
 *   along with this program; if not, write to the                         *
 *   Free Software Foundation, Inc.,                                       *
 *   59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.             *
 ***************************************************************************/

/* This poorly named file handles the base networking layer */

#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

#include <stdlib.h>
#include <stdio.h>
#include <netdb.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <unistd.h>
#include <sys/time.h>
#include <sys/select.h>
#include <string.h>
#include <signal.h>
#include <arpa/inet.h>
#include <errno.h>

#include <networking.h>

#include <macros.h>
#include <state.h>
#include <strtotype.h>
#include <logutils.h>
#include <response.h>

/* Server globals */

uint	server_authport;
uint	server_authport_set;

int	server_auth_socket;

/* I really need to recombine the server settings.... */
uint	server_max_queries;
uint	server_max_queries_set;

/* This is here till i get the proper macro included */
#undef max
#define max(a,b) ((a>b) ? (a) : (b) )

static ConnContext* cc_list;
static uint current_connections;
static int max_fd;			/* Needed for select() */
static fd_set rfds, wfds, efds;		/* read, write, and except file descriptor sets */

/* Sets the authport.  0 on success, -1 on error. */
int
set_authport(char* port) {
	if( server_authport_set && (!state.cmdline) )
        	return 0;

        uint sv = server_authport;
        if( strtoport(&server_authport, port) )
                return -1;

        if( sv != server_authport )
        	server_authport_set = TRUE;
        return 0;
} /* -- end set_authport() -- */

/* Prints the current value of authport.  Prefixes with
'*' if state.show_settings and it is the default value */
void
print_authport( void ) {
	char* service;
        struct servent* sv = getservbyport(htons(server_authport), NULL);
        if( sv )
        	service = sv->s_name;
        else
        	service = "no name in /etc/services";

        print_default_marker_if(!server_authport_set);
        println("authport == %u (%s)", server_authport, service);
} /* -- end print_authport() -- */

/* Sets the number of maximum connections... note that
there is a hard limit based on the number of open file
descriptors, which is 255.  This function rejects anything
over 250 */
int
set_max_queries( char* max_queries ) {
	ulong tmp;
	if( strtoulong(&tmp, max_queries) )
		return -1;

	if( tmp > 250 ) {
		carp("Cannot set max queries to more than 250");
		return -1;
	}

	if( tmp != server_max_queries ) {
		server_max_queries_set = TRUE;
		server_max_queries = (uint) tmp;
	}

	return 0;
} /* -- end set_max_queries() -- */

/* Displays the maximum number of queries */
void
print_max_queries( void ) {
	print_default_marker_if(!server_max_queries_set);
	println("Maximum Concurrent Queries: %d", server_max_queries);
} /* -- end print_max_queries() -- */

/* binds and listens to the required port */
int
start_networking( void ) {
        server_auth_socket = socket(AF_INET, SOCK_STREAM, 0);
        if( server_auth_socket < 0 ) {
        	choke("Could not open socket");
                return -1;
        }

        int retval;

        struct sockaddr_in authaddr = {
        	AF_INET,
                htons(server_authport),
                INADDR_ANY
        };

        proc_to_prime_privs();
        retval = bind(server_auth_socket, &authaddr, sizeof authaddr);
        suspend_proc_privs();
        if( retval ) {
        	choke("Could not bind to auth port %d", server_authport);
                return -1;
        }

        if( listen(server_auth_socket, 1) ) {
        	choke("Could not listen to auth socket");
                return -1;
        }

	/* Initailize our connection contexts */
	cc_list = NULL;
        return 0;
} /* -- end start_networking() -- */

/* free's a connection context and its
associated resources */
void
free_connection_context( ConnContext* cc ) {
	if(!cc)
		return;
	if( cc->qsock > 0 ) {
		shutdown(cc->qsock, SHUT_RDWR);
		close(cc->qsock);
	}
	if_notnull_free(cc->answer)
	free(cc);
} /* -- end free_connection_context() -- */

/* Initializes a ConnContext */
void
clear_connection_context(ConnContext* cc) {
	memset(cc, 0, sizeof(ConnContext));
	cc->next = NULL;
	cc->answer = NULL;
	cc->qsock = -1;
} /* -- end clear_connection_context() -- */

/* Creates and returns a cleared ConnContext */
ConnContext*
new_connection_context( void ) {
	ConnContext* cc = malloc(sizeof(ConnContext));
	if( ! cc )
		choke("Could not create new ConnContext");
	else
		clear_connection_context(cc);

	return cc;
} /* -- end new_connection_context() -- */

/* Clears the list of '-1' sockets */
void
clean_cc_list( void ) {
	if( ! cc_list )
		return;
	ConnContext* cc = cc_list;
	ConnContext* last = NULL;
	while( cc ) {
		if( (cc->qsock == -1) || (cc->finished_write) ) {
			ConnContext* n = cc->next;
			free_connection_context(cc);
			if(last)
				last->next = n;
			cc = n;
			continue;
		}
		last = cc;
		cc = cc->next;
	}
	return;
} /* -- end clean_cc_list() -- */

/* Adds a new socket to the ConnContext list.
Returns 0 if successful, -1 on error. */
int
ref_socket(int fd, ulong rhost) {
	if( fd < 0 ) {
		carp("Invalid socket");
		return -1;
	}
	dreport("Adding socket reference %d", fd);
	ConnContext* tmp = get_connection_context(fd);
	if( tmp ) {
		dreport("duplicate socket ref %d @ %p", fd, tmp);
		return 0;
	}

	ConnContext *cc = new_connection_context();
	if( ! cc )
		return -1;

	cc->qsock = fd;
	cc->next = cc_list;
	cc->rhost = rhost;

	cc_list = cc;
	++current_connections;
	return 0;
} /* -- end ref_socket() -- */

/* Retrives a ConnContext by filedescriptor.
Returns ConnContext* if found, NULL otherwise */
ConnContext*
get_connection_context(int fd) {
	if( fd < 0 ) {
		carp("Invalid File Descriptor");
		return NULL;
	}
	if( ! cc_list )
		return NULL;

	ConnContext* tmp;
	for( tmp = cc_list; tmp; tmp = tmp->next ) {
		if( tmp->qsock == fd )
			return tmp;
	}
	return NULL;
} /* -- end get_connection_context() -- */

/* Deletes an entry from the list of 'ConnContext's; free's
its resources.  0 on success, -1 otherwise. */
int
unref_connection_context(ConnContext* cc) {
	if( ! cc ) {
		carp("Invalid ConnContext");
		return -1;
	}

	dreport("Unreffing ConnContext @ %p", cc);
	ConnContext* tmp, * last;
	last = NULL;
	tmp = cc_list;
	while( tmp && (tmp != cc) ) {
		last = tmp;
		tmp = tmp->next;
		dreport("last %p; next %p", last, tmp->next);
	}

	if(! tmp ) {
		carp("ConnContext (%p) not found!", cc);
		return -1;
	}

	if(last)
		last->next = tmp->next;
	else
		cc_list = tmp->next;
	free_connection_context(tmp);
	--current_connections;

	return 0;
} /* -- end unref_connection_context() -- */

/* Reads from a socket filling it's ConnContext.
To read OOB data, pass "TRUE" for oob.  Returns
0 on success, -1 on any error. An error condition will unref
the ConnContext -- do not use it! */
int
read_query(ConnContext* cc, uint oob) {
	if( ! cc ) {
		carp("Invalid ConnContext");
		return -1;
	}

	dreport("Reading from socket %d (cc %p) (OOB == %d)", cc->qsock, cc, oob);
	char buf[MSG_SIZE];
	memset(buf, 0, sizeof buf);
	int rflags = 0;
	if( oob )
		rflags = MSG_OOB;
	ssize_t bytes = recv(cc->qsock, (void*)buf, MSG_SIZE - 1, rflags);
	if( bytes < 0 ) {
		if( errno == EINTR ) {
			dreport("Signal %s caught");
			return 0;
		}
		if( errno == EPIPE ) {
			dreport("EPIPE: connection teminated");
			unref_connection_context(cc);
			return 0;
		}
		choke("Socket error");
		dreport("(cc: %p) socket: %d", cc, cc->qsock);
		unref_connection_context(cc);
		return -1;
	} else if( bytes == 0 ) {
		dreport("No bytes read");
		return 0;
	}
	else {
		dreport("%d bytes read; buf_len == %d", bytes, bytes + cc->query_len);
		if(bytes + cc->query_len > MSG_SIZE - 2) {
			carp("Hostile intent");
			unref_connection_context(cc);
			return -1;
		}
		strcat((char*) &(cc->query), buf);
		cc->query_len = strlen(cc->query);
	}

	/* We should be a little more explicit concerning the
	terminating character, but i don't think any
	ident implementation is quite so unforgiving */
	if( *(cc->query + cc->query_len - 1) == '\n' ) {
		dreport("read finished, \\n found at index %d", cc->query_len - 1);
		cc->finished_read = TRUE;
	}
	return 0;
} /* -- end read_query() -- */

/* Writes to a socket from it's ConnContext.
Returns 0 on success, -1 on error.  An error condition
will unref the ConnContext, do not use it! */
int
write_response(ConnContext* cc) {
	if( ! cc ) {
		carp("Invalid ConnContext");
		return -1;
	}

	dreport("Writing to socket %d (cc %p)", cc->qsock, cc);
	dreport("response length: %d, bytes written: %d", cc->response_len, cc->written_len);
	ssize_t bytes = write(cc->qsock, cc->response + cc->written_len,
			cc->response_len - cc->written_len);
	if( bytes < 0 ) {
		if( errno == EINTR ) {
			dreport("Signal %d caught");
			return 0;
		}
		if( errno == EPIPE ) {
			dreport("EPIPE: connection terminated");
			unref_connection_context(cc);
			return 0;
		}

		choke("Socket Error");
		unref_connection_context(cc);
		return -1;
	} else if( bytes == 0) {
		dreport("No bytes written");
		return 0;
	} else {
		cc->written_len += bytes;
		dreport("%d bytes written; %d total bytes written", bytes, cc->written_len);
	}

	if( cc->written_len == cc->response_len ) {
		dreport("Finished write");
		cc->finished_write = TRUE;
	}
	return 0;
} /* -- end write_response() -- */

/* Sets the response in a ConnContext, truncates to 512 bytes.
The answer must have room for the port pair, three spaces, a
comma, a colon, and a CR && LF, because the response must be
less than 512 bytes.  Note that the answer string must not
contain the CRLF pair.  Returns 0 on success, -1 on error. */
int
set_response(ConnContext* cc, char* answer) {
	if( !cc ) {
		carp("Invalid ConnContext");
		return -1;
	}
	if( is_nil(answer) ) {
		carp("Invalid answer");
		return -1;
	}

	dreport("Setting response");
	cc->answer = answer;
	char* buf = strprintf("%d , %d : %s\r\n",
			cc->lport, cc->rport, answer);
	if( ! buf ) {
		choke("Could not make room for response");
		return -1;
	}
	/* Truncate to 512 bytes */
	if(strlen(buf) > MSG_SIZE - 2) {
		buf[MSG_SIZE - 4] = '\r';
		buf[MSG_SIZE - 3] = '\n';
		buf[MSG_SIZE - 2] = 0;
		dreport("response truncated");
	}
	cc->response_len = strlen(buf);
	memcpy(&(cc->response), buf, cc->response_len);

	free(buf);
	return 0;
} /* -- end set_response() -- */

/******
 * The socket loop checks all referenced sockets
 * for data available for read, write, or exceptions.
 ******/

/* Initializes the socket loop */
int
sockloop_prepare( void ) {
	FD_ZERO(&rfds);
	FD_ZERO(&wfds);
	FD_ZERO(&efds);
	max_fd = 0;

	clean_cc_list();
	ConnContext* cc = cc_list;

	while(cc) {
		if( cc->finished_read )
			FD_SET(cc->qsock, &wfds);
		else {
			FD_SET(cc->qsock, &rfds);
			/* URG data in an ident query? */
			FD_SET(cc->qsock, &efds);
		}
		max_fd = max(max_fd, cc->qsock);
		cc = cc->next;
	}

	if( current_connections < server_max_queries ) {
		FD_SET(server_auth_socket, &rfds);
		max_fd = max(server_auth_socket, max_fd);
	}

	return 0;
} /* -- end sockloop_prepare() -- */

/* Waits for a socket event.  Returns the number of socket
events (and errno will be EINTR if a signal was caught),
-1 on error. */
int
sockloop_select( void ) {
	/* This is for when we implement a cache_timeout */
	struct timeval tv;

	errno = 0;
	int retval = select(max_fd + 1, &rfds, &wfds, &efds, NULL);
	if( (retval < 0) && (errno != EINTR) ) {
		choke("select error");
		return -1;
	} else if( retval == 0 ) {
		/* TODO: clean the cache */
	}

	return retval;
} /* -- end sockloop_select() -- */

/* Performs socket actions on each socket.
0 on success, -1 on error.  If anything was interrupted by a signal,
we return with 0 and errno == EINTR */
int
sockloop_perform( void ) {
	if( cc_list ) {
		/* This is so ugly -- we have to iterate thru cc_list three times! */
		ConnContext* cc;
		int k;

		/* We first check all exception'd fd's */
		cc = cc_list;
		while( cc ) {
			if( FD_ISSET(cc->qsock, &efds) ) {
				if( read_query(cc, TRUE) )
					continue;
				if( errno == EINTR )
					return 0;
			}
			cc = cc->next;
		}

		/* next, we check writes */
		cc = cc_list;
		while( cc ) {
			ConnContext* next = cc->next;
			if( FD_ISSET(cc->qsock, &wfds) ) {
				if( write_response(cc) )
					continue;
				if( errno == EINTR )
					return 0;
				if( cc->finished_write ) {
					unref_connection_context(cc);
				}
			}
			cc = next;
		}

		/* And now we check for reads */
		cc = cc_list;
		while( cc ) {
			ConnContext* next = cc->next;
			if( FD_ISSET(cc->qsock, &rfds) ) {
				if( read_query(cc, FALSE) )
					continue;
				if( errno == EINTR )
					return 0;
			}
			cc = next;
		}
	} /* end checking cc_list */

	/* Finally we check for new connections */
	if( FD_ISSET(server_auth_socket, &rfds) ) {
		dreport("New connection reported");
		struct sockaddr_in saddr;
		socklen_t saddr_len = sizeof(saddr);
		int newcon = accept(server_auth_socket, (struct sockaddr*) &saddr, &saddr_len);
		if( newcon < 0 ) {
			choke("Error accepting new connections");
			return -1;
		}
		if( ref_socket(newcon, ntohl(saddr.sin_addr.s_addr)) ) {
			carp("Could not ref new connection");
			return -1;
		}
	}

	return 0;
} /* -- end sockloop_perform() -- */

/* Performs one iteration thru the socket loop; returns 0
success (or if EINTR), -1 on error */
int
sockloop_iter( void ) {
	dreport("-");
	if( sockloop_prepare() < 0 )
		return -1;
	if( errno == EINTR )
		return 0;

	if( sockloop_select() < 0 )
		return -1;
	if( errno == EINTR )
		return 0;

	if( sockloop_perform() < 0 )
		return -1;
	if( errno == EINTR )
		return 0;

	return 0;
} /* -- end sockloop_iter() -- */

/* Continuously loops thru the socket loop; returns
0 on success, -1 on error, or if an unhandled signal was caught. */
int
sockloop_do( void ) {
	while(TRUE) {
		if( sockloop_iter() )
			return -1;
		/* So far all of our signals are handled;
		if we needed to check for a global flag and
		act accordingly on a signal, this is where
		to put the code */
		if(errno == EINTR)
			;

		/* If a socket has finished writing, we can close it;
		if a socket is finished reading, we need to process
		the query */
		ConnContext* cc = cc_list;
		while( cc ) {
			dreport("Updating ConnContext list");
			/* This might just go inside write_response() */
			if( cc->finished_write ) {
				ConnContext* tmp = cc->next;
				unref_connection_context(cc);
				cc = tmp;
				continue;
			}

			if(cc->finished_read) {
				char* answer = get_answer_string(
						cc->query, cc->rhost,
						&(cc->rport), &(cc->lport)
				);
				if( ! answer ) {
					carp("Failed to get answer");
					ConnContext* tmp = cc->next;
					unref_connection_context(cc);
					cc = tmp;
					continue;
				}
				if( set_response(cc, answer) ) {
					carp("Failed to set response");
					ConnContext* tmp = cc->next;
					unref_connection_context(cc);
					cc = tmp;
					continue;
				}
			}
		cc = cc->next;
		}
	} /* do while TRUE */

	/* never reached, just to shut up the compiler */
	return 0;
} /* -- end sockloop_do() -- */


/* -- end networking.c -- */
