Browse code

- check for database table version

Jan Janak authored on 29/11/2003 00:35:51
Showing 3 changed files
... ...
@@ -47,6 +47,7 @@
47 47
 
48 48
 MODULE_VERSION
49 49
 
50
+#define TABLE_VERSION 3
50 51
 
51 52
 /*
52 53
  * Module destroy function prototype
... ...
@@ -77,18 +78,35 @@ post_auth_f post_auth_func = 0;
77 78
  */
78 79
 int (*sl_reply)(struct sip_msg* _msg, char* _str1, char* _str2);
79 80
 
81
+
82
+#define USER_COL "username"
83
+#define USER_COL_LEN (sizeof(USER_COL) - 1)
84
+
85
+#define DOMAIN_COL "domain"
86
+#define DOMAIN_COL_LEN (sizeof(DOMAIN_COL) - 1)
87
+
88
+#define RPID_COL "rpid"
89
+#define RPID_COL_LEN (sizeof(RPID_COL) - 1)
90
+
91
+#define PASS_COL "ha1"
92
+#define PASS_COL_LEN (sizeof(PASS_COL) - 1)
93
+
94
+#define PASS_COL_2 "ha1b"
95
+#define PASS_COL_2_LEN (sizeof(PASS_COL_2) - 1)
96
+
97
+
80 98
 /*
81 99
  * Module parameter variables
82 100
  */
83
-char* db_url           = DEFAULT_RODB_URL;
84
-char* user_column      = "username";
85
-char* domain_column    = "domain";
86
-char* rpid_column      = "rpid";
87
-char* pass_column      = "ha1";
88
-char* pass_column_2    = "ha1b";
89
-int   calc_ha1         = 0;
90
-int   use_domain       = 1;    /* Use also domain when looking up a table row */
91
-int   use_rpid         = 0;    /* Fetch Remote-Party-ID */
101
+str db_url           = {DEFAULT_RODB_URL, DEFAULT_RODB_URL_LEN};
102
+str user_column      = {USER_COL, USER_COL_LEN};
103
+str domain_column    = {DOMAIN_COL, DOMAIN_COL_LEN};
104
+str rpid_column      = {RPID_COL, RPID_COL_LEN};
105
+str pass_column      = {PASS_COL, PASS_COL_LEN};
106
+str pass_column_2    = {PASS_COL_2, PASS_COL_2_LEN};
107
+int calc_ha1         = 0;
108
+int use_domain       = 1;    /* Use also domain when looking up a table row */
109
+int use_rpid         = 0;    /* Fetch Remote-Party-ID */
92 110
 
93 111
 db_con_t* db_handle;   /* Database connection handle */
94 112
 
