/***************************************************************************
 *   Copyright (C) 2004 by Trevor "beltorak" Torrez                        *
 *   beltorak@phreaker.net                                                 *
 *   networking.c part of cidentd version 0.2                               *
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 *   This program is distributed in the hope that it will be useful,       *
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of        *
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the         *
 *   GNU General Public License for more details.                          *
 *                                                                         *
 *   You should have received a copy of the GNU General Public License     *
 *   along with this program; if not, write to the                         *
 *   Free Software Foundation, Inc.,                                       *
 *   59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.             *
 ***************************************************************************/

/* This poorly named file handles the base networking layer */

#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

#include <stdlib.h>
#include <stdio.h>
#include <netdb.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <unistd.h>
#include <sys/time.h>
#include <sys/select.h>
#include <string.h>
#include <signal.h>
#include <arpa/inet.h>
#include <errno.h>

#include <networking.h>

#include <macros.h>
#include <state.h>
#include <strtotype.h>
#include <logutils.h>
#include <response.h>
#include <serverconf.h>
#include <reutils.h>

/* This is here till i get the proper macro included */
#undef max
#define max(a,b) ((a>b) ? (a) : (b) )

static ConnContext* cc_list;
static uint current_connections;
static int max_fd;			/* Needed for select() */
static fd_set rfds, wfds, efds;		/* read, write, and except file descriptor sets */

/* Sets the authport.  0 on success, -1 on error. */
int
set_authport(char* port) {
	if( server.is_modifed.authport && (!state.cmdline) )
        	return 0;

        uint authport;
        if( strtoport(&authport, port) )
                return -1;

        if( authport != server.authport ) {
		server.authport = authport;
        	server.is_modifed.authport = TRUE;
	}
        return 0;
} /* -- end set_authport() -- */

/* Prints the current value of authport.  Prefixes with
'*' if state.show_settings and it is the default value */
void
print_authport( void ) {
	char* service;
        struct servent* sv = getservbyport(htons(server.authport), NULL);
        if( sv )
        	service = sv->s_name;
        else
        	service = "no name in /etc/services";

        print_default_marker_if(!server.is_modifed.authport);
        println("authport == %u (%s)", server.authport, service);
} /* -- end print_authport() -- */

/* Sets the number of maximum connections... note that
there is a hard limit based on the number of open file
descriptors, which is 255.  This function rejects anything
over 250 */
int
set_max_queries( char* max_queries ) {
	ulong q;
	if( strtoulong(&q, max_queries) )
		return -1;

	if( q > 250 ) {
		carp("Cannot set max queries to more than 250");
		return -1;
	}

	if( q != server.max_queries ) {
		server.max_queries = (uint) q;
		server.is_modifed.max_queries = TRUE;
	}

	return 0;
} /* -- end set_max_queries() -- */

/* Displays the maximum number of queries */
void
print_max_queries( void ) {
	print_default_marker_if(!server.is_modifed.max_queries);
	println("Maximum Concurrent Queries: %d", server.max_queries);
} /* -- end print_max_queries() -- */

/* binds and listens to the required port */
int
start_networking( void ) {
        server.auth_socket = socket(AF_INET, SOCK_STREAM, 0);
        if( server.auth_socket < 0 ) {
        	choke("Could not open socket");
                return -1;
        }

        int retval;

        struct sockaddr_in authaddr = {
        	AF_INET,
                htons(server.authport),
                INADDR_ANY
        };

        proc_to_prime_privs();
        retval = bind(server.auth_socket, &authaddr, sizeof authaddr);
        suspend_proc_privs();
        if( retval ) {
        	choke("Could not bind to auth port %d", server.authport);
                return -1;
        }

        if( listen(server.auth_socket, 1) ) {
        	choke("Could not listen to auth socket");
                return -1;
        }

	/* Initailize our connection contexts */
	cc_list = NULL;
        return 0;
} /* -- end start_networking() -- */

