122 lines
		
	
	
		
			2.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			122 lines
		
	
	
		
			2.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package db
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"database/sql"
 | |
| 	"fmt"
 | |
| 	"io/ioutil"
 | |
| 	"regexp"
 | |
| 	"sort"
 | |
| 	"time"
 | |
| 
 | |
| 	"git.example.com/example/goserv/assets/configfs"
 | |
| 	"github.com/jmoiron/sqlx"
 | |
| 
 | |
| 	// pq injects itself into sql as 'postgres'
 | |
| 	_ "github.com/lib/pq"
 | |
| )
 | |
| 
 | |
| // DB is a concurrency-safe db connection instance
 | |
| var DB *sqlx.DB
 | |
| var firstDBURL PleaseDoubleCheckTheDatabaseURLDontDropProd
 | |
| 
 | |
| // Init initializes the database
 | |
| func Init(pgURL string) error {
 | |
| 	// https://godoc.org/github.com/lib/pq
 | |
| 
 | |
| 	firstDBURL = PleaseDoubleCheckTheDatabaseURLDontDropProd(pgURL)
 | |
| 	dbtype := "postgres"
 | |
| 
 | |
| 	ctx, done := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second))
 | |
| 	defer done()
 | |
| 	db, err := sql.Open(dbtype, pgURL)
 | |
| 	if err := db.PingContext(ctx); nil != err {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	// basic stuff
 | |
| 	f, err := configfs.Assets.Open("./postgres/init.sql")
 | |
| 	if nil != err {
 | |
| 		return err
 | |
| 	}
 | |
| 	sqlBytes, err := ioutil.ReadAll(f)
 | |
| 	if nil != err {
 | |
| 		return err
 | |
| 	}
 | |
| 	if _, err := db.ExecContext(ctx, string(sqlBytes)); nil != err {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	// project-specific stuff
 | |
| 	f, err = configfs.Assets.Open("./postgres/tables.sql")
 | |
| 	if nil != err {
 | |
| 		return err
 | |
| 	}
 | |
| 	sqlBytes, err = ioutil.ReadAll(f)
 | |
| 	if nil != err {
 | |
| 		return err
 | |
| 	}
 | |
| 	if _, err := db.ExecContext(ctx, string(sqlBytes)); nil != err {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	DB = sqlx.NewDb(db, dbtype)
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // PleaseDoubleCheckTheDatabaseURLDontDropProd is just a friendly,
 | |
| // hopefully helpful reminder, not to only use this in test files,
 | |
| // and to not drop the production database
 | |
| type PleaseDoubleCheckTheDatabaseURLDontDropProd string
 | |
| 
 | |
| // DropAllTables runs drop.sql, which is intended only for tests
 | |
| func DropAllTables(dbURL PleaseDoubleCheckTheDatabaseURLDontDropProd) error {
 | |
| 	if err := CanDropAllTables(string(dbURL)); nil != err {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	// drop stuff
 | |
| 	f, err := configfs.Assets.Open("./postgres/drop.sql")
 | |
| 	if nil != err {
 | |
| 		return err
 | |
| 	}
 | |
| 	sqlBytes, err := ioutil.ReadAll(f)
 | |
| 	if nil != err {
 | |
| 		return err
 | |
| 	}
 | |
| 	ctx, done := context.WithDeadline(context.Background(), time.Now().Add(1*time.Second))
 | |
| 	defer done()
 | |
| 	if _, err := DB.ExecContext(ctx, string(sqlBytes)); nil != err {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // CanDropAllTables returns an error if the dbURL does not contain the words "test" or
 | |
| // "demo" at a letter boundary
 | |
| func CanDropAllTables(dbURL string) error {
 | |
| 	var isDemo bool
 | |
| 	nonalpha := regexp.MustCompile(`[^a-zA-Z]`)
 | |
| 	haystack := nonalpha.Split(dbURL, -1)
 | |
| 	sort.Strings(haystack)
 | |
| 	for _, needle := range []string{"test", "demo"} {
 | |
| 		// the index to insert x if x is not present (it could be len(a))
 | |
| 		// (meaning that it is the index at which it exists, if it exists)
 | |
| 		i := sort.SearchStrings(haystack, needle)
 | |
| 		if i < len(haystack) && haystack[i] == needle {
 | |
| 			isDemo = true
 | |
| 			break
 | |
| 		}
 | |
| 	}
 | |
| 	if isDemo {
 | |
| 		return nil
 | |
| 	}
 | |
| 	return fmt.Errorf(
 | |
| 		"test and demo database URLs must contain the word 'test' or 'demo' "+
 | |
| 			"separated by a non-alphabet character, such as /test2/db_demo1\n%q\n",
 | |
| 		dbURL,
 | |
| 	)
 | |
| }
 |