204 lines
		
	
	
		
			4.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			204 lines
		
	
	
		
			4.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package mockid
 | |
| 
 | |
| import (
 | |
| 	"crypto"
 | |
| 	"crypto/ecdsa"
 | |
| 	"crypto/rand"
 | |
| 	"crypto/rsa"
 | |
| 	"crypto/sha256"
 | |
| 	"encoding/base64"
 | |
| 	"encoding/json"
 | |
| 	"fmt"
 | |
| 	"io"
 | |
| 	"math/big"
 | |
| 	"net/http"
 | |
| 	"net/url"
 | |
| 	"os"
 | |
| 	"strconv"
 | |
| 	"sync"
 | |
| 	"time"
 | |
| 
 | |
| 	"git.coolaj86.com/coolaj86/go-mockid/xkeypairs"
 | |
| 	"git.rootprojects.org/root/keypairs"
 | |
| 	//jwt "github.com/dgrijalva/jwt-go"
 | |
| )
 | |
| 
 | |
| // TestMain will overwrite this
 | |
| var rndsrc io.Reader = rand.Reader
 | |
| 
 | |
| type PublicJWK struct {
 | |
| 	Crv   string `json:"crv"`
 | |
| 	KeyID string `json:"kid,omitempty"`
 | |
| 	Kty   string `json:"kty,omitempty"`
 | |
| 	X     string `json:"x"`
 | |
| 	Y     string `json:"y"`
 | |
| }
 | |
| 
 | |
| type InspectableToken struct {
 | |
| 	Public    keypairs.PublicKey     `json:"jwk"`
 | |
| 	Protected map[string]interface{} `json:"protected"`
 | |
| 	Payload   map[string]interface{} `json:"payload"`
 | |
| 	Signature string                 `json:"signature"`
 | |
| 	Verified  bool                   `json:"verified"`
 | |
| 	Errors    []string               `json:"errors"`
 | |
| }
 | |
| 
 | |
| func (t *InspectableToken) MarshalJSON() ([]byte, error) {
 | |
| 	pub := keypairs.MarshalJWKPublicKey(t.Public)
 | |
| 	header, _ := json.Marshal(t.Protected)
 | |
| 	payload, _ := json.Marshal(t.Payload)
 | |
| 	errs, _ := json.Marshal(t.Errors)
 | |
| 	return []byte(fmt.Sprintf(
 | |
| 		`{"jwk":%s,"protected":%s,"payload":%s,"signature":%q,"verified":%t,"errors":%s}`,
 | |
| 		pub, header, payload, t.Signature, t.Verified, errs,
 | |
| 	)), nil
 | |
| }
 | |
| 
 | |
| var defaultFrom string
 | |
| var defaultReplyTo string
 | |
| 
 | |
| //var nonces map[string]int64
 | |
| //var nonCh chan string
 | |
| var nonces sync.Map
 | |
| var salt []byte
 | |
| 
 | |
| func Init() {
 | |
| 	var err error
 | |
| 	salt64 := os.Getenv("SALT")
 | |
| 	salt, err = base64.RawURLEncoding.DecodeString(salt64)
 | |
| 	if len(salt64) < 22 || nil != err {
 | |
| 		panic("SALT must be set as 22+ character base64")
 | |
| 	}
 | |
| 	defaultFrom = os.Getenv("MAILER_FROM")
 | |
| 	defaultReplyTo = os.Getenv("MAILER_REPLY_TO")
 | |
| 	//nonces = make(map[string]int64)
 | |
| 	//nonCh = make(chan string)
 | |
| 
 | |
| 	/*
 | |
| 		  go func() {
 | |
| 		    for {
 | |
| 		      nonce := <- nonCh
 | |
| 			    nonces[nonce] = time.Now().Unix()
 | |
| 		    }
 | |
| 		  }()
 | |
| 	*/
 | |
| }
 | |
| 
 | |
| func GenToken(host string, privkey keypairs.PrivateKey, query url.Values) (string, string, string) {
 | |
| 	thumbprint := keypairs.ThumbprintPublicKey(keypairs.NewPublicKey(privkey.Public()))
 | |
| 	// TODO keypairs.Alg(key)
 | |
| 	alg := "ES256"
 | |
| 	switch privkey.(type) {
 | |
| 	case *rsa.PrivateKey:
 | |
| 		alg = "RS256"
 | |
| 	}
 | |
| 	protected := fmt.Sprintf(`{"typ":"JWT","alg":%q,"kid":"%s"}`, alg, thumbprint)
 | |
| 	protected64 := base64.RawURLEncoding.EncodeToString([]byte(protected))
 | |
| 
 | |
| 	exp, err := xkeypairs.ParseDuration(query.Get("exp"))
 | |
| 	if nil != err {
 | |
| 		// cryptic error code
 | |
| 		// TODO propagate error
 | |
| 		exp = 422
 | |
| 	}
 | |
| 
 | |
| 	payload := fmt.Sprintf(
 | |
| 		`{"iss":"%s/","sub":"dummy","exp":%s}`,
 | |
| 		host, strconv.FormatInt(time.Now().Add(time.Duration(exp)*time.Second).Unix(), 10),
 | |
| 	)
 | |
| 	payload64 := base64.RawURLEncoding.EncodeToString([]byte(payload))
 | |
| 
 | |
| 	hash := sha256.Sum256([]byte(fmt.Sprintf(`%s.%s`, protected64, payload64)))
 | |
| 	sig := JOSESign(privkey, hash[:])
 | |
| 	sig64 := base64.RawURLEncoding.EncodeToString(sig)
 | |
| 	token := fmt.Sprintf("%s.%s.%s\n", protected64, payload64, sig64)
 | |
| 	return protected, payload, token
 | |
| }
 | |
