/*
 * $Id$
 *
 * allow_trusted related functions
 *
 * Copyright (C) 2003 Juha Heinanen
 *
 * This file is part of Kamailio, a free SIP server.
 *
 * Kamailio is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version
 *
 * Kamailio is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License 
 * along with this program; if not, write to the Free Software 
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 *
 * History:
 * --------
 *  2004-06-07  updated to the new DB api, moved reload_trusted_table (andrei)
 */

#include <sys/types.h>
#include <regex.h>
#include <string.h>

#include "permissions.h"
#include "hash.h"
#include "../../config.h"
#include "../../lib/srdb1/db.h"
#include "../../ip_addr.h"
#include "../../mem/shm_mem.h"
#include "../../parser/msg_parser.h"
#include "../../parser/parse_from.h"
#include "../../usr_avp.h"

#define TABLE_VERSION 5

struct trusted_list ***hash_table;     /* Pointer to current hash table pointer */
struct trusted_list **hash_table_1;   /* Pointer to hash table 1 */
struct trusted_list **hash_table_2;   /* Pointer to hash table 2 */


static db1_con_t* db_handle = 0;
static db_func_t perm_dbf;


/*
 * Reload trusted table to new hash table and when done, make new hash table
 * current one.
 */
int reload_trusted_table(void)
{
	db_key_t cols[4];
	db1_res_t* res = NULL;
	db_row_t* row;
	db_val_t* val;

	struct trusted_list **new_hash_table;
	int i;

	char *pattern, *tag;

	cols[0] = &source_col;
	cols[1] = &proto_col;
	cols[2] = &from_col;
	cols[3] = &tag_col;

	if (db_handle == 0) {
	    LM_ERR("no connection to database\n");
	    return -1;
	}

	if (perm_dbf.use_table(db_handle, &trusted_table) < 0) {
		LM_ERR("failed to use trusted table\n");
		return -1;
	}

	if (perm_dbf.query(db_handle, NULL, 0, NULL, cols, 0, 4, 0, &res) < 0) {
		LM_ERR("failed to query database\n");
		return -1;
	}

	/* Choose new hash table and free its old contents */
	if (*hash_table == hash_table_1) {
		empty_hash_table(hash_table_2);
		new_hash_table = hash_table_2;
	} else {
		empty_hash_table(hash_table_1);
		new_hash_table = hash_table_1;
	}

	row = RES_ROWS(res);

	LM_DBG("number of rows in trusted table: %d\n", RES_ROW_N(res));
		
	for (i = 0; i < RES_ROW_N(res); i++) {
	    val = ROW_VALUES(row + i);
	    if ((ROW_N(row + i) == 4) &&
		(VAL_TYPE(val) == DB1_STRING) && !VAL_NULL(val) &&
		(VAL_TYPE(val + 1) == DB1_STRING) && !VAL_NULL(val + 1) &&
		(VAL_NULL(val + 2) ||
		 ((VAL_TYPE(val + 2) == DB1_STRING) && !VAL_NULL(val + 2))) &&
		(VAL_NULL(val + 3) ||
		 ((VAL_TYPE(val + 3) == DB1_STRING) && !VAL_NULL(val + 3)))) {
		if (VAL_NULL(val + 2)) {
		    pattern = 0;
		} else {
		    pattern = (char *)VAL_STRING(val + 2);
		}
		if (VAL_NULL(val + 3)) {
		    tag = 0;
		} else {
		    tag = (char *)VAL_STRING(val + 3);
		}
		if (hash_table_insert(new_hash_table,
				      (char *)VAL_STRING(val),
				      (char *)VAL_STRING(val + 1),
				      pattern, tag) == -1) {
		    LM_ERR("hash table problem\n");
		    perm_dbf.free_result(db_handle, res);
		    return -1;
		}
		LM_DBG("tuple <%s, %s, %s, %s> inserted into trusted hash "
		    "table\n", VAL_STRING(val), VAL_STRING(val + 1),
		    pattern, tag);
	    } else {
		LM_ERR("database problem\n");
		perm_dbf.free_result(db_handle, res);
		return -1;
	    }
	}

	perm_dbf.free_result(db_handle, res);

	*hash_table = new_hash_table;

	LM_DBG("trusted table reloaded successfully.\n");
	
	return 1;
}


/*
 * Initialize data structures
 */
