From 4fb3290cf6e0307ec4a7cded9598e35bd360f301 Mon Sep 17 00:00:00 2001 From: Flavio Fois Date: Tue, 24 Mar 2026 08:56:05 +0100 Subject: [PATCH] add rate limiting configuration for authenticated and unauthenticated requests --- .env.example | 12 +++ go.mod | 5 +- go.sum | 2 - internal/config/config.go | 41 ++++++++ internal/middleware/ratelimit.ban.go | 150 +++++++++++++++++---------- internal/routes/v1/v1.go | 9 +- internal/routes/v2/v2.go | 9 +- main.go | 11 +- 8 files changed, 155 insertions(+), 84 deletions(-) diff --git a/.env.example b/.env.example index 15f472f..473f591 100644 --- a/.env.example +++ b/.env.example @@ -11,3 +11,15 @@ DATABASE_NAME=emly # API Keys API_KEY=key-one ADMIN_KEY=admin-key-one + +# Rate Limiting (unauthenticated: no X-API-Key / X-Admin-Key) +RL_UNAUTH_MAX_REQS=10 +RL_UNAUTH_WINDOW=5m +RL_UNAUTH_MAX_FAILS=5 +RL_UNAUTH_BAN_DUR=15m + +# Rate Limiting (authenticated: X-API-Key or X-Admin-Key present) +RL_AUTH_MAX_REQS=100 +RL_AUTH_WINDOW=1m +RL_AUTH_MAX_FAILS=20 +RL_AUTH_BAN_DUR=5m diff --git a/go.mod b/go.mod index 367c912..909c0d0 100644 --- a/go.mod +++ b/go.mod @@ -10,10 +10,7 @@ require ( golang.org/x/crypto v0.49.0 ) -require ( - golang.org/x/sys v0.42.0 // indirect - golang.org/x/time v0.15.0 // indirect -) +require golang.org/x/sys v0.42.0 // indirect require ( filippo.io/edwards25519 v1.1.1 // indirect diff --git a/go.sum b/go.sum index 43d1c1f..0822857 100644 --- a/go.sum +++ b/go.sum @@ -23,5 +23,3 @@ golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= -golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= -golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno= diff --git a/internal/config/config.go b/internal/config/config.go index 1158bea..7222c00 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -5,8 +5,20 @@ import ( "strconv" "strings" "sync" + "time" ) +type RateLimitConfig struct { + UnauthMaxReqs int + UnauthWindow time.Duration + UnauthMaxFails int + UnauthBanDur time.Duration + AuthMaxReqs int + AuthWindow time.Duration + AuthMaxFails int + AuthBanDur time.Duration +} + type Config struct { Port string DSN string @@ -16,6 +28,7 @@ type Config struct { MaxOpenConns int MaxIdleConns int ConnMaxLifetime int + RateLimit RateLimitConfig } var ( @@ -85,5 +98,33 @@ func load() *Config { MaxOpenConns: maxOpenConns, MaxIdleConns: maxIdleConns, ConnMaxLifetime: connMaxLifetime, + RateLimit: RateLimitConfig{ + UnauthMaxReqs: envInt("RL_UNAUTH_MAX_REQS", 10), + UnauthWindow: envDuration("RL_UNAUTH_WINDOW", 5*time.Minute), + UnauthMaxFails: envInt("RL_UNAUTH_MAX_FAILS", 5), + UnauthBanDur: envDuration("RL_UNAUTH_BAN_DUR", 15*time.Minute), + AuthMaxReqs: envInt("RL_AUTH_MAX_REQS", 100), + AuthWindow: envDuration("RL_AUTH_WINDOW", time.Minute), + AuthMaxFails: envInt("RL_AUTH_MAX_FAILS", 20), + AuthBanDur: envDuration("RL_AUTH_BAN_DUR", 5*time.Minute), + }, } } + +func envInt(key string, fallback int) int { + if s := os.Getenv(key); s != "" { + if n, err := strconv.Atoi(s); err == nil { + return n + } + } + return fallback +} + +func envDuration(key string, fallback time.Duration) time.Duration { + if s := os.Getenv(key); s != "" { + if d, err := time.ParseDuration(s); err == nil { + return d + } + } + return fallback +} diff --git a/internal/middleware/ratelimit.ban.go b/internal/middleware/ratelimit.ban.go index b43c895..497a5fc 100644 --- a/internal/middleware/ratelimit.ban.go +++ b/internal/middleware/ratelimit.ban.go @@ -1,54 +1,70 @@ -// middleware/ratelimit.go package middleware import ( + "log" "net" "net/http" "sync" "time" - "golang.org/x/time/rate" + "emly-api-go/internal/config" ) -type visitor struct { - limiter *rate.Limiter - lastSeen time.Time - failures int +type limitConfig struct { + maxReqs int + window time.Duration + maxFails int + banDur time.Duration +} + +type ipState struct { + count int + windowStart time.Time + failures int + lastSeen time.Time } type RateLimiter struct { - mu sync.Mutex - visitors map[string]*visitor - banned sync.Map // ip -> unban time + mu sync.Mutex + unauthVisitors map[string]*ipState + authVisitors map[string]*ipState + banned sync.Map // ip -> unban time (shared) - // config - rps rate.Limit // richieste/sec normali - burst int - maxFails int // quanti 429 prima del ban - banDur time.Duration // durata ban + unauthCfg limitConfig + authCfg limitConfig cleanEvery time.Duration } -func NewRateLimiter(rps float64, burst, maxFails int, banDur time.Duration) *RateLimiter { +// NewRateLimiter creates a two-tier rate limiter configured from cfg: +// - Unauthenticated (no X-API-Key / X-Admin-Key): RL_UNAUTH_* env vars +// - Authenticated (X-API-Key or X-Admin-Key present): RL_AUTH_* env vars +func NewRateLimiter(cfg *config.Config) *RateLimiter { rl := &RateLimiter{ - visitors: make(map[string]*visitor), - rps: rate.Limit(rps), - burst: burst, - maxFails: maxFails, - banDur: banDur, - cleanEvery: 5 * time.Minute, + unauthVisitors: make(map[string]*ipState), + authVisitors: make(map[string]*ipState), + unauthCfg: limitConfig{ + maxReqs: cfg.RateLimit.UnauthMaxReqs, + window: cfg.RateLimit.UnauthWindow, + maxFails: cfg.RateLimit.UnauthMaxFails, + banDur: cfg.RateLimit.UnauthBanDur, + }, + authCfg: limitConfig{ + maxReqs: cfg.RateLimit.AuthMaxReqs, + window: cfg.RateLimit.AuthWindow, + maxFails: cfg.RateLimit.AuthMaxFails, + banDur: cfg.RateLimit.AuthBanDur, + }, + cleanEvery: 10 * time.Minute, } go rl.cleanupLoop() return rl } func (rl *RateLimiter) getIP(r *http.Request) string { - // Rispetta X-Forwarded-For se dietro Traefik/proxy if ip := r.Header.Get("X-Real-IP"); ip != "" { return ip } if ip := r.Header.Get("X-Forwarded-For"); ip != "" { - // Prendi il primo IP (quello del client originale) if h, _, err := net.SplitHostPort(ip); err == nil { return h } @@ -58,62 +74,84 @@ func (rl *RateLimiter) getIP(r *http.Request) string { return host } -func (rl *RateLimiter) getVisitor(ip string) *visitor { +func (rl *RateLimiter) isAuthenticated(r *http.Request) bool { + return r.Header.Get("X-API-Key") != "" || r.Header.Get("X-Admin-Key") != "" +} + +// record increments the counter for the IP and returns whether the limit was +// exceeded, the current failure count, and whether the IP should be banned. +func (rl *RateLimiter) record(ip string, auth bool) (exceeded bool, failures int, shouldBan bool, banDur time.Duration) { rl.mu.Lock() defer rl.mu.Unlock() - v, ok := rl.visitors[ip] - if !ok { - v = &visitor{ - limiter: rate.NewLimiter(rl.rps, rl.burst), - } - rl.visitors[ip] = v + var visitors map[string]*ipState + var cfg limitConfig + if auth { + visitors = rl.authVisitors + cfg = rl.authCfg + } else { + visitors = rl.unauthVisitors + cfg = rl.unauthCfg } - v.lastSeen = time.Now() - return v + + v, ok := visitors[ip] + if !ok { + v = &ipState{windowStart: time.Now()} + visitors[ip] = v + } + + now := time.Now() + v.lastSeen = now + + // Roll the window if expired + if now.Sub(v.windowStart) >= cfg.window { + v.count = 0 + v.windowStart = now + } + + v.count++ + + if v.count > cfg.maxReqs { + v.failures++ + return true, v.failures, v.failures >= cfg.maxFails, cfg.banDur + } + + // Legitimate request within limit — reset failure streak + v.failures = 0 + return false, 0, false, 0 } func (rl *RateLimiter) Handler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ip := rl.getIP(r) - // Controlla ban attivo + // Check active ban if unbanAt, banned := rl.banned.Load(ip); banned { if time.Now().Before(unbanAt.(time.Time)) { w.Header().Set("Retry-After", unbanAt.(time.Time).Format(time.RFC1123)) http.Error(w, "too many requests - temporarily banned", http.StatusForbidden) return } - // Ban scaduto rl.banned.Delete(ip) } - v := rl.getVisitor(ip) + auth := rl.isAuthenticated(r) + exceeded, failures, shouldBan, banDur := rl.record(ip, auth) - if !v.limiter.Allow() { - rl.mu.Lock() - v.failures++ - fails := v.failures - rl.mu.Unlock() - - if fails >= rl.maxFails { - unbanAt := time.Now().Add(rl.banDur) + if exceeded { + if shouldBan { + unbanAt := time.Now().Add(banDur) rl.banned.Store(ip, unbanAt) - // Opzionale: loga il ban + log.Printf("[RATE-LIMIT] IP %s banned until %s (path: %s, auth: %v)", ip, unbanAt.Format(time.RFC1123), r.URL.Path, auth) w.Header().Set("Retry-After", unbanAt.Format(time.RFC1123)) http.Error(w, "banned", http.StatusForbidden) return } - + log.Printf("[RATE-LIMIT] IP %s exceeded limit — violation %d (path: %s, auth: %v)", ip, failures, r.URL.Path, auth) http.Error(w, "rate limit exceeded", http.StatusTooManyRequests) return } - // Reset failures su richiesta legittima - rl.mu.Lock() - v.failures = 0 - rl.mu.Unlock() - next.ServeHTTP(w, r) }) } @@ -123,13 +161,17 @@ func (rl *RateLimiter) cleanupLoop() { defer ticker.Stop() for range ticker.C { rl.mu.Lock() - for ip, v := range rl.visitors { - if time.Since(v.lastSeen) > 10*time.Minute { - delete(rl.visitors, ip) + for ip, v := range rl.unauthVisitors { + if time.Since(v.lastSeen) > rl.unauthCfg.window*2 { + delete(rl.unauthVisitors, ip) + } + } + for ip, v := range rl.authVisitors { + if time.Since(v.lastSeen) > rl.authCfg.window*2 { + delete(rl.authVisitors, ip) } } rl.mu.Unlock() - // Pulisci anche i ban scaduti rl.banned.Range(func(k, v any) bool { if time.Now().After(v.(time.Time)) { rl.banned.Delete(k) diff --git a/internal/routes/v1/v1.go b/internal/routes/v1/v1.go index f0af0d3..ec205e2 100644 --- a/internal/routes/v1/v1.go +++ b/internal/routes/v1/v1.go @@ -3,8 +3,8 @@ package v1 import ( emlyMiddleware "emly-api-go/internal/middleware" "net/http" - "time" + "emly-api-go/internal/config" "emly-api-go/internal/handlers" "github.com/go-chi/chi/v5" @@ -17,12 +17,7 @@ import ( func NewRouter(db *sqlx.DB) http.Handler { r := chi.NewRouter() - rl := emlyMiddleware.NewRateLimiter( - 5, // 5 req/sec per IP - 10, // burst fino a 10 - 20, // ban dopo 20 violazioni - 15*time.Minute, // ban di 15 minuti - ) + rl := emlyMiddleware.NewRateLimiter(config.Load()) r.Use(rl.Handler) diff --git a/internal/routes/v2/v2.go b/internal/routes/v2/v2.go index 710a3da..1beefdd 100644 --- a/internal/routes/v2/v2.go +++ b/internal/routes/v2/v2.go @@ -3,8 +3,8 @@ package v2 import ( emlyMiddleware "emly-api-go/internal/middleware" "net/http" - "time" + "emly-api-go/internal/config" "emly-api-go/internal/handlers" "github.com/go-chi/chi/v5" @@ -17,12 +17,7 @@ import ( func NewRouter(db *sqlx.DB) http.Handler { r := chi.NewRouter() - rl := emlyMiddleware.NewRateLimiter( - 5, // 5 req/sec per IP - 10, // burst fino a 10 - 20, // ban dopo 20 violazioni - 15*time.Minute, // ban di 15 minuti - ) + rl := emlyMiddleware.NewRateLimiter(config.Load()) r.Use(rl.Handler) diff --git a/main.go b/main.go index 2777ffb..c77bdf5 100644 --- a/main.go +++ b/main.go @@ -8,7 +8,6 @@ import ( "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" - "github.com/go-chi/httprate" "github.com/jmoiron/sqlx" "github.com/joho/godotenv" @@ -51,15 +50,7 @@ func main() { r.Use(middleware.Recoverer) r.Use(middleware.Timeout(30 * time.Second)) - // Global rate limit to 100 requests per minute - r.Use(httprate.LimitByIP(100, time.Minute)) - - rl := emlyMiddleware.NewRateLimiter( - 5, // 5 req/sec per IP - 10, // burst fino a 10 - 20, // ban dopo 20 violazioni - 30*time.Minute, // ban di 15 minuti - ) + rl := emlyMiddleware.NewRateLimiter(cfg) r.Use(rl.Handler)