122 lines
2.9 KiB
Go

package db
import (
"context"
"database/sql"
"fmt"
"io/ioutil"
"regexp"
"sort"
"time"
"git.example.com/example/goserv/assets/configfs"
"github.com/jmoiron/sqlx"
// pq injects itself into sql as 'postgres'
_ "github.com/lib/pq"
)
// DB is a concurrency-safe db connection instance
var DB *sqlx.DB
var firstDBURL PleaseDoubleCheckTheDatabaseURLDontDropProd
// Init initializes the database
func Init(pgURL string) error {
// https://godoc.org/github.com/lib/pq
firstDBURL = PleaseDoubleCheckTheDatabaseURLDontDropProd(pgURL)
dbtype := "postgres"
ctx, done := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second))
defer done()
db, err := sql.Open(dbtype, pgURL)
if err := db.PingContext(ctx); nil != err {
return err
}
// basic stuff
f, err := configfs.Assets.Open("./postgres/init.sql")
if nil != err {
return err
}
sqlBytes, err := ioutil.ReadAll(f)
if nil != err {
return err
}
if _, err := db.ExecContext(ctx, string(sqlBytes)); nil != err {
return err
}
// project-specific stuff
f, err = configfs.Assets.Open("./postgres/tables.sql")
if nil != err {
return err
}
sqlBytes, err = ioutil.ReadAll(f)
if nil != err {
return err
}
if _, err := db.ExecContext(ctx, string(sqlBytes)); nil != err {
return err
}
DB = sqlx.NewDb(db, dbtype)
return nil
}
// PleaseDoubleCheckTheDatabaseURLDontDropProd is just a friendly,
// hopefully helpful reminder, not to only use this in test files,
// and to not drop the production database
type PleaseDoubleCheckTheDatabaseURLDontDropProd string
// DropAllTables runs drop.sql, which is intended only for tests
func DropAllTables(dbURL PleaseDoubleCheckTheDatabaseURLDontDropProd) error {
if err := CanDropAllTables(string(dbURL)); nil != err {
return err
}
// drop stuff
f, err := configfs.Assets.Open("./postgres/drop.sql")
if nil != err {
return err
}
sqlBytes, err := ioutil.ReadAll(f)
if nil != err {
return err
}
ctx, done := context.WithDeadline(context.Background(), time.Now().Add(1*time.Second))
defer done()
if _, err := DB.ExecContext(ctx, string(sqlBytes)); nil != err {
return err
}
return nil
}
// CanDropAllTables returns an error if the dbURL does not contain the words "test" or
// "demo" at a letter boundary
func CanDropAllTables(dbURL string) error {
var isDemo bool
nonalpha := regexp.MustCompile(`[^a-zA-Z]`)
haystack := nonalpha.Split(dbURL, -1)
sort.Strings(haystack)
for _, needle := range []string{"test", "demo"} {
// the index to insert x if x is not present (it could be len(a))
// (meaning that it is the index at which it exists, if it exists)
i := sort.SearchStrings(haystack, needle)
if i < len(haystack) && haystack[i] == needle {
isDemo = true
break
}
}
if isDemo {
return nil
}
return fmt.Errorf(
"test and demo database URLs must contain the word 'test' or 'demo' "+
"separated by a non-alphabet character, such as /test2/db_demo1\n%q\n",
dbURL,
)
}