opaal/internal/flows/jwt_bearer.go

364 lines
11 KiB
Go

package flows
import (
"crypto/rand"
"crypto/rsa"
"davidallendj/opaal/internal/oauth"
"encoding/json"
"fmt"
"os"
"time"
"github.com/davidallendj/go-utils/cryptox"
"github.com/google/uuid"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jws"
"github.com/lestrrat-go/jwx/v2/jwt"
)
type JwtBearerFlowParams struct {
AccessToken string
IdToken string
// IdentityProvider *oidc.IdentityProvider
TrustedIssuer *oauth.TrustedIssuer
Client *oauth.Client
Audience []string
Refresh bool
Verbose bool
KeyPath string
}
type JwtBearerFlowEndpoints struct {
TrustedIssuers string
Token string
Clients string
Register string
}
func NewJwtBearerFlow(eps JwtBearerFlowEndpoints, params JwtBearerFlowParams) (string, error) {
// 1. verify that the JWT from the issuer is valid using all keys
var (
// idp = params.IdentityProvider
accessToken = params.AccessToken
idToken = params.IdToken
client = params.Client
trustedIssuer = params.TrustedIssuer
verbose = params.Verbose
)
// pre-condition checks to make sure certain variables are set
if client == nil {
return "", fmt.Errorf("invalid client (client is nil)")
}
if accessToken != "" {
_, err := jws.Verify([]byte(accessToken), jws.WithKeySet(client.Provider.KeySet), jws.WithValidateKey(true))
if err != nil {
return "", fmt.Errorf("failed to verify access token: %v", err)
}
}
if idToken != "" {
_, err := jws.Verify([]byte(idToken), jws.WithKeySet(client.Provider.KeySet), jws.WithValidateKey(true))
if err != nil {
return "", fmt.Errorf("failed to verify ID token: %v", err)
}
}
// TODO: 2. Check if we are already registered as a trusted issuer with authorization server...
// 3.a if not, create a new JWKS (or just JWK) to be verified
var (
keyPath string = params.KeyPath
privateJwk jwk.Key
publicJwk jwk.Key
)
rawPrivateKey, err := os.ReadFile(keyPath)
if err != nil {
if verbose {
fmt.Printf("failed to read private key...generating a new one.\n")
}
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return "", fmt.Errorf("failed to generate new RSA key: %v", err)
}
privateJwk, publicJwk, err = GenerateJwkKeyPairFromPrivateKey(privateKey) // FIXME: needs to pull correct version from cryptox
if err != nil {
return "", fmt.Errorf("failed to generate JWK pair from private key: %v", err)
}
// save new key to key path to reuse later
b := cryptox.MarshalRSAPrivateKey(privateKey)
err = os.WriteFile(keyPath, b, os.ModePerm)
if err != nil {
fmt.Printf("failed to write private key to file: %v\n", err)
}
} else {
privateKey, err := cryptox.GenerateRSAPrivateKey(rawPrivateKey)
if err != nil {
return "", fmt.Errorf("failed to generate RSA key from string: %v", err)
}
privateJwk, publicJwk, err = cryptox.GenerateJwkKeyPairFromPrivateKey(privateKey)
if err != nil {
return "", fmt.Errorf("failed to generate JWK pair from private key: %v", err)
}
}
// add more required claims and validate
publicJwk.Set("kid", uuid.New().String())
publicJwk.Set("use", "sig")
if err := publicJwk.Validate(); err != nil {
return "", fmt.Errorf("failed to validate public JWK: %v", err)
}
trustedIssuer.PublicKey = publicJwk
// add offline_access scope to enable refresh tokens
if params.Refresh {
trustedIssuer.Scope = append(trustedIssuer.Scope, "offline_access")
}
// 3.b ...and then, add opaal's server host as a trusted issuer with JWK
if verbose {
fmt.Printf("Attempting to add issuer to authorization server...\n")
}
res, err := client.AddTrustedIssuer(
eps.TrustedIssuers,
trustedIssuer,
)
if err != nil {
return "", fmt.Errorf("failed to add trusted issuer: %v", err)
}
fmt.Printf("trusted issuer: %v\n", string(res))
// TODO: add trusted issuer to cache if successful
// 4. create a new JWT based on the claims from the identity provider and sign
parsedIdToken, err := jwt.ParseString(idToken, jwt.WithKeySet(client.Provider.KeySet))
if err != nil {
return "", fmt.Errorf("failed to parse ID token: %v", err)
}
payload := parsedIdToken.PrivateClaims()
payload["iss"] = trustedIssuer.Issuer
payload["aud"] = []string{eps.Token}
payload["iat"] = time.Now().Unix()
payload["nbf"] = time.Now().Unix()
payload["exp"] = time.Now().Add(time.Second * 3600 * 16).Unix()
payload["sub"] = "opaal"
// if an "audience" value is set, then override the token endpoint value
if len(params.Audience) > 0 {
payload["aud"] = params.Audience
}
// include the offline_access scope if refresh tokens are enabled
if params.Refresh {
v, ok := payload["scope"]
if !ok {
payload["scope"] = []string{"offline_access"}
} else {
// FIXME: probably should not assume scope is []string even though it should be
scope := v.([]string)
scope = append(scope, "offline_access")
payload["scope"] = scope
}
// also include offline_access in client to make request
client.Scope = append(client.Scope, "offline_access")
}
payloadJson, err := json.Marshal(payload)
if err != nil {
return "", fmt.Errorf("failed to marshal payload: %v", err)
}
newJwt, err := jws.Sign(payloadJson, jws.WithJSON(), jws.WithKey(jwa.RS256, privateJwk))
if err != nil {
return "", fmt.Errorf("failed to sign token: %v", err)
}
// 5. dynamically register new OAuth client and authorize it to make jwt_bearer request
fmt.Printf("Registering new OAuth2 client with authorization server...\n")
res, err = client.RegisterOAuthClient(eps.Register, []oauth.GrantType{oauth.JwtBearer})
if err != nil {
return "", fmt.Errorf("failed to register client: %v", err)
}
fmt.Printf("%v\n", string(res))
// extract the client info from response
var clientData map[string]any
err = json.Unmarshal(res, &clientData)
if err != nil {
return "", fmt.Errorf("failed to unmarshal client data: %v", err)
} else {
// check for error first
errJson := clientData["error"]
if errJson == nil {
client.Id = clientData["client_id"].(string)
client.Secret = clientData["client_secret"].(string)
} else {
// delete client and try to create again
fmt.Printf("Attempting to delete client...\n")
err := client.DeleteOAuthClient(eps.Clients)
if err != nil {
return "", fmt.Errorf("failed to delete OAuth client: %v", err)
}
fmt.Printf("Attempting to re-create client...\n")
res, err := client.CreateOAuthClient(eps.Clients, []oauth.GrantType{oauth.JwtBearer})
if err != nil {
return "", fmt.Errorf("failed to register client: %v", err)
}
fmt.Printf("%v\n", string(res))
}
}
// TODO: add OAuth client to cache if successfully
// authorize the client
// fmt.Printf("Attempting to authorize client...\n")
// res, err = client.AuthorizeOAuthClient(config.Authorization.RequestUrls.Authorize)
// if err != nil {
// return fmt.Errorf("failed to authorize client: %v", err)
// }
// fmt.Printf("%v\n", string(res))
// 6. send JWT to authorization server and receive a access token
if eps.Token != "" {
fmt.Printf("Fetching access token from authorization server...\n")
fmt.Printf("jwt: %s\n", string(newJwt))
res, err := client.PerformJwtBearerTokenGrant(eps.Token, string(newJwt))
if err != nil {
return "", fmt.Errorf("failed to fetch access token: %v", err)
}
// extract token from response if there are no errors
var data map[string]any
err = json.Unmarshal(res, &data)
if err != nil {
return "", fmt.Errorf("failed to unmarshal response: %v", err)
}
if data["error"] != nil {
return "", fmt.Errorf("the authorization server returned an error (%v): %v", data["error"], data["error_description"])
}
fmt.Printf("%s\n", res)
err = json.Unmarshal(res, &data)
if err != nil {
return "", fmt.Errorf("failed to unmarshal access token: %v", err)
}
return data["access_token"].(string), nil
} else {
return "", fmt.Errorf("token endpoint not set")
}
return string(res), nil
}
func ForwardToken(eps JwtBearerFlowEndpoints, params JwtBearerFlowParams) error {
var (
client = params.Client
idToken = params.IdToken
// idp = params.IdentityProvider
verbose = params.Verbose
)
// fetch JWKS and add issuer to authentication server to submit ID token
if verbose {
fmt.Printf("Fetching JWKS from authentication server for verification...\n")
}
err := client.Provider.FetchJwks()
if err != nil {
return fmt.Errorf("failed to fetch JWK: %v", err)
} else {
if verbose {
fmt.Printf("Successfully retrieved JWK from authentication server.\n\n")
fmt.Printf("Attempting to add issuer to authorization server...\n")
}
ti := &oauth.TrustedIssuer{
Issuer: client.Provider.Issuer,
Subject: "1",
ExpiresAt: time.Now().Add(time.Second * 3600),
}
res, err := client.AddTrustedIssuer(
eps.TrustedIssuers,
ti,
)
if err != nil {
return fmt.Errorf("failed to add trusted issuer: %v", err)
}
if verbose {
fmt.Printf("%v\n", string(res))
}
}
// try and register a new client with authorization server
if verbose {
fmt.Printf("Registering new OAuth2 client with authorization server...\n")
}
res, err := client.RegisterOAuthClient(eps.Register, []oauth.GrantType{oauth.JwtBearer})
if err != nil {
return fmt.Errorf("failed to register client: %v", err)
}
if verbose {
fmt.Printf("%v\n", string(res))
}
// extract the client info from response
var clientData map[string]any
err = json.Unmarshal(res, &clientData)
if err != nil {
return fmt.Errorf("failed to unmarshal client data: %v", err)
} else {
// check for error first
errJson := clientData["error"]
if errJson == nil {
client.Id = clientData["client_id"].(string)
client.Secret = clientData["client_secret"].(string)
} else {
// delete client and create again
fmt.Printf("Attempting to delete client...\n")
err := client.DeleteOAuthClient(eps.Clients)
if err != nil {
return fmt.Errorf("failed to delete OAuth client: %v", err)
}
fmt.Printf("Attempting to re-create client...\n")
res, err := client.CreateOAuthClient(eps.Clients, []oauth.GrantType{oauth.JwtBearer})
if err != nil {
return fmt.Errorf("failed to register client: %v", err)
}
fmt.Printf("%v\n", string(res))
}
}
// authorize the client
// fmt.Printf("Attempting to authorize client...\n")
// res, err = client.AuthorizeOAuthClient(config.Authorization.RequestUrls.Authorize)
// if err != nil {
// return fmt.Errorf("failed to authorize client: %v", err)
// }
// fmt.Printf("%v\n", string(res))
// use ID token/user info to fetch access token from authentication server
if eps.Token != "" {
if verbose {
fmt.Printf("Fetching access token from authorization server...\n")
}
res, err := client.PerformJwtBearerTokenGrant(eps.Token, idToken)
if err != nil {
return fmt.Errorf("failed to fetch access token: %v", err)
}
if verbose {
fmt.Printf("%s\n", res)
}
} else {
return fmt.Errorf("token endpoint is not set")
}
return nil
}
func GenerateJwkKeyPairFromPrivateKey(privateKey *rsa.PrivateKey) (jwk.Key, jwk.Key, error) {
privateJwk, err := jwk.FromRaw(privateKey)
if err != nil {
return nil, nil, fmt.Errorf("failed to create private JWK: %v", err)
}
publicJwk, err := jwk.PublicKeyOf(privateJwk)
if err != nil {
return nil, nil, fmt.Errorf("failed to create public JWK: %v", err)
}
return privateJwk, publicJwk, nil
}