165 lines
		
	
	
		
			3.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			165 lines
		
	
	
		
			3.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package api
 | |
| 
 | |
| import (
 | |
| 	"bytes"
 | |
| 	"encoding/json"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"io/ioutil"
 | |
| 	"log"
 | |
| 	mathrand "math/rand"
 | |
| 	"net/http"
 | |
| 	"net/http/httptest"
 | |
| 	"net/url"
 | |
| 	"os"
 | |
| 	"strings"
 | |
| 	"testing"
 | |
| 
 | |
| 	"git.example.com/example/goserv/internal/db"
 | |
| 	"git.rootprojects.org/root/keypairs"
 | |
| 	"git.rootprojects.org/root/keypairs/keyfetch"
 | |
| 
 | |
| 	"github.com/go-chi/chi"
 | |
| )
 | |
| 
 | |
| var srv *httptest.Server
 | |
| 
 | |
| var testKey keypairs.PrivateKey
 | |
| var testPub keypairs.PublicKey
 | |
| var testWhitelist keyfetch.Whitelist
 | |
| 
 | |
| func init() {
 | |
| 	// In tests it's nice to get the same "random" values, every time
 | |
| 	RandReader = testReader{}
 | |
| 	mathrand.Seed(0)
 | |
| }
 | |
| 
 | |
| func TestMain(m *testing.M) {
 | |
| 	connStr := needsTestDB(m)
 | |
| 	if strings.Contains(connStr, "@localhost/") || strings.Contains(connStr, "@localhost:") {
 | |
| 		connStr += "?sslmode=disable"
 | |
| 	} else {
 | |
| 		connStr += "?sslmode=required"
 | |
| 	}
 | |
| 
 | |
| 	if err := db.Init(connStr); nil != err {
 | |
| 		log.Fatal("db connection error", err)
 | |
| 		return
 | |
| 	}
 | |
| 	if err := db.DropAllTables(db.PleaseDoubleCheckTheDatabaseURLDontDropProd(connStr)); nil != err {
 | |
| 		log.Fatal(err)
 | |
| 	}
 | |
| 	if err := db.Init(connStr); nil != err {
 | |
| 		log.Fatal("db connection error", err)
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	var err error
 | |
| 	testKey = keypairs.NewDefaultPrivateKey()
 | |
| 	testPub = keypairs.NewPublicKey(testKey.Public())
 | |
| 	r := chi.NewRouter()
 | |
| 	srv = httptest.NewServer(Init(testPub, r))
 | |
| 	testWhitelist, err = keyfetch.NewWhitelist(nil, []string{srv.URL})
 | |
| 	if nil != err {
 | |
| 		log.Fatal("bad whitelist", err)
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	os.Exit(m.Run())
 | |
| }
 | |
| 
 | |
| // public APIs
 | |
| 
 | |
| func Test_Public_Ping(t *testing.T) {
 | |
| 	if err := testPing("public"); nil != err {
 | |
| 		t.Fatal(err)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // test types
 | |
| 
 | |
| type testReader struct{}
 | |
| 
 | |
| func (testReader) Read(p []byte) (n int, err error) {
 | |
| 	return mathrand.Read(p)
 | |
| }
 | |
| 
 | |
| func testPing(which string) error {
 | |
| 	urlstr := fmt.Sprintf("/api/%s/ping", which)
 | |
| 	res, err := testReq("GET", urlstr, "", nil, 200)
 | |
| 	if nil != err {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	data := map[string]interface{}{}
 | |
| 	if err := json.NewDecoder(res.Body).Decode(&data); nil != err {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	if success, ok := data["success"].(bool); !ok || !success {
 | |
| 		log.Printf("Bad Response\n\tURL:%s\n\tBody:\n%#v", urlstr, data)
 | |
| 		return errors.New("bad response: missing success")
 | |
| 	}
 | |
| 
 | |
| 	if ppid, _ := data["ppid"].(string); "" != ppid {
 | |
| 		return fmt.Errorf("the effective user ID isn't what it should be: %q != %q", ppid, "")
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func testReq(method, pathname string, jwt string, payload []byte, expectedStatus int) (*http.Response, error) {
 | |
| 	client := srv.Client()
 | |
| 	urlstr, _ := url.Parse(srv.URL + pathname)
 | |
| 
 | |
| 	if "" == method {
 | |
| 		method = "GET"
 | |
| 	}
 | |
| 
 | |
| 	req := &http.Request{
 | |
| 		Method: method,
 | |
| 		URL:    urlstr,
 | |
| 		Body:   ioutil.NopCloser(bytes.NewReader(payload)),
 | |
| 		Header: http.Header{},
 | |
| 	}
 | |
| 
 | |
| 	if len(jwt) > 0 {
 | |
| 		req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", jwt))
 | |
| 	}
 | |
| 	res, err := client.Do(req)
 | |
| 	if nil != err {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	if expectedStatus > 0 {
 | |
| 		if expectedStatus != res.StatusCode {
 | |
| 			data, _ := ioutil.ReadAll(res.Body)
 | |
| 			log.Printf("Bad Response: %d\n\tURL:%s\n\tBody:\n%s", res.StatusCode, urlstr, string(data))
 | |
| 			return nil, fmt.Errorf("bad status code: %d", res.StatusCode)
 | |
| 		}
 | |
| 	}
 | |
| 	return res, nil
 | |
| }
 | |
| 
 | |
| func needsTestDB(m *testing.M) string {
 | |
| 	connStr := os.Getenv("TEST_DATABASE_URL")
 | |
| 	if "" == connStr {
 | |
| 		log.Fatal(`no connection string defined
 | |
| 
 | |
| You must set TEST_DATABASE_URL to run db tests.
 | |
| 
 | |
| You may find this helpful:
 | |
| 
 | |
|     psql 'postgres://postgres:postgres@localhost:5432/postgres'
 | |
| 
 | |
| 	DROP DATABASE IF EXISTS postgres_test;
 | |
| 	CREATE DATABASE postgres_test;
 | |
| 	\q
 | |
| 
 | |
| Then your test database URL will be
 | |
| 
 | |
|     export TEST_DATABASE_URL=postgres://postgres:postgres@localhost:5432/postgres_test
 | |
| `)
 | |
| 	}
 | |
| 	return connStr
 | |
| }
 |