/* free's a connection context and its
associated resources */
void
free_connection_context( ConnContext* cc ) {
	if(!cc)
		return;
	kill_connection(cc);
	if_notnull_free(cc->answer)
	free(cc);
} /* -- end free_connection_context() -- */

/* shutdown and close a socket */
void
kill_connection(ConnContext* cc) {
	if( cc->qsock < 0 )
		return;
	dreport("Killing socket %d (cc %p)", cc->qsock, cc);
	shutdown(cc->qsock, SHUT_RDWR);
	close(cc->qsock);
	cc->qsock = -1;
} /* -- end kill_connection() -- */

/* Initializes a ConnContext */
void
clear_connection_context(ConnContext* cc) {
	memset(cc, 0, sizeof(ConnContext));
	cc->next = NULL;
	cc->answer = NULL;
	cc->qsock = -1;
} /* -- end clear_connection_context() -- */

/* Creates and returns a cleared ConnContext */
ConnContext*
new_connection_context( void ) {
	ConnContext* cc = malloc(sizeof(ConnContext));
	if( ! cc )
		choke("Could not create new ConnContext");
	else
		clear_connection_context(cc);

	return cc;
} /* -- end new_connection_context() -- */

/* Clears the list of dead sockets */
void
clean_cc_list( void ) {
	ConnContext* cc = cc_list;
	while( cc ) {
		if( cc->qsock == -1 )
			unref_connection_context(&cc);
		else
			cc = cc->next;
	}
	return;
} /* -- end clean_cc_list() -- */

/* Adds a new socket to the ConnContext list.
Returns 0 if successful, -1 on error. */
int
ref_socket(int fd, ulong rhost) {
	if( fd < 0 ) {
		carp("Invalid socket");
		return -1;
	}

	ConnContext* tmp = get_connection_context(fd);
	if( tmp ) {
		dreport("duplicate socket ref %d @ %p", fd, tmp);
		return 0;
	}

	ConnContext *cc = new_connection_context();
	if( ! cc )
		return -1;

	dreport("Adding socket reference %d (cc %p)", fd, cc);
	cc->qsock = fd;
	cc->rhost = rhost;
	cc->next = cc_list;
	cc_list = cc;

	++current_connections;
	return 0;
} /* -- end ref_socket() -- */

/* Retrives a ConnContext by filedescriptor.
Returns ConnContext* if found, NULL otherwise */
ConnContext*
get_connection_context(int fd) {
	if( fd < 0 ) {
		carp("Invalid File Descriptor");
		return NULL;
	}
	if( ! cc_list )
		return NULL;

	dreport("Seeking ConnContext with fd %d", fd);
	ConnContext* cc;
	for( cc = cc_list; cc; cc = cc->next ) {
		if( cc->qsock == fd )
			return cc;
	}
	return NULL;
} /* -- end get_connection_context() -- */

/* Deletes an entry from the list of 'ConnContext's; free's
its resources.  Sets the passed ptr to the next ConnContext
in the list.  Returns 0 on success, -1 otherwise. */
int
unref_connection_context(ConnContext** cc_p) {
	if( ! *cc_p ) {
		carp("Invalid ConnContext");
		return -1;
	}

	ConnContext* cc = *cc_p;
	dreport("Unreffing ConnContext @ %p", cc);
	*cc_p = NULL;
	ConnContext* curr = cc_list;
	ConnContext* last = NULL;
	while( curr && (curr != cc) ) {
		last = curr;
		curr = curr->next;
	}

	if( ! curr ) {
		carp("ConnContext (%p) not found!", cc);
		return -1;
	}

	*cc_p = curr->next;
	if(last)
		last->next = curr->next;
	else
		cc_list = curr->next;
	free_connection_context(curr);
	--current_connections;

	return 0;
} /* -- end unref_connection_context() -- */

