diff --git a/.env.example b/.env.example index ee70683..3112d43 100644 --- a/.env.example +++ b/.env.example @@ -33,3 +33,12 @@ RL_AUTH_MAX_REQS=100 RL_AUTH_WINDOW=1m RL_AUTH_MAX_FAILS=20 RL_AUTH_BAN_DUR=5m + +# Cloudflare R2 Storage +USE_S3_COMPATIBLE_STORAGE=false +CF_ACCOUNT_ID=your-cloudflare-account-id +CF_R2_ACCESS_KEY_ID=your-r2-access-key-id +CF_R2_SECRET_ACCESS_KEY=your-r2-secret-access-key +CF_R2_BUCKET_NAME=your-bucket-name +CF_R2_REGION=auto +CF_R2_ENDPOINT=https://your-endpoint.r2.cloudflarestorage.com \ No newline at end of file diff --git a/go.mod b/go.mod index 909c0d0..f14bcf6 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,11 @@ module emly-api-go go 1.26 require ( + github.com/aws/aws-sdk-go-v2 v1.41.7 + github.com/aws/aws-sdk-go-v2/config v1.32.18 + github.com/aws/aws-sdk-go-v2/credentials v1.19.17 + github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.22.19 + github.com/aws/aws-sdk-go-v2/service/s3 v1.101.0 github.com/go-chi/chi/v5 v5.2.4 github.com/go-chi/httprate v0.14.1 github.com/go-sql-driver/mysql v1.8.1 @@ -10,7 +15,23 @@ require ( golang.org/x/crypto v0.49.0 ) -require golang.org/x/sys v0.42.0 // indirect +require ( + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.10 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.24 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.15 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.23 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.23 // indirect + github.com/aws/aws-sdk-go-v2/service/signin v1.0.11 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.30.17 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.36.0 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.42.1 // indirect + github.com/aws/smithy-go v1.25.1 // indirect + golang.org/x/sys v0.42.0 // indirect +) require ( filippo.io/edwards25519 v1.1.1 // indirect diff --git a/go.sum b/go.sum index 0822857..f9316ac 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,43 @@ filippo.io/edwards25519 v1.1.1 h1:YpjwWWlNmGIDyXOn8zLzqiD+9TyIlPhGFG96P39uBpw= filippo.io/edwards25519 v1.1.1/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +github.com/aws/aws-sdk-go-v2 v1.41.7 h1:DWpAJt66FmnnaRIOT/8ASTucrvuDPZASqhhLey6tLY8= +github.com/aws/aws-sdk-go-v2 v1.41.7/go.mod h1:4LAfZOPHNVNQEckOACQx60Y8pSRjIkNZQz1w92xpMJc= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.10 h1:gx1AwW1Iyk9Z9dD9F4akX5gnN3QZwUB20GGKH/I+Rho= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.10/go.mod h1:qqY157uZoqm5OXq/amuaBJyC9hgBCBQnsaWnPe905GY= +github.com/aws/aws-sdk-go-v2/config v1.32.18 h1:Hcia46bxhGgF3BaSnG8nSNCWmqTK6bj9xN9/FJ3WK6Q= +github.com/aws/aws-sdk-go-v2/config v1.32.18/go.mod h1:zEjCAYmxqDadH1WX8CdBvmLKhUEUVFgKRQG38zjDmrY= +github.com/aws/aws-sdk-go-v2/credentials v1.19.17 h1:gP2nkGsS+KMvF/jfFz2Vv2qiiOqWKyPACSzPsqHgoW8= +github.com/aws/aws-sdk-go-v2/credentials v1.19.17/go.mod h1:Bsew3S/moG5iT77giPj1q8wb/s0RE5/QfH+ASjYtuQc= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23 h1:UuSfcORqNSz/ey3VPRS8TcVH2Ikf0/sC+Hdj400QI6U= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23/go.mod h1:+G/OSGiOFnSOkYloKj/9M35s74LgVAdJBSD5lsFfqKg= +github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.22.19 h1:VH0xfFwHfPYhu+EcxyCcw3VTZskpbA+/s0pTXwhSsL8= +github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.22.19/go.mod h1:S/XkAXcnCpzwsjC9EU0BakuvreXfSTUADHb7rC7jvaQ= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23 h1:GpT/TrnBYuE5gan2cZbTtvP+JlHsutdmlV2YfEyNde0= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23/go.mod h1:xYWD6BS9ywC5bS3sz9Xh04whO/hzK2plt2Zkyrp4JuA= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23 h1:bpd8vxhlQi2r1hiueOw02f/duEPTMK59Q4QMAoTTtTo= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23/go.mod h1:15DfR2nw+CRHIk0tqNyifu3G1YdAOy68RftkhMDDwYk= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.24 h1:OQqn11BtaYv1WLUowvcA30MpzIu8Ti4pcLPIIyoKZrA= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.24/go.mod h1:X5ZJyfwVrWA96GzPmUCWFQaEARPR7gCrpq2E92PJwAE= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9 h1:FLudkZLt5ci0ozzgkVo8BJGwvqNaZbTWb3UcucAateA= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9/go.mod h1:w7wZ/s9qK7c8g4al+UyoF1Sp/Z45UwMGcqIzLWVQHWk= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.15 h1:ieLCO1JxUWuxTZ1cRd0GAaeX7O6cIxnwk7tc1LsQhC4= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.15/go.mod h1:e3IzZvQ3kAWNykvE0Tr0RDZCMFInMvhku3qNpcIQXhM= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.23 h1:pbrxO/kuIwgEsOPLkaHu0O+m4fNgLU8B3vxQ+72jTPw= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.23/go.mod h1:/CMNUqoj46HpS3MNRDEDIwcgEnrtZlKRaHNaHxIFpNA= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.23 h1:03xatSQO4+AM1lTAbnRg5OK528EUg744nW7F73U8DKw= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.23/go.mod h1:M8l3mwgx5ToK7wot2sBBce/ojzgnPzZXUV445gTSyE8= +github.com/aws/aws-sdk-go-v2/service/s3 v1.101.0 h1:etqBTKY581iwLL/H/S2sVgk3C9lAsTJFeXWFDsDcWOU= +github.com/aws/aws-sdk-go-v2/service/s3 v1.101.0/go.mod h1:L2dcoOgS2VSgbPLvpak2NyUPsO1TBN7M45Z4H7DlRc4= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.11 h1:TdJ+HdzOBhU8+iVAOGUTU63VXopcumCOF1paFulHWZc= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.11/go.mod h1:R82ZRExE/nheo0N+T8zHPcLRTcH8MGsnR3BiVGX0TwI= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.17 h1:7byT8HUWrgoRp6sXjxtZwgOKfhss5fW6SkLBtqzgRoE= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.17/go.mod h1:xNWknVi4Ezm1vg1QsB/5EWpAJURq22uqd38U8qKvOJc= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.36.0 h1:nDARhv/oF55bcxF7rCI/4PDxOKnVXVWwDuDwCs2I2SQ= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.36.0/go.mod h1:4vIRDq+CJB2xFAXZ+YgGUTiEft7oAQlhIs71xcSeuVg= +github.com/aws/aws-sdk-go-v2/service/sts v1.42.1 h1:F/M5Y9I3nwr2IEpshZgh1GeHpOItExNM9L1euNuh/fk= +github.com/aws/aws-sdk-go-v2/service/sts v1.42.1/go.mod h1:mTNxImtovCOEEuD65mKW7DCsL+2gjEH+RPEAexAzAio= +github.com/aws/smithy-go v1.25.1 h1:J8ERsGSU7d+aCmdQur5Txg6bVoYelvQJgtZehD12GkI= +github.com/aws/smithy-go v1.25.1/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/go-chi/chi/v5 v5.2.4 h1:WtFKPHwlywe8Srng8j2BhOD9312j9cGUxG1SP4V2cR4= diff --git a/internal/config/config.go b/internal/config/config.go index 7222c00..7027f7d 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -2,6 +2,7 @@ package config import ( "os" + "regexp" "strconv" "strings" "sync" @@ -19,16 +20,28 @@ type RateLimitConfig struct { AuthBanDur time.Duration } +type R2Config struct { + AccountID string + AccessKeyID string + SecretAccessKey string + BucketName string + Region string + Endpoint string +} + type Config struct { - Port string - DSN string - Database string - APIKey string - AdminKey string - MaxOpenConns int - MaxIdleConns int - ConnMaxLifetime int - RateLimit RateLimitConfig + Port string + DSN string + Database string + APIKey string + AdminKey string + MaxOpenConns int + MaxIdleConns int + ConnMaxLifetime int + UpdatesEnabled bool + UseS3CompatibleStorage bool + RateLimit RateLimitConfig + R2 R2Config } var ( @@ -84,20 +97,39 @@ func load() *Config { if dbName == "" { panic("DATABASE_NAME environment variable is required") } + dbNameRegex := regexp.MustCompile("^[a-zA-Z0-9_]+$") + // Test the regex against the dbName, otherwise panic to prevent potential SQL injection + validDbName, err := regexp.Match(dbNameRegex.String(), []byte(dbName)) + if err != nil { + panic("failed to validate database name: " + err.Error()) + } + if !validDbName { + panic("invalid database name: must match regex " + dbNameRegex.String()) + } if os.Getenv("DB_DSN") == "" { panic("DB_DSN environment variable is required") } return &Config{ - Port: port, - DSN: os.Getenv("DB_DSN"), - Database: dbName, - APIKey: apiKey, - AdminKey: adminKey, - MaxOpenConns: maxOpenConns, - MaxIdleConns: maxIdleConns, - ConnMaxLifetime: connMaxLifetime, + Port: port, + DSN: os.Getenv("DB_DSN"), + Database: dbName, + APIKey: apiKey, + AdminKey: adminKey, + MaxOpenConns: maxOpenConns, + MaxIdleConns: maxIdleConns, + ConnMaxLifetime: connMaxLifetime, + UpdatesEnabled: strings.ToLower(strings.TrimSpace(os.Getenv("UPDATES_ENABLED"))) == "true", + UseS3CompatibleStorage: strings.ToLower(strings.TrimSpace(os.Getenv("USE_S3_COMPATIBLE_STORAGE"))) == "true", + R2: R2Config{ + AccountID: os.Getenv("CF_ACCOUNT_ID"), + AccessKeyID: os.Getenv("CF_R2_ACCESS_KEY_ID"), + SecretAccessKey: os.Getenv("CF_R2_SECRET_ACCESS_KEY"), + BucketName: os.Getenv("CF_R2_BUCKET_NAME"), + Region: envString("CF_R2_REGION", "auto"), + Endpoint: os.Getenv("CF_R2_ENDPOINT"), + }, RateLimit: RateLimitConfig{ UnauthMaxReqs: envInt("RL_UNAUTH_MAX_REQS", 10), UnauthWindow: envDuration("RL_UNAUTH_WINDOW", 5*time.Minute), @@ -111,6 +143,13 @@ func load() *Config { } } +func envString(key, fallback string) string { + if s := os.Getenv(key); s != "" { + return s + } + return fallback +} + func envInt(key string, fallback int) int { if s := os.Getenv(key); s != "" { if n, err := strconv.Atoi(s); err == nil { diff --git a/internal/database/schema/migrations/tasks.json b/internal/database/schema/migrations/tasks.json index 8be7841..820a019 100644 --- a/internal/database/schema/migrations/tasks.json +++ b/internal/database/schema/migrations/tasks.json @@ -19,5 +19,13 @@ { "type": "column_not_exists", "table": "user", "column": "enabled" } ] } + ,{ + "id": "3_updates", + "sql_file": "3_updates.sql", + "description": "Create update_releases table for API-managed software updates.", + "conditions": [ + { "type": "table_not_exists", "table": "update_releases" } + ] + } ] } \ No newline at end of file diff --git a/internal/handlers/admin_auth.route.go b/internal/handlers/admin_auth.route.go index f05d147..fd6a187 100644 --- a/internal/handlers/admin_auth.route.go +++ b/internal/handlers/admin_auth.route.go @@ -198,8 +198,8 @@ func ValidateSession(db *sqlx.DB) http.HandlerFunc { sessionID, ) if err != nil { - jsonError(w, http.StatusUnauthorized, "invalid session") - log.Fatalf("Database error during session validation: %v", err) + log.Printf("[AUTH] Database error during session validation: %v", err) + jsonError(w, http.StatusInternalServerError, "internal server error") return } diff --git a/internal/handlers/bug_report.route.go b/internal/handlers/bug_report.route.go index 33a29b6..f0ac024 100644 --- a/internal/handlers/bug_report.route.go +++ b/internal/handlers/bug_report.route.go @@ -3,6 +3,7 @@ package handlers import ( "archive/zip" "bytes" + "context" "database/sql" "embed" "encoding/json" @@ -21,6 +22,8 @@ import ( "github.com/jmoiron/sqlx" "emly-api-go/internal/models" + "emly-api-go/internal/storage" + "emly-api-go/internal/timing" ) //go:embed templates/report.txt.tmpl @@ -41,12 +44,13 @@ var fileRoles = []struct { {"config", models.FileRoleConfig, "application/json"}, } -func CreateBugReport(db *sqlx.DB, dbName string) http.HandlerFunc { +func CreateBugReport(db *sqlx.DB, dbName string, s3conn *storage.S3Connector) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if err := r.ParseMultipartForm(32 << 20); err != nil { jsonError(w, http.StatusBadRequest, "invalid multipart form: "+err.Error()) return } + timing.Mark(r.Context(), "parse_form") name := r.FormValue("name") email := r.FormValue("email") @@ -84,6 +88,7 @@ func CreateBugReport(db *sqlx.DB, dbName string) http.HandlerFunc { jsonError(w, http.StatusInternalServerError, err.Error()) return } + timing.Mark(r.Context(), "db_insert_report") reportID, err := result.LastInsertId() if err != nil { @@ -120,7 +125,7 @@ func CreateBugReport(db *sqlx.DB, dbName string) http.HandlerFunc { log.Printf("[BUGREPORT] File uploaded: role=%s size=%d bytes", fr.role, len(data)) - _, err = db.ExecContext(r.Context(), + fileResult, err := db.ExecContext(r.Context(), fmt.Sprintf("INSERT INTO %s.bug_report_files (report_id, file_role, filename, mime_type, file_size, data) VALUES (?, ?, ?, ?, ?, ?)", dbName), reportID, fr.role, filename, mimeType, len(data), data, ) @@ -128,6 +133,25 @@ func CreateBugReport(db *sqlx.DB, dbName string) http.HandlerFunc { jsonError(w, http.StatusInternalServerError, err.Error()) return } + timing.Mark(r.Context(), "db_insert_file_"+string(fr.role)) + + if s3conn != nil { + fileID, err := fileResult.LastInsertId() + if err != nil { + log.Printf("[S3] could not get file insert id for report %d role %s: %v", reportID, fr.role, err) + } else { + s3Key := fmt.Sprintf("emly-api-files/bug-reports/%d/files/%s", reportID, filename) + if _, err := s3conn.UploadFile( + context.Background(), s3Key, + bytes.NewReader(data), mimeType, + map[string]string{"filename": filename, "id": strconv.FormatInt(fileID, 10)}, + ); err != nil { + log.Printf("[S3] upload failed for key %s: %v", s3Key, err) + } else { + timing.Mark(r.Context(), "s3_upload_file_"+string(fr.role)) + } + } + } } log.Printf("[BUGREPORT] Created successfully with id=%d", reportID) @@ -177,6 +201,7 @@ func GetAllBugReports(db *sqlx.DB, dbName string) http.HandlerFunc { jsonError(w, http.StatusInternalServerError, err.Error()) return } + timing.Mark(r.Context(), "db_count") mainQuery := fmt.Sprintf(` SELECT br.*, COUNT(bf.id) as file_count @@ -193,6 +218,7 @@ func GetAllBugReports(db *sqlx.DB, dbName string) http.HandlerFunc { jsonError(w, http.StatusInternalServerError, err.Error()) return } + timing.Mark(r.Context(), "db_select") jsonOK(w, map[string]interface{}{ "data": reports, @@ -280,7 +306,7 @@ func GetReportFilesByReportID(db *sqlx.DB, dbName string) http.HandlerFunc { } } -func GetBugReportZipById(db *sqlx.DB, dbName string) http.HandlerFunc { +func GetBugReportZipByID(db *sqlx.DB, dbName string) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { id := chi.URLParam(r, "id") if id == "" { @@ -298,12 +324,14 @@ func GetBugReportZipById(db *sqlx.DB, dbName string) http.HandlerFunc { jsonError(w, http.StatusInternalServerError, err.Error()) return } + timing.Mark(r.Context(), "db_fetch_report") var files []models.BugReportFile if err := db.SelectContext(r.Context(), &files, fmt.Sprintf("SELECT * FROM %s.bug_report_files WHERE report_id = ?", dbName), id); err != nil { jsonError(w, http.StatusInternalServerError, err.Error()) return } + timing.Mark(r.Context(), "db_fetch_files") var sysInfoStr string if len(report.SystemInfo) > 0 && string(report.SystemInfo) != "null" { @@ -353,6 +381,7 @@ func GetBugReportZipById(db *sqlx.DB, dbName string) http.HandlerFunc { jsonError(w, http.StatusInternalServerError, err.Error()) return } + timing.Mark(r.Context(), "zip_build") w.Header().Set("Content-Type", "application/zip") w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"report-%d.zip\"", report.ID)) @@ -364,7 +393,7 @@ func GetBugReportZipById(db *sqlx.DB, dbName string) http.HandlerFunc { } } -func GetReportFileByFileID(db *sqlx.DB, dbName string) http.HandlerFunc { +func GetReportFileByFileID(db *sqlx.DB, dbName string, s3conn *storage.S3Connector) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { reportId := chi.URLParam(r, "id") if reportId == "" { @@ -377,6 +406,44 @@ func GetReportFileByFileID(db *sqlx.DB, dbName string) http.HandlerFunc { return } + var filename string + if err := db.GetContext(r.Context(), &filename, fmt.Sprintf("SELECT filename FROM %s.bug_report_files WHERE report_id = ? AND id = ?", dbName), reportId, fileId); err != nil { + jsonError(w, http.StatusInternalServerError, err.Error()) + return + } + timing.Mark(r.Context(), "db_fetch_filename_by_id") + + // Try S3 first. + if s3conn != nil { + s3Key := fmt.Sprintf("emly-api-files/bug-reports/%s/files/%s", reportId, filename) + rc, info, err := s3conn.GetFile(r.Context(), s3Key) + if err == nil { + defer rc.Close() + timing.Mark(r.Context(), "s3_hit") + log.Println("[S3] cache hit for key", s3Key) + + mimeType := info.ContentType + if mimeType == "" { + mimeType = "application/octet-stream" + } + filename := info.Metadata["filename"] + if filename == "" { + filename = fileId + } + w.Header().Set("Content-Type", mimeType) + w.Header().Set("Content-Disposition", "attachment; filename=\""+filename+"\"") + _, _ = io.Copy(w, rc) + return + } + if storage.IsNotFound(err) { + log.Printf("[S3] file %s not found on s3", fileId) + } + if !storage.IsNotFound(err) { + log.Printf("[S3] unexpected error fetching key %s: %v", s3Key, err) + } + } + + // Fallback: query DB. var file models.BugReportFile err := db.GetContext(r.Context(), &file, fmt.Sprintf("SELECT filename, mime_type, data FROM %s.bug_report_files WHERE report_id = ? AND id = ?", dbName), reportId, fileId) if errors.Is(err, sql.ErrNoRows) { @@ -387,6 +454,25 @@ func GetReportFileByFileID(db *sqlx.DB, dbName string) http.HandlerFunc { jsonError(w, http.StatusInternalServerError, err.Error()) return } + timing.Mark(r.Context(), "db_select") + + // Lazy-upload to S3 so future requests are served from there. + if s3conn != nil { + s3Key := fmt.Sprintf("emly-api-files/bug-reports/%s/files/%s", reportId, fileId) + dataCopy := make([]byte, len(file.Data)) + copy(dataCopy, file.Data) + mime := file.MimeType + fname := file.Filename + go func() { + if _, err := s3conn.UploadFile( + context.Background(), s3Key, + bytes.NewReader(dataCopy), mime, + map[string]string{"filename": fname}, + ); err != nil { + log.Printf("[S3] lazy upload failed for key %s: %v", s3Key, err) + } + }() + } mimeType := file.MimeType if mimeType == "" { @@ -394,10 +480,7 @@ func GetReportFileByFileID(db *sqlx.DB, dbName string) http.HandlerFunc { } w.Header().Set("Content-Type", mimeType) w.Header().Set("Content-Disposition", "attachment; filename=\""+file.Filename+"\"") - _, err = w.Write(file.Data) - if err != nil { - return - } + _, _ = w.Write(file.Data) } } diff --git a/internal/middleware/timing.go b/internal/middleware/timing.go new file mode 100644 index 0000000..34e5954 --- /dev/null +++ b/internal/middleware/timing.go @@ -0,0 +1,62 @@ +package middleware + +import ( + "fmt" + "log" + "net/http" + "strings" + "time" + + "emly-api-go/internal/timing" +) + +// Timing is a middleware that measures per-request step durations. +// +// It injects a *timing.Timer into the request context so that handlers can +// record named checkpoints with timing.Mark(r.Context(), "step_name"). +// After the handler returns, it logs a single line of the form: +// +// [TIMING] METHOD /path step1=1.2ms step2=18ms total=20ms +// +// Each step duration is measured from the previous checkpoint (or request +// start for the first one), so the values add up to the total. +func Timing(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx, t := timing.NewContext(r.Context()) + next.ServeHTTP(w, r.WithContext(ctx)) + + total := time.Since(t.Start) + checkpoints := t.Checkpoints() + + if len(checkpoints) == 0 { + // No checkpoints: just log the total so every request is visible. + log.Printf("[TIMING] %s %s total=%s", r.Method, r.URL.Path, round(total)) + return + } + + parts := make([]string, 0, len(checkpoints)+1) + prev := t.Start + for _, cp := range checkpoints { + parts = append(parts, fmt.Sprintf("%s=%s", cp.Name, round(cp.At.Sub(prev)))) + prev = cp.At + } + // Remainder after the last checkpoint. + if tail := total - prev.Sub(t.Start); tail > 0 { + parts = append(parts, fmt.Sprintf("response=%s", round(tail))) + } + parts = append(parts, fmt.Sprintf("total=%s", round(total))) + + log.Printf("[TIMING] %s %s %s", r.Method, r.URL.Path, strings.Join(parts, " ")) + }) +} + +func round(d time.Duration) string { + switch { + case d < time.Microsecond: + return fmt.Sprintf("%dns", d.Nanoseconds()) + case d < time.Millisecond: + return fmt.Sprintf("%.2fµs", float64(d.Nanoseconds())/1e3) + default: + return fmt.Sprintf("%.2fms", float64(d.Nanoseconds())/1e6) + } +} diff --git a/internal/routes/routes.go b/internal/routes/routes.go index a8aff3e..6d9b18b 100644 --- a/internal/routes/routes.go +++ b/internal/routes/routes.go @@ -1,20 +1,18 @@ package routes import ( - v2 "emly-api-go/internal/routes/v2" "net/http" v1 "emly-api-go/internal/routes/v1" + v2 "emly-api-go/internal/routes/v2" + "emly-api-go/internal/storage" "github.com/go-chi/chi/v5" "github.com/jmoiron/sqlx" ) // RegisterAll mounts every versioned API onto the root router. -// To add a new API version, create internal/routes/v2 and add: -// -// r.Mount("/v2", v2.NewRouter(db)) -func RegisterAll(r chi.Router, db *sqlx.DB) { +func RegisterAll(r chi.Router, db *sqlx.DB, s3conn *storage.S3Connector) { r.Get("/", func(w http.ResponseWriter, r *http.Request) { _, err := w.Write([]byte("emly-api-go")) if err != nil { @@ -22,6 +20,6 @@ func RegisterAll(r chi.Router, db *sqlx.DB) { } }) - r.Mount("/v1", v1.NewRouter(db)) - r.Mount("/v2", v2.NewRouter(db)) + r.Mount("/v1", v1.NewRouter(db, s3conn)) + r.Mount("/v2", v2.NewRouter(db, s3conn)) } diff --git a/internal/routes/v1/bug_reports.go b/internal/routes/v1/bug_reports.go index 3776421..cc163a7 100644 --- a/internal/routes/v1/bug_reports.go +++ b/internal/routes/v1/bug_reports.go @@ -5,13 +5,14 @@ import ( "time" "emly-api-go/internal/handlers" + "emly-api-go/internal/storage" "github.com/go-chi/chi/v5" "github.com/go-chi/httprate" "github.com/jmoiron/sqlx" ) -func registerBugReports(r chi.Router, db *sqlx.DB, dbName string) { +func registerBugReports(r chi.Router, db *sqlx.DB, dbName string, s3conn *storage.S3Connector) { r.Route("/bug-reports", func(r chi.Router) { // API key only: submit a report and check count r.Group(func(r chi.Router) { @@ -19,7 +20,7 @@ func registerBugReports(r chi.Router, db *sqlx.DB, dbName string) { r.Use(httprate.LimitByIP(30, time.Minute)) r.Get("/count", handlers.GetReportsCount(db, dbName)) - r.Post("/", handlers.CreateBugReport(db, dbName)) + r.Post("/", handlers.CreateBugReport(db, dbName, s3conn)) }) // API key + admin key: full read/write access @@ -32,10 +33,10 @@ func registerBugReports(r chi.Router, db *sqlx.DB, dbName string) { r.Get("/{id}", handlers.GetBugReportByID(db, dbName)) r.Get("/{id}/status", handlers.GetReportStatusByID(db, dbName)) r.Get("/{id}/files", handlers.GetReportFilesByReportID(db, dbName)) - r.Get("/{id}/files/{file_id}", handlers.GetReportFileByFileID(db, dbName)) - r.Get("/{id}/download", handlers.GetBugReportZipById(db, dbName)) + r.Get("/{id}/files/{file_id}", handlers.GetReportFileByFileID(db, dbName, s3conn)) + r.Get("/{id}/download", handlers.GetBugReportZipByID(db, dbName)) r.Patch("/{id}/status", handlers.PatchBugReportStatus(db, dbName)) r.Delete("/{id}", handlers.DeleteBugReportByID(db, dbName)) }) }) -} \ No newline at end of file +} diff --git a/internal/routes/v1/v1.go b/internal/routes/v1/v1.go index 1c72e18..13eaf1b 100644 --- a/internal/routes/v1/v1.go +++ b/internal/routes/v1/v1.go @@ -6,15 +6,14 @@ import ( "emly-api-go/internal/config" "emly-api-go/internal/handlers" + "emly-api-go/internal/storage" "github.com/go-chi/chi/v5" "github.com/jmoiron/sqlx" ) // NewRouter returns a chi.Router with all /v1 routes mounted. -// Add new API versions by creating an analogous package (e.g. v2) and -// mounting it alongside this one in internal/routes/routes.go. -func NewRouter(db *sqlx.DB) http.Handler { +func NewRouter(db *sqlx.DB, s3conn *storage.S3Connector) http.Handler { r := chi.NewRouter() rl := emlyMiddleware.NewRateLimiter(config.Load()) @@ -33,7 +32,7 @@ func NewRouter(db *sqlx.DB) http.Handler { r.Route("/api", func(r chi.Router) { registerAdmin(r, db) - registerBugReports(r, db, config.Load().Database) + registerBugReports(r, db, config.Load().Database, s3conn) }) return r diff --git a/internal/routes/v2/bug_reports.go b/internal/routes/v2/bug_reports.go index b16988d..58e8efb 100644 --- a/internal/routes/v2/bug_reports.go +++ b/internal/routes/v2/bug_reports.go @@ -5,13 +5,14 @@ import ( "time" "emly-api-go/internal/handlers" + "emly-api-go/internal/storage" "github.com/go-chi/chi/v5" "github.com/go-chi/httprate" "github.com/jmoiron/sqlx" ) -func registerBugReports(r chi.Router, db *sqlx.DB, dbName string) { +func registerBugReports(r chi.Router, db *sqlx.DB, dbName string, s3conn *storage.S3Connector) { r.Route("/bug-report", func(r chi.Router) { // API key only: submit a report and check count r.Group(func(r chi.Router) { @@ -19,7 +20,7 @@ func registerBugReports(r chi.Router, db *sqlx.DB, dbName string) { r.Use(httprate.LimitByIP(30, time.Minute)) r.Get("/count", handlers.GetReportsCount(db, dbName)) - r.Post("/", handlers.CreateBugReport(db, dbName)) + r.Post("/", handlers.CreateBugReport(db, dbName, s3conn)) }) // API key + admin key: full read/write access @@ -32,8 +33,8 @@ func registerBugReports(r chi.Router, db *sqlx.DB, dbName string) { r.Get("/{id}", handlers.GetBugReportByID(db, dbName)) r.Get("/{id}/status", handlers.GetReportStatusByID(db, dbName)) r.Get("/{id}/files", handlers.GetReportFilesByReportID(db, dbName)) - r.Get("/{id}/files/{file_id}", handlers.GetReportFileByFileID(db, dbName)) - r.Get("/{id}/download", handlers.GetBugReportZipById(db, dbName)) + r.Get("/{id}/files/{file_id}", handlers.GetReportFileByFileID(db, dbName, s3conn)) + r.Get("/{id}/download", handlers.GetBugReportZipByID(db, dbName)) r.Patch("/{id}/status", handlers.PatchBugReportStatus(db, dbName)) r.Delete("/{id}", handlers.DeleteBugReportByID(db, dbName)) }) diff --git a/internal/routes/v2/v2.go b/internal/routes/v2/v2.go index dcf4579..4401a15 100644 --- a/internal/routes/v2/v2.go +++ b/internal/routes/v2/v2.go @@ -6,15 +6,14 @@ import ( "emly-api-go/internal/config" "emly-api-go/internal/handlers" + "emly-api-go/internal/storage" "github.com/go-chi/chi/v5" "github.com/jmoiron/sqlx" ) -// NewRouter returns a chi.Router with all /v1 routes mounted. -// Add new API versions by creating an analogous package (e.g. v2) and -// mounting it alongside this one in internal/routes/routes.go. -func NewRouter(db *sqlx.DB) http.Handler { +// NewRouter returns a chi.Router with all /v2 routes mounted. +func NewRouter(db *sqlx.DB, s3conn *storage.S3Connector) http.Handler { r := chi.NewRouter() rl := emlyMiddleware.NewRateLimiter(config.Load()) @@ -33,7 +32,8 @@ func NewRouter(db *sqlx.DB) http.Handler { r.Route("/api", func(r chi.Router) { registerAdmin(r, db) - registerBugReports(r, db, config.Load().Database) + registerBugReports(r, db, config.Load().Database, s3conn) + registerUpdates(r, db, config.Load()) }) return r diff --git a/internal/storage/migrateFiles.go b/internal/storage/migrateFiles.go new file mode 100644 index 0000000..a2c6b3d --- /dev/null +++ b/internal/storage/migrateFiles.go @@ -0,0 +1,122 @@ +package storage + +import ( + "bytes" + "context" + "database/sql" + "emly-api-go/internal/models" + "errors" + "fmt" + "log" + "sync" + "time" + + "github.com/jmoiron/sqlx" +) + +func MigrateReportFilesToS3(db *sqlx.DB, s3conn *S3Connector, dbName string) error { + var wg sync.WaitGroup + errCh := make(chan error, 128) // buffer ragionevole + reportsRows, err := db.Query("SELECT id, created_at, updated_at FROM emly_bugreports_dev.bug_reports ORDER BY created_at DESC") + if err != nil { + return err + } + defer reportsRows.Close() + + var totalReports, totalFiles, skipped, uploaded int + + for reportsRows.Next() { + var reportId int + var createdAt, updatedAt time.Time + + if err := reportsRows.Scan( + &reportId, &createdAt, &updatedAt, + ); err != nil { + return err + } + totalReports++ + log.Printf("[migrate] processing report %d", reportId) + + filesRows, err := db.Query( + "SELECT id, report_id, filename FROM emly_bugreports_dev.bug_report_files WHERE report_id = ?", + reportId, + ) + + if err != nil { + return err + } + + for filesRows.Next() { + var fileID int + var fileReportID int + var fileName string + if err := filesRows.Scan(&fileID, &fileReportID, &fileName); err != nil { + filesRows.Close() + return err + } + + var file models.BugReportFile + err := db.GetContext(context.Background(), &file, fmt.Sprintf("SELECT filename, mime_type, data FROM %s.bug_report_files WHERE report_id = ? AND id = ?", dbName), reportId, fileID) + if errors.Is(err, sql.ErrNoRows) { + log.Printf("[migrate] report %d / file %d: not found in bug_report_files, skipping", reportId, fileID) + skipped++ + continue + } + if err != nil { + filesRows.Close() + return fmt.Errorf("report %d / file %d: %w", reportId, fileID, err) + } + + if s3conn != nil { + s3Key := fmt.Sprintf("emly-api-files/bug-reports/%d/files/%s", reportId, fileName) + dataCopy := make([]byte, len(file.Data)) + copy(dataCopy, file.Data) + mime := file.MimeType + fname := file.Filename + totalFiles++ + log.Printf("[migrate] report %d / file %d (%s, %d bytes): uploading to s3://%s", reportId, fileID, fname, len(dataCopy), s3Key) + wg.Add(1) + go func(key, mimeType, filename string, payload []byte, rid, fid int) { + defer wg.Done() + + _, upErr := s3conn.UploadFile( + context.Background(), + key, + bytes.NewReader(payload), + mimeType, + map[string]string{"filename": filename}, + ) + if upErr != nil { + errCh <- fmt.Errorf("report %d / file %d (%s): %w", rid, fid, key, upErr) + log.Printf("[migrate] [ERROR] upload failed for s3://%s: %v", key, upErr) + return + } + log.Printf("[migrate] upload complete: s3://%s", key) + }(s3Key, mime, fname, dataCopy, reportId, fileID) + + uploaded++ + } + } + + if err := filesRows.Close(); err != nil { + return err + } + } + + wg.Wait() + close(errCh) + + var uploadErrCount int + for e := range errCh { + uploadErrCount++ + log.Printf("[migrate] [ERROR] %v", e) + } + + log.Printf("[migrate] done — reports: %d, files queued: %d, skipped: %d, upload errors: %d", + totalReports, uploaded, skipped, uploadErrCount) + + if uploadErrCount > 0 { + return fmt.Errorf("migration completed with %d upload errors", uploadErrCount) + } + return nil +} diff --git a/internal/storage/s3connector.go b/internal/storage/s3connector.go new file mode 100644 index 0000000..8883c2f --- /dev/null +++ b/internal/storage/s3connector.go @@ -0,0 +1,322 @@ +package storage + +import ( + "context" + "errors" + "fmt" + "io" + "strings" + "time" + + "emly-api-go/internal/config" + + "github.com/aws/aws-sdk-go-v2/aws" + awsconfig "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/feature/s3/manager" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/aws-sdk-go-v2/service/s3/types" +) + +type S3Connector struct { + client *s3.Client + uploader *manager.Uploader + downloader *manager.Downloader + bucket string +} + +type FileInfo struct { + Key string + Size int64 + LastModified time.Time + ETag string + ContentType string + Metadata map[string]string +} + +// IsNotFound reports whether err represents a missing object (404 / NoSuchKey). +func IsNotFound(err error) bool { + if err == nil { + return false + } + var nsk *types.NoSuchKey + if errors.As(err, &nsk) { + return true + } + // Fallback for S3-compatible stores (e.g. Cloudflare R2) that surface + // the error code via the generic APIError interface. + var ae interface{ ErrorCode() string } + if errors.As(err, &ae) { + switch ae.ErrorCode() { + case "NoSuchKey", "NotFound", "404": + return true + } + } + return false +} + +type FolderInfo struct { + Prefix string +} + +func NewCloudflareR2Connector(cfg config.R2Config) (*S3Connector, error) { + if cfg.AccessKeyID == "" || cfg.SecretAccessKey == "" || cfg.BucketName == "" { + return nil, fmt.Errorf("missing required R2 config fields (CF_R2_ACCESS_KEY_ID, CF_R2_SECRET_ACCESS_KEY, CF_R2_BUCKET_NAME)") + } + + endpoint := cfg.Endpoint + if endpoint == "" { + if cfg.AccountID == "" { + return nil, fmt.Errorf("either CF_R2_ENDPOINT or CF_ACCOUNT_ID must be set") + } + endpoint = fmt.Sprintf("https://%s.r2.cloudflarestorage.com", cfg.AccountID) + } + + region := cfg.Region + if region == "" { + region = "auto" + } + + awsCfg, err := awsconfig.LoadDefaultConfig(context.TODO(), + awsconfig.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(cfg.AccessKeyID, cfg.SecretAccessKey, "")), + awsconfig.WithRegion(region), + ) + if err != nil { + return nil, fmt.Errorf("failed to load R2 config: %w", err) + } + + client := s3.NewFromConfig(awsCfg, func(o *s3.Options) { + o.BaseEndpoint = aws.String(endpoint) + o.UsePathStyle = true + }) + + return &S3Connector{ + client: client, + uploader: manager.NewUploader(client), + downloader: manager.NewDownloader(client), + bucket: cfg.BucketName, + }, nil +} + +// Ping verifies connectivity by calling HeadBucket on the configured bucket. +func (c *S3Connector) Ping(ctx context.Context) error { + _, err := c.client.HeadBucket(ctx, &s3.HeadBucketInput{ + Bucket: aws.String(c.bucket), + }) + if err != nil { + return fmt.Errorf("R2 ping failed for bucket %q: %w", c.bucket, err) + } + return nil +} + +// UploadFile uploads body to key in the bucket and returns the public URL. +// metadata is optional; pass nil if not needed. +func (c *S3Connector) UploadFile(ctx context.Context, key string, body io.Reader, contentType string, metadata map[string]string) (string, error) { + result, err := c.uploader.Upload(ctx, &s3.PutObjectInput{ + Bucket: aws.String(c.bucket), + Key: aws.String(key), + Body: body, + ContentType: aws.String(contentType), + Metadata: metadata, + }) + if err != nil { + return "", fmt.Errorf("upload %q: %w", key, err) + } + return result.Location, nil +} + +// GetFile returns the object body at key. Caller must close it. +func (c *S3Connector) GetFile(ctx context.Context, key string) (io.ReadCloser, *FileInfo, error) { + out, err := c.client.GetObject(ctx, &s3.GetObjectInput{ + Bucket: aws.String(c.bucket), + Key: aws.String(key), + }) + if err != nil { + return nil, nil, fmt.Errorf("get %q: %w", key, err) + } + + info := &FileInfo{ + Key: key, + Size: aws.ToInt64(out.ContentLength), + ETag: strings.Trim(aws.ToString(out.ETag), `"`), + ContentType: aws.ToString(out.ContentType), + Metadata: out.Metadata, + } + if out.LastModified != nil { + info.LastModified = *out.LastModified + } + + return out.Body, info, nil +} + +// DownloadFile downloads key into dst and returns bytes written. +func (c *S3Connector) DownloadFile(ctx context.Context, key string, dst io.WriterAt) (int64, error) { + n, err := c.downloader.Download(ctx, dst, &s3.GetObjectInput{ + Bucket: aws.String(c.bucket), + Key: aws.String(key), + }) + if err != nil { + return 0, fmt.Errorf("download %q: %w", key, err) + } + return n, nil +} + +// DeleteFile deletes the object at key. +func (c *S3Connector) DeleteFile(ctx context.Context, key string) error { + _, err := c.client.DeleteObject(ctx, &s3.DeleteObjectInput{ + Bucket: aws.String(c.bucket), + Key: aws.String(key), + }) + if err != nil { + return fmt.Errorf("delete %q: %w", key, err) + } + return nil +} + +// DeleteFiles deletes up to 1000 objects in one request. +func (c *S3Connector) DeleteFiles(ctx context.Context, keys []string) error { + if len(keys) == 0 { + return nil + } + objects := make([]types.ObjectIdentifier, len(keys)) + for i, k := range keys { + objects[i] = types.ObjectIdentifier{Key: aws.String(k)} + } + _, err := c.client.DeleteObjects(ctx, &s3.DeleteObjectsInput{ + Bucket: aws.String(c.bucket), + Delete: &types.Delete{Objects: objects, Quiet: aws.Bool(true)}, + }) + if err != nil { + return fmt.Errorf("batch delete: %w", err) + } + return nil +} + +// RenameFile copies src to dst then deletes src (R2 has no native rename). +func (c *S3Connector) RenameFile(ctx context.Context, srcKey, dstKey string) error { + _, err := c.client.CopyObject(ctx, &s3.CopyObjectInput{ + Bucket: aws.String(c.bucket), + CopySource: aws.String(c.bucket + "/" + srcKey), + Key: aws.String(dstKey), + }) + if err != nil { + return fmt.Errorf("copy %q → %q: %w", srcKey, dstKey, err) + } + return c.DeleteFile(ctx, srcKey) +} + +// ListFiles returns all objects directly under prefix (non-recursive). +func (c *S3Connector) ListFiles(ctx context.Context, prefix string) ([]FileInfo, error) { + prefix = normalizePrefix(prefix) + + var files []FileInfo + pager := s3.NewListObjectsV2Paginator(c.client, &s3.ListObjectsV2Input{ + Bucket: aws.String(c.bucket), + Prefix: aws.String(prefix), + Delimiter: aws.String("/"), + }) + for pager.HasMorePages() { + page, err := pager.NextPage(ctx) + if err != nil { + return nil, fmt.Errorf("list files under %q: %w", prefix, err) + } + for _, obj := range page.Contents { + key := aws.ToString(obj.Key) + if strings.HasSuffix(key, "/") { + continue // skip folder placeholders + } + fi := FileInfo{ + Key: key, + Size: aws.ToInt64(obj.Size), + ETag: strings.Trim(aws.ToString(obj.ETag), `"`), + } + if obj.LastModified != nil { + fi.LastModified = *obj.LastModified + } + files = append(files, fi) + } + } + return files, nil +} + +// ListFolders returns the immediate sub-folders under prefix. +func (c *S3Connector) ListFolders(ctx context.Context, prefix string) ([]FolderInfo, error) { + prefix = normalizePrefix(prefix) + + var folders []FolderInfo + pager := s3.NewListObjectsV2Paginator(c.client, &s3.ListObjectsV2Input{ + Bucket: aws.String(c.bucket), + Prefix: aws.String(prefix), + Delimiter: aws.String("/"), + }) + for pager.HasMorePages() { + page, err := pager.NextPage(ctx) + if err != nil { + return nil, fmt.Errorf("list folders under %q: %w", prefix, err) + } + for _, cp := range page.CommonPrefixes { + folders = append(folders, FolderInfo{Prefix: aws.ToString(cp.Prefix)}) + } + } + return folders, nil +} + +// CreateFolder writes a zero-byte placeholder object to make the folder visible. +func (c *S3Connector) CreateFolder(ctx context.Context, folderPath string) error { + key := normalizePrefix(folderPath) + if key == "" { + return fmt.Errorf("folder path cannot be empty") + } + _, err := c.client.PutObject(ctx, &s3.PutObjectInput{ + Bucket: aws.String(c.bucket), + Key: aws.String(key), + ContentLength: aws.Int64(0), + }) + if err != nil { + return fmt.Errorf("create folder %q: %w", key, err) + } + return nil +} + +// DeleteFolder removes all objects under folderPath in batches of 1000. +func (c *S3Connector) DeleteFolder(ctx context.Context, folderPath string) error { + prefix := normalizePrefix(folderPath) + + var keys []string + pager := s3.NewListObjectsV2Paginator(c.client, &s3.ListObjectsV2Input{ + Bucket: aws.String(c.bucket), + Prefix: aws.String(prefix), + }) + for pager.HasMorePages() { + page, err := pager.NextPage(ctx) + if err != nil { + return fmt.Errorf("list for delete %q: %w", prefix, err) + } + for _, obj := range page.Contents { + keys = append(keys, aws.ToString(obj.Key)) + } + } + + for i := 0; i < len(keys); i += 1000 { + end := i + 1000 + if end > len(keys) { + end = len(keys) + } + if err := c.DeleteFiles(ctx, keys[i:end]); err != nil { + return err + } + } + return nil +} + +// normalizePrefix ensures prefix ends with "/" (returns "" for root). +func normalizePrefix(p string) string { + p = strings.TrimPrefix(p, "/") + if p == "" { + return "" + } + if !strings.HasSuffix(p, "/") { + p += "/" + } + return p +} diff --git a/internal/timing/timing.go b/internal/timing/timing.go new file mode 100644 index 0000000..d37e9e6 --- /dev/null +++ b/internal/timing/timing.go @@ -0,0 +1,50 @@ +package timing + +import ( + "context" + "time" +) + +type contextKey struct{} + +// Checkpoint is a named point in time recorded during request processing. +type Checkpoint struct { + Name string + At time.Time +} + +// Timer records the request start time and named checkpoints. +type Timer struct { + Start time.Time + checkpoints []Checkpoint +} + +// Mark records a checkpoint with the given name. +func (t *Timer) Mark(name string) { + t.checkpoints = append(t.checkpoints, Checkpoint{Name: name, At: time.Now()}) +} + +// Checkpoints returns all recorded checkpoints in order. +func (t *Timer) Checkpoints() []Checkpoint { + return t.checkpoints +} + +// NewContext attaches a new Timer to ctx and returns both. +func NewContext(ctx context.Context) (context.Context, *Timer) { + t := &Timer{Start: time.Now()} + return context.WithValue(ctx, contextKey{}, t), t +} + +// FromContext retrieves the Timer from ctx, or nil if not present. +func FromContext(ctx context.Context) *Timer { + t, _ := ctx.Value(contextKey{}).(*Timer) + return t +} + +// Mark records a checkpoint in the Timer stored in ctx, if any. +// It is a no-op when ctx carries no Timer (e.g. in tests). +func Mark(ctx context.Context, name string) { + if t := FromContext(ctx); t != nil { + t.Mark(name) + } +} diff --git a/main.go b/main.go index 7e54239..af8f84d 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "log" "net/http" @@ -16,6 +17,7 @@ import ( "emly-api-go/internal/database" "emly-api-go/internal/database/schema" "emly-api-go/internal/routes" + "emly-api-go/internal/storage" emlyMiddleware "emly-api-go/internal/middleware" ) @@ -46,6 +48,37 @@ func main() { log.Fatalf("schema migration failed: %v", err) } + var s3conn *storage.S3Connector + if cfg.UseS3CompatibleStorage { + conn, err := storage.NewCloudflareR2Connector(cfg.R2) + if err != nil { + log.Fatalf("R2 connector init failed: %v", err) + } + if err := conn.Ping(context.Background()); err != nil { + log.Fatalf("R2 connection test failed: %v", err) + } + log.Printf("R2 storage connected (bucket: %s)", cfg.R2.BucketName) + s3conn = conn + } + + argsWithoutProg := os.Args[1:] + for _, arg := range argsWithoutProg { + log.Printf("arg: %s", arg) + if arg == "--migrate-files" { + if cfg.UseS3CompatibleStorage && s3conn != nil { + log.Printf("migrate report files from db to s3...") + if err := storage.MigrateReportFilesToS3(db, s3conn, cfg.Database); err != nil { + log.Fatalf("migrating report files failed: %v", err) + } + log.Printf("migrate report files from db to s3 completed successfully") + continue + } else { + log.Printf("migrate report files from db to s3 skipped (R2 not enabled)") + } + + } + } + r := chi.NewRouter() // Global middlewares @@ -54,12 +87,13 @@ func main() { r.Use(middleware.Logger) r.Use(middleware.Recoverer) r.Use(middleware.Timeout(30 * time.Second)) + r.Use(emlyMiddleware.Timing) rl := emlyMiddleware.NewRateLimiter(cfg) r.Use(rl.Handler) - routes.RegisterAll(r, db) + routes.RegisterAll(r, db, s3conn) addr := fmt.Sprintf(":%s", cfg.Port) log.Printf("server listening on %s", addr)