implement admin key authentication and refactor API key handling
This commit is contained in:
@@ -8,7 +8,8 @@ import (
|
|||||||
type Config struct {
|
type Config struct {
|
||||||
Port string
|
Port string
|
||||||
DSN string
|
DSN string
|
||||||
APIKeys []string
|
APIKey string
|
||||||
|
AdminKey string
|
||||||
}
|
}
|
||||||
|
|
||||||
func Load() *Config {
|
func Load() *Config {
|
||||||
@@ -17,18 +18,30 @@ func Load() *Config {
|
|||||||
port = "8080"
|
port = "8080"
|
||||||
}
|
}
|
||||||
|
|
||||||
raw := os.Getenv("API_KEYS")
|
raw := os.Getenv("API_KEY")
|
||||||
var keys []string
|
var apiKey string
|
||||||
for _, k := range strings.Split(raw, ",") {
|
for _, k := range strings.Split(raw, ",") {
|
||||||
k = strings.TrimSpace(k)
|
k = strings.TrimSpace(k)
|
||||||
if k != "" {
|
if k != "" {
|
||||||
keys = append(keys, k)
|
apiKey = k
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
raw = os.Getenv("ADMIN_KEY")
|
||||||
|
var adminKey string
|
||||||
|
for _, k := range strings.Split(raw, ",") {
|
||||||
|
k = strings.TrimSpace(k)
|
||||||
|
if k != "" {
|
||||||
|
adminKey = k
|
||||||
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Config{
|
return &Config{
|
||||||
Port: port,
|
Port: port,
|
||||||
DSN: os.Getenv("DB_DSN"),
|
DSN: os.Getenv("DB_DSN"),
|
||||||
APIKeys: keys,
|
APIKey: apiKey,
|
||||||
|
AdminKey: adminKey,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,24 +0,0 @@
|
|||||||
package handlers
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
)
|
|
||||||
|
|
||||||
var ExampleGet http.HandlerFunc = func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
json.NewEncoder(w).Encode(map[string]string{"message": "example GET"})
|
|
||||||
}
|
|
||||||
|
|
||||||
var ExamplePost http.HandlerFunc = func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
body, _ := io.ReadAll(r.Body)
|
|
||||||
defer r.Body.Close()
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.WriteHeader(http.StatusCreated)
|
|
||||||
json.NewEncoder(w).Encode(map[string]any{
|
|
||||||
"message": "example POST",
|
|
||||||
"received": string(body),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
36
internal/middleware/adminKey.go
Normal file
36
internal/middleware/adminKey.go
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/jmoiron/sqlx"
|
||||||
|
|
||||||
|
"emly-api-go/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func AdminKeyAuth(_ *sqlx.DB) func(http.Handler) http.Handler {
|
||||||
|
cfg := config.Load()
|
||||||
|
|
||||||
|
if len(cfg.AdminKey) == 0 {
|
||||||
|
log.Panic("API key or admin key are empty")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
allowed := make(map[string]struct{}, 1)
|
||||||
|
allowed[cfg.AdminKey] = struct{}{}
|
||||||
|
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
key := r.Header.Get("X-Admin-Key")
|
||||||
|
if _, ok := allowed[key]; !ok {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{"error": "unauthorized admin key"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -2,6 +2,7 @@ package middleware
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/jmoiron/sqlx"
|
"github.com/jmoiron/sqlx"
|
||||||
@@ -12,11 +13,14 @@ import (
|
|||||||
func APIKeyAuth(_ *sqlx.DB) func(http.Handler) http.Handler {
|
func APIKeyAuth(_ *sqlx.DB) func(http.Handler) http.Handler {
|
||||||
cfg := config.Load()
|
cfg := config.Load()
|
||||||
|
|
||||||
allowed := make(map[string]struct{}, len(cfg.APIKeys))
|
if len(cfg.APIKey) == 0 {
|
||||||
for _, k := range cfg.APIKeys {
|
log.Panic("API key or admin key are empty")
|
||||||
allowed[k] = struct{}{}
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
allowed := make(map[string]struct{}, 1)
|
||||||
|
allowed[cfg.APIKey] = struct{}{}
|
||||||
|
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
key := r.Header.Get("X-API-Key")
|
key := r.Header.Get("X-API-Key")
|
||||||
|
|||||||
39
main.go
39
main.go
@@ -9,6 +9,7 @@ import (
|
|||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
"github.com/go-chi/chi/v5/middleware"
|
"github.com/go-chi/chi/v5/middleware"
|
||||||
"github.com/go-chi/httprate"
|
"github.com/go-chi/httprate"
|
||||||
|
"github.com/jmoiron/sqlx"
|
||||||
"github.com/joho/godotenv"
|
"github.com/joho/godotenv"
|
||||||
|
|
||||||
"emly-api-go/internal/config"
|
"emly-api-go/internal/config"
|
||||||
@@ -27,27 +28,31 @@ func main() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("database connection failed: %v", err)
|
log.Fatalf("database connection failed: %v", err)
|
||||||
}
|
}
|
||||||
defer db.Close()
|
defer func(db *sqlx.DB) {
|
||||||
|
err := db.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("closing database failed: %v", err)
|
||||||
|
}
|
||||||
|
}(db)
|
||||||
|
|
||||||
r := chi.NewRouter()
|
r := chi.NewRouter()
|
||||||
|
|
||||||
// ── Global middleware ────────────────────────────────────────────────────
|
// Global middlewares
|
||||||
r.Use(middleware.RequestID)
|
r.Use(middleware.RequestID)
|
||||||
r.Use(middleware.RealIP)
|
r.Use(middleware.RealIP)
|
||||||
r.Use(middleware.Logger)
|
r.Use(middleware.Logger)
|
||||||
r.Use(middleware.Recoverer)
|
r.Use(middleware.Recoverer)
|
||||||
r.Use(middleware.Timeout(30 * time.Second))
|
r.Use(middleware.Timeout(30 * time.Second))
|
||||||
|
|
||||||
// ── Global rate-limit: 100 req / min per IP ──────────────────────────────
|
// Global rate limit to 100 requests per minute
|
||||||
r.Use(httprate.LimitByIP(100, time.Minute))
|
r.Use(httprate.LimitByIP(100, time.Minute))
|
||||||
|
|
||||||
// ── Public routes ────────────────────────────────────────────────────────
|
// Public routes (Not protected by any API Key)
|
||||||
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Write([]byte("emly-api-go"))
|
w.Write([]byte("emly-api-go"))
|
||||||
})
|
})
|
||||||
|
|
||||||
r.Route("/api/v1", func(r chi.Router) {
|
r.Route("/api/v1", func(r chi.Router) {
|
||||||
// Add a header called X-Server
|
|
||||||
r.Use(func(next http.Handler) http.Handler {
|
r.Use(func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("X-Server", "emly-api-go")
|
w.Header().Set("X-Server", "emly-api-go")
|
||||||
@@ -58,17 +63,7 @@ func main() {
|
|||||||
// Health – public, no API key required
|
// Health – public, no API key required
|
||||||
r.Get("/health", handlers.Health(db))
|
r.Get("/health", handlers.Health(db))
|
||||||
|
|
||||||
// ── Protected routes: require valid API key ──────────────────────────
|
// ROUTE: Bug Reports - Protected via API Key
|
||||||
r.Group(func(r chi.Router) {
|
|
||||||
r.Use(apimw.APIKeyAuth(db))
|
|
||||||
|
|
||||||
// Tighter rate-limit on protected group: 30 req / min per IP
|
|
||||||
r.Use(httprate.LimitByIP(30, time.Minute))
|
|
||||||
|
|
||||||
r.Get("/example", handlers.ExampleGet)
|
|
||||||
r.Post("/example", handlers.ExamplePost)
|
|
||||||
})
|
|
||||||
|
|
||||||
r.Route("/bug-reports", func(r chi.Router) {
|
r.Route("/bug-reports", func(r chi.Router) {
|
||||||
r.Group(func(r chi.Router) {
|
r.Group(func(r chi.Router) {
|
||||||
r.Use(apimw.APIKeyAuth(db))
|
r.Use(apimw.APIKeyAuth(db))
|
||||||
@@ -76,9 +71,19 @@ func main() {
|
|||||||
// Tighter rate-limit on protected group: 30 req / min per IP
|
// Tighter rate-limit on protected group: 30 req / min per IP
|
||||||
r.Use(httprate.LimitByIP(30, time.Minute))
|
r.Use(httprate.LimitByIP(30, time.Minute))
|
||||||
|
|
||||||
|
r.Get("/count", handlers.GetReportsCount(db))
|
||||||
|
})
|
||||||
|
|
||||||
|
r.Group(func(r chi.Router) {
|
||||||
|
// More strict auth due to sensitive info
|
||||||
|
r.Use(apimw.APIKeyAuth(db))
|
||||||
|
r.Use(apimw.AdminKeyAuth(db))
|
||||||
|
|
||||||
|
// Tighter rate-limit on protected group: 30 req / min per IP
|
||||||
|
r.Use(httprate.LimitByIP(30, time.Minute))
|
||||||
|
|
||||||
r.Get("/", handlers.GetAllBugReports(db))
|
r.Get("/", handlers.GetAllBugReports(db))
|
||||||
r.Get("/{id}", handlers.GetBugReportByID(db))
|
r.Get("/{id}", handlers.GetBugReportByID(db))
|
||||||
r.Get("/count", handlers.GetReportsCount(db))
|
|
||||||
r.Get("/{id}/files", handlers.GetReportFilesByReportID(db))
|
r.Get("/{id}/files", handlers.GetReportFilesByReportID(db))
|
||||||
r.Get("/{id}/files/{file_id}", handlers.GetReportFileByFileID(db))
|
r.Get("/{id}/files/{file_id}", handlers.GetReportFileByFileID(db))
|
||||||
r.Get("/{id}/zip", handlers.GetBugReportZipById(db))
|
r.Get("/{id}/zip", handlers.GetBugReportZipById(db))
|
||||||
|
|||||||
Reference in New Issue
Block a user