/* Reads from a socket filling it's ConnContext.
On socket error, the socket is killed.  */
void
read_query(ConnContext* cc) {
	if( ! (cc && (cc->qsock >= 0) ) ) {
		carp("Invalid ConnContext");
		return;
	}

	char buf[MSG_SIZE];
	memset(buf, 0, sizeof buf);
	ssize_t bytes = recv(cc->qsock, (void*)buf, MSG_SIZE - 1, 0);
	if( bytes < 0 ) {
		if( errno == EINTR )
			dreport("Signal %s caught");
		else {
			if( errno == EPIPE )
				dreport("EPIPE: connection teminated");
			else
				choke("Socket error");
			kill_connection(cc);
		}
		return;
	} else if( bytes == 0 ) {
		dreport("No bytes read");
		kill_connection(cc);
		return;
	}
	else {
		dreport("%d bytes read; total bytes read == %d", bytes, bytes + cc->query_len);
		if( strlen(buf) < bytes ) {
			dreport("Emedded '\\0'");
			kill_connection(cc);
			return;
		}
		if(bytes + cc->query_len > MSG_SIZE - 2) {
			dreport("Hostile intent");
			kill_connection(cc);
			return;
		}
		strcat((char*) &(cc->query), buf);
		cc->query_len = strlen(cc->query);
	}

	/* Check for a CR or LF */
	size_t end = strcspn(cc->query, "\n\r");
	if( *(cc->query + end) ) {
		/* if there was data following, it is an error. */
		end += strspn(cc->query + end, "\r\n");
		if( *(cc->query + end) ) {
			kill_connection(cc);
			return;
		}
		cc->finished_read = TRUE;
	}

	return;
} /* -- end read_query() -- */

/* Writes to a socket from it's ConnContext.
On error, the socket is closed and set to -1.  */
void
write_response(ConnContext* cc) {
	if( ! (cc && (cc->qsock >= 0) ) ) {
		carp("Invalid ConnContext");
		return;
	}

	ssize_t bytes = write(cc->qsock, cc->response + cc->written_len,
			cc->response_len - cc->written_len);
	if( bytes < 0 ) {
		if( errno == EINTR )
			dreport("Signal %d caught");
		else {
			if( errno == EPIPE )
				dreport("EPIPE: connection terminated");
			else
				choke("Socket Error");
			kill_connection(cc);
		}
		return;
	} else if( bytes == 0) {
		dreport("No bytes written");
		kill_connection(cc);
		return;
	} else {
		cc->written_len += bytes;
		dreport("%d bytes written; %d total bytes written", bytes, cc->written_len);
	}

	if( cc->written_len == cc->response_len ) {
		dreport("Finished write");
		kill_connection(cc);
	}
	return;
} /* -- end write_response() -- */

/* Sets the response in a ConnContext, truncates to 512 bytes.
The answer must have room for the port pair, three spaces, a
comma, a colon, and a CR && LF, because the response must be
less than 512 bytes.  Note that the answer string must not
contain the CRLF pair.  Returns 0 on success, -1 on error. */
int
set_response(ConnContext* cc, char* answer) {
	if( !cc ) {
		carp("Invalid ConnContext");
		return -1;
	}
	if( is_nil(answer) ) {
		carp("Invalid answer");
		return -1;
	}

	dreport("Setting response");
	cc->answer = answer;
	char* buf = strprintf("%d , %d : %s\r\n",
			cc->lport, cc->rport, answer);
	if( ! buf ) {
		choke("Could not make room for response");
		return -1;
	}
	/* Truncate to 512 bytes */
	if(strlen(buf) > MSG_SIZE - 2) {
		buf[MSG_SIZE - 4] = '\r';
		buf[MSG_SIZE - 3] = '\n';
		buf[MSG_SIZE - 2] = 0;
		dreport("response truncated");
	}
	cc->response_len = strlen(buf);
	memcpy(&(cc->response), buf, cc->response_len);

	free(buf);
	return 0;
} /* -- end set_response() -- */

