diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000..003fc99 --- /dev/null +++ b/.clang-format @@ -0,0 +1,53 @@ +--- +Language: Cpp +BasedOnStyle: LLVM +AccessModifierOffset: -4 +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +AlignOperands: false +AlignTrailingComments: false +AlwaysBreakTemplateDeclarations: Yes +BraceWrapping: + AfterCaseLabel: true + AfterClass: true + AfterControlStatement: true + AfterEnum: true + AfterFunction: true + AfterNamespace: true + AfterStruct: true + AfterUnion: true + AfterExternBlock: false + BeforeCatch: true + BeforeElse: true + BeforeLambdaBody: true + BeforeWhile: true + SplitEmptyFunction: true + SplitEmptyRecord: true + SplitEmptyNamespace: true +BreakBeforeBraces: Attach +BreakConstructorInitializers: AfterColon +BreakConstructorInitializersBeforeComma: false +ColumnLimit: 120 +ConstructorInitializerAllOnOneLineOrOnePerLine: false +IncludeCategories: + - Regex: '^<.*' + Priority: 1 + - Regex: '^".*' + Priority: 2 + - Regex: '.*' + Priority: 3 +IncludeIsMainRegex: '([-_](test|unittest))?$' +IndentCaseBlocks: true +IndentWidth: 4 +InsertNewlineAtEOF: true +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 2 +NamespaceIndentation: All +SpaceInEmptyParentheses: false +SpacesInAngles: false +SpacesInConditionalStatement: false +SpacesInCStyleCastParentheses: false +SpacesInParentheses: false +TabWidth: 4 +... diff --git a/CMakeLists.txt b/CMakeLists.txt index cbfab3c..c01e985 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -45,10 +45,7 @@ add_compile_definitions(UA2F_GIT_BRANCH="${GIT_BRANCH}") add_compile_definitions(UA2F_GIT_TAG="${GIT_TAG}") add_compile_definitions(UA2F_VERSION="${UA2F_VERSION_STR}") -include(CheckSymbolExists) -check_symbol_exists(__malloc_hook "malloc.h" IS_LIBC_GLIBC) - -if (IS_LIBC_GLIBC) +if (UA2F_ENABLE_ASAN) add_compile_options(-fsanitize=address) add_link_options(-fsanitize=address) else () @@ -97,7 +94,6 @@ if (UA2F_BUILD_TESTS) set(CMAKE_CXX_STANDARD 14) set(CMAKE_CXX_STANDARD_REQUIRED ON) - cmake_policy(SET CMP0135 NEW) include(FetchContent) FetchContent_Declare( googletest @@ -111,6 +107,7 @@ if (UA2F_BUILD_TESTS) add_executable( ua2f_test test/util_test.cc + test/cache_test.cc src/util.c src/cache.c src/cli.c diff --git a/src/cache.c b/src/cache.c index 53c6f17..bafd307 100644 --- a/src/cache.c +++ b/src/cache.c @@ -2,13 +2,14 @@ #include "third/uthash.h" #include -#include #include +#include #include pthread_rwlock_t cacheLock; -struct cache *not_http_dst_cache = NULL; +static struct cache *not_http_dst_cache = NULL; +static int check_interval; _Noreturn static void check_cache() { while (true) { @@ -18,7 +19,7 @@ _Noreturn static void check_cache() { struct cache *cur, *tmp; HASH_ITER(hh, not_http_dst_cache, cur, tmp) { - if (difftime(now, cur->last_time) > CACHE_TIMEOUT) { + if (difftime(now, cur->last_time) > check_interval * 2) { HASH_DEL(not_http_dst_cache, cur); free(cur); } @@ -27,11 +28,13 @@ _Noreturn static void check_cache() { pthread_rwlock_unlock(&cacheLock); // wait for 1 minute - sleep(CACHE_CHECK_INTERVAL); + sleep(check_interval); } } -void init_not_http_cache() { +void init_not_http_cache(const int interval) { + check_interval = interval; + if (pthread_rwlock_init(&cacheLock, NULL) != 0) { syslog(LOG_ERR, "Failed to init cache lock"); exit(EXIT_FAILURE); @@ -47,7 +50,7 @@ void init_not_http_cache() { syslog(LOG_INFO, "Cleanup thread created"); } -bool cache_contains(const char* addr_port) { +bool cache_contains(const char *addr_port) { pthread_rwlock_rdlock(&cacheLock); struct cache *s; @@ -58,7 +61,7 @@ bool cache_contains(const char* addr_port) { if (s != NULL) { bool ret; pthread_rwlock_wrlock(&cacheLock); - if (difftime(time(NULL), s->last_time) > CACHE_TIMEOUT) { + if (difftime(time(NULL), s->last_time) > check_interval * 2) { HASH_DEL(not_http_dst_cache, s); free(s); ret = false; diff --git a/src/cache.h b/src/cache.h index e70f6c2..61b1602 100644 --- a/src/cache.h +++ b/src/cache.h @@ -5,9 +5,6 @@ #include #include "third/uthash.h" -#define CACHE_TIMEOUT 127 -#define CACHE_CHECK_INTERVAL 128 - #define INET6_ADDRSTRLEN 46 // 1111:1111:1111:1111:1111:1111:111.111.111.111:65535 // with null terminator @@ -19,11 +16,11 @@ struct cache { UT_hash_handle hh; }; -void init_not_http_cache(); +void init_not_http_cache(int interval); // add addr_port to cache, assume it's not a http dst -void cache_add(const char* addr_port); +void cache_add(const char *addr_port); -bool cache_contains(const char* addr_port); +bool cache_contains(const char *addr_port); -#endif //UA2F_CACHE_H +#endif // UA2F_CACHE_H diff --git a/src/cli.c b/src/cli.c index 061d5f3..3957c58 100644 --- a/src/cli.c +++ b/src/cli.c @@ -1,6 +1,6 @@ -#include #include #include +#include #include "cli.h" #include "config.h" @@ -46,4 +46,4 @@ void try_print_info(const int argc, char *argv[]) { printf(" --version\n"); printf(" --help\n"); exit(1); -} \ No newline at end of file +} diff --git a/src/cli.h b/src/cli.h index b8701ad..f28642d 100644 --- a/src/cli.h +++ b/src/cli.h @@ -21,4 +21,4 @@ void try_print_info(int argc, char *argv[]); -#endif //UA2F_CLI_H +#endif // UA2F_CLI_H diff --git a/src/config.c b/src/config.c index e925bce..467d454 100644 --- a/src/config.c +++ b/src/config.c @@ -1,13 +1,13 @@ #ifdef UA2F_ENABLE_UCI -#include #include #include +#include #include "config.h" struct ua2f_config config = { - .use_custom_ua = false, - .custom_ua = NULL, + .use_custom_ua = false, + .custom_ua = NULL, }; void load_config() { @@ -37,7 +37,7 @@ void load_config() { config.custom_ua = strdup(custom_ua); } - cleanup: +cleanup: uci_free_context(ctx); } -#endif \ No newline at end of file +#endif diff --git a/src/config.h b/src/config.h index c5827c8..08bc43c 100644 --- a/src/config.h +++ b/src/config.h @@ -1,3 +1,5 @@ +#pragma once + #ifdef UA2F_ENABLE_UCI #ifndef UA2F_CONFIG_H #define UA2F_CONFIG_H @@ -13,5 +15,5 @@ void load_config(); extern struct ua2f_config config; -#endif //UA2F_CONFIG_H +#endif // UA2F_CONFIG_H #endif diff --git a/src/handler.c b/src/handler.c index bb9a6f6..5f2331a 100644 --- a/src/handler.c +++ b/src/handler.c @@ -1,18 +1,18 @@ -#include #include "handler.h" #include "cache.h" -#include "util.h" -#include "statistics.h" #include "custom.h" +#include "statistics.h" +#include "util.h" +#include #ifdef UA2F_ENABLE_UCI #include "config.h" #endif -#include -#include #include #include +#include +#include #define MAX_USER_AGENT_LENGTH (0xffff + (MNL_SOCKET_BUFFER_SIZE / 2)) static char *replacement_user_agent_string = NULL; @@ -82,11 +82,8 @@ struct mark_op { uint32_t mark; }; -static void send_verdict( - const struct nf_queue *queue, - const struct nf_packet *pkt, - const struct mark_op mark, - struct pkt_buff *mangled_pkt_buff) { +static void send_verdict(const struct nf_queue *queue, const struct nf_packet *pkt, const struct mark_op mark, + struct pkt_buff *mangled_pkt_buff) { struct nlmsghdr *nlh = nfqueue_put_header(pkt->queue_num, NFQNL_MSG_VERDICT); if (nlh == NULL) { syslog(LOG_ERR, "failed to put nfqueue header"); @@ -116,7 +113,7 @@ static void send_verdict( syslog(LOG_ERR, "failed to send verdict: %s", strerror(errno)); } - end: +end: if (nlh != NULL) { free(nlh); } @@ -133,39 +130,39 @@ static void add_to_cache(const struct nf_packet *pkt) { static struct mark_op get_next_mark(const struct nf_packet *pkt, const bool has_ua) { if (!conntrack_info_available) { - return (struct mark_op) {false, 0}; + return (struct mark_op){false, 0}; } // I didn't think this will happen, but just in case // firewall should already have a rule to return all marked with CONNMARK_NOT_HTTP packets if (pkt->conn_mark == CONNMARK_NOT_HTTP) { syslog(LOG_WARNING, "Packet has already been marked as not http. Maybe firewall rules are wrong?"); - return (struct mark_op) {false, 0}; + return (struct mark_op){false, 0}; } if (pkt->conn_mark == CONNMARK_HTTP) { - return (struct mark_op) {false, 0}; + return (struct mark_op){false, 0}; } if (has_ua) { - return (struct mark_op) {true, CONNMARK_HTTP}; + return (struct mark_op){true, CONNMARK_HTTP}; } if (!pkt->has_connmark || pkt->conn_mark == 0) { - return (struct mark_op) {true, CONNMARK_ESTIMATE_LOWER}; + return (struct mark_op){true, CONNMARK_ESTIMATE_LOWER}; } if (pkt->conn_mark == CONNMARK_ESTIMATE_VERDICT) { add_to_cache(pkt); - return (struct mark_op) {true, CONNMARK_NOT_HTTP}; + return (struct mark_op){true, CONNMARK_NOT_HTTP}; } if (pkt->conn_mark >= CONNMARK_ESTIMATE_LOWER && pkt->conn_mark <= CONNMARK_ESTIMATE_UPPER) { - return (struct mark_op) {true, pkt->conn_mark + 1}; + return (struct mark_op){true, pkt->conn_mark + 1}; } syslog(LOG_WARNING, "Unexpected connmark value: %d, Maybe other program has changed connmark?", pkt->conn_mark); - return (struct mark_op) {true, pkt->conn_mark + 1}; + return (struct mark_op){true, pkt->conn_mark + 1}; } bool should_ignore(const struct nf_packet *pkt) { @@ -186,7 +183,7 @@ void handle_packet(const struct nf_queue *queue, const struct nf_packet *pkt) { syslog(LOG_WARNING, "Note that this may lead to performance degradation. Especially on low-end routers."); } else { if (!cache_initialized) { - init_not_http_cache(); + init_not_http_cache(60); cache_initialized = true; } } @@ -194,7 +191,7 @@ void handle_packet(const struct nf_queue *queue, const struct nf_packet *pkt) { struct pkt_buff *pkt_buff = NULL; if (conntrack_info_available && should_ignore(pkt)) { - send_verdict(queue, pkt, (struct mark_op) {true, CONNMARK_NOT_HTTP}, NULL); + send_verdict(queue, pkt, (struct mark_op){true, CONNMARK_NOT_HTTP}, NULL); goto end; } @@ -238,7 +235,7 @@ void handle_packet(const struct nf_queue *queue, const struct nf_packet *pkt) { const __auto_type tcp_hdr = nfq_tcp_get_hdr(pkt_buff); if (tcp_hdr == NULL) { // This packet is not tcp, pass it - send_verdict(queue, pkt, (struct mark_op) {false, 0}, NULL); + send_verdict(queue, pkt, (struct mark_op){false, 0}, NULL); syslog(LOG_WARNING, "Received non-tcp packet. You may set wrong firewall rules."); goto end; } @@ -259,13 +256,13 @@ void handle_packet(const struct nf_queue *queue, const struct nf_packet *pkt) { goto end; } -// FIXME: can lead to false positive, -// should also get CTA_COUNTERS_ORIG to check if this packet is a initial tcp packet + // FIXME: can lead to false positive, + // should also get CTA_COUNTERS_ORIG to check if this packet is a initial tcp packet -// if (!is_http_protocol(tcp_payload, tcp_payload_len)) { -// send_verdict(queue, pkt, get_next_mark(pkt, false), NULL); -// goto end; -// } + // if (!is_http_protocol(tcp_payload, tcp_payload_len)) { + // send_verdict(queue, pkt, get_next_mark(pkt, false), NULL); + // goto end; + // } count_http_packet(); const void *search_start = tcp_payload; @@ -288,7 +285,7 @@ void handle_packet(const struct nf_queue *queue, const struct nf_packet *pkt) { void *ua_start = ua_pos + USER_AGENT_MATCH_LENGTH; // for non-standard user-agent like User-Agent:XXX with no space after colon - if (*(char *) ua_start == ' ') { + if (*(char *)ua_start == ' ') { ua_start++; } @@ -318,7 +315,7 @@ void handle_packet(const struct nf_queue *queue, const struct nf_packet *pkt) { send_verdict(queue, pkt, get_next_mark(pkt, has_ua), pkt_buff); - end: +end: free(pkt->payload); if (pkt_buff != NULL) { pktb_free(pkt_buff); diff --git a/src/handler.h b/src/handler.h index 4af7f3f..6b2ec79 100644 --- a/src/handler.h +++ b/src/handler.h @@ -7,4 +7,4 @@ void init_handler(); void handle_packet(const struct nf_queue *queue, const struct nf_packet *pkt); -#endif //UA2F_HANDLER_H +#endif // UA2F_HANDLER_H diff --git a/src/statistics.c b/src/statistics.c index 6c19539..07a10cb 100644 --- a/src/statistics.c +++ b/src/statistics.c @@ -1,8 +1,8 @@ +#include "statistics.h" #include #include -#include #include -#include "statistics.h" +#include static long long user_agent_packet_count = 0; static long long http_packet_count = 0; @@ -19,30 +19,20 @@ void init_statistics() { syslog(LOG_INFO, "Statistics initialized."); } -void count_user_agent_packet() { - user_agent_packet_count++; -} +void count_user_agent_packet() { user_agent_packet_count++; } -void count_tcp_packet() { - tcp_packet_count++; -} +void count_tcp_packet() { tcp_packet_count++; } -void count_http_packet() { - http_packet_count++; -} +void count_http_packet() { http_packet_count++; } -void count_ipv4_packet() { - ipv4_packet_count++; -} +void count_ipv4_packet() { ipv4_packet_count++; } -void count_ipv6_packet() { - ipv6_packet_count++; -} +void count_ipv6_packet() { ipv6_packet_count++; } static char time_string_buffer[100]; char *fill_time_string(const double sec) { - const int s = (int) sec; + const int s = (int)sec; memset(time_string_buffer, 0, sizeof(time_string_buffer)); if (s <= 60) { sprintf(time_string_buffer, "%d seconds", s); @@ -52,8 +42,7 @@ char *fill_time_string(const double sec) { sprintf(time_string_buffer, "%d hours, %d minutes and %d seconds", s / 3600, s % 3600 / 60, s % 60); } else { sprintf(time_string_buffer, "%d days, %d hours, %d minutes and %d seconds", s / 86400, s % 86400 / 3600, - s % 3600 / 60, - s % 60); + s % 3600 / 60, s % 60); } return time_string_buffer; } @@ -62,17 +51,8 @@ void try_print_statistics() { if (user_agent_packet_count / last_report_count == 2 || user_agent_packet_count - last_report_count >= 8192) { last_report_count = user_agent_packet_count; const time_t current_t = time(NULL); - syslog( - LOG_INFO, - "UA2F has handled %lld ua http, %lld http, %lld tcp. %lld ipv4, %lld ipv6 packets in %s.", - user_agent_packet_count, - http_packet_count, - tcp_packet_count, - ipv4_packet_count, - ipv6_packet_count, - fill_time_string(difftime(current_t, start_t)) - ); + syslog(LOG_INFO, "UA2F has handled %lld ua http, %lld http, %lld tcp. %lld ipv4, %lld ipv6 packets in %s.", + user_agent_packet_count, http_packet_count, tcp_packet_count, ipv4_packet_count, ipv6_packet_count, + fill_time_string(difftime(current_t, start_t))); } } - - diff --git a/src/ua2f.c b/src/ua2f.c index f6cb01a..9b8c2f6 100644 --- a/src/ua2f.c +++ b/src/ua2f.c @@ -1,17 +1,17 @@ -#include "statistics.h" -#include "handler.h" -#include "util.h" #include "cli.h" +#include "handler.h" +#include "statistics.h" #include "third/nfqueue-mnl.h" +#include "util.h" #ifdef UA2F_ENABLE_UCI #include "config.h" #endif +#include #include #include #include -#include #pragma clang diagnostic push #pragma ide diagnostic ignored "EndlessLoop" @@ -65,4 +65,4 @@ int main(const int argc, char *argv[]) { nfqueue_close(queue); } -#pragma clang diagnostic pop \ No newline at end of file +#pragma clang diagnostic pop diff --git a/src/util.c b/src/util.c index 36f46ca..3f07850 100644 --- a/src/util.c +++ b/src/util.c @@ -1,7 +1,7 @@ -#include -#include #include #include +#include +#include void *memncasemem(const void *l, size_t l_len, const void *s, const size_t s_len) { register char *cur, *last; @@ -18,15 +18,15 @@ void *memncasemem(const void *l, size_t l_len, const void *s, const size_t s_len /* special case where s_len == 1 */ if (s_len == 1) { - for (cur = (char *) cl; l_len--; cur++) + for (cur = (char *)cl; l_len--; cur++) if (tolower(cur[0]) == tolower(cs[0])) return cur; } /* the last position where its possible to find "s" in "l" */ - last = (char *) cl + l_len - s_len; + last = (char *)cl + l_len - s_len; - for (cur = (char *) cl; cur <= last; cur++) + for (cur = (char *)cl; cur <= last; cur++) if (tolower(cur[0]) == tolower(cs[0])) { if (strncasecmp(cur, cs, s_len) == 0) { return cur; @@ -47,7 +47,9 @@ static bool probe_http_method(const char *p, const int len, const char *opt) { bool is_http_protocol(const char *p, const unsigned int len) { bool pass = false; -#define PROBE_HTTP_METHOD(opt) if ((pass = probe_http_method(p, len, opt)) != false) return pass +#define PROBE_HTTP_METHOD(opt) \ + if ((pass = probe_http_method(p, len, opt)) != false) \ + return pass PROBE_HTTP_METHOD("GET"); PROBE_HTTP_METHOD("POST"); @@ -60,4 +62,4 @@ bool is_http_protocol(const char *p, const unsigned int len) { #undef PROBE_HTTP_METHOD return false; -} \ No newline at end of file +} diff --git a/src/util.h b/src/util.h index 6f80df2..e1e9d50 100644 --- a/src/util.h +++ b/src/util.h @@ -1,12 +1,12 @@ #ifndef UA2F_UTIL_H #define UA2F_UTIL_H -#include #include +#include #define QUEUE_NUM 10010 void *memncasemem(const void *l, size_t l_len, const void *s, size_t s_len); bool is_http_protocol(const char *p, unsigned int len); -#endif //UA2F_UTIL_H +#endif // UA2F_UTIL_H diff --git a/test/cache_test.cc b/test/cache_test.cc new file mode 100644 index 0000000..e9e9c17 --- /dev/null +++ b/test/cache_test.cc @@ -0,0 +1,43 @@ +#include + +extern "C" { +#include +} + +#define CACHE_TIMEOUT 2 + +class CacheTest : public ::testing::Test +{ +protected: + void SetUp() override + { + init_not_http_cache(CACHE_TIMEOUT); + } +}; + + +TEST_F(CacheTest, CacheAddAndContains) +{ + const char* addr_port = "127.0.0.1:2335"; + cache_add(addr_port); + EXPECT_TRUE(cache_contains(addr_port)); +} + +TEST_F(CacheTest, CacheDoesNotContainAfterTimeout) +{ + const char* addr_port = "127.0.0.1:2334"; + cache_add(addr_port); + sleep(CACHE_TIMEOUT * 2 + 2); + EXPECT_FALSE(cache_contains(addr_port)); +} + +TEST_F(CacheTest, CacheContainsAfterRenewal) +{ + const char* addr_port = "127.0.0.1:2333"; + cache_add(addr_port); + EXPECT_TRUE(cache_contains(addr_port)); + sleep(CACHE_TIMEOUT * 2 + 2); + EXPECT_FALSE(cache_contains(addr_port)); + cache_add(addr_port); + EXPECT_TRUE(cache_contains(addr_port)); +}