added stats module for neonserv.krypton-bouncer.de stats
[NeonServV5.git] / src / mysqlConn.c
1 /* mysqlConn.c - NeonServ v5.4
2  * Copyright (C) 2011-2012  Philipp Kreil (pk910)
3  * 
4  * This program is free software: you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation, either version 3 of the License, or
7  * (at your option) any later version.
8  * 
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12  * GNU General Public License for more details.
13  * 
14  * You should have received a copy of the GNU General Public License 
15  * along with this program. If not, see <http://www.gnu.org/licenses/>. 
16  */
17
18 #include "mysqlConn.h"
19 #define DATABASE_VERSION "16"
20
21 static void show_mysql_error();
22
23 struct mysql_conn_struct {
24     unsigned int tid;
25     MYSQL *mysql_conn;
26     struct used_result *used_results;
27     struct escaped_string *escaped_strings;
28     struct mysql_conn_struct *next;
29 };
30
31 struct used_result {
32     MYSQL_RES *result;
33     struct used_result *next;
34 };
35
36 struct escaped_string {
37     char *string;
38     struct escaped_string *next;
39 };
40
41 struct mysql_conn_struct *get_mysql_conn_struct();
42
43 struct mysql_conn_struct *mysql_conns = NULL;
44 static int mysql_serverport;
45 static char *mysql_host, *mysql_user, *mysql_pass, *mysql_base;
46
47 #ifdef HAVE_THREADS
48 static pthread_mutex_t synchronized;
49 #endif
50
51 static void check_mysql() {
52     MYSQL *mysql_conn = get_mysql_conn();
53     int errid;
54     if((errid = mysql_ping(mysql_conn))) {
55         if(mysql_errno(mysql_conn) == CR_SERVER_GONE_ERROR) {
56             if(!mysql_real_connect(mysql_conn, mysql_host, mysql_user, mysql_pass, mysql_base, mysql_serverport, NULL, 0)) {
57                 show_mysql_error();
58             }
59         } else {
60             //mysql error
61             show_mysql_error();
62         }
63     }
64 }
65
66 MYSQL_RES *mysql_use() {
67     struct mysql_conn_struct *mysql_conn = get_mysql_conn_struct();
68     MYSQL_RES *res = mysql_store_result(mysql_conn->mysql_conn);
69     struct used_result *result = malloc(sizeof(*result));
70     if (!result) {
71         mysql_free_result(res);
72         return NULL;
73     }
74     result->result = res;
75     result->next = mysql_conn->used_results;
76     mysql_conn->used_results = result;
77     return res;
78 }
79
80 void mysql_free() {
81     struct mysql_conn_struct *mysql_conn = get_mysql_conn_struct();
82     if(!mysql_conn) return;
83     struct used_result *result, *next_result;
84     for(result = mysql_conn->used_results; result; result = next_result) {
85         next_result = result->next;
86         mysql_free_result(result->result);
87         free(result);
88     }
89     mysql_conn->used_results = NULL;
90     struct escaped_string *escaped, *next_escaped;
91     for(escaped = mysql_conn->escaped_strings; escaped; escaped = next_escaped) {
92         next_escaped = escaped->next;
93         free(escaped->string);
94         free(escaped);
95     }
96     mysql_conn->escaped_strings = NULL;
97 }
98
99 void init_mysql(char *host, int port, char *user, char *pass, char *base) {
100     THREAD_MUTEX_INIT(synchronized);
101     mysql_host = strdup(host);
102     mysql_serverport = port;
103     mysql_user = strdup(user);
104     mysql_pass = strdup(pass);
105     mysql_base = strdup(base);
106     
107     
108     MYSQL *mysql_conn = get_mysql_conn();
109     
110     //check database version...
111     int version = 0;
112     if(!mysql_query(mysql_conn, "SELECT `database_version` FROM `version`")) {
113         MYSQL_RES *res = mysql_use();
114         MYSQL_ROW row;
115         if((row = mysql_fetch_row(res))) {
116             version = atoi(row[0]);
117         }
118     }
119     if(!version) {
120         //CREATE DATABASE
121         FILE *f = fopen("database.sql", "r");
122         mysql_set_server_option(mysql_conn, MYSQL_OPTION_MULTI_STATEMENTS_ON);
123         if (f) {
124             char line[512];
125             char query_buffer[8192];
126             int query_buffer_pos = 0;
127             while (fgets(line, sizeof(line), f)) {
128                 query_buffer_pos += sprintf(query_buffer + query_buffer_pos, " %s", line);
129                 if(line[(strlen(line) - 2)] == ';') {
130                     if(mysql_query(mysql_conn, query_buffer))
131                         show_mysql_error();
132                     query_buffer_pos = 0;
133                 }
134             }
135             fclose(f);
136         }
137         f = fopen("database.defaults.sql", "r");
138         if (f) {
139             char line[4096];
140             char query_buffer[131072];
141             int query_buffer_pos = 0;
142             while (fgets(line, sizeof(line), f)) {
143                 query_buffer_pos += sprintf(query_buffer + query_buffer_pos, " %s", line);
144                 if(line[(strlen(line) - 2)] == ';') {
145                     if(mysql_query(mysql_conn, query_buffer))
146                         show_mysql_error();
147                     query_buffer_pos = 0;
148                 }
149             }
150             fclose(f);
151         }
152         do { 
153             MYSQL_RES *res = mysql_store_result(mysql_conn); 
154             mysql_free_result(res); 
155         } while(!mysql_next_result(mysql_conn));
156         mysql_set_server_option(mysql_conn, MYSQL_OPTION_MULTI_STATEMENTS_OFF);
157         mysql_query(mysql_conn, "INSERT INTO `version` (`database_version`) VALUES ('" DATABASE_VERSION "')");
158     }
159     else if(version < atoi(DATABASE_VERSION)) {
160         //UPDATE DATABASE
161         FILE *f = fopen("database.upgrade.sql", "r");
162         mysql_set_server_option(mysql_conn, MYSQL_OPTION_MULTI_STATEMENTS_ON);
163         if (f) {
164             char line[512];
165             char query_buffer[8192];
166             int query_buffer_pos = 0, use_querys = 0;
167             sprintf(query_buffer, "-- version: %d", version);
168             while (fgets(line, sizeof(line), f)) {
169                 if(use_querys) {
170                     query_buffer_pos += sprintf(query_buffer + query_buffer_pos, " %s", line);
171                     if(line[strlen(line) - 1] == ';') {
172                         mysql_query(mysql_conn, query_buffer);
173                         query_buffer_pos = 0;
174                     }
175                 } else if(!stricmplen(query_buffer, line, strlen(query_buffer))) {
176                     use_querys = 1;
177                 }
178             }
179             if(query_buffer_pos) {
180                 if(mysql_query(mysql_conn, query_buffer))
181                     show_mysql_error();
182             }
183             fclose(f);
184         } else
185             perror("database.sql missing!");
186         do { 
187             MYSQL_RES *res = mysql_store_result(mysql_conn); 
188             mysql_free_result(res); 
189         } while(!mysql_next_result(mysql_conn));
190         mysql_set_server_option(mysql_conn, MYSQL_OPTION_MULTI_STATEMENTS_OFF);
191         mysql_query(mysql_conn, "UPDATE `version` SET `database_version` = '" DATABASE_VERSION "'");
192     }
193 }
194
195 void free_mysql() {
196     struct mysql_conn_struct *mysql_conn, *next;
197     for(mysql_conn = mysql_conns; mysql_conn; mysql_conn = next) {
198         next = mysql_conn->next;
199         mysql_close(mysql_conn->mysql_conn);
200         free(mysql_conn);
201     }
202     mysql_conns = NULL;
203 }
204
205 static void show_mysql_error() {
206     MYSQL *mysql_conn = get_mysql_conn();
207     //show mysql_error()
208     putlog(LOGLEVEL_ERROR, "MySQL Error: %s\n", mysql_error(mysql_conn));
209 }
210
211 void printf_mysql_query(const char *text, ...) {
212     MYSQL *mysql_conn = get_mysql_conn();
213     va_list arg_list;
214     char queryBuf[MYSQLMAXLEN];
215     int pos;
216     queryBuf[0] = '\0';
217     va_start(arg_list, text);
218     pos = vsnprintf(queryBuf, MYSQLMAXLEN - 2, text, arg_list);
219     va_end(arg_list);
220     if (pos < 0 || pos > (MYSQLMAXLEN - 2)) pos = MYSQLMAXLEN - 2;
221     queryBuf[pos] = '\0';
222     putlog(LOGLEVEL_MYSQL, "MySQL: %s\n", queryBuf);
223     if(mysql_query(mysql_conn, queryBuf)) {
224         check_mysql();
225         if(mysql_query(mysql_conn, queryBuf)) {
226             show_mysql_error();
227         }
228     }
229 }
230
231 void printf_long_mysql_query(int len, const char *text, ...) {
232     MYSQL *mysql_conn = get_mysql_conn();
233     va_list arg_list;
234     char queryBuf[len];
235     int pos;
236     queryBuf[0] = '\0';
237     va_start(arg_list, text);
238     pos = vsnprintf(queryBuf, len - 2, text, arg_list);
239     va_end(arg_list);
240     if (pos < 0 || pos > (len - 2)) pos = len - 2;
241     queryBuf[pos] = '\0';
242     putlog(LOGLEVEL_MYSQL, "MySQL: %s\n", queryBuf);
243     if(mysql_query(mysql_conn, queryBuf)) {
244         check_mysql();
245         if(mysql_query(mysql_conn, queryBuf)) {
246             show_mysql_error();
247         }
248     }
249 }
250
251 char* escape_string(const char *str) {
252     struct mysql_conn_struct *mysql_conn = get_mysql_conn_struct();
253     struct escaped_string *escapedstr = malloc(sizeof(*escapedstr));
254     if (!escapedstr) {
255         return NULL;
256     }
257     char escaped[strlen(str)*2+1];
258     mysql_real_escape_string(mysql_conn->mysql_conn, escaped, str, strlen(str));
259     escapedstr->string = strdup(escaped);
260     escapedstr->next = mysql_conn->escaped_strings;
261     mysql_conn->escaped_strings = escapedstr;
262     return escapedstr->string;
263 }
264
265 struct mysql_conn_struct *get_mysql_conn_struct() {
266     SYNCHRONIZE(synchronized);
267     struct mysql_conn_struct *mysql_conn;
268     unsigned int tid;
269     #ifdef HAVE_THREADS
270     tid = (unsigned int) pthread_self_tid();
271     #else
272     tid = 1;
273     #endif
274     for(mysql_conn = mysql_conns; mysql_conn; mysql_conn = mysql_conn->next) {
275         if(mysql_conn->tid == tid) {
276             DESYNCHRONIZE(synchronized);
277             return mysql_conn;
278         }
279     }
280     mysql_conn = malloc(sizeof(*mysql_conn));
281     mysql_conn->mysql_conn = mysql_init(NULL);
282     mysql_conn->tid = tid;
283     mysql_conn->used_results = NULL;
284     mysql_conn->escaped_strings = NULL;
285     mysql_conn->next = mysql_conns;
286     mysql_conns = mysql_conn;
287     if (!mysql_real_connect(mysql_conn->mysql_conn, mysql_host, mysql_user, mysql_pass, mysql_base, mysql_serverport, NULL, 0)) {
288         //error
289         show_mysql_error();
290     }
291     DESYNCHRONIZE(synchronized);
292     return mysql_conn;
293 }
294
295 MYSQL *get_mysql_conn() {
296     struct mysql_conn_struct *mysql_conn = get_mysql_conn_struct();
297     return mysql_conn->mysql_conn;
298 }