/* Gets the socket owner given the remote host and the
port pair.  Returns 0 on success, a c_error_t on error. */
c_error_t
get_sock_owner(uint lport, ulong rhost, uint rport, uid_t* owner) {
	dreport("Seeking: rhost: %.8x; rport: %d; lport: %d", rhost, rport, lport);
	char* line = NULL;
	size_t len = 0;
	FILE* FP = fopen("/proc/net/tcp", "r");
	if( ! FP ) {
		choke("Could not open /proc/net/tcp");
		return ET_UNKNOWN;
	}
	if( getline(&line, &len, FP) < 0 ) {
		choke("Could not read /proc/net/tcp");
		fclose(FP);
		if_notnull_free(line);
		return ET_UNKNOWN;
	}
	/* discard the header */
	if_notnull_free(line);

	struct procnettcp_info tcpi;
	while(getline(&line, &len, FP) > 0) {
		chomp(line);
		if( is_wspace_only(line) ) {
			if_notnull_free(line);
			continue;
		}
		if( get_sockinfo(&tcpi, line) ) {
			carp("get_sockinfo() failed");
			if_notnull_free(line);
			fclose(FP);
			return ET_UNKNOWN;
		}
		if_notnull_free(line);

		if( tcpi.state != 1 )
			continue;

		if(
		    (tcpi.lport == lport) &&
		    (tcpi.rport == rport) &&
		    (tcpi.rhost == rhost)
		) {
			fclose(FP);
			*owner = tcpi.owner;
			return 0;
		}
	}

	dreport("No user");
	fclose(FP);
	return ET_NOUSER;
} /* -- end get_sock_owner() -- */

/* Prepares a ConnContext for writing. Returns 0 on
success; -1 on error.  */
int
prepare_for_write(ConnContext* cc) {
	dreport("Preparing socket %d (cc %p) for write", cc->qsock, cc);
	/* decode the query */
	chomp(cc->query);
	if( is_wspace_only(cc->query) ) {
		dreport("Query was blank");
		return -1;
	}

	if( extract_port_pair(&(cc->lport), &(cc->rport), cc->query) )
		return -1;

	uid_t owner;
	char* answer;
	c_error_t reterr = get_sock_owner(cc->lport, cc->rhost, cc->rport, &owner);
	/* If error, return the appropriate string */
	if( reterr ) {
		dreport("Getting default answer");
		answer = get_default_answer_string(reterr);
	} else
		answer = get_answer_string(owner);

	/* Note that this duplicates some of the functionality
	of get_answer_string() -- but if the uid is unnamed we
	need to inform the sysadmin of the fact; it may be an
	indication of a security breach */
	char* username;
	if( get_usernam_r(&username, owner) )
		carp("Could not retrieve username for uid %d", owner);
	else if( ! username ) {
		carp("Warning! Socket owner has no username!");
		struct in_addr remote_host;
		remote_host.s_addr = htonl(cc->rhost);
		carp("uid: %d, local port: %d, remote-host: %s:%d",
				owner, cc->lport, inet_ntoa(remote_host), cc->rport);
	}

	if( ! answer )
		return -1;

	if( set_response(cc, answer) )
		return -1;

	return 0;
} /* -- end prepare_for_write() -- */

/******
 * The socket loop checks all referenced sockets
 * for data available for read or write. OOB data
 * is put into the normal read stream.
 ******/

/* Initializes the socket loop */
int
sockloop_prepare( void ) {
	if( last_clean.tv_sec == 0 )
		last_clean.tv_sec = time(NULL);

	FD_ZERO(&rfds);
	FD_ZERO(&wfds);
	FD_ZERO(&efds);
	max_fd = 0;

	clean_cc_list();
	ConnContext* cc = cc_list;

	while(cc) {
		if( cc->finished_read ) {
			/* If a socket has finished reading, prepare it for writing */
			if( prepare_for_write(cc) )
				kill_connection(cc);
			else
				FD_SET(cc->qsock, &wfds);
		} else {
			FD_SET(cc->qsock, &rfds);
		}
		max_fd = max(max_fd, cc->qsock);
		cc = cc->next;
	}
	clean_cc_list();

	if( current_connections < server.max_queries ) {
		FD_SET(server.auth_socket, &rfds);
		max_fd = max(server.auth_socket, max_fd);
	}

	return 0;
} /* -- end sockloop_prepare() -- */