int init_trusted(void)
{
	/* Check if hash table needs to be loaded from trusted table */
	if (!db_url.s) {
		LM_INFO("db_url parameter of permissions module not set, "
			"disabling allow_trusted\n");
		return 0;
	} else {
		if (db_bind_mod(&db_url, &perm_dbf) < 0) {
			LM_ERR("load a database support module\n");
			return -1;
		}

		if (!DB_CAPABILITY(perm_dbf, DB_CAP_QUERY)) {
			LM_ERR("database module does not implement 'query' function\n");
			return -1;
		}
	}

	hash_table_1 = hash_table_2 = 0;
	hash_table = 0;

	if (db_mode == ENABLE_CACHE) {
		db_handle = perm_dbf.init(&db_url);
		if (!db_handle) {
			LM_ERR("unable to connect database\n");
			return -1;
		}

		if(db_check_table_version(&perm_dbf, db_handle, &trusted_table, TABLE_VERSION) < 0) {
			LM_ERR("error during table version check.\n");
			perm_dbf.close(db_handle);
			return -1;
		}

		hash_table_1 = new_hash_table();
		if (!hash_table_1) return -1;
		
		hash_table_2  = new_hash_table();
		if (!hash_table_2) goto error;
		
		hash_table = (struct trusted_list ***)shm_malloc
			(sizeof(struct trusted_list **));
		if (!hash_table) goto error;

		*hash_table = hash_table_1;

		if (reload_trusted_table() == -1) {
			LM_CRIT("reload of trusted table failed\n");
			goto error;
		}

		perm_dbf.close(db_handle);
		db_handle = 0;
	}
	return 0;

error:
	if (hash_table_1) {
		free_hash_table(hash_table_1);
		hash_table_1 = 0;
	}
	if (hash_table_2) {
		free_hash_table(hash_table_2);
		hash_table_2 = 0;
	}
	if (hash_table) {
		shm_free(hash_table);
		hash_table = 0;
	}
	perm_dbf.close(db_handle);
	db_handle = 0;
	return -1;
}


/*
 * Open database connections if necessary
 */
int init_child_trusted(int rank)
{
    	if ((rank <= 0) && (rank != PROC_RPC) && (rank != PROC_UNIXSOCK))
		return 0;

	if (!db_url.s) {
		return 0;
	}
	
	db_handle = perm_dbf.init(&db_url);
	if (!db_handle) {
	    LM_ERR("unable to connect database\n");
	    return -1;
	}

	if (db_check_table_version(&perm_dbf, db_handle, &trusted_table,
				   TABLE_VERSION) < 0) {
	    LM_ERR("error during table version check.\n");
	    perm_dbf.close(db_handle);
	    return -1;
	}

	return 0;
}


/*
 * Open database connection if necessary
 */
int mi_init_trusted(void)
{
    if (!db_url.s || db_handle) return 0;
    db_handle = perm_dbf.init(&db_url);
    if (!db_handle) {
	LM_ERR("unable to connect database\n");
	return -1;
    }
    return 0;
}


/*
 * Close connections and release memory
 */
void clean_trusted(void)
{
	if (hash_table_1) free_hash_table(hash_table_1);
	if (hash_table_2) free_hash_table(hash_table_2);
	if (hash_table) shm_free(hash_table);
}


/*
 * Matches protocol string against the protocol of the request.  Returns 1 on
 * success and 0 on failure.
 */
static inline int match_proto(const char *proto_string, int proto_int)
{
	if (strcasecmp(proto_string, "any") == 0) return 1;
	
	if (proto_int == PROTO_UDP) {
		if (strcasecmp(proto_string, "udp") == 0) {
			return 1;
		} else {
			return 0;
		}
	}
	
	if (proto_int == PROTO_TCP) {
		if (strcasecmp(proto_string, "tcp") == 0) {
			return 1;
		} else {
			return 0;
		}
	}
	
	if (proto_int == PROTO_TLS) {
		if (strcasecmp(proto_string, "tls") == 0) {
			return 1;
		} else {
			return 0;
		}
	}
	
	if (proto_int == PROTO_SCTP) {
		if (strcasecmp(proto_string, "sctp") == 0) {
			return 1;
		} else {
			return 0;
		}
	}

	LM_ERR("unknown request protocol\n");

	return 0;
}

/*
 * Matches from uri against patterns returned from database.  Returns 1 when
 * first pattern matches and 0 if none of the patterns match.
 */