... ...
@@ -107,15 +125,15 @@ static cmd_export_t cmds[] = {
107 125
  * Exported parameters
108 126
  */
109 127
 static param_export_t params[] = {
110
-	{"db_url",            STR_PARAM, &db_url       },
111
-	{"user_column",       STR_PARAM, &user_column  },
112
-	{"domain_column",     STR_PARAM, &domain_column},
113
-	{"rpid_column",       STR_PARAM, &rpid_column  },
114
-	{"password_column",   STR_PARAM, &pass_column  },
115
-	{"password_column_2", STR_PARAM, &pass_column_2},
116
-	{"calculate_ha1",     INT_PARAM, &calc_ha1     },
117
-	{"use_domain",        INT_PARAM, &use_domain   },
118
-	{"use_rpid",          INT_PARAM, &use_rpid     },
128
+	{"db_url",            STR_PARAM, &db_url.s       },
129
+	{"user_column",       STR_PARAM, &user_column.s  },
130
+	{"domain_column",     STR_PARAM, &domain_column.s},
131
+	{"rpid_column",       STR_PARAM, &rpid_column.s  },
132
+	{"password_column",   STR_PARAM, &pass_column.s  },
133
+	{"password_column_2", STR_PARAM, &pass_column_2.s},
134
+	{"calculate_ha1",     INT_PARAM, &calc_ha1       },
135
+	{"use_domain",        INT_PARAM, &use_domain     },
136
+	{"use_rpid",          INT_PARAM, &use_rpid       },
119 137
 	{0, 0, 0}
120 138
 };
121 139
 
... ...
@@ -137,13 +155,15 @@ struct module_exports exports = {
137 155
 
138 156
 static int child_init(int rank)
139 157
 {
140
-	db_handle = db_init(db_url);
158
+	     /* Close connection opened in mod_init */
159
+	db_close(db_handle);
160
+	db_handle = db_init(db_url.s);
141 161
 	if (!db_handle) {
142 162
 		LOG(L_ERR, "auth_db:init_child(): Unable to connect database\n");
143 163
 		return -1;
144 164
 	}
145
-	return 0;
146 165
 
166
+	return 0;
147 167
 }
148 168
 
149 169
 
... ...
@@ -151,23 +171,41 @@ static int mod_init(void)
151 171
 {
152 172
 	DBG("auth_db module - initializing\n");
153 173
 	
174
+	db_url.len = strlen(db_url.s);
175
+	user_column.len = strlen(user_column.s);
176
+	domain_column.len = strlen(domain_column.s);
177
+	rpid_column.len = strlen(rpid_column.s);
178
+	pass_column.len = strlen(pass_column.s);
179
+	pass_column_2.len = strlen(pass_column.s);
180
+
154 181
 	     /* Find a database module */
155
-	if (bind_dbmod(db_url)) {
156
-		LOG(L_ERR, "mod_init(): Unable to bind database module\n");
182
+	if (bind_dbmod(db_url.s) < 0) {
183
+		LOG(L_ERR, "auth_db:mod_init(): Unable to bind database module\n");
157 184
 		return -1;
158 185
 	}
159 186
 
187
+	     /* Open database connection in parent */
188
+	db_handle = db_init(db_url.s);
189
+	if (!db_handle) {
190
+		LOG(L_ERR, "auth_db:mod_init(): Error while connecting database\n");
191
+		return -1;
192
+	} else {
193
+		LOG(L_INFO, "auth_db:mod_init(): Database connection opened successfuly\n");
194
+	}
195
+
160 196
 	pre_auth_func = (pre_auth_f)find_export("pre_auth", 0, 0);
161 197
 	post_auth_func = (post_auth_f)find_export("post_auth", 0, 0);
162 198
 
163 199
 	if (!(pre_auth_func && post_auth_func)) {
164 200
 		LOG(L_ERR, "auth_db:mod_init(): This module requires auth module\n");
201
+		db_close(db_handle);
165 202
 		return -2;
166 203
 	}
167 204
 
168 205
 	sl_reply = find_export("sl_send_reply", 2, 0);
169 206
 	if (!sl_reply) {
170 207
 		LOG(L_ERR, "auth_db:mod_init(): This module requires sl module\n");
208
+		db_close(db_handle);
171 209
 		return -2;
172 210
 	}
173 211
 
... ...
@@ -188,6 +226,8 @@ static void destroy(void)
188 226
 static int str_fixup(void** param, int param_no)
189 227
 {
190 228
 	str* s;
229
+	int ver;
230
+	str name;
191 231
 
192 232
 	if (param_no == 1) {
193 233
 		s = (str*)pkg_malloc(sizeof(str));
... ...
@@ -199,6 +239,19 @@ static int str_fixup(void** param, int param_no)
199 239
 		s->s = (char*)*param;
200 240
 		s->len = strlen(s->s);
201 241
 		*param = (void*)s;
242
+	} else if (param_no == 2) {
243
+		name.s = (char*)*param;
244
+		name.len = strlen(name.s);
245
+
246
+		ver = table_version(db_handle, &name);
247
+
248
+		if (ver < 0) {
249
+			LOG(L_ERR, "auth_db:str_fixup(): Error while querying table version\n");
250
+			return -1;
251
+		} else if (ver < TABLE_VERSION) {
252
+			LOG(L_ERR, "auth_db:str_fixup(): Invalid table version (use ser_mysql.sh reinstall)\n");
253
+			return -1;
254
+		}
202 255
 	}
203 256
 
204 257
 	return 0;
... ...
@@ -31,6 +31,7 @@
31 31
 #ifndef AUTHDB_MOD_H
32 32
 #define AUTHDB_MOD_H
33 33
 
34
+#include "../../str.h"
34 35
 #include "../../db/db.h"
35 36
 #include "../auth/api.h"
36 37
 #include "../../parser/msg_parser.h"
... ...
@@ -40,14 +41,14 @@
40 41
  * Module parameters variables
41 42
  */
42 43
 
43
-extern char* db_url;          /* Database URL */
44
-extern char* user_column;     /* 'username' column name */
45
-extern char* domain_column;   /* 'domain' column name */
46
-extern char* rpid_column;     /* 'rpid' column name */
47
-extern char* pass_column;     /* 'password' column name */
48
-extern char* pass_column_2;   /* Column containg HA1 string constructed
49
-			       * of user@domain username
50
-			       */
44
+extern str db_url;          /* Database URL */
45
+extern str user_column;     /* 'username' column name */
46
+extern str domain_column;   /* 'domain' column name */
47
+extern str rpid_column;     /* 'rpid' column name */
48
+extern str pass_column;     /* 'password' column name */
49
+extern str pass_column_2;   /* Column containg HA1 string constructed
50
+			     * of user@domain username
51
+			     */
51 52
 
52 53
 extern int calc_ha1;          /* if set to 1, ha1 is calculated by the server */
53 54
 extern int use_domain;        /* If set to 1 then the domain will be used when selecting a row */
... ...
@@ -60,10 +60,10 @@ static inline int get_ha1(struct username* _username, str* _domain, char* _table
60 60
 	str result;
61 61
 	int n, nc;
62 62
 
63
-	keys[0] = user_column;
64
-	keys[1] = domain_column;
65
-	col[0] = (_username->domain.len && !calc_ha1) ? (pass_column_2) : (pass_column);	
66
-	col[1] = rpid_column;
63
+	keys[0] = user_column.s;
64
+	keys[1] = domain_column.s;
65
+	col[0] = (_username->domain.len && !calc_ha1) ? (pass_column_2.s) : (pass_column.s);	
66
+	col[1] = rpid_column.s;
67 67
 
68 68
 	VAL_TYPE(vals) = VAL_TYPE(vals + 1) = DB_STR;
69 69
 	VAL_NULL(vals) = VAL_NULL(vals + 1) = 0;