/* Waits for a socket event.  Returns 0 if the timeout was reached,
or the number of socket events (and errno will be EINTR if
a signal was caught), or -1 on error. */
int
sockloop_select( void ) {
	errno = 0;
	struct timeval timeout = {0,0};
	timeout.tv_sec = last_clean.tv_sec + server.cache_age;

	int retval = select(max_fd + 1, &rfds, &wfds, &efds, &timeout);
	if( (retval < 0) && (errno != EINTR) ) {
		choke("select error");
		return -1;
	}
	return retval;
} /* -- end sockloop_select() -- */

/* Performs socket actions on each socket.
0 on success, -1 on error.  If anything was interrupted by a signal,
we return with 0 and errno == EINTR */
int
sockloop_perform( void ) {
	if( cc_list ) {
		/* This is so ugly -- we have to iterate thru cc_list three times! */
		ConnContext* cc;

		dreport("Checking reads");
		/* And now we check for reads */
		cc = cc_list;
		while( cc ) {
			if( (cc->qsock >= 0) && FD_ISSET(cc->qsock, &rfds) ) {
				read_query(cc);
				if( errno == EINTR )
					return 0;
			}
			cc = cc->next;
		}

		dreport("Checking writes");
		/* next, we check writes */
		cc = cc_list;
		while( cc ) {
			if( (cc->qsock >= 0) && FD_ISSET(cc->qsock, &wfds) ) {
				write_response(cc);
				if( errno == EINTR )
					return 0;
			}
			cc = cc->next;
		}
	} /* end checking cc_list */
	clean_cc_list();

	/* Finally we check for new connections */
	if( FD_ISSET(server.auth_socket, &rfds) ) {
		dreport("New connection reported");
		struct sockaddr_in saddr;
		socklen_t saddr_len = sizeof(saddr);
		int newcon = accept(server.auth_socket, (struct sockaddr*) &saddr, &saddr_len);
		if( newcon < 0 ) {
			choke("Error accepting new connections");
			return -1;
		}
		if( ref_socket(newcon, ntohl(saddr.sin_addr.s_addr)) ) {
			carp("Could not ref new connection");
			return -1;
		}
		/* Place all OOB data into the normal stream */
		int yes = TRUE;
		if( setsockopt(newcon, SOL_TCP, SO_OOBINLINE, &yes, sizeof yes) ) {
			choke("Could not set OOBINLINE socket option");
			exit(0);
		}
	}

	return 0;
} /* -- end sockloop_perform() -- */

/* Performs one iteration thru the socket loop; returns 0
success (or if EINTR), -1 on error */
int
sockloop_iter( void ) {
	dreport("-");

	int retval;

	if( sockloop_prepare() < 0 )
		return -1;

	retval = sockloop_select();
	if( retval == 0 ) {
		clean_cache();
		return 0;
	}
	if( retval < 0 )
		return -1;
	if( errno == EINTR )
		return 0;

	if( sockloop_perform() < 0 )
		return -1;

	return 0;
} /* -- end sockloop_iter() -- */

/* Continuously loops thru the socket loop; returns
0 on success, -1 on error, or if an unhandled signal was caught. */
int
sockloop_do( void ) {
	while(TRUE) {
		if( sockloop_iter() )
			return -1;
		/* So far all of our signals are handled;
		if we needed to check for a global flag and
		act accordingly on a signal, this is where
		to put the code */
		if(errno == EINTR)
			;
	}

	/* never reached, just to shut up the compiler */
	return 0;
} /* -- end sockloop_do() -- */


/* -- end networking.c -- */