static int match_res(struct sip_msg* msg, int proto, db1_res_t* _r)
{
        int i, tag_avp_type;
	str uri;
	char uri_string[MAX_URI_SIZE+1];
	db_row_t* row;
	db_val_t* val;
	regex_t preg;
	int_str tag_avp, avp_val;

	if (parse_from_header(msg) < 0) return -1;
	uri = get_from(msg)->uri;
	if (uri.len > MAX_URI_SIZE) {
		LM_ERR("message has From URI too large\n");
		return -1;
	}
	memcpy(uri_string, uri.s, uri.len);
	uri_string[uri.len] = (char)0;

	row = RES_ROWS(_r);
		
	for(i = 0; i < RES_ROW_N(_r); i++) {
	    val = ROW_VALUES(row + i);
	    if ((ROW_N(row + i) == 3) &&
		(VAL_TYPE(val) == DB1_STRING) && !VAL_NULL(val) &&
		match_proto(VAL_STRING(val), proto) &&
		(VAL_NULL(val + 1) ||
		 ((VAL_TYPE(val + 1) == DB1_STRING) && !VAL_NULL(val + 1))) &&
		(VAL_NULL(val + 2) ||
		 ((VAL_TYPE(val + 2) == DB1_STRING) && !VAL_NULL(val + 2))))
	    {
		if (VAL_NULL(val + 1)) goto found;
		if (regcomp(&preg, (char *)VAL_STRING(val + 1), REG_NOSUB)) {
		    LM_ERR("invalid regular expression\n");
		    continue;
		}
		if (regexec(&preg, uri_string, 0, (regmatch_t *)0, 0)) {
		    regfree(&preg);
		    continue;
		} else {
		    regfree(&preg);
		    goto found;
		}
	    }
	}
	return -1;

found:
	get_tag_avp(&tag_avp, &tag_avp_type);
	if (tag_avp.n && !VAL_NULL(val + 2)) {
	    avp_val.s.s = (char *)VAL_STRING(val + 2);
	    avp_val.s.len = strlen(avp_val.s.s);
	    if (add_avp(tag_avp_type|AVP_VAL_STR, tag_avp, avp_val) != 0) {
		LM_ERR("failed to set of tag_avp failed\n");
		return -1;
	    }
	}
	return 1;
}


/*
 * Checks based on given source IP address and protocol, and From URI
 * of request if request can be trusted without authentication.
 */
int allow_trusted(struct sip_msg* msg, char *src_ip, int proto) 
{
	int result;
	db1_res_t* res = NULL;
	
	db_key_t keys[1];
	db_val_t vals[1];
	db_key_t cols[3];

	if (db_mode == DISABLE_CACHE) {
	
	        if (db_handle == 0) {
		    LM_ERR("no connection to database\n");
		    return -1;
	        }

		keys[0] = &source_col;
		cols[0] = &proto_col;
		cols[1] = &from_col;
		cols[2] = &tag_col;

		if (perm_dbf.use_table(db_handle, &trusted_table) < 0) {
			LM_ERR("failed to use trusted table\n");
			return -1;
		}
		
		VAL_TYPE(vals) = DB1_STRING;
		VAL_NULL(vals) = 0;
		VAL_STRING(vals) = src_ip;

		if (perm_dbf.query(db_handle, keys, 0, vals, cols, 1, 3, 0,
				   &res) < 0){
			LM_ERR("failed to query database\n");
			return -1;
		}

		if (RES_ROW_N(res) == 0) {
			perm_dbf.free_result(db_handle, res);
			return -1;
		}
		
		result = match_res(msg, proto, res);
		perm_dbf.free_result(db_handle, res);
		return result;
	} else {
		return match_hash_table(*hash_table, msg, src_ip, proto);
	}
}


/*
 * Checks based on request's source address, protocol, and From URI
 * if request can be trusted without authentication.
 */
int allow_trusted_0(struct sip_msg* _msg, char* str1, char* str2) 
{
    return allow_trusted(_msg, ip_addr2a(&(_msg->rcv.src_ip)),
			 _msg->rcv.proto);
}


/*
 * Checks based on source address and protocol given in pvar arguments and
 * and requests's From URI, if request can be trusted without authentication.
 */
int allow_trusted_2(struct sip_msg* _msg, char* _src_ip_sp, char* _proto_sp) 
{
    pv_spec_t *src_ip_sp, *proto_sp;
    pv_value_t pv_val;
    char *src_ip, *proto;
    int proto_int;

    src_ip_sp = (pv_spec_t *)_src_ip_sp;
    proto_sp = (pv_spec_t *)_proto_sp;
    
    if (src_ip_sp && (pv_get_spec_value(_msg, src_ip_sp, &pv_val) == 0)) {
	if (pv_val.flags & PV_VAL_STR) {
	    src_ip = pv_val.rs.s;
	} else {
	    LM_ERR("src_ip pvar value is not string\n");
	    return -1;
	}
    } else {
	LM_ERR("src_ip pvar does not exist or has no value\n");
	return -1;
    }
    
    if (proto_sp && (pv_get_spec_value(_msg, proto_sp, &pv_val) == 0)) {
	if (pv_val.flags & PV_VAL_STR) {
	    strlower(&(pv_val.rs));
	    proto = pv_val.rs.s;
	} else {
	    LM_ERR("proto pvar value is not string\n");
	    return -1;
	}
    } else {
	LM_ERR("proto pvar does not exist or has no value\n");
	return -1;
    }

    if (strcmp(proto, "udp") == 0) {
	proto_int = PROTO_UDP;
    } else if (strcmp(proto, "tcp") == 0) {
	proto_int = PROTO_TCP;
    } else if (strcmp(proto, "tls") == 0) {
	proto_int = PROTO_TLS;
    } else if (strcmp(proto, "sctp") == 0) {
	proto_int = PROTO_SCTP;
    } else {
	LM_ERR("unknown protocol %s\n", proto);
	return -1;
    }

    return allow_trusted(_msg, src_ip, proto_int);
}