Files
api-golang/internal/database/schema/migrator.go
Flavio Fois fa1f65baf7
Some checks failed
Build & Publish Docker Image / build-and-push (push) Failing after 33s
add support for SQLite as an alternative database backend
Implement SQLite support using the pure Go `modernc.org/sqlite` driver and update the migration system to handle driver-specific schemas. Users can now choose between MySQL and SQLite by setting the `DB_DRIVER` environment variable.
2026-03-29 17:46:27 +02:00

318 lines
8.4 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package schema
import (
"embed"
"encoding/json"
"fmt"
"log"
"strings"
"github.com/jmoiron/sqlx"
)
//go:embed mysql sqlite
var migrationsFS embed.FS
type taskFile struct {
Tasks []task `json:"tasks"`
}
type task struct {
ID string `json:"id"`
SQLFile string `json:"sql_file"`
Description string `json:"description"`
Conditions []condition `json:"conditions"`
}
type condition struct {
Type string `json:"type"` // "column_not_exists" | "index_not_exists" | "column_exists" | "index_exists" | "table_not_exists" | "table_exists"
Table string `json:"table"`
Column string `json:"column,omitempty"`
Index string `json:"index,omitempty"`
}
// Migrate reads the driver-specific migrations and applies them.
func Migrate(db *sqlx.DB, dbName string, driver string) error {
empty, err := schemaIsEmpty(db, dbName, driver)
if err != nil {
return fmt.Errorf("schema: check empty: %w", err)
}
if empty {
log.Println("[migrate] empty schema detected running init.sql")
if err := runInitSQL(db, driver); err != nil {
return err
}
} else {
log.Println("[migrate] checking if tables exist")
tableNames := []string{"bug_reports", "bug_report_files", "rate_limit_hwid", "user", "session"}
var foundTables []string
for _, tableName := range tableNames {
found, err := tableExists(db, dbName, tableName, driver)
if err != nil {
return fmt.Errorf("schema: check table %s: %w", tableName, err)
}
if !found {
log.Printf("[migrate] warning: expected table %s not found", tableName)
continue
}
foundTables = append(foundTables, tableName)
}
if len(foundTables) != len(tableNames) {
log.Printf("[migrate] warning: expected %d tables, found %d running init.sql", len(tableNames), len(foundTables))
if err := runInitSQL(db, driver); err != nil {
return err
}
} else {
log.Println("[migrate] all expected tables found skipping init.sql")
}
}
raw, err := migrationsFS.ReadFile(driver + "/migrations/tasks.json")
if err != nil {
return fmt.Errorf("schema: read tasks.json: %w", err)
}
var tf taskFile
if err := json.Unmarshal(raw, &tf); err != nil {
return fmt.Errorf("schema: parse tasks.json: %w", err)
}
for _, t := range tf.Tasks {
needed, err := shouldRun(db, dbName, t.Conditions, driver)
if err != nil {
return fmt.Errorf("schema: evaluate conditions for %s: %w", t.ID, err)
}
if !needed {
log.Printf("[migrate] skip %s conditions already met", t.ID)
continue
}
sqlBytes, err := migrationsFS.ReadFile(driver + "/migrations/" + t.SQLFile)
if err != nil {
return fmt.Errorf("schema: read %s: %w", t.SQLFile, err)
}
stmts := splitStatements(string(sqlBytes))
for _, stmt := range stmts {
if _, err := db.Exec(stmt); err != nil {
return fmt.Errorf("schema: exec %s: %w\nSQL: %s", t.ID, err, stmt)
}
}
log.Printf("[migrate] applied %s %s", t.ID, t.Description)
}
return nil
}
func runInitSQL(db *sqlx.DB, driver string) error {
initSQL, err := migrationsFS.ReadFile(driver + "/init.sql")
if err != nil {
return fmt.Errorf("schema: read init.sql: %w", err)
}
for _, stmt := range splitStatements(string(initSQL)) {
if _, err := db.Exec(stmt); err != nil {
return fmt.Errorf("schema: exec init.sql: %w\nSQL: %s", err, stmt)
}
}
log.Println("[migrate] init.sql applied base schema created")
return nil
}
// ---------- Condition evaluator ----------
func shouldRun(db *sqlx.DB, dbName string, conds []condition, driver string) (bool, error) {
for _, c := range conds {
met, err := evaluate(db, dbName, c, driver)
if err != nil {
return false, err
}
if met {
return true, nil
}
}
return false, nil
}
func evaluate(db *sqlx.DB, dbName string, c condition, driver string) (bool, error) {
switch c.Type {
case "column_not_exists":
exists, err := columnExists(db, dbName, c.Table, c.Column, driver)
return !exists, err
case "column_exists":
return columnExists(db, dbName, c.Table, c.Column, driver)
case "index_not_exists":
exists, err := indexExists(db, dbName, c.Table, c.Index, driver)
return !exists, err
case "index_exists":
return indexExists(db, dbName, c.Table, c.Index, driver)
case "table_not_exists":
exists, err := tableExists(db, dbName, c.Table, driver)
return !exists, err
case "table_exists":
return tableExists(db, dbName, c.Table, driver)
default:
return false, fmt.Errorf("unknown condition type: %s", c.Type)
}
}
// ---------- MySQL condition checks ----------
func columnExistsMySQL(db *sqlx.DB, dbName, table, column string) (bool, error) {
var count int
err := db.Get(&count,
`SELECT COUNT(*) FROM information_schema.COLUMNS
WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? AND COLUMN_NAME = ?`,
dbName, table, column)
return count > 0, err
}
func indexExistsMySQL(db *sqlx.DB, dbName, table, index string) (bool, error) {
var count int
err := db.Get(&count,
`SELECT COUNT(*) FROM information_schema.STATISTICS
WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? AND INDEX_NAME = ?`,
dbName, table, index)
return count > 0, err
}
func tableExistsMySQL(db *sqlx.DB, dbName, table string) (bool, error) {
var count int
err := db.Get(&count,
`SELECT COUNT(*) FROM information_schema.TABLES
WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?`,
dbName, table)
return count > 0, err
}
func schemaIsEmptyMySQL(db *sqlx.DB, dbName string) (bool, error) {
var count int
err := db.Get(&count,
`SELECT COUNT(*) FROM information_schema.TABLES WHERE TABLE_SCHEMA = ?`, dbName)
return count == 0, err
}
// ---------- SQLite condition checks ----------
func columnExistsSQLite(db *sqlx.DB, table, column string) (bool, error) {
var count int
// pragma_table_info is a table-valued function available since SQLite 3.16.0
err := db.Get(&count,
fmt.Sprintf("SELECT COUNT(*) FROM pragma_table_info('%s') WHERE name = ?", table),
column)
return count > 0, err
}
func indexExistsSQLite(db *sqlx.DB, table, index string) (bool, error) {
var count int
err := db.Get(&count,
`SELECT COUNT(*) FROM sqlite_master WHERE type='index' AND tbl_name=? AND name=?`,
table, index)
return count > 0, err
}
func tableExistsSQLite(db *sqlx.DB, table string) (bool, error) {
var count int
err := db.Get(&count,
`SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?`, table)
return count > 0, err
}
func schemaIsEmptySQLite(db *sqlx.DB) (bool, error) {
var count int
err := db.Get(&count, `SELECT COUNT(*) FROM sqlite_master WHERE type='table'`)
return count == 0, err
}
// ---------- Driver-dispatched wrappers ----------
func columnExists(db *sqlx.DB, dbName, table, column, driver string) (bool, error) {
if driver == "sqlite" {
return columnExistsSQLite(db, table, column)
}
return columnExistsMySQL(db, dbName, table, column)
}
func indexExists(db *sqlx.DB, dbName, table, index, driver string) (bool, error) {
if driver == "sqlite" {
return indexExistsSQLite(db, table, index)
}
return indexExistsMySQL(db, dbName, table, index)
}
func tableExists(db *sqlx.DB, dbName, table, driver string) (bool, error) {
if driver == "sqlite" {
return tableExistsSQLite(db, table)
}
return tableExistsMySQL(db, dbName, table)
}
func schemaIsEmpty(db *sqlx.DB, dbName, driver string) (bool, error) {
if driver == "sqlite" {
return schemaIsEmptySQLite(db)
}
return schemaIsEmptyMySQL(db, dbName)
}
// splitStatements splits a SQL blob on top-level ";" only, respecting
// BEGIN...END blocks (e.g. triggers) so their inner semicolons are not split.
func splitStatements(sql string) []string {
var out []string
var buf strings.Builder
depth := 0
n := len(sql)
for i := 0; i < n; {
c := sql[i]
// Collect whole identifier tokens to detect BEGIN / END keywords.
if isIdentStart(c) {
j := i
for j < n && isIdentChar(sql[j]) {
j++
}
word := strings.ToUpper(sql[i:j])
switch word {
case "BEGIN":
depth++
case "END":
if depth > 0 {
depth--
}
}
buf.WriteString(sql[i:j])
i = j
continue
}
if c == ';' && depth == 0 {
if stmt := strings.TrimSpace(buf.String()); stmt != "" {
out = append(out, stmt)
}
buf.Reset()
i++
continue
}
buf.WriteByte(c)
i++
}
if stmt := strings.TrimSpace(buf.String()); stmt != "" {
out = append(out, stmt)
}
return out
}
func isIdentStart(c byte) bool {
return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || c == '_'
}
func isIdentChar(c byte) bool {
return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_'
}