diff --git a/src/cache.c b/src/cache.c index bafd307..60828f7 100644 --- a/src/cache.c +++ b/src/cache.c @@ -8,7 +8,7 @@ pthread_rwlock_t cacheLock; -static struct cache *not_http_dst_cache = NULL; +struct cache *not_http_dst_cache; static int check_interval; _Noreturn static void check_cache() { @@ -27,7 +27,6 @@ _Noreturn static void check_cache() { pthread_rwlock_unlock(&cacheLock); - // wait for 1 minute sleep(check_interval); } } @@ -50,11 +49,11 @@ void init_not_http_cache(const int interval) { syslog(LOG_INFO, "Cleanup thread created"); } -bool cache_contains(const char *addr_port) { +bool cache_contains(struct addr_port target) { pthread_rwlock_rdlock(&cacheLock); struct cache *s; - HASH_FIND_STR(not_http_dst_cache, addr_port, s); + HASH_FIND(hh, not_http_dst_cache, &target, sizeof(struct addr_port), s); pthread_rwlock_unlock(&cacheLock); @@ -76,19 +75,18 @@ bool cache_contains(const char *addr_port) { return false; } -void cache_add(const char *addr_port) { +void cache_add(struct addr_port addr_port) { pthread_rwlock_wrlock(&cacheLock); struct cache *s; - HASH_FIND_STR(not_http_dst_cache, addr_port, s); - if (s != NULL) { - s->last_time = time(NULL); - } else { + + HASH_FIND(hh, not_http_dst_cache, &addr_port, sizeof(struct addr_port), s); + if (s == NULL) { s = malloc(sizeof(struct cache)); - strcpy(s->addr_port, addr_port); - s->last_time = time(NULL); - HASH_ADD_STR(not_http_dst_cache, addr_port, s); + memcpy(&s->target.addr, &addr_port, sizeof(struct addr_port)); + HASH_ADD(hh, not_http_dst_cache, target.addr, sizeof(struct addr_port), s); } + s->last_time = time(NULL); pthread_rwlock_unlock(&cacheLock); } diff --git a/src/cache.h b/src/cache.h index 61b1602..3ac783f 100644 --- a/src/cache.h +++ b/src/cache.h @@ -3,24 +3,30 @@ #include #include +#include + +#include "third/nfqueue-mnl.h" #include "third/uthash.h" -#define INET6_ADDRSTRLEN 46 -// 1111:1111:1111:1111:1111:1111:111.111.111.111:65535 -// with null terminator -#define MAX_ADDR_PORT_LENGTH (INET6_ADDRSTRLEN + 7) +struct addr_port { + ip_address_t addr; + uint16_t port; +}; struct cache { - char addr_port[MAX_ADDR_PORT_LENGTH]; + struct addr_port target; time_t last_time; UT_hash_handle hh; }; +extern struct cache *not_http_dst_cache; +extern pthread_rwlock_t cacheLock; + void init_not_http_cache(int interval); // add addr_port to cache, assume it's not a http dst -void cache_add(const char *addr_port); +void cache_add(struct addr_port addr_port); -bool cache_contains(const char *addr_port); +bool cache_contains(struct addr_port addr_port); #endif // UA2F_CACHE_H diff --git a/src/handler.c b/src/handler.c index 5f2331a..0409045 100644 --- a/src/handler.c +++ b/src/handler.c @@ -56,27 +56,6 @@ void init_handler() { syslog(LOG_INFO, "Handler initialized."); } -// should free the ret value -static char *ip_to_str(const ip_address_t *ip, const uint16_t port, const int ip_version) { - ASSERT(ip_version == IPV4 || ip_version == IPV6); - char *ip_buf = malloc(MAX_ADDR_PORT_LENGTH); - memset(ip_buf, 0, MAX_ADDR_PORT_LENGTH); - const char *retval = NULL; - - if (ip_version == IPV4) { - retval = inet_ntop(AF_INET, &ip->in4, ip_buf, INET_ADDRSTRLEN); - } else { - retval = inet_ntop(AF_INET6, &ip->in6, ip_buf, INET6_ADDRSTRLEN); - } - ASSERT(retval != NULL); - - char port_buf[7]; - sprintf(port_buf, ":%d", port); - strcat(ip_buf, port_buf); - - return ip_buf; -} - struct mark_op { bool should_set; uint32_t mark; @@ -89,7 +68,7 @@ static void send_verdict(const struct nf_queue *queue, const struct nf_packet *p syslog(LOG_ERR, "failed to put nfqueue header"); goto end; } - nfq_nlmsg_verdict_put(nlh, pkt->packet_id, NF_ACCEPT); + nfq_nlmsg_verdict_put(nlh, (int)pkt->packet_id, NF_ACCEPT); if (mark.should_set) { struct nlattr *nest = mnl_attr_nest_start_check(nlh, SEND_BUF_LEN, NFQA_CT); @@ -123,9 +102,12 @@ static bool conntrack_info_available = true; static bool cache_initialized = false; static void add_to_cache(const struct nf_packet *pkt) { - char *ip_str = ip_to_str(&pkt->orig.dst, pkt->orig.dst_port, pkt->orig.ip_version); - cache_add(ip_str); - free(ip_str); + struct addr_port target = { + .addr = pkt->orig.dst, + .port = pkt->orig.dst_port, + }; + + cache_add(target); } static struct mark_op get_next_mark(const struct nf_packet *pkt, const bool has_ua) { @@ -167,10 +149,12 @@ static struct mark_op get_next_mark(const struct nf_packet *pkt, const bool has_ bool should_ignore(const struct nf_packet *pkt) { bool retval = false; + struct addr_port target = { + .addr = pkt->orig.dst, + .port = pkt->orig.dst_port, + }; - char *ip_str = ip_to_str(&pkt->orig.dst, pkt->orig.dst_port, pkt->orig.ip_version); - retval = cache_contains(ip_str); - free(ip_str); + retval = cache_contains(target); return retval; } diff --git a/test/cache_test.cc b/test/cache_test.cc index e9e9c17..430bc4c 100644 --- a/test/cache_test.cc +++ b/test/cache_test.cc @@ -4,40 +4,45 @@ extern "C" { #include } -#define CACHE_TIMEOUT 2 - -class CacheTest : public ::testing::Test -{ +class CacheTest : public ::testing::Test { protected: - void SetUp() override - { - init_not_http_cache(CACHE_TIMEOUT); + ip_address_t test_addr{}; + + void SetUp() override { + test_addr.ip4 = 12345; + init_not_http_cache(2); + } + + void TearDown() override { + pthread_rwlock_wrlock(&cacheLock); + // Clear the cache after each test + struct cache *cur, *tmp; + HASH_ITER(hh, not_http_dst_cache, cur, tmp) { + HASH_DEL(not_http_dst_cache, cur); + free(cur); + } + pthread_rwlock_unlock(&cacheLock); } }; - -TEST_F(CacheTest, CacheAddAndContains) -{ - const char* addr_port = "127.0.0.1:2335"; - cache_add(addr_port); - EXPECT_TRUE(cache_contains(addr_port)); +TEST_F(CacheTest, CacheInitiallyEmpty) { + EXPECT_FALSE(cache_contains(test_addr)); } -TEST_F(CacheTest, CacheDoesNotContainAfterTimeout) -{ - const char* addr_port = "127.0.0.1:2334"; - cache_add(addr_port); - sleep(CACHE_TIMEOUT * 2 + 2); - EXPECT_FALSE(cache_contains(addr_port)); +TEST_F(CacheTest, AddToCache) { + cache_add(test_addr); + EXPECT_TRUE(cache_contains(test_addr)); } -TEST_F(CacheTest, CacheContainsAfterRenewal) -{ - const char* addr_port = "127.0.0.1:2333"; - cache_add(addr_port); - EXPECT_TRUE(cache_contains(addr_port)); - sleep(CACHE_TIMEOUT * 2 + 2); - EXPECT_FALSE(cache_contains(addr_port)); - cache_add(addr_port); - EXPECT_TRUE(cache_contains(addr_port)); +TEST_F(CacheTest, AddAndRemoveFromCache) { + cache_add(test_addr); + EXPECT_TRUE(cache_contains(test_addr)); + sleep(5); + EXPECT_FALSE(cache_contains(test_addr)); } + +TEST_F(CacheTest, CacheDoesNotContainNonexistentEntry) { + ip_address_t nonexistent_addr; + nonexistent_addr.ip4 = 54321; // Assign a value different from test_addr + EXPECT_FALSE(cache_contains(nonexistent_addr)); +} \ No newline at end of file