| 
 | |
| func JOSESign(privkey keypairs.PrivateKey, hash []byte) []byte {
 | |
| 	var sig []byte
 | |
| 
 | |
| 	switch k := privkey.(type) {
 | |
| 	case *rsa.PrivateKey:
 | |
| 		panic("TODO: implement rsa sign")
 | |
| 	case *ecdsa.PrivateKey:
 | |
| 		r, s, _ := ecdsa.Sign(rndsrc, k, hash[:])
 | |
| 		rb := r.Bytes()
 | |
| 		fmt.Println("debug:")
 | |
| 		fmt.Println(r, s)
 | |
| 		for len(rb) < 32 {
 | |
| 			rb = append([]byte{0}, rb...)
 | |
| 		}
 | |
| 		sb := s.Bytes()
 | |
| 		for len(rb) < 32 {
 | |
| 			sb = append([]byte{0}, sb...)
 | |
| 		}
 | |
| 		sig = append(rb, sb...)
 | |
| 	}
 | |
| 	return sig
 | |
| }
 | |
| 
 | |
| // TODO: move to keypairs
 | |
| 
 | |
| func JOSEVerify(pubkey keypairs.PublicKey, hash []byte, sig []byte) bool {
 | |
| 
 | |
| 	switch pub := pubkey.Key().(type) {
 | |
| 	case *rsa.PublicKey:
 | |
| 		// TODO keypairs.Size(key) to detect key size ?
 | |
| 		//alg := "SHA256"
 | |
| 		// TODO: this hasn't been tested yet
 | |
| 		if err := rsa.VerifyPKCS1v15(pub, crypto.SHA256, hash, sig); nil != err {
 | |
| 			return false
 | |
| 		}
 | |
| 		return true
 | |
| 	case *ecdsa.PublicKey:
 | |
| 		r := &big.Int{}
 | |
| 		r.SetBytes(sig[0:32])
 | |
| 		s := &big.Int{}
 | |
| 		s.SetBytes(sig[32:])
 | |
| 		fmt.Println("debug: sig len:", len(sig))
 | |
| 		fmt.Println("debug: r, s:", r, s)
 | |
| 		return ecdsa.Verify(pub, hash, r, s)
 | |
| 	default:
 | |
| 		panic("impossible condition: non-rsa/non-ecdsa key")
 | |
| 		return false
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func issueNonce(w http.ResponseWriter, r *http.Request) {
 | |
| 	b := make([]byte, 16)
 | |
| 	_, _ = rand.Read(b)
 | |
| 	nonce := base64.RawURLEncoding.EncodeToString(b)
 | |
| 	//nonCh <- nonce
 | |
| 	nonces.Store(nonce, time.Now())
 | |
| 
 | |
| 	w.Header().Set("Replay-Nonce", nonce)
 | |
| }
 | |
| 
 | |
| func requireNonce(next http.HandlerFunc) http.HandlerFunc {
 | |
| 	return func(w http.ResponseWriter, r *http.Request) {
 | |
| 		nonce := r.Header.Get("Replay-Nonce")
 | |
| 		// TODO expire nonces every so often
 | |
| 		//t := nonces[nonce]
 | |
| 		var t time.Time
 | |
| 		tmp, ok := nonces.Load(nonce)
 | |
| 		if ok {
 | |
| 			t = tmp.(time.Time)
 | |
| 		}
 | |
| 		if !ok || time.Now().Sub(t) > 15*time.Minute {
 | |
| 			http.Error(
 | |
| 				w,
 | |
| 				`{ "error": "invalid or expired nonce", "error_code": "ENONCE" }`,
 | |
| 				http.StatusBadRequest,
 | |
| 			)
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		//delete(nonces, nonce)
 | |
| 		nonces.Delete(nonce)
 | |
| 		issueNonce(w, r)
 | |
| 
 | |
| 		next(w, r)
 | |
| 	}
 | |
| }
 |