Compare commits
7 Commits
e7678bc1d4
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fa1f65baf7 | ||
|
|
e6d663f4f2 | ||
|
|
09a760e025 | ||
|
|
858b0642d9 | ||
|
|
4fb3290cf6 | ||
|
|
9d4a1b7ef3 | ||
|
|
69b3a917d3 |
35
.env.example
35
.env.example
@@ -1,13 +1,42 @@
|
||||
# Server Settings
|
||||
PORT=8080
|
||||
|
||||
# Infrastruttura Docker (Traefik + MySQL)
|
||||
API_DOMAIN=api.esempio.com
|
||||
ACME_EMAIL=tua@email.com
|
||||
MYSQL_ROOT_PASSWORD=password-sicura
|
||||
|
||||
# DB Settings
|
||||
# DB_DRIVER: "mysql" (default) o "sqlite"
|
||||
DB_DRIVER=mysql
|
||||
|
||||
# MySQL
|
||||
DB_DSN=root:secret@tcp(127.0.0.1:3306)/emly?parseTime=true&loc=UTC
|
||||
MAX_OPEN_CONNS=25
|
||||
MAX_IDLE_CONNS=5
|
||||
CONN_MAX_LIFETIME=5m
|
||||
DB_MAX_OPEN_CONNS=25
|
||||
DB_MAX_IDLE_CONNS=5
|
||||
DB_CONN_MAX_LIFETIME=5
|
||||
DATABASE_NAME=emly
|
||||
|
||||
# SQLite (usare invece di MySQL: DB_DRIVER=sqlite, DB_DSN=./data.db, DATABASE_NAME non necessario)
|
||||
# DB_DSN=./data.db
|
||||
|
||||
# API Keys
|
||||
API_KEY=key-one
|
||||
ADMIN_KEY=admin-key-one
|
||||
|
||||
# Rate Limiting — Traefik edge (condiviso tra repliche)
|
||||
TRAEFIK_RL_AVERAGE=30
|
||||
TRAEFIK_RL_BURST=10
|
||||
TRAEFIK_RL_PERIOD=1m
|
||||
|
||||
# Rate Limiting — App (unauthenticated: no X-API-Key / X-Admin-Key)
|
||||
RL_UNAUTH_MAX_REQS=10
|
||||
RL_UNAUTH_WINDOW=5m
|
||||
RL_UNAUTH_MAX_FAILS=5
|
||||
RL_UNAUTH_BAN_DUR=15m
|
||||
|
||||
# Rate Limiting — App (authenticated: X-API-Key or X-Admin-Key present)
|
||||
RL_AUTH_MAX_REQS=100
|
||||
RL_AUTH_WINDOW=1m
|
||||
RL_AUTH_MAX_FAILS=20
|
||||
RL_AUTH_BAN_DUR=5m
|
||||
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -38,4 +38,7 @@ go.work.sum
|
||||
|
||||
tmp/
|
||||
|
||||
build/
|
||||
build/
|
||||
|
||||
# Database files
|
||||
*.db
|
||||
154
docker-compose-prod.yml
Normal file
154
docker-compose-prod.yml
Normal file
@@ -0,0 +1,154 @@
|
||||
networks:
|
||||
traefik_public:
|
||||
driver: bridge
|
||||
internal:
|
||||
driver: bridge
|
||||
internal: true
|
||||
|
||||
volumes:
|
||||
mysql_data:
|
||||
traefik_certs:
|
||||
logs:
|
||||
|
||||
# ── Anchor: variabili d'ambiente comuni a tutte le repliche ─────────────────
|
||||
x-api-env: &api-env
|
||||
PORT: "8080"
|
||||
DB_DSN: "root:${MYSQL_ROOT_PASSWORD}@tcp(mysql:3306)/${DATABASE_NAME}?parseTime=true&loc=UTC"
|
||||
DATABASE_NAME: ${DATABASE_NAME}
|
||||
API_KEY: ${API_KEY}
|
||||
ADMIN_KEY: ${ADMIN_KEY}
|
||||
DB_MAX_OPEN_CONNS: ${DB_MAX_OPEN_CONNS:-25}
|
||||
DB_MAX_IDLE_CONNS: ${DB_MAX_IDLE_CONNS:-5}
|
||||
DB_CONN_MAX_LIFETIME: ${DB_CONN_MAX_LIFETIME:-5}
|
||||
RL_UNAUTH_MAX_REQS: ${RL_UNAUTH_MAX_REQS:-10}
|
||||
RL_UNAUTH_WINDOW: ${RL_UNAUTH_WINDOW:-5m}
|
||||
RL_UNAUTH_MAX_FAILS: ${RL_UNAUTH_MAX_FAILS:-5}
|
||||
RL_UNAUTH_BAN_DUR: ${RL_UNAUTH_BAN_DUR:-15m}
|
||||
RL_AUTH_MAX_REQS: ${RL_AUTH_MAX_REQS:-100}
|
||||
RL_AUTH_WINDOW: ${RL_AUTH_WINDOW:-1m}
|
||||
RL_AUTH_MAX_FAILS: ${RL_AUTH_MAX_FAILS:-20}
|
||||
RL_AUTH_BAN_DUR: ${RL_AUTH_BAN_DUR:-5m}
|
||||
|
||||
# ── Anchor: configurazione base del servizio API ────────────────────────────
|
||||
x-api-base: &api-base
|
||||
build: .
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
- traefik_public
|
||||
- internal
|
||||
volumes:
|
||||
- logs:/logs
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "10m"
|
||||
max-file: "5"
|
||||
depends_on:
|
||||
mysql:
|
||||
condition: service_healthy
|
||||
labels:
|
||||
# Traefik: abilita il container e definisce il router HTTPS
|
||||
- "traefik.enable=true"
|
||||
- "traefik.http.routers.emly-api.rule=Host(`${API_DOMAIN}`)"
|
||||
- "traefik.http.routers.emly-api.entrypoints=websecure"
|
||||
- "traefik.http.routers.emly-api.tls.certresolver=letsencrypt"
|
||||
- "traefik.http.routers.emly-api.middlewares=rl,hsts"
|
||||
# Load balancer: tutte le repliche condividono lo stesso service name
|
||||
- "traefik.http.services.emly-api.loadbalancer.server.port=8080"
|
||||
- "traefik.http.services.emly-api.loadbalancer.healthcheck.path=/v1/health"
|
||||
- "traefik.http.services.emly-api.loadbalancer.healthcheck.interval=10s"
|
||||
- "traefik.http.services.emly-api.loadbalancer.healthcheck.timeout=3s"
|
||||
# Rate limiting edge (condiviso tra repliche, applicato prima del LB)
|
||||
- "traefik.http.middlewares.rl.ratelimit.average=${TRAEFIK_RL_AVERAGE:-30}"
|
||||
- "traefik.http.middlewares.rl.ratelimit.burst=${TRAEFIK_RL_BURST:-10}"
|
||||
- "traefik.http.middlewares.rl.ratelimit.period=${TRAEFIK_RL_PERIOD:-1m}"
|
||||
# HSTS
|
||||
- "traefik.http.middlewares.hsts.headers.stsSeconds=31536000"
|
||||
- "traefik.http.middlewares.hsts.headers.stsIncludeSubdomains=true"
|
||||
- "traefik.http.middlewares.hsts.headers.forceSTSHeader=true"
|
||||
# Watchtower: aggiorna automaticamente questa immagine
|
||||
- "com.centurylinklabs.watchtower.enable=true"
|
||||
|
||||
# ── Servizi ─────────────────────────────────────────────────────────────────
|
||||
services:
|
||||
|
||||
# ── Traefik ──────────────────────────────────────────────────────────────
|
||||
traefik:
|
||||
image: traefik:v3
|
||||
restart: unless-stopped
|
||||
command:
|
||||
- "--api=false"
|
||||
# Entry points
|
||||
- "--entrypoints.web.address=:80"
|
||||
- "--entrypoints.web.http.redirections.entrypoint.to=websecure"
|
||||
- "--entrypoints.web.http.redirections.entrypoint.scheme=https"
|
||||
- "--entrypoints.websecure.address=:443"
|
||||
# Docker provider
|
||||
- "--providers.docker=true"
|
||||
- "--providers.docker.exposedbydefault=false"
|
||||
- "--providers.docker.network=traefik_public"
|
||||
# ACME / Let's Encrypt (TLS challenge)
|
||||
- "--certificatesresolvers.letsencrypt.acme.email=${ACME_EMAIL}"
|
||||
- "--certificatesresolvers.letsencrypt.acme.storage=/certificates/acme.json"
|
||||
- "--certificatesresolvers.letsencrypt.acme.tlschallenge=true"
|
||||
- "--log.level=INFO"
|
||||
ports:
|
||||
- "80:80"
|
||||
- "443:443"
|
||||
volumes:
|
||||
- /var/run/docker.sock:/var/run/docker.sock:ro
|
||||
- traefik_certs:/certificates
|
||||
networks:
|
||||
- traefik_public
|
||||
|
||||
# ── MySQL ────────────────────────────────────────────────────────────────
|
||||
mysql:
|
||||
image: mysql:8
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
MYSQL_ROOT_PASSWORD: ${MYSQL_ROOT_PASSWORD}
|
||||
MYSQL_DATABASE: ${DATABASE_NAME}
|
||||
volumes:
|
||||
- mysql_data:/var/lib/mysql
|
||||
networks:
|
||||
- internal
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "mysqladmin ping -h localhost -p${MYSQL_ROOT_PASSWORD} --silent"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
start_period: 30s
|
||||
|
||||
# ── API replica 1 ────────────────────────────────────────────────────────
|
||||
api-1:
|
||||
<<: *api-base
|
||||
environment:
|
||||
<<: *api-env
|
||||
INSTANCE_NAME: api-1
|
||||
|
||||
# ── API replica 2 ────────────────────────────────────────────────────────
|
||||
api-2:
|
||||
<<: *api-base
|
||||
environment:
|
||||
<<: *api-env
|
||||
INSTANCE_NAME: api-2
|
||||
|
||||
# ── API replica 3 ────────────────────────────────────────────────────────
|
||||
api-3:
|
||||
<<: *api-base
|
||||
environment:
|
||||
<<: *api-env
|
||||
INSTANCE_NAME: api-3
|
||||
|
||||
# ── Watchtower ───────────────────────────────────────────────────────────
|
||||
watchtower:
|
||||
image: containrrr/watchtower
|
||||
restart: unless-stopped
|
||||
command:
|
||||
- "--cleanup"
|
||||
- "--schedule"
|
||||
- "0 0 * * * *"
|
||||
environment:
|
||||
WATCHTOWER_LABEL_ENABLE: "true"
|
||||
volumes:
|
||||
- /var/run/docker.sock:/var/run/docker.sock:ro
|
||||
10
go.mod
10
go.mod
@@ -11,8 +11,16 @@ require (
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
golang.org/x/sys v0.42.0 // indirect
|
||||
golang.org/x/time v0.15.0 // indirect
|
||||
modernc.org/libc v1.70.0 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
modernc.org/memory v1.11.0 // indirect
|
||||
modernc.org/sqlite v1.47.0 // indirect
|
||||
)
|
||||
|
||||
require (
|
||||
|
||||
21
go.sum
21
go.sum
@@ -2,6 +2,8 @@ filippo.io/edwards25519 v1.1.1 h1:YpjwWWlNmGIDyXOn8zLzqiD+9TyIlPhGFG96P39uBpw=
|
||||
filippo.io/edwards25519 v1.1.1/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/go-chi/chi/v5 v5.2.4 h1:WtFKPHwlywe8Srng8j2BhOD9312j9cGUxG1SP4V2cR4=
|
||||
github.com/go-chi/chi/v5 v5.2.4/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0=
|
||||
github.com/go-chi/httprate v0.14.1 h1:EKZHYEZ58Cg6hWcYzoZILsv7ppb46Wt4uQ738IRtpZs=
|
||||
@@ -9,19 +11,34 @@ github.com/go-chi/httprate v0.14.1/go.mod h1:TUepLXaz/pCjmCtf/obgOQJ2Sz6rC8fSf5c
|
||||
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
|
||||
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
|
||||
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g=
|
||||
github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ=
|
||||
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||
github.com/lib/pq v1.2.0 h1:LXpIM/LZ5xGFhOpXAQUIMM1HdyqzVYM13zNdjCEEcA0=
|
||||
github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg=
|
||||
github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
|
||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
|
||||
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
|
||||
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
|
||||
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
|
||||
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
|
||||
modernc.org/libc v1.70.0 h1:U58NawXqXbgpZ/dcdS9kMshu08aiA6b7gusEusqzNkw=
|
||||
modernc.org/libc v1.70.0/go.mod h1:OVmxFGP1CI/Z4L3E0Q3Mf1PDE0BucwMkcXjjLntvHJo=
|
||||
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
||||
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
||||
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
|
||||
modernc.org/sqlite v1.47.0 h1:R1XyaNpoW4Et9yly+I2EeX7pBza/w+pmYee/0HJDyKk=
|
||||
modernc.org/sqlite v1.47.0/go.mod h1:hWjRO6Tj/5Ik8ieqxQybiEOUXy0NJFNp2tpvVpKlvig=
|
||||
|
||||
@@ -5,10 +5,23 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type RateLimitConfig struct {
|
||||
UnauthMaxReqs int
|
||||
UnauthWindow time.Duration
|
||||
UnauthMaxFails int
|
||||
UnauthBanDur time.Duration
|
||||
AuthMaxReqs int
|
||||
AuthWindow time.Duration
|
||||
AuthMaxFails int
|
||||
AuthBanDur time.Duration
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Port string
|
||||
Driver string
|
||||
DSN string
|
||||
Database string
|
||||
APIKey string
|
||||
@@ -16,6 +29,7 @@ type Config struct {
|
||||
MaxOpenConns int
|
||||
MaxIdleConns int
|
||||
ConnMaxLifetime int
|
||||
RateLimit RateLimitConfig
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -67,9 +81,19 @@ func load() *Config {
|
||||
connMaxLifetime = 5
|
||||
}
|
||||
|
||||
dbName := os.Getenv("DATABASE_NAME")
|
||||
if dbName == "" {
|
||||
panic("DATABASE_NAME environment variable is required")
|
||||
driver := os.Getenv("DB_DRIVER")
|
||||
if driver == "" {
|
||||
driver = "mysql"
|
||||
}
|
||||
|
||||
var dbName string
|
||||
if driver == "sqlite" {
|
||||
dbName = "main"
|
||||
} else {
|
||||
dbName = os.Getenv("DATABASE_NAME")
|
||||
if dbName == "" {
|
||||
panic("DATABASE_NAME environment variable is required")
|
||||
}
|
||||
}
|
||||
|
||||
if os.Getenv("DB_DSN") == "" {
|
||||
@@ -78,6 +102,7 @@ func load() *Config {
|
||||
|
||||
return &Config{
|
||||
Port: port,
|
||||
Driver: driver,
|
||||
DSN: os.Getenv("DB_DSN"),
|
||||
Database: dbName,
|
||||
APIKey: apiKey,
|
||||
@@ -85,5 +110,33 @@ func load() *Config {
|
||||
MaxOpenConns: maxOpenConns,
|
||||
MaxIdleConns: maxIdleConns,
|
||||
ConnMaxLifetime: connMaxLifetime,
|
||||
RateLimit: RateLimitConfig{
|
||||
UnauthMaxReqs: envInt("RL_UNAUTH_MAX_REQS", 10),
|
||||
UnauthWindow: envDuration("RL_UNAUTH_WINDOW", 5*time.Minute),
|
||||
UnauthMaxFails: envInt("RL_UNAUTH_MAX_FAILS", 5),
|
||||
UnauthBanDur: envDuration("RL_UNAUTH_BAN_DUR", 15*time.Minute),
|
||||
AuthMaxReqs: envInt("RL_AUTH_MAX_REQS", 100),
|
||||
AuthWindow: envDuration("RL_AUTH_WINDOW", time.Minute),
|
||||
AuthMaxFails: envInt("RL_AUTH_MAX_FAILS", 20),
|
||||
AuthBanDur: envDuration("RL_AUTH_BAN_DUR", 5*time.Minute),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func envInt(key string, fallback int) int {
|
||||
if s := os.Getenv(key); s != "" {
|
||||
if n, err := strconv.Atoi(s); err == nil {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func envDuration(key string, fallback time.Duration) time.Duration {
|
||||
if s := os.Getenv(key); s != "" {
|
||||
if d, err := time.ParseDuration(s); err == nil {
|
||||
return d
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
@@ -1,23 +1,41 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/jmoiron/sqlx"
|
||||
_ "modernc.org/sqlite"
|
||||
|
||||
"emly-api-go/internal/config"
|
||||
)
|
||||
|
||||
func Connect(cfg *config.Config) (*sqlx.DB, error) {
|
||||
db, err := sqlx.Connect("mysql", cfg.DSN)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var db *sqlx.DB
|
||||
var err error
|
||||
|
||||
db.SetMaxOpenConns(cfg.MaxOpenConns)
|
||||
db.SetMaxIdleConns(cfg.MaxIdleConns)
|
||||
db.SetConnMaxLifetime(time.Duration(cfg.ConnMaxLifetime) * time.Minute)
|
||||
switch cfg.Driver {
|
||||
case "sqlite":
|
||||
db, err = sqlx.Connect("sqlite", cfg.DSN)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Enable foreign key support (disabled by default in SQLite)
|
||||
if _, err = db.Exec("PRAGMA foreign_keys = ON"); err != nil {
|
||||
return nil, fmt.Errorf("sqlite: enable foreign_keys: %w", err)
|
||||
}
|
||||
case "mysql":
|
||||
db, err = sqlx.Connect("mysql", cfg.DSN)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
db.SetMaxOpenConns(cfg.MaxOpenConns)
|
||||
db.SetMaxIdleConns(cfg.MaxIdleConns)
|
||||
db.SetConnMaxLifetime(time.Duration(cfg.ConnMaxLifetime) * time.Minute)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported DB_DRIVER %q: must be mysql or sqlite", cfg.Driver)
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
//go:embed init.sql migrations/*.json migrations/*.sql
|
||||
//go:embed mysql sqlite
|
||||
var migrationsFS embed.FS
|
||||
|
||||
type taskFile struct {
|
||||
@@ -31,62 +31,43 @@ type condition struct {
|
||||
Index string `json:"index,omitempty"`
|
||||
}
|
||||
|
||||
// Migrate reads migrations/tasks.json and executes every task whose
|
||||
// conditions are ALL satisfied (i.e. logical AND).
|
||||
func Migrate(db *sqlx.DB, dbName string) error {
|
||||
// If the database has no tables at all, bootstrap with init.sql.
|
||||
empty, err := schemaIsEmpty(db, dbName)
|
||||
// Migrate reads the driver-specific migrations and applies them.
|
||||
func Migrate(db *sqlx.DB, dbName string, driver string) error {
|
||||
empty, err := schemaIsEmpty(db, dbName, driver)
|
||||
if err != nil {
|
||||
return fmt.Errorf("schema: check empty: %w", err)
|
||||
}
|
||||
if empty {
|
||||
log.Println("[migrate] empty schema detected – running init.sql")
|
||||
initSQL, err := migrationsFS.ReadFile("init.sql")
|
||||
if err != nil {
|
||||
return fmt.Errorf("schema: read init.sql: %w", err)
|
||||
if err := runInitSQL(db, driver); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, stmt := range splitStatements(string(initSQL)) {
|
||||
if _, err := db.Exec(stmt); err != nil {
|
||||
return fmt.Errorf("schema: exec init.sql: %w\nSQL: %s", err, stmt)
|
||||
}
|
||||
}
|
||||
log.Println("[migrate] init.sql applied – base schema created")
|
||||
} else {
|
||||
log.Println("[migrate] checking if tables exist")
|
||||
// Check if the tables are there or not
|
||||
var tableNames []string
|
||||
tableNames := []string{"bug_reports", "bug_report_files", "rate_limit_hwid", "user", "session"}
|
||||
var foundTables []string
|
||||
tableNames = append(tableNames, "bug_reports", "bug_report_files", "rate_limit_hwid", "user", "session")
|
||||
for _, tableName := range tableNames {
|
||||
found, err := tableExists(db, dbName, tableName)
|
||||
found, err := tableExists(db, dbName, tableName, driver)
|
||||
if err != nil {
|
||||
return fmt.Errorf("schema: check table %s: %w", tableName, err)
|
||||
}
|
||||
if !found {
|
||||
log.Printf("[migrate] warning: expected table %s not found – schema may be in an inconsistent state", tableName)
|
||||
log.Printf("[migrate] warning: expected table %s not found", tableName)
|
||||
continue
|
||||
}
|
||||
foundTables = append(foundTables, tableName)
|
||||
}
|
||||
if len(foundTables) != len(tableNames) {
|
||||
log.Printf("[migrate] warning: expected %d tables, found %d", len(tableNames), len(foundTables))
|
||||
log.Printf("[migrate] info: running init.sql")
|
||||
initSQL, err := migrationsFS.ReadFile("init.sql")
|
||||
if err != nil {
|
||||
return fmt.Errorf("schema: read init.sql: %w", err)
|
||||
log.Printf("[migrate] warning: expected %d tables, found %d – running init.sql", len(tableNames), len(foundTables))
|
||||
if err := runInitSQL(db, driver); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, stmt := range splitStatements(string(initSQL)) {
|
||||
if _, err := db.Exec(stmt); err != nil {
|
||||
return fmt.Errorf("schema: exec init.sql: %w\nSQL: %s", err, stmt)
|
||||
}
|
||||
}
|
||||
log.Println("[migrate] init.sql applied – base schema created")
|
||||
} else {
|
||||
log.Println("[migrate] all expected tables found – skipping init.sql")
|
||||
}
|
||||
}
|
||||
|
||||
raw, err := migrationsFS.ReadFile("migrations/tasks.json")
|
||||
raw, err := migrationsFS.ReadFile(driver + "/migrations/tasks.json")
|
||||
if err != nil {
|
||||
return fmt.Errorf("schema: read tasks.json: %w", err)
|
||||
}
|
||||
@@ -97,7 +78,7 @@ func Migrate(db *sqlx.DB, dbName string) error {
|
||||
}
|
||||
|
||||
for _, t := range tf.Tasks {
|
||||
needed, err := shouldRun(db, dbName, t.Conditions)
|
||||
needed, err := shouldRun(db, dbName, t.Conditions, driver)
|
||||
if err != nil {
|
||||
return fmt.Errorf("schema: evaluate conditions for %s: %w", t.ID, err)
|
||||
}
|
||||
@@ -106,7 +87,7 @@ func Migrate(db *sqlx.DB, dbName string) error {
|
||||
continue
|
||||
}
|
||||
|
||||
sqlBytes, err := migrationsFS.ReadFile("migrations/" + t.SQLFile)
|
||||
sqlBytes, err := migrationsFS.ReadFile(driver + "/migrations/" + t.SQLFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("schema: read %s: %w", t.SQLFile, err)
|
||||
}
|
||||
@@ -122,11 +103,25 @@ func Migrate(db *sqlx.DB, dbName string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func runInitSQL(db *sqlx.DB, driver string) error {
|
||||
initSQL, err := migrationsFS.ReadFile(driver + "/init.sql")
|
||||
if err != nil {
|
||||
return fmt.Errorf("schema: read init.sql: %w", err)
|
||||
}
|
||||
for _, stmt := range splitStatements(string(initSQL)) {
|
||||
if _, err := db.Exec(stmt); err != nil {
|
||||
return fmt.Errorf("schema: exec init.sql: %w\nSQL: %s", err, stmt)
|
||||
}
|
||||
}
|
||||
log.Println("[migrate] init.sql applied – base schema created")
|
||||
return nil
|
||||
}
|
||||
|
||||
// ---------- Condition evaluator ----------
|
||||
|
||||
func shouldRun(db *sqlx.DB, dbName string, conds []condition) (bool, error) {
|
||||
func shouldRun(db *sqlx.DB, dbName string, conds []condition, driver string) (bool, error) {
|
||||
for _, c := range conds {
|
||||
met, err := evaluate(db, dbName, c)
|
||||
met, err := evaluate(db, dbName, c, driver)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -137,81 +132,186 @@ func shouldRun(db *sqlx.DB, dbName string, conds []condition) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func evaluate(db *sqlx.DB, dbName string, c condition) (bool, error) {
|
||||
func evaluate(db *sqlx.DB, dbName string, c condition, driver string) (bool, error) {
|
||||
switch c.Type {
|
||||
case "column_not_exists":
|
||||
exists, err := columnExists(db, dbName, c.Table, c.Column)
|
||||
exists, err := columnExists(db, dbName, c.Table, c.Column, driver)
|
||||
return !exists, err
|
||||
|
||||
case "column_exists":
|
||||
return columnExists(db, dbName, c.Table, c.Column)
|
||||
return columnExists(db, dbName, c.Table, c.Column, driver)
|
||||
|
||||
case "index_not_exists":
|
||||
exists, err := indexExists(db, dbName, c.Table, c.Index)
|
||||
exists, err := indexExists(db, dbName, c.Table, c.Index, driver)
|
||||
return !exists, err
|
||||
|
||||
case "index_exists":
|
||||
return indexExists(db, dbName, c.Table, c.Index)
|
||||
return indexExists(db, dbName, c.Table, c.Index, driver)
|
||||
|
||||
case "table_not_exists":
|
||||
exists, err := tableExists(db, dbName, c.Table)
|
||||
exists, err := tableExists(db, dbName, c.Table, driver)
|
||||
return !exists, err
|
||||
|
||||
case "table_exists":
|
||||
return tableExists(db, dbName, c.Table)
|
||||
return tableExists(db, dbName, c.Table, driver)
|
||||
|
||||
default:
|
||||
return false, fmt.Errorf("unknown condition type: %s", c.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func columnExists(db *sqlx.DB, dbName, table, column string) (bool, error) {
|
||||
// ---------- MySQL condition checks ----------
|
||||
|
||||
func columnExistsMySQL(db *sqlx.DB, dbName, table, column string) (bool, error) {
|
||||
var count int
|
||||
err := db.Get(&count,
|
||||
`SELECT COUNT(*) FROM information_schema.COLUMNS
|
||||
WHERE TABLE_SCHEMA = ?
|
||||
AND TABLE_NAME = ?
|
||||
AND COLUMN_NAME = ?`, dbName, table, column)
|
||||
WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? AND COLUMN_NAME = ?`,
|
||||
dbName, table, column)
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func indexExists(db *sqlx.DB, dbName, table, index string) (bool, error) {
|
||||
func indexExistsMySQL(db *sqlx.DB, dbName, table, index string) (bool, error) {
|
||||
var count int
|
||||
err := db.Get(&count,
|
||||
`SELECT COUNT(*) FROM information_schema.STATISTICS
|
||||
WHERE TABLE_SCHEMA = ?
|
||||
AND TABLE_NAME = ?
|
||||
AND INDEX_NAME = ?`, dbName, table, index)
|
||||
WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? AND INDEX_NAME = ?`,
|
||||
dbName, table, index)
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func tableExists(db *sqlx.DB, dbName, table string) (bool, error) {
|
||||
func tableExistsMySQL(db *sqlx.DB, dbName, table string) (bool, error) {
|
||||
var count int
|
||||
err := db.Get(&count,
|
||||
`SELECT COUNT(*) FROM information_schema.TABLES
|
||||
WHERE TABLE_SCHEMA = ?
|
||||
AND TABLE_NAME = ?`, dbName, table)
|
||||
WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?`,
|
||||
dbName, table)
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func schemaIsEmpty(db *sqlx.DB, dbName string) (bool, error) {
|
||||
func schemaIsEmptyMySQL(db *sqlx.DB, dbName string) (bool, error) {
|
||||
var count int
|
||||
err := db.Get(&count,
|
||||
`SELECT COUNT(*) FROM information_schema.TABLES
|
||||
WHERE TABLE_SCHEMA = ?`, dbName)
|
||||
`SELECT COUNT(*) FROM information_schema.TABLES WHERE TABLE_SCHEMA = ?`, dbName)
|
||||
return count == 0, err
|
||||
}
|
||||
|
||||
// splitStatements splits a SQL blob on ";" respecting only top-level
|
||||
// semicolons (good enough for simple ALTER / CREATE statements).
|
||||
// ---------- SQLite condition checks ----------
|
||||
|
||||
func columnExistsSQLite(db *sqlx.DB, table, column string) (bool, error) {
|
||||
var count int
|
||||
// pragma_table_info is a table-valued function available since SQLite 3.16.0
|
||||
err := db.Get(&count,
|
||||
fmt.Sprintf("SELECT COUNT(*) FROM pragma_table_info('%s') WHERE name = ?", table),
|
||||
column)
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func indexExistsSQLite(db *sqlx.DB, table, index string) (bool, error) {
|
||||
var count int
|
||||
err := db.Get(&count,
|
||||
`SELECT COUNT(*) FROM sqlite_master WHERE type='index' AND tbl_name=? AND name=?`,
|
||||
table, index)
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func tableExistsSQLite(db *sqlx.DB, table string) (bool, error) {
|
||||
var count int
|
||||
err := db.Get(&count,
|
||||
`SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?`, table)
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func schemaIsEmptySQLite(db *sqlx.DB) (bool, error) {
|
||||
var count int
|
||||
err := db.Get(&count, `SELECT COUNT(*) FROM sqlite_master WHERE type='table'`)
|
||||
return count == 0, err
|
||||
}
|
||||
|
||||
// ---------- Driver-dispatched wrappers ----------
|
||||
|
||||
func columnExists(db *sqlx.DB, dbName, table, column, driver string) (bool, error) {
|
||||
if driver == "sqlite" {
|
||||
return columnExistsSQLite(db, table, column)
|
||||
}
|
||||
return columnExistsMySQL(db, dbName, table, column)
|
||||
}
|
||||
|
||||
func indexExists(db *sqlx.DB, dbName, table, index, driver string) (bool, error) {
|
||||
if driver == "sqlite" {
|
||||
return indexExistsSQLite(db, table, index)
|
||||
}
|
||||
return indexExistsMySQL(db, dbName, table, index)
|
||||
}
|
||||
|
||||
func tableExists(db *sqlx.DB, dbName, table, driver string) (bool, error) {
|
||||
if driver == "sqlite" {
|
||||
return tableExistsSQLite(db, table)
|
||||
}
|
||||
return tableExistsMySQL(db, dbName, table)
|
||||
}
|
||||
|
||||
func schemaIsEmpty(db *sqlx.DB, dbName, driver string) (bool, error) {
|
||||
if driver == "sqlite" {
|
||||
return schemaIsEmptySQLite(db)
|
||||
}
|
||||
return schemaIsEmptyMySQL(db, dbName)
|
||||
}
|
||||
|
||||
// splitStatements splits a SQL blob on top-level ";" only, respecting
|
||||
// BEGIN...END blocks (e.g. triggers) so their inner semicolons are not split.
|
||||
func splitStatements(sql string) []string {
|
||||
raw := strings.Split(sql, ";")
|
||||
out := make([]string, 0, len(raw))
|
||||
for _, s := range raw {
|
||||
s = strings.TrimSpace(s)
|
||||
if s != "" {
|
||||
out = append(out, s)
|
||||
var out []string
|
||||
var buf strings.Builder
|
||||
depth := 0
|
||||
n := len(sql)
|
||||
|
||||
for i := 0; i < n; {
|
||||
c := sql[i]
|
||||
|
||||
// Collect whole identifier tokens to detect BEGIN / END keywords.
|
||||
if isIdentStart(c) {
|
||||
j := i
|
||||
for j < n && isIdentChar(sql[j]) {
|
||||
j++
|
||||
}
|
||||
word := strings.ToUpper(sql[i:j])
|
||||
switch word {
|
||||
case "BEGIN":
|
||||
depth++
|
||||
case "END":
|
||||
if depth > 0 {
|
||||
depth--
|
||||
}
|
||||
}
|
||||
buf.WriteString(sql[i:j])
|
||||
i = j
|
||||
continue
|
||||
}
|
||||
|
||||
if c == ';' && depth == 0 {
|
||||
if stmt := strings.TrimSpace(buf.String()); stmt != "" {
|
||||
out = append(out, stmt)
|
||||
}
|
||||
buf.Reset()
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
buf.WriteByte(c)
|
||||
i++
|
||||
}
|
||||
|
||||
if stmt := strings.TrimSpace(buf.String()); stmt != "" {
|
||||
out = append(out, stmt)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func isIdentStart(c byte) bool {
|
||||
return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || c == '_'
|
||||
}
|
||||
|
||||
func isIdentChar(c byte) bool {
|
||||
return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_'
|
||||
}
|
||||
|
||||
65
internal/database/schema/sqlite/init.sql
Normal file
65
internal/database/schema/sqlite/init.sql
Normal file
@@ -0,0 +1,65 @@
|
||||
CREATE TABLE IF NOT EXISTS bug_reports (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL,
|
||||
email TEXT NOT NULL,
|
||||
description TEXT NOT NULL,
|
||||
hwid TEXT NOT NULL DEFAULT '',
|
||||
hostname TEXT NOT NULL DEFAULT '',
|
||||
os_user TEXT NOT NULL DEFAULT '',
|
||||
submitter_ip TEXT NOT NULL DEFAULT '',
|
||||
system_info TEXT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'new' CHECK(status IN ('new','in_review','resolved','closed')),
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_status ON bug_reports(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_hwid ON bug_reports(hwid);
|
||||
CREATE INDEX IF NOT EXISTS idx_hostname ON bug_reports(hostname);
|
||||
CREATE INDEX IF NOT EXISTS idx_os_user ON bug_reports(os_user);
|
||||
CREATE INDEX IF NOT EXISTS idx_created_at ON bug_reports(created_at);
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS trg_bug_reports_updated_at
|
||||
AFTER UPDATE ON bug_reports
|
||||
FOR EACH ROW
|
||||
WHEN NEW.updated_at = OLD.updated_at
|
||||
BEGIN
|
||||
UPDATE bug_reports SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
|
||||
END;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS bug_report_files (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
report_id INTEGER NOT NULL,
|
||||
file_role TEXT NOT NULL CHECK(file_role IN ('screenshot','mail_file','localstorage','config','system_info')),
|
||||
filename TEXT NOT NULL,
|
||||
mime_type TEXT NOT NULL DEFAULT 'application/octet-stream',
|
||||
file_size INTEGER NOT NULL DEFAULT 0,
|
||||
data BLOB NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (report_id) REFERENCES bug_reports(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_report_id ON bug_report_files(report_id);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS rate_limit_hwid (
|
||||
hwid TEXT PRIMARY KEY,
|
||||
window_start DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
count INTEGER NOT NULL DEFAULT 0
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS user (
|
||||
id TEXT PRIMARY KEY,
|
||||
username TEXT NOT NULL UNIQUE,
|
||||
password_hash TEXT NOT NULL,
|
||||
role TEXT NOT NULL DEFAULT 'user' CHECK(role IN ('admin','user')),
|
||||
enabled INTEGER NOT NULL DEFAULT 1,
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
displayname TEXT NOT NULL DEFAULT ''
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS session (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL,
|
||||
expires_at DATETIME NOT NULL,
|
||||
FOREIGN KEY (user_id) REFERENCES user(id) ON DELETE CASCADE
|
||||
);
|
||||
3
internal/database/schema/sqlite/migrations/tasks.json
Normal file
3
internal/database/schema/sqlite/migrations/tasks.json
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"tasks": []
|
||||
}
|
||||
@@ -41,15 +41,13 @@ var fileRoles = []struct {
|
||||
{"config", models.FileRoleConfig, "application/json"},
|
||||
}
|
||||
|
||||
func CreateBugReport(db *sqlx.DB) http.HandlerFunc {
|
||||
func CreateBugReport(db *sqlx.DB, dbName string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseMultipartForm(32 << 20); err != nil {
|
||||
jsonError(w, http.StatusBadRequest, "invalid multipart form: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
log.Println("Req form value", r.Form)
|
||||
|
||||
name := r.FormValue("name")
|
||||
email := r.FormValue("email")
|
||||
description := r.FormValue("description")
|
||||
@@ -79,7 +77,7 @@ func CreateBugReport(db *sqlx.DB) http.HandlerFunc {
|
||||
log.Printf("[BUGREPORT] Received from name=%s hwid=%s ip=%s", name, hwid, submitterIP)
|
||||
|
||||
result, err := db.ExecContext(r.Context(),
|
||||
"INSERT INTO emly_bugreports_dev.bug_reports (name, email, description, hwid, hostname, os_user, submitter_ip, system_info, status) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
fmt.Sprintf("INSERT INTO %s.bug_reports (name, email, description, hwid, hostname, os_user, submitter_ip, system_info, status) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", dbName),
|
||||
name, email, description, hwid, hostname, osUser, submitterIP, systemInfo, models.BugReportStatusNew,
|
||||
)
|
||||
if err != nil {
|
||||
@@ -94,9 +92,7 @@ func CreateBugReport(db *sqlx.DB) http.HandlerFunc {
|
||||
}
|
||||
|
||||
for _, fr := range fileRoles {
|
||||
log.Println("Processing file role", fr.field)
|
||||
file, header, err := r.FormFile(fr.field)
|
||||
log.Printf("FormFile for field %s returned error: %v", fr.field, err)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
@@ -125,7 +121,7 @@ func CreateBugReport(db *sqlx.DB) http.HandlerFunc {
|
||||
log.Printf("[BUGREPORT] File uploaded: role=%s size=%d bytes", fr.role, len(data))
|
||||
|
||||
_, err = db.ExecContext(r.Context(),
|
||||
"INSERT INTO emly_bugreports_dev.bug_report_files (report_id, file_role, filename, mime_type, file_size, data) VALUES (?, ?, ?, ?, ?, ?)",
|
||||
fmt.Sprintf("INSERT INTO %s.bug_report_files (report_id, file_role, filename, mime_type, file_size, data) VALUES (?, ?, ?, ?, ?, ?)", dbName),
|
||||
reportID, fr.role, filename, mimeType, len(data), data,
|
||||
)
|
||||
if err != nil {
|
||||
@@ -144,7 +140,7 @@ func CreateBugReport(db *sqlx.DB) http.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func GetAllBugReports(db *sqlx.DB) http.HandlerFunc {
|
||||
func GetAllBugReports(db *sqlx.DB, dbName string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
page, pageSize := 1, 20
|
||||
if p := r.URL.Query().Get("page"); p != "" {
|
||||
@@ -176,17 +172,17 @@ func GetAllBugReports(db *sqlx.DB) http.HandlerFunc {
|
||||
}
|
||||
|
||||
var total int
|
||||
countQuery := "SELECT COUNT(*) FROM emly_bugreports_dev.bug_reports br " + whereClause
|
||||
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s.bug_reports br ", dbName) + whereClause
|
||||
if err := db.GetContext(r.Context(), &total, countQuery, params...); err != nil {
|
||||
jsonError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
mainQuery := `
|
||||
mainQuery := fmt.Sprintf(`
|
||||
SELECT br.*, COUNT(bf.id) as file_count
|
||||
FROM emly_bugreports_dev.bug_reports br
|
||||
LEFT JOIN emly_bugreports_dev.bug_report_files bf ON bf.report_id = br.id
|
||||
` + whereClause + `
|
||||
FROM %s.bug_reports br
|
||||
LEFT JOIN %s.bug_report_files bf ON bf.report_id = br.id
|
||||
`, dbName, dbName) + whereClause + `
|
||||
GROUP BY br.id
|
||||
ORDER BY br.created_at DESC
|
||||
LIMIT ? OFFSET ?`
|
||||
@@ -208,7 +204,7 @@ func GetAllBugReports(db *sqlx.DB) http.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func GetBugReportByID(db *sqlx.DB) http.HandlerFunc {
|
||||
func GetBugReportByID(db *sqlx.DB, dbName string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
id := chi.URLParam(r, "id")
|
||||
if id == "" {
|
||||
@@ -217,7 +213,7 @@ func GetBugReportByID(db *sqlx.DB) http.HandlerFunc {
|
||||
}
|
||||
|
||||
var report models.BugReport
|
||||
reportErr := db.GetContext(r.Context(), &report, "SELECT * FROM emly_bugreports_dev.bug_reports WHERE id = ?", id)
|
||||
reportErr := db.GetContext(r.Context(), &report, fmt.Sprintf("SELECT * FROM %s.bug_reports WHERE id = ?", dbName), id)
|
||||
if errors.Is(reportErr, sql.ErrNoRows) {
|
||||
jsonError(w, http.StatusNotFound, "bug report not found")
|
||||
return
|
||||
@@ -239,11 +235,11 @@ func GetBugReportByID(db *sqlx.DB) http.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func GetReportsCount(db *sqlx.DB) http.HandlerFunc {
|
||||
func GetReportsCount(db *sqlx.DB, dbName string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
rawStatus := r.URL.Query().Get("status")
|
||||
|
||||
query := "SELECT COUNT(*) FROM emly_bugreports_dev.bug_reports"
|
||||
query := fmt.Sprintf("SELECT COUNT(*) FROM %s.bug_reports", dbName)
|
||||
var args []interface{}
|
||||
|
||||
if strings.TrimSpace(rawStatus) != "" {
|
||||
@@ -266,7 +262,7 @@ func GetReportsCount(db *sqlx.DB) http.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func GetReportFilesByReportID(db *sqlx.DB) http.HandlerFunc {
|
||||
func GetReportFilesByReportID(db *sqlx.DB, dbName string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
id := chi.URLParam(r, "id")
|
||||
if id == "" {
|
||||
@@ -275,7 +271,7 @@ func GetReportFilesByReportID(db *sqlx.DB) http.HandlerFunc {
|
||||
}
|
||||
|
||||
var files []models.BugReportFile
|
||||
if err := db.SelectContext(r.Context(), &files, "SELECT * FROM emly_bugreports_dev.bug_report_files WHERE report_id = ?", id); err != nil {
|
||||
if err := db.SelectContext(r.Context(), &files, fmt.Sprintf("SELECT * FROM %s.bug_report_files WHERE report_id = ?", dbName), id); err != nil {
|
||||
jsonError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
@@ -284,7 +280,7 @@ func GetReportFilesByReportID(db *sqlx.DB) http.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func GetBugReportZipById(db *sqlx.DB) http.HandlerFunc {
|
||||
func GetBugReportZipById(db *sqlx.DB, dbName string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
id := chi.URLParam(r, "id")
|
||||
if id == "" {
|
||||
@@ -293,7 +289,7 @@ func GetBugReportZipById(db *sqlx.DB) http.HandlerFunc {
|
||||
}
|
||||
|
||||
var report models.BugReport
|
||||
err := db.GetContext(r.Context(), &report, "SELECT * FROM emly_bugreports_dev.bug_reports WHERE id = ?", id)
|
||||
err := db.GetContext(r.Context(), &report, fmt.Sprintf("SELECT * FROM %s.bug_reports WHERE id = ?", dbName), id)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
jsonError(w, http.StatusNotFound, "bug report not found")
|
||||
return
|
||||
@@ -304,7 +300,7 @@ func GetBugReportZipById(db *sqlx.DB) http.HandlerFunc {
|
||||
}
|
||||
|
||||
var files []models.BugReportFile
|
||||
if err := db.SelectContext(r.Context(), &files, "SELECT * FROM emly_bugreports_dev.bug_report_files WHERE report_id = ?", id); err != nil {
|
||||
if err := db.SelectContext(r.Context(), &files, fmt.Sprintf("SELECT * FROM %s.bug_report_files WHERE report_id = ?", dbName), id); err != nil {
|
||||
jsonError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
@@ -368,7 +364,7 @@ func GetBugReportZipById(db *sqlx.DB) http.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func GetReportFileByFileID(db *sqlx.DB) http.HandlerFunc {
|
||||
func GetReportFileByFileID(db *sqlx.DB, dbName string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
reportId := chi.URLParam(r, "id")
|
||||
if reportId == "" {
|
||||
@@ -382,7 +378,7 @@ func GetReportFileByFileID(db *sqlx.DB) http.HandlerFunc {
|
||||
}
|
||||
|
||||
var file models.BugReportFile
|
||||
err := db.GetContext(r.Context(), &file, "SELECT filename, mime_type, data FROM emly_bugreports_dev.bug_report_files WHERE report_id = ? AND id = ?", reportId, fileId)
|
||||
err := db.GetContext(r.Context(), &file, fmt.Sprintf("SELECT filename, mime_type, data FROM %s.bug_report_files WHERE report_id = ? AND id = ?", dbName), reportId, fileId)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
jsonError(w, http.StatusNotFound, "file not found")
|
||||
return
|
||||
@@ -405,7 +401,7 @@ func GetReportFileByFileID(db *sqlx.DB) http.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func GetReportStatusByID(db *sqlx.DB) http.HandlerFunc {
|
||||
func GetReportStatusByID(db *sqlx.DB, dbName string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
reportId := chi.URLParam(r, "id")
|
||||
if reportId == "" {
|
||||
@@ -414,7 +410,7 @@ func GetReportStatusByID(db *sqlx.DB) http.HandlerFunc {
|
||||
}
|
||||
|
||||
var reportStatus models.BugReportStatus
|
||||
if err := db.GetContext(r.Context(), &reportStatus, "SELECT status FROM emly_bugreports_dev.bug_reports WHERE id = ?", reportId); err != nil {
|
||||
if err := db.GetContext(r.Context(), &reportStatus, fmt.Sprintf("SELECT status FROM %s.bug_reports WHERE id = ?", dbName), reportId); err != nil {
|
||||
jsonError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
@@ -423,7 +419,7 @@ func GetReportStatusByID(db *sqlx.DB) http.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func PatchBugReportStatus(db *sqlx.DB) http.HandlerFunc {
|
||||
func PatchBugReportStatus(db *sqlx.DB, dbName string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
reportId := chi.URLParam(r, "id")
|
||||
if reportId == "" {
|
||||
@@ -438,7 +434,7 @@ func PatchBugReportStatus(db *sqlx.DB) http.HandlerFunc {
|
||||
}
|
||||
reportStatus := models.BugReportStatus(body)
|
||||
|
||||
result, err := db.ExecContext(r.Context(), "UPDATE emly_bugreports_dev.bug_reports SET status = ? WHERE id = ?", reportStatus, reportId)
|
||||
result, err := db.ExecContext(r.Context(), fmt.Sprintf("UPDATE %s.bug_reports SET status = ? WHERE id = ?", dbName), reportStatus, reportId)
|
||||
if err != nil {
|
||||
jsonError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
@@ -457,7 +453,7 @@ func PatchBugReportStatus(db *sqlx.DB) http.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func DeleteBugReportByID(db *sqlx.DB) http.HandlerFunc {
|
||||
func DeleteBugReportByID(db *sqlx.DB, dbName string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
reportId := chi.URLParam(r, "id")
|
||||
if reportId == "" {
|
||||
@@ -465,7 +461,7 @@ func DeleteBugReportByID(db *sqlx.DB) http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
result, err := db.ExecContext(r.Context(), "DELETE FROM emly_bugreports_dev.bug_reports WHERE id = ?", reportId)
|
||||
result, err := db.ExecContext(r.Context(), fmt.Sprintf("DELETE FROM %s.bug_reports WHERE id = ?", dbName), reportId)
|
||||
if err != nil {
|
||||
jsonError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
|
||||
@@ -1,54 +1,70 @@
|
||||
// middleware/ratelimit.go
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
"emly-api-go/internal/config"
|
||||
)
|
||||
|
||||
type visitor struct {
|
||||
limiter *rate.Limiter
|
||||
lastSeen time.Time
|
||||
failures int
|
||||
type limitConfig struct {
|
||||
maxReqs int
|
||||
window time.Duration
|
||||
maxFails int
|
||||
banDur time.Duration
|
||||
}
|
||||
|
||||
type ipState struct {
|
||||
count int
|
||||
windowStart time.Time
|
||||
failures int
|
||||
lastSeen time.Time
|
||||
}
|
||||
|
||||
type RateLimiter struct {
|
||||
mu sync.Mutex
|
||||
visitors map[string]*visitor
|
||||
banned sync.Map // ip -> unban time
|
||||
mu sync.Mutex
|
||||
unauthVisitors map[string]*ipState
|
||||
authVisitors map[string]*ipState
|
||||
banned sync.Map // ip -> unban time (shared)
|
||||
|
||||
// config
|
||||
rps rate.Limit // richieste/sec normali
|
||||
burst int
|
||||
maxFails int // quanti 429 prima del ban
|
||||
banDur time.Duration // durata ban
|
||||
unauthCfg limitConfig
|
||||
authCfg limitConfig
|
||||
cleanEvery time.Duration
|
||||
}
|
||||
|
||||
func NewRateLimiter(rps float64, burst, maxFails int, banDur time.Duration) *RateLimiter {
|
||||
// NewRateLimiter creates a two-tier rate limiter configured from cfg:
|
||||
// - Unauthenticated (no X-API-Key / X-Admin-Key): RL_UNAUTH_* env vars
|
||||
// - Authenticated (X-API-Key or X-Admin-Key present): RL_AUTH_* env vars
|
||||
func NewRateLimiter(cfg *config.Config) *RateLimiter {
|
||||
rl := &RateLimiter{
|
||||
visitors: make(map[string]*visitor),
|
||||
rps: rate.Limit(rps),
|
||||
burst: burst,
|
||||
maxFails: maxFails,
|
||||
banDur: banDur,
|
||||
cleanEvery: 5 * time.Minute,
|
||||
unauthVisitors: make(map[string]*ipState),
|
||||
authVisitors: make(map[string]*ipState),
|
||||
unauthCfg: limitConfig{
|
||||
maxReqs: cfg.RateLimit.UnauthMaxReqs,
|
||||
window: cfg.RateLimit.UnauthWindow,
|
||||
maxFails: cfg.RateLimit.UnauthMaxFails,
|
||||
banDur: cfg.RateLimit.UnauthBanDur,
|
||||
},
|
||||
authCfg: limitConfig{
|
||||
maxReqs: cfg.RateLimit.AuthMaxReqs,
|
||||
window: cfg.RateLimit.AuthWindow,
|
||||
maxFails: cfg.RateLimit.AuthMaxFails,
|
||||
banDur: cfg.RateLimit.AuthBanDur,
|
||||
},
|
||||
cleanEvery: 10 * time.Minute,
|
||||
}
|
||||
go rl.cleanupLoop()
|
||||
return rl
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) getIP(r *http.Request) string {
|
||||
// Rispetta X-Forwarded-For se dietro Traefik/proxy
|
||||
if ip := r.Header.Get("X-Real-IP"); ip != "" {
|
||||
return ip
|
||||
}
|
||||
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
|
||||
// Prendi il primo IP (quello del client originale)
|
||||
if h, _, err := net.SplitHostPort(ip); err == nil {
|
||||
return h
|
||||
}
|
||||
@@ -58,62 +74,83 @@ func (rl *RateLimiter) getIP(r *http.Request) string {
|
||||
return host
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) getVisitor(ip string) *visitor {
|
||||
func (rl *RateLimiter) isAuthenticated(r *http.Request) bool {
|
||||
return r.Header.Get("X-API-Key") != "" || r.Header.Get("X-Admin-Key") != ""
|
||||
}
|
||||
|
||||
// record increments the counter for the IP and returns whether the limit was
|
||||
// exceeded, the current failure count, and whether the IP should be banned.
|
||||
func (rl *RateLimiter) record(ip string, auth bool) (exceeded bool, failures int, shouldBan bool, banDur time.Duration) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
v, ok := rl.visitors[ip]
|
||||
if !ok {
|
||||
v = &visitor{
|
||||
limiter: rate.NewLimiter(rl.rps, rl.burst),
|
||||
}
|
||||
rl.visitors[ip] = v
|
||||
var visitors map[string]*ipState
|
||||
var cfg limitConfig
|
||||
if auth {
|
||||
visitors = rl.authVisitors
|
||||
cfg = rl.authCfg
|
||||
} else {
|
||||
visitors = rl.unauthVisitors
|
||||
cfg = rl.unauthCfg
|
||||
}
|
||||
v.lastSeen = time.Now()
|
||||
return v
|
||||
|
||||
v, ok := visitors[ip]
|
||||
if !ok {
|
||||
v = &ipState{windowStart: time.Now()}
|
||||
visitors[ip] = v
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
v.lastSeen = now
|
||||
|
||||
// Roll the window if expired
|
||||
if now.Sub(v.windowStart) >= cfg.window {
|
||||
v.count = 0
|
||||
v.windowStart = now
|
||||
}
|
||||
|
||||
v.count++
|
||||
|
||||
if v.count > cfg.maxReqs {
|
||||
v.failures++
|
||||
return true, v.failures, v.failures >= cfg.maxFails, cfg.banDur
|
||||
}
|
||||
|
||||
// Legitimate request within limit — reset failure streak
|
||||
v.failures = 0
|
||||
return false, 0, false, 0
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) Handler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ip := rl.getIP(r)
|
||||
|
||||
// Controlla ban attivo
|
||||
// Drop connection silently if IP is banned
|
||||
if unbanAt, banned := rl.banned.Load(ip); banned {
|
||||
if time.Now().Before(unbanAt.(time.Time)) {
|
||||
w.Header().Set("Retry-After", unbanAt.(time.Time).Format(time.RFC1123))
|
||||
http.Error(w, "too many requests - temporarily banned", http.StatusForbidden)
|
||||
return
|
||||
log.Printf("[RATE-LIMIT] IP %s dropped (banned until %s, path: %s)", ip, unbanAt.(time.Time).Format(time.RFC1123), r.URL.Path)
|
||||
panic(http.ErrAbortHandler)
|
||||
}
|
||||
// Ban scaduto
|
||||
rl.banned.Delete(ip)
|
||||
}
|
||||
|
||||
v := rl.getVisitor(ip)
|
||||
auth := rl.isAuthenticated(r)
|
||||
exceeded, failures, shouldBan, banDur := rl.record(ip, auth)
|
||||
|
||||
if !v.limiter.Allow() {
|
||||
rl.mu.Lock()
|
||||
v.failures++
|
||||
fails := v.failures
|
||||
rl.mu.Unlock()
|
||||
|
||||
if fails >= rl.maxFails {
|
||||
unbanAt := time.Now().Add(rl.banDur)
|
||||
if exceeded {
|
||||
if shouldBan {
|
||||
unbanAt := time.Now().Add(banDur)
|
||||
rl.banned.Store(ip, unbanAt)
|
||||
// Opzionale: loga il ban
|
||||
log.Printf("[RATE-LIMIT] IP %s banned until %s (path: %s, auth: %v)", ip, unbanAt.Format(time.RFC1123), r.URL.Path, auth)
|
||||
w.Header().Set("Retry-After", unbanAt.Format(time.RFC1123))
|
||||
http.Error(w, "banned", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("[RATE-LIMIT] IP %s exceeded limit — violation %d (path: %s, auth: %v)", ip, failures, r.URL.Path, auth)
|
||||
http.Error(w, "rate limit exceeded", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
|
||||
// Reset failures su richiesta legittima
|
||||
rl.mu.Lock()
|
||||
v.failures = 0
|
||||
rl.mu.Unlock()
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
@@ -123,13 +160,17 @@ func (rl *RateLimiter) cleanupLoop() {
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
rl.mu.Lock()
|
||||
for ip, v := range rl.visitors {
|
||||
if time.Since(v.lastSeen) > 10*time.Minute {
|
||||
delete(rl.visitors, ip)
|
||||
for ip, v := range rl.unauthVisitors {
|
||||
if time.Since(v.lastSeen) > rl.unauthCfg.window*2 {
|
||||
delete(rl.unauthVisitors, ip)
|
||||
}
|
||||
}
|
||||
for ip, v := range rl.authVisitors {
|
||||
if time.Since(v.lastSeen) > rl.authCfg.window*2 {
|
||||
delete(rl.authVisitors, ip)
|
||||
}
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
// Pulisci anche i ban scaduti
|
||||
rl.banned.Range(func(k, v any) bool {
|
||||
if time.Now().After(v.(time.Time)) {
|
||||
rl.banned.Delete(k)
|
||||
|
||||
@@ -11,15 +11,15 @@ import (
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
func registerBugReports(r chi.Router, db *sqlx.DB) {
|
||||
func registerBugReports(r chi.Router, db *sqlx.DB, dbName string) {
|
||||
r.Route("/bug-reports", func(r chi.Router) {
|
||||
// API key only: submit a report and check count
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(apimw.APIKeyAuth(db))
|
||||
r.Use(httprate.LimitByIP(30, time.Minute))
|
||||
|
||||
r.Get("/count", handlers.GetReportsCount(db))
|
||||
r.Post("/", handlers.CreateBugReport(db))
|
||||
r.Get("/count", handlers.GetReportsCount(db, dbName))
|
||||
r.Post("/", handlers.CreateBugReport(db, dbName))
|
||||
})
|
||||
|
||||
// API key + admin key: full read/write access
|
||||
@@ -28,14 +28,14 @@ func registerBugReports(r chi.Router, db *sqlx.DB) {
|
||||
r.Use(apimw.AdminKeyAuth(db))
|
||||
r.Use(httprate.LimitByIP(30, time.Minute))
|
||||
|
||||
r.Get("/", handlers.GetAllBugReports(db))
|
||||
r.Get("/{id}", handlers.GetBugReportByID(db))
|
||||
r.Get("/{id}/status", handlers.GetReportStatusByID(db))
|
||||
r.Get("/{id}/files", handlers.GetReportFilesByReportID(db))
|
||||
r.Get("/{id}/files/{file_id}", handlers.GetReportFileByFileID(db))
|
||||
r.Get("/{id}/download", handlers.GetBugReportZipById(db))
|
||||
r.Patch("/{id}/status", handlers.PatchBugReportStatus(db))
|
||||
r.Delete("/{id}", handlers.DeleteBugReportByID(db))
|
||||
r.Get("/", handlers.GetAllBugReports(db, dbName))
|
||||
r.Get("/{id}", handlers.GetBugReportByID(db, dbName))
|
||||
r.Get("/{id}/status", handlers.GetReportStatusByID(db, dbName))
|
||||
r.Get("/{id}/files", handlers.GetReportFilesByReportID(db, dbName))
|
||||
r.Get("/{id}/files/{file_id}", handlers.GetReportFileByFileID(db, dbName))
|
||||
r.Get("/{id}/download", handlers.GetBugReportZipById(db, dbName))
|
||||
r.Patch("/{id}/status", handlers.PatchBugReportStatus(db, dbName))
|
||||
r.Delete("/{id}", handlers.DeleteBugReportByID(db, dbName))
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -3,8 +3,8 @@ package v1
|
||||
import (
|
||||
emlyMiddleware "emly-api-go/internal/middleware"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"emly-api-go/internal/config"
|
||||
"emly-api-go/internal/handlers"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
@@ -17,12 +17,7 @@ import (
|
||||
func NewRouter(db *sqlx.DB) http.Handler {
|
||||
r := chi.NewRouter()
|
||||
|
||||
rl := emlyMiddleware.NewRateLimiter(
|
||||
5, // 5 req/sec per IP
|
||||
10, // burst fino a 10
|
||||
20, // ban dopo 20 violazioni
|
||||
15*time.Minute, // ban di 15 minuti
|
||||
)
|
||||
rl := emlyMiddleware.NewRateLimiter(config.Load())
|
||||
|
||||
r.Use(rl.Handler)
|
||||
|
||||
@@ -38,7 +33,7 @@ func NewRouter(db *sqlx.DB) http.Handler {
|
||||
|
||||
r.Route("/api", func(r chi.Router) {
|
||||
registerAdmin(r, db)
|
||||
registerBugReports(r, db)
|
||||
registerBugReports(r, db, config.Load().Database)
|
||||
})
|
||||
|
||||
return r
|
||||
|
||||
@@ -11,15 +11,15 @@ import (
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
func registerBugReports(r chi.Router, db *sqlx.DB) {
|
||||
func registerBugReports(r chi.Router, db *sqlx.DB, dbName string) {
|
||||
r.Route("/bug-report", func(r chi.Router) {
|
||||
// API key only: submit a report and check count
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(apimw.APIKeyAuth(db))
|
||||
r.Use(httprate.LimitByIP(30, time.Minute))
|
||||
|
||||
r.Get("/count", handlers.GetReportsCount(db))
|
||||
r.Post("/", handlers.CreateBugReport(db))
|
||||
r.Get("/count", handlers.GetReportsCount(db, dbName))
|
||||
r.Post("/", handlers.CreateBugReport(db, dbName))
|
||||
})
|
||||
|
||||
// API key + admin key: full read/write access
|
||||
@@ -28,14 +28,14 @@ func registerBugReports(r chi.Router, db *sqlx.DB) {
|
||||
r.Use(apimw.AdminKeyAuth(db))
|
||||
r.Use(httprate.LimitByIP(30, time.Minute))
|
||||
|
||||
r.Get("/", handlers.GetAllBugReports(db))
|
||||
r.Get("/{id}", handlers.GetBugReportByID(db))
|
||||
r.Get("/{id}/status", handlers.GetReportStatusByID(db))
|
||||
r.Get("/{id}/files", handlers.GetReportFilesByReportID(db))
|
||||
r.Get("/{id}/files/{file_id}", handlers.GetReportFileByFileID(db))
|
||||
r.Get("/{id}/download", handlers.GetBugReportZipById(db))
|
||||
r.Patch("/{id}/status", handlers.PatchBugReportStatus(db))
|
||||
r.Delete("/{id}", handlers.DeleteBugReportByID(db))
|
||||
r.Get("/", handlers.GetAllBugReports(db, dbName))
|
||||
r.Get("/{id}", handlers.GetBugReportByID(db, dbName))
|
||||
r.Get("/{id}/status", handlers.GetReportStatusByID(db, dbName))
|
||||
r.Get("/{id}/files", handlers.GetReportFilesByReportID(db, dbName))
|
||||
r.Get("/{id}/files/{file_id}", handlers.GetReportFileByFileID(db, dbName))
|
||||
r.Get("/{id}/download", handlers.GetBugReportZipById(db, dbName))
|
||||
r.Patch("/{id}/status", handlers.PatchBugReportStatus(db, dbName))
|
||||
r.Delete("/{id}", handlers.DeleteBugReportByID(db, dbName))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3,8 +3,8 @@ package v2
|
||||
import (
|
||||
emlyMiddleware "emly-api-go/internal/middleware"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"emly-api-go/internal/config"
|
||||
"emly-api-go/internal/handlers"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
@@ -17,12 +17,7 @@ import (
|
||||
func NewRouter(db *sqlx.DB) http.Handler {
|
||||
r := chi.NewRouter()
|
||||
|
||||
rl := emlyMiddleware.NewRateLimiter(
|
||||
5, // 5 req/sec per IP
|
||||
10, // burst fino a 10
|
||||
20, // ban dopo 20 violazioni
|
||||
15*time.Minute, // ban di 15 minuti
|
||||
)
|
||||
rl := emlyMiddleware.NewRateLimiter(config.Load())
|
||||
|
||||
r.Use(rl.Handler)
|
||||
|
||||
@@ -38,7 +33,7 @@ func NewRouter(db *sqlx.DB) http.Handler {
|
||||
|
||||
r.Route("/api", func(r chi.Router) {
|
||||
registerAdmin(r, db)
|
||||
registerBugReports(r, db)
|
||||
registerBugReports(r, db, config.Load().Database)
|
||||
})
|
||||
|
||||
return r
|
||||
|
||||
18
main.go
18
main.go
@@ -4,11 +4,11 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/go-chi/httprate"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/joho/godotenv"
|
||||
|
||||
@@ -24,6 +24,10 @@ func main() {
|
||||
// Load .env (ignored if not present in production)
|
||||
_ = godotenv.Load()
|
||||
|
||||
if name := os.Getenv("INSTANCE_NAME"); name != "" {
|
||||
log.SetPrefix("[" + name + "] ")
|
||||
}
|
||||
|
||||
cfg := config.Load()
|
||||
|
||||
db, err := database.Connect(cfg)
|
||||
@@ -38,7 +42,7 @@ func main() {
|
||||
}(db)
|
||||
|
||||
// Run conditional schema migrations
|
||||
if err := schema.Migrate(db, cfg.Database); err != nil {
|
||||
if err := schema.Migrate(db, cfg.Database, cfg.Driver); err != nil {
|
||||
log.Fatalf("schema migration failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -51,15 +55,7 @@ func main() {
|
||||
r.Use(middleware.Recoverer)
|
||||
r.Use(middleware.Timeout(30 * time.Second))
|
||||
|
||||
// Global rate limit to 100 requests per minute
|
||||
r.Use(httprate.LimitByIP(100, time.Minute))
|
||||
|
||||
rl := emlyMiddleware.NewRateLimiter(
|
||||
5, // 5 req/sec per IP
|
||||
10, // burst fino a 10
|
||||
20, // ban dopo 20 violazioni
|
||||
30*time.Minute, // ban di 15 minuti
|
||||
)
|
||||
rl := emlyMiddleware.NewRateLimiter(cfg)
|
||||
|
||||
r.Use(rl.Handler)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user