fixed some missing includes
[NeonServV5.git] / src / mysqlConn.c
1 /* mysqlConn.c - NeonServ v5.6
2  * Copyright (C) 2011-2012  Philipp Kreil (pk910)
3  * 
4  * This program is free software: you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation, either version 3 of the License, or
7  * (at your option) any later version.
8  * 
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12  * GNU General Public License for more details.
13  * 
14  * You should have received a copy of the GNU General Public License 
15  * along with this program. If not, see <http://www.gnu.org/licenses/>. 
16  */
17
18 #include "mysqlConn.h"
19 #include "ConfigParser.h"
20 #include "tools.h"
21 #define DATABASE_VERSION "20"
22
23 static void show_mysql_error();
24
25 struct mysql_conn_struct {
26     unsigned int tid;
27     MYSQL *mysql_conn;
28     struct used_result *used_results;
29     struct escaped_string *escaped_strings;
30     struct mysql_conn_struct *next;
31 };
32
33 struct used_result {
34     MYSQL_RES *result;
35     struct used_result *next;
36 };
37
38 struct escaped_string {
39     char *string;
40     struct escaped_string *next;
41 };
42
43 struct mysql_conn_struct *get_mysql_conn_struct();
44
45 struct mysql_conn_struct *mysql_conns = NULL;
46 static int mysql_serverport;
47 static char *mysql_host, *mysql_user, *mysql_pass, *mysql_base;
48
49 #ifdef HAVE_THREADS
50 static pthread_mutex_t synchronized;
51 #endif
52
53 static void check_mysql() {
54     MYSQL *mysql_conn = get_mysql_conn();
55     int errid;
56     if((errid = mysql_ping(mysql_conn))) {
57         if(mysql_errno(mysql_conn) == CR_SERVER_GONE_ERROR) {
58             if(!mysql_real_connect(mysql_conn, mysql_host, mysql_user, mysql_pass, mysql_base, mysql_serverport, NULL, 0)) {
59                 show_mysql_error();
60             }
61         } else {
62             //mysql error
63             show_mysql_error();
64         }
65     }
66 }
67
68 MYSQL_RES *mysql_use() {
69     struct mysql_conn_struct *mysql_conn = get_mysql_conn_struct();
70     MYSQL_RES *res = mysql_store_result(mysql_conn->mysql_conn);
71     struct used_result *result = malloc(sizeof(*result));
72     if (!result) {
73         mysql_free_result(res);
74         return NULL;
75     }
76     result->result = res;
77     result->next = mysql_conn->used_results;
78     mysql_conn->used_results = result;
79     return res;
80 }
81
82 void mysql_free() {
83     struct mysql_conn_struct *mysql_conn = get_mysql_conn_struct();
84     if(!mysql_conn) return;
85     struct used_result *result, *next_result;
86     for(result = mysql_conn->used_results; result; result = next_result) {
87         next_result = result->next;
88         mysql_free_result(result->result);
89         free(result);
90     }
91     mysql_conn->used_results = NULL;
92     struct escaped_string *escaped, *next_escaped;
93     for(escaped = mysql_conn->escaped_strings; escaped; escaped = next_escaped) {
94         next_escaped = escaped->next;
95         free(escaped->string);
96         free(escaped);
97     }
98     mysql_conn->escaped_strings = NULL;
99 }
100
101 int reload_mysql() {
102     char *new_mysql_host = get_string_field("MySQL.host");
103     char *new_mysql_user = get_string_field("MySQL.user");
104     char *new_mysql_pass = get_string_field("MySQL.pass");
105     char *new_mysql_base = get_string_field("MySQL.base");
106     if(!(new_mysql_host && new_mysql_user && new_mysql_pass && new_mysql_base))
107         return 0;
108     
109     //replace login data
110     if(mysql_host)
111         free(mysql_host);
112     mysql_host = strdup(new_mysql_host);
113     
114     if(mysql_user)
115         free(mysql_user);
116     mysql_user = strdup(new_mysql_user);
117     
118     if(mysql_pass)
119         free(mysql_pass);
120     mysql_pass = strdup(new_mysql_pass);
121     
122     if(mysql_base)
123         free(mysql_base);
124     mysql_base = strdup(new_mysql_base);
125     
126     mysql_serverport = get_int_field("MySQL.port");
127     if(!mysql_serverport)
128         mysql_serverport = 3306;
129     return 1;
130 }
131
132 void init_mysql() {
133     THREAD_MUTEX_INIT(synchronized);
134     
135     MYSQL *mysql_conn = get_mysql_conn();
136     
137     //check database version...
138     int version = 0;
139     if(!mysql_query(mysql_conn, "SELECT `database_version` FROM `version`")) {
140         MYSQL_RES *res = mysql_use();
141         MYSQL_ROW row;
142         if((row = mysql_fetch_row(res))) {
143             version = atoi(row[0]);
144         }
145     }
146     if(!version) {
147         //CREATE DATABASE
148         FILE *f = fopen("database.sql", "r");
149         mysql_set_server_option(mysql_conn, MYSQL_OPTION_MULTI_STATEMENTS_ON);
150         if (f) {
151             char line[512];
152             char query_buffer[8192];
153             int query_buffer_pos = 0;
154             while (fgets(line, sizeof(line), f)) {
155                 query_buffer_pos += sprintf(query_buffer + query_buffer_pos, " %s", line);
156                 if(line[(strlen(line) - 2)] == ';') {
157                     if(mysql_query(mysql_conn, query_buffer))
158                         show_mysql_error();
159                     query_buffer_pos = 0;
160                 }
161             }
162             fclose(f);
163         }
164         f = fopen("database.defaults.sql", "r");
165         if (f) {
166             char line[4096];
167             char query_buffer[131072];
168             int query_buffer_pos = 0;
169             while (fgets(line, sizeof(line), f)) {
170                 query_buffer_pos += sprintf(query_buffer + query_buffer_pos, " %s", line);
171                 if(line[(strlen(line) - 2)] == ';') {
172                     if(mysql_query(mysql_conn, query_buffer))
173                         show_mysql_error();
174                     query_buffer_pos = 0;
175                 }
176             }
177             fclose(f);
178         }
179         do { 
180             MYSQL_RES *res = mysql_store_result(mysql_conn); 
181             mysql_free_result(res); 
182         } while(!mysql_next_result(mysql_conn));
183         mysql_set_server_option(mysql_conn, MYSQL_OPTION_MULTI_STATEMENTS_OFF);
184         mysql_query(mysql_conn, "INSERT INTO `version` (`database_version`) VALUES ('" DATABASE_VERSION "')");
185     }
186     else if(version < atoi(DATABASE_VERSION)) {
187         //UPDATE DATABASE
188         FILE *f = fopen("database.upgrade.sql", "r");
189         mysql_set_server_option(mysql_conn, MYSQL_OPTION_MULTI_STATEMENTS_ON);
190         if (f) {
191             char line[512];
192             char query_buffer[8192];
193             int query_buffer_pos = 0, use_querys = 0;
194             sprintf(query_buffer, "-- version: %d", version);
195             while (fgets(line, sizeof(line), f)) {
196                 if(use_querys) {
197                     query_buffer_pos += sprintf(query_buffer + query_buffer_pos, " %s", line);
198                     if(line[strlen(line) - 1] == ';') {
199                         mysql_query(mysql_conn, query_buffer);
200                         query_buffer_pos = 0;
201                     }
202                 } else if(!stricmplen(query_buffer, line, strlen(query_buffer))) {
203                     use_querys = 1;
204                 }
205             }
206             if(query_buffer_pos) {
207                 if(mysql_query(mysql_conn, query_buffer))
208                     show_mysql_error();
209             }
210             fclose(f);
211         } else
212             perror("database.sql missing!");
213         do { 
214             MYSQL_RES *res = mysql_store_result(mysql_conn); 
215             mysql_free_result(res); 
216         } while(!mysql_next_result(mysql_conn));
217         mysql_set_server_option(mysql_conn, MYSQL_OPTION_MULTI_STATEMENTS_OFF);
218         mysql_query(mysql_conn, "UPDATE `version` SET `database_version` = '" DATABASE_VERSION "'");
219     }
220 }
221
222 void free_mysql() {
223     struct mysql_conn_struct *mysql_conn, *next;
224     for(mysql_conn = mysql_conns; mysql_conn; mysql_conn = next) {
225         next = mysql_conn->next;
226         mysql_close(mysql_conn->mysql_conn);
227         free(mysql_conn);
228     }
229     mysql_conns = NULL;
230 }
231
232 static void show_mysql_error() {
233     MYSQL *mysql_conn = get_mysql_conn();
234     //show mysql_error()
235     putlog(LOGLEVEL_ERROR, "MySQL Error: %s\n", mysql_error(mysql_conn));
236 }
237
238 void printf_mysql_query(const char *text, ...) {
239     MYSQL *mysql_conn = get_mysql_conn();
240     va_list arg_list;
241     char queryBuf[MYSQLMAXLEN];
242     int pos;
243     queryBuf[0] = '\0';
244     va_start(arg_list, text);
245     pos = vsnprintf(queryBuf, MYSQLMAXLEN - 2, text, arg_list);
246     va_end(arg_list);
247     if (pos < 0 || pos > (MYSQLMAXLEN - 2)) pos = MYSQLMAXLEN - 2;
248     queryBuf[pos] = '\0';
249     putlog(LOGLEVEL_MYSQL, "MySQL: %s\n", queryBuf);
250     if(mysql_query(mysql_conn, queryBuf)) {
251         check_mysql();
252         if(mysql_query(mysql_conn, queryBuf)) {
253             show_mysql_error();
254         }
255     }
256 }
257
258 void printf_long_mysql_query(int len, const char *text, ...) {
259     MYSQL *mysql_conn = get_mysql_conn();
260     va_list arg_list;
261     char queryBuf[len];
262     int pos;
263     queryBuf[0] = '\0';
264     va_start(arg_list, text);
265     pos = vsnprintf(queryBuf, len - 2, text, arg_list);
266     va_end(arg_list);
267     if (pos < 0 || pos > (len - 2)) pos = len - 2;
268     queryBuf[pos] = '\0';
269     putlog(LOGLEVEL_MYSQL, "MySQL: %s\n", queryBuf);
270     if(mysql_query(mysql_conn, queryBuf)) {
271         check_mysql();
272         if(mysql_query(mysql_conn, queryBuf)) {
273             show_mysql_error();
274         }
275     }
276 }
277
278 char* escape_string(const char *str) {
279     struct mysql_conn_struct *mysql_conn = get_mysql_conn_struct();
280     struct escaped_string *escapedstr = malloc(sizeof(*escapedstr));
281     if (!escapedstr) {
282         return NULL;
283     }
284     char escaped[strlen(str)*2+1];
285     mysql_real_escape_string(mysql_conn->mysql_conn, escaped, str, strlen(str));
286     escapedstr->string = strdup(escaped);
287     escapedstr->next = mysql_conn->escaped_strings;
288     mysql_conn->escaped_strings = escapedstr;
289     return escapedstr->string;
290 }
291
292 struct mysql_conn_struct *get_mysql_conn_struct() {
293     SYNCHRONIZE(synchronized);
294     struct mysql_conn_struct *mysql_conn;
295     unsigned int tid;
296     #ifdef HAVE_THREADS
297     tid = (unsigned int) pthread_self_tid();
298     #else
299     tid = 1;
300     #endif
301     for(mysql_conn = mysql_conns; mysql_conn; mysql_conn = mysql_conn->next) {
302         if(mysql_conn->tid == tid) {
303             DESYNCHRONIZE(synchronized);
304             return mysql_conn;
305         }
306     }
307     mysql_conn = malloc(sizeof(*mysql_conn));
308     mysql_conn->mysql_conn = mysql_init(NULL);
309     mysql_conn->tid = tid;
310     mysql_conn->used_results = NULL;
311     mysql_conn->escaped_strings = NULL;
312     mysql_conn->next = mysql_conns;
313     mysql_conns = mysql_conn;
314     if (!mysql_real_connect(mysql_conn->mysql_conn, mysql_host, mysql_user, mysql_pass, mysql_base, mysql_serverport, NULL, 0)) {
315         //error
316         show_mysql_error();
317     }
318     DESYNCHRONIZE(synchronized);
319     return mysql_conn;
320 }
321
322 MYSQL *get_mysql_conn() {
323     struct mysql_conn_struct *mysql_conn = get_mysql_conn_struct();
324     return mysql_conn->mysql_conn;
325 }