Major refactoring and code restructure

This commit is contained in:
David Allen 2024-03-10 20:19:29 -06:00
parent 72adbe1f0d
commit 6d63211d35
No known key found for this signature in database
GPG key ID: 1D2A29322FBB6FCB
10 changed files with 454 additions and 859 deletions

View file

@ -1,424 +0,0 @@
package opaal
import (
"crypto/rand"
"crypto/rsa"
"davidallendj/opaal/internal/oidc"
"encoding/json"
"errors"
"fmt"
"net/http"
"reflect"
"time"
"github.com/davidallendj/go-utils/util"
"github.com/google/uuid"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jws"
"github.com/lestrrat-go/jwx/v2/jwt"
)
// TODO: change authorization code flow to use these instead
type AuthorizationCodeFlowEndpoints struct {
Login string
Token string
Identities string
TrustedIssuer string
Register string
}
func AuthorizationCodeWithConfig(config *Config, server *Server, client *Client, idp *oidc.IdentityProvider) error {
// check preconditions are met
err := verifyParams(config, server, client, idp)
if err != nil {
return err
}
// build the authorization URL to redirect user for social sign-in
state := config.Authentication.Flows["authorization_code"]["state"]
var authorizationUrl = client.BuildAuthorizationUrl(idp.Endpoints.Authorization, state)
// print the authorization URL for sharing
fmt.Printf("Login with identity provider:\n\n %s/login\n %s\n\n",
server.GetListenAddr(), authorizationUrl,
)
// automatically open browser to initiate login flow (only useful for testing and debugging)
if config.Options.OpenBrowser {
util.OpenUrl(authorizationUrl)
}
// authorize oauth client and listen for callback from provider
fmt.Printf("Waiting for authorization code redirect @%s/oidc/callback...\n", server.GetListenAddr())
code, err := server.WaitForAuthorizationCode(authorizationUrl, "")
if errors.Is(err, http.ErrServerClosed) {
fmt.Printf("\n=========================================\nServer closed.\n=========================================\n\n")
} else if err != nil {
return fmt.Errorf("failed to start server: %s", err)
}
// start up another server in background to listen for success or failures
d := StartListener(server)
// use code from response and exchange for bearer token (with ID token)
bearerToken, err := client.FetchTokenFromAuthenticationServer(
code,
idp.Endpoints.Token,
state,
)
if err != nil {
return fmt.Errorf("failed to fetch token from issuer: %v", err)
}
// fmt.Printf("%v\n", string(bearerToken))
// unmarshal data to get id_token and access_token
var data map[string]any
err = json.Unmarshal([]byte(bearerToken), &data)
if err != nil || data == nil {
return fmt.Errorf("failed to unmarshal token: %v", err)
}
// make sure we have an ID token
if data["id_token"] == nil {
return fmt.Errorf("no ID token found...aborting")
}
// extract ID token from bearer as JSON string for easy consumption
idToken := data["id_token"].(string)
idJwtSegments, err := util.DecodeJwt(idToken)
if err != nil {
fmt.Printf("failed to parse ID token: %v\n", err)
} else {
fmt.Printf("id_token: %v\n", idToken)
if config.Options.DecodeIdToken {
if err != nil {
fmt.Printf("failed to decode JWT: %v\n", err)
} else {
for i, segment := range idJwtSegments {
// don't print last segment (signatures)
if i == len(idJwtSegments)-1 {
break
}
fmt.Printf("%s\n", string(segment))
}
}
}
fmt.Println()
}
// extract the access token to get the scopes
accessToken := data["access_token"].(string)
accessJwtSegments, err := util.DecodeJwt(accessToken)
if err != nil || len(accessJwtSegments) <= 0 {
fmt.Printf("failed to parse access token: %v\n", err)
} else {
fmt.Printf("access_token (from identity provider): %v\n", accessToken)
if config.Options.DecodeIdToken {
if err != nil {
fmt.Printf("failed to decode JWT: %v\n", err)
} else {
for i, segment := range accessJwtSegments {
// don't print last segment (signatures)
if i == len(accessJwtSegments)-1 {
break
}
fmt.Printf("%s\n", string(segment))
}
}
}
fmt.Println()
}
if !config.Options.ForwardToken {
// TODO: implement our own JWT to send to Hydra
// 1. verify that the JWT from the issuer is valid
key, ok := idp.Jwks.Key(0)
if !ok {
return fmt.Errorf("no key found in key set")
}
parsedIdToken, err := jwt.ParseString(idToken, jwt.WithKey(jwa.RS256, key))
if err != nil {
return fmt.Errorf("failed to parse ID token: %v", err)
}
_, err = jwt.ParseString(accessToken, jwt.WithKeySet(idp.Jwks))
if err != nil {
return fmt.Errorf("failed to parse access token: %v", err)
}
_, err = jws.Verify([]byte(idToken), jws.WithKeySet(idp.Jwks), jws.WithValidateKey(true))
if err != nil {
return fmt.Errorf("failed to verify JWT: %v", err)
}
// 2. create a new JWKS (or just JWK) to be verified
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return fmt.Errorf("failed to generate private RSA k-ey: %v", err)
}
privateJwk, err := jwk.FromRaw(privateKey)
if err != nil {
return fmt.Errorf("failed to create private JWK: %v", err)
}
publicJwk, err := jwk.PublicKeyOf(privateJwk)
if err != nil {
return fmt.Errorf("failed to create public JWK: %v", err)
}
publicJwk.Set("kid", uuid.New().String())
// 3. add opaal's server host as a trusted issuer with JWK
fmt.Printf("Attempting to add issuer to authorization server...\n")
res, err := client.AddTrustedIssuer(
config.Authorization.RequestUrls.TrustedIssuers,
server.Addr,
publicJwk,
"1",
time.Second*3600,
)
if err != nil {
return fmt.Errorf("failed to add trusted issuer: %v", err)
}
fmt.Printf("%v\n", string(res))
// 4. create a new JWT based on the claims from the identity provider and sign
payload := parsedIdToken.PrivateClaims()
payload["iss"] = server.Addr
payload["aud"] = []string{config.Authorization.RequestUrls.Token}
payload["iat"] = time.Now().Unix()
payload["nbf"] = time.Now().Unix()
payload["exp"] = time.Now().Add(time.Second * 3600).Unix()
payload["sub"] = "1"
payloadJson, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("failed to marshal payload: %v", err)
}
newToken, err := jws.Sign(payloadJson, jws.WithJSON(), jws.WithKey(jwa.RS256, privateJwk))
if err != nil {
return fmt.Errorf("failed to sign token: %v", err)
}
// sig = rsasha256(b64urlencode(header) + "." + b64urlencode(payload))
// signature := util.EncodeBase64() + util.EncodeBase64() +
// 5. dynamically register new OAuth client and authorize it
fmt.Printf("Registering new OAuth2 client with authorization server...\n")
res, err = client.RegisterOAuthClient(config.Authorization.RequestUrls.Register, []string{})
if err != nil {
return fmt.Errorf("failed to register client: %v", err)
}
fmt.Printf("%v\n", string(res))
// extract the client info from response
var clientData map[string]any
err = json.Unmarshal(res, &clientData)
if err != nil {
return fmt.Errorf("failed to unmarshal client data: %v", err)
} else {
// check for error first
errJson := clientData["error"]
if errJson == nil {
client.Id = clientData["client_id"].(string)
client.Secret = clientData["client_secret"].(string)
} else {
// delete client and create again
fmt.Printf("Attempting to delete client...\n")
err := client.DeleteOAuthClient(config.Authorization.RequestUrls.Clients)
if err != nil {
return fmt.Errorf("failed to delete OAuth client: %v", err)
}
fmt.Printf("Attempting to re-create client...\n")
res, err := client.CreateOAuthClient(config.Authorization.RequestUrls.Clients, []string{})
if err != nil {
return fmt.Errorf("failed to register client: %v", err)
}
fmt.Printf("%v\n", string(res))
}
}
// authorize the client
// fmt.Printf("Attempting to authorize client...\n")
// res, err = client.AuthorizeOAuthClient(config.Authorization.RequestUrls.Authorize)
// if err != nil {
// return fmt.Errorf("failed to authorize client: %v", err)
// }
// fmt.Printf("%v\n", string(res))
// 6. send JWT to authorization server and receive a access token
if config.Authorization.RequestUrls.Token != "" {
fmt.Printf("Fetching access token from authorization server...\n")
res, err := client.PerformTokenGrant(config.Authorization.RequestUrls.Token, string(newToken))
if err != nil {
return fmt.Errorf("failed to fetch access token: %v", err)
}
fmt.Printf("%s\n", res)
}
} else {
// extract the scope from access token claims
// var scope []string
// var accessJsonPayload map[string]any
// var accessJwtPayload []byte = accessJwtSegments[1]
// if accessJsonPayload != nil {
// err := json.Unmarshal(accessJwtPayload, &accessJsonPayload)
// if err != nil {
// return fmt.Errorf("failed to unmarshal JWT: %v", err)
// }
// scope = idJsonPayload["scope"].([]string)
// }
// create a new identity with identity and session manager if url is provided
// if config.RequestUrls.Identities != "" {
// fmt.Printf("Attempting to create a new identity...\n")
// err := client.CreateIdentity(config.RequestUrls.Identities, idToken)
// if err != nil {
// return fmt.Errorf("failed to create new identity: %v", err)
// }
// _, err = client.FetchIdentities(config.RequestUrls.Identities)
// if err != nil {
// return fmt.Errorf("failed to fetch identities: %v", err)
// }
// fmt.Printf("Created new identity successfully.\n\n")
// }
// extract the subject from ID token claims
var subject string
var audience []string
var idJsonPayload map[string]any
var idJwtPayload []byte = idJwtSegments[1]
if idJwtPayload != nil {
err := json.Unmarshal(idJwtPayload, &idJsonPayload)
if err != nil {
return fmt.Errorf("failed to unmarshal JWT: %v", err)
}
subject = idJsonPayload["sub"].(string)
audType := reflect.ValueOf(idJsonPayload["aud"])
switch audType.Kind() {
case reflect.String:
audience = append(audience, idJsonPayload["aud"].(string))
case reflect.Array:
audience = idJsonPayload["aud"].([]string)
}
} else {
return fmt.Errorf("failed to extract subject from ID token claims")
}
// fetch JWKS and add issuer to authentication server to submit ID token
fmt.Printf("Fetching JWKS from authentication server for verification...\n")
err = idp.FetchJwks()
if err != nil {
return fmt.Errorf("failed to fetch JWK: %v", err)
} else {
fmt.Printf("Successfully retrieved JWK from authentication server.\n\n")
fmt.Printf("Attempting to add issuer to authorization server...\n")
res, err := client.AddTrustedIssuerWithIdentityProvider(
config.Authorization.RequestUrls.TrustedIssuers,
idp,
subject,
time.Duration(1000),
)
if err != nil {
return fmt.Errorf("failed to add trusted issuer: %v", err)
}
fmt.Printf("%v\n", string(res))
}
// add client ID to audience
audience = append(audience, client.Id)
audience = append(audience, "http://127.0.0.1:4444/oauth2/token")
// try and register a new client with authorization server
fmt.Printf("Registering new OAuth2 client with authorization server...\n")
res, err := client.RegisterOAuthClient(config.Authorization.RequestUrls.Register, audience)
if err != nil {
return fmt.Errorf("failed to register client: %v", err)
}
fmt.Printf("%v\n", string(res))
// extract the client info from response
var clientData map[string]any
err = json.Unmarshal(res, &clientData)
if err != nil {
return fmt.Errorf("failed to unmarshal client data: %v", err)
} else {
// check for error first
errJson := clientData["error"]
if errJson == nil {
client.Id = clientData["client_id"].(string)
client.Secret = clientData["client_secret"].(string)
} else {
// delete client and create again
fmt.Printf("Attempting to delete client...\n")
err := client.DeleteOAuthClient(config.Authorization.RequestUrls.Clients)
if err != nil {
return fmt.Errorf("failed to delete OAuth client: %v", err)
}
fmt.Printf("Attempting to re-create client...\n")
res, err := client.CreateOAuthClient(config.Authorization.RequestUrls.Clients, audience)
if err != nil {
return fmt.Errorf("failed to register client: %v", err)
}
fmt.Printf("%v\n", string(res))
}
}
// authorize the client
// fmt.Printf("Attempting to authorize client...\n")
// res, err = client.AuthorizeOAuthClient(config.Authorization.RequestUrls.Authorize)
// if err != nil {
// return fmt.Errorf("failed to authorize client: %v", err)
// }
// fmt.Printf("%v\n", string(res))
// use ID token/user info to fetch access token from authentication server
if config.Authorization.RequestUrls.Token != "" {
fmt.Printf("Fetching access token from authorization server...\n")
res, err := client.PerformTokenGrant(config.Authorization.RequestUrls.Token, idToken)
if err != nil {
return fmt.Errorf("failed to fetch access token: %v", err)
}
fmt.Printf("%s\n", res)
}
}
var access_token []byte
d <- access_token
return nil
}
func verifyParams(config *Config, server *Server, client *Client, idp *oidc.IdentityProvider) error {
// make sure we have a valid server and client
if server == nil {
return fmt.Errorf("server not initialized or valid (server == nil)")
}
if client == nil {
return fmt.Errorf("client not initialized or valid (client == nil)")
}
if idp == nil {
return fmt.Errorf("identity provider not initialized or valid (idp == nil)")
}
// check if all appropriate parameters are set in config
if !HasRequiredConfigParams(config) {
return fmt.Errorf("required params not set correctly or missing")
}
return nil
}
func StartListener(server *Server) chan []byte {
d := make(chan []byte)
quit := make(chan bool)
go server.Serve(d)
go func() {
select {
case <-d:
fmt.Printf("got access token")
quit <- true
case <-quit:
close(d)
close(quit)
return
default:
}
}()
return d
}

View file

@ -1,93 +0,0 @@
package opaal
import (
"net/http"
"net/http/cookiejar"
"slices"
"github.com/davidallendj/go-utils/mathx"
"golang.org/x/net/publicsuffix"
)
type Client struct {
http.Client
Id string `yaml:"id"`
Secret string `yaml:"secret"`
Name string `yaml:"name"`
Description string `yaml:"description"`
Issuer string `yaml:"issuer"`
RegistrationAccessToken string `yaml:"registration-access-token"`
RedirectUris []string `yaml:"redirect-uris"`
Scope []string `yaml:"scope"`
FlowId string
CsrfToken string
}
func NewClient() *Client {
return &Client{}
}
func NewClientWithConfig(config *Config) *Client {
// make sure config is valid
if config == nil {
return nil
}
// make sure we have at least one client
clients := config.Authentication.Clients
if len(clients) <= 0 {
return nil
}
// use the first client found by default
return &Client{
Id: clients[0].Id,
Secret: clients[0].Secret,
Name: clients[0].Name,
Issuer: clients[0].Issuer,
Scope: clients[0].Scope,
RedirectUris: clients[0].RedirectUris,
}
}
func NewClientWithConfigByIndex(config *Config, index int) *Client {
size := len(config.Authentication.Clients)
index = mathx.Clamp(index, 0, size)
return nil
}
func NewClientWithConfigByName(config *Config, name string) *Client {
index := slices.IndexFunc(config.Authentication.Clients, func(c Client) bool {
return c.Name == name
})
if index >= 0 {
return &config.Authentication.Clients[index]
}
return nil
}
func NewClientWithConfigByProvider(config *Config, issuer string) *Client {
index := slices.IndexFunc(config.Authentication.Clients, func(c Client) bool {
return c.Issuer == issuer
})
if index >= 0 {
return &config.Authentication.Clients[index]
}
return nil
}
func NewClientWithConfigById(config *Config, id string) *Client {
index := slices.IndexFunc(config.Authentication.Clients, func(c Client) bool {
return c.Id == id
})
if index >= 0 {
return &config.Authentication.Clients[index]
}
return nil
}
func (client *Client) ClearCookies() {
jar, _ := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List})
client.Jar = jar
}

View file

@ -1,156 +0,0 @@
package db
import (
"davidallendj/opaal/internal/oidc"
"fmt"
"github.com/jmoiron/sqlx"
_ "github.com/mattn/go-sqlite3"
)
func CreateIdentityProvidersIfNotExists(path string) (*sqlx.DB, error) {
schema := `
CREATE TABLE IF NOT EXISTS identity_providers (
issuer TEXT NOT NULL,
authorization_endpoint TEXT,
token_endpoint TEXT,
revocation_endpoint TEXT,
introspection_endpoint TEXT,
userinfo_endpoint TEXT,
jwks_uri TEXT,
response_types_supported TEXT,
response_modes_supported TEXT,
grant_types_supported TEXT,
token_endpoint_auth_methods_supported TEXT,
subject_types_supported TEXT,
id_token_signing_alg_values_supported TEXT,
claim_types_supported TEXT,
claims_supported TEXT,
jwks TEXT,
PRIMARY KEY (issuer)
);
`
db, err := sqlx.Open("sqlite3", path)
if err != nil {
return nil, fmt.Errorf("could not open database: %v", err)
}
db.MustExec(schema)
return db, nil
}
func InsertIdentityProviders(path string, providers *[]oidc.IdentityProvider) error {
if providers == nil {
return fmt.Errorf("states == nil")
}
// create database if it doesn't already exist
db, err := CreateIdentityProvidersIfNotExists(path)
if err != nil {
return err
}
// insert all probe states into db
tx := db.MustBegin()
for _, state := range *providers {
sql := `INSERT OR REPLACE INTO identity_providers
(
issuer,
authorization_endpoint,
token_endpoint,
revocation_endpoint,
introspection_endpoint,
userinfo_endpoint,
jwks_uri,
response_types_supported,
response_modes_supported,
grant_types_supported,
token_endpoint_auth_methods_supported,
subject_types_supported,
id_token_signing_alg_values_supported,
claim_types_supported,
claims_supported,
jwks
)
VALUES
(
:issuer,
:authorization_endpoint,
:token_endpoint,
:revocation_endpoint,
:introspection_endpoint,
:userinfo_endpoint,
:jwks_uri,
:response_types_supported,
:response_modes_supported,
:grant_types_supported,
:token_endpoint_auth_methods_supported,
:subject_types_supported,
:id_token_signing_alg_values_supported,
:claim_types_supported,
:claims_supported,
:jwks
);`
_, err := tx.NamedExec(sql, &state)
if err != nil {
fmt.Printf("could not execute transaction: %v\n", err)
}
}
err = tx.Commit()
if err != nil {
return fmt.Errorf("could not commit transaction: %v", err)
}
return nil
}
func GetIdentityProvider(path string, issuer string) (*oidc.IdentityProvider, error) {
db, err := sqlx.Open("sqlite3", path)
if err != nil {
return nil, fmt.Errorf("could not open database: %v", err)
}
results := &oidc.IdentityProvider{}
err = db.Select(&results, "SELECT * FROM magellan_scanned_ports ORDER BY host ASC, port ASC LIMIT 1;")
if err != nil {
return nil, fmt.Errorf("could not retrieve probes: %v", err)
}
return results, nil
}
func GetIdentityProviders(path string) ([]oidc.IdentityProvider, error) {
db, err := sqlx.Open("sqlite3", path)
if err != nil {
return nil, fmt.Errorf("could not open database: %v", err)
}
results := []oidc.IdentityProvider{}
err = db.Select(&results, "SELECT * FROM magellan_scanned_ports ORDER BY host ASC, port ASC;")
if err != nil {
return nil, fmt.Errorf("could not retrieve probes: %v", err)
}
return results, nil
}
func DeleteIdentityProviders(path string, results *[]oidc.IdentityProvider) error {
if results == nil {
return fmt.Errorf("no probe results found")
}
db, err := sqlx.Open("sqlite3", path)
if err != nil {
return fmt.Errorf("could not open database: %v", err)
}
tx := db.MustBegin()
for _, state := range *results {
sql := `DELETE FROM identity_providers WHERE host = :issuer;`
_, err := tx.NamedExec(sql, &state)
if err != nil {
fmt.Printf("could not execute transaction: %v\n", err)
}
}
err = tx.Commit()
if err != nil {
return fmt.Errorf("could not commit transaction: %v", err)
}
return nil
}

View file

@ -1,6 +1,7 @@
package opaal
package flows
import (
"davidallendj/opaal/internal/oauth"
"fmt"
)
@ -15,9 +16,9 @@ type ClientCredentialsFlowEndpoints struct {
Token string
}
func ClientCredentials(eps ClientCredentialsFlowEndpoints, client *Client) error {
func NewClientCredentialsFlow(eps ClientCredentialsFlowEndpoints, client *oauth.Client) error {
// register a new OAuth 2 client with authorization srever
_, err := client.CreateOAuthClient(eps.Create, nil)
_, err := client.CreateOAuthClient(eps.Create)
if err != nil {
return fmt.Errorf("failed to register OAuth client: %v", err)
}
@ -37,12 +38,3 @@ func ClientCredentials(eps ClientCredentialsFlowEndpoints, client *Client) error
fmt.Printf("token: %v\n", string(res))
return nil
}
func ClientCredentialsWithConfig(config *Config, client *Client) error {
eps := ClientCredentialsFlowEndpoints{
Create: config.Authorization.RequestUrls.Clients,
Authorize: config.Authorization.RequestUrls.Authorize,
Token: config.Authorization.RequestUrls.Token,
}
return ClientCredentials(eps, client)
}

View file

@ -0,0 +1,319 @@
package flows
import (
"crypto/rand"
"crypto/rsa"
"davidallendj/opaal/internal/oauth"
"davidallendj/opaal/internal/oidc"
"encoding/json"
"fmt"
"os"
"time"
"github.com/davidallendj/go-utils/cryptox"
"github.com/google/uuid"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jws"
"github.com/lestrrat-go/jwx/v2/jwt"
)
type JwtBearerFlowParams struct {
AccessToken string
IdToken string
IdentityProvider *oidc.IdentityProvider
TrustedIssuer *oauth.TrustedIssuer
Client *oauth.Client
Verbose bool
KeyPath string
}
type JwtBearerEndpoints struct {
TrustedIssuers string
Token string
Clients string
Register string
}
func NewJwtBearerFlow(eps JwtBearerEndpoints, params JwtBearerFlowParams) (string, error) {
// 1. verify that the JWT from the issuer is valid using all keys
var (
idp = params.IdentityProvider
accessToken = params.AccessToken
idToken = params.IdToken
client = params.Client
trustedIssuer = params.TrustedIssuer
verbose = params.Verbose
)
if accessToken != "" {
_, err := jws.Verify([]byte(accessToken), jws.WithKeySet(idp.KeySet), jws.WithValidateKey(true))
if err != nil {
return "", fmt.Errorf("failed to verify access token: %v", err)
}
}
if idToken != "" {
_, err := jws.Verify([]byte(idToken), jws.WithKeySet(idp.KeySet), jws.WithValidateKey(true))
if err != nil {
return "", fmt.Errorf("failed to verify ID token: %v", err)
}
}
// 2. Check if we are already registered as a trusted issuer with authorization server...
// 3.a if not, create a new JWKS (or just JWK) to be verified
var (
keyPath string = params.KeyPath
privateJwk jwk.Key
publicJwk jwk.Key
)
rawPrivateKey, err := os.ReadFile(keyPath)
if err != nil {
if verbose {
fmt.Printf("failed to read private key...generating a new one.\n")
}
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return "", fmt.Errorf("failed to generate new RSA key: %v", err)
}
privateJwk, publicJwk, err = cryptox.GenerateJwkKeyPairFromPrivateKey(privateKey)
if err != nil {
return "", fmt.Errorf("failed to generate JWK pair from private key: %v", err)
}
// save new key to key path to reuse later
b := cryptox.MarshalRSAPrivateKey(privateKey)
err = os.WriteFile(keyPath, b, os.ModePerm)
if err != nil {
fmt.Printf("failed to write private key to file: %v", err)
}
} else {
privateKey, err := cryptox.GenerateRSAPrivateKey(rawPrivateKey)
if err != nil {
return "", fmt.Errorf("failed to generate RSA key from string: %v", err)
}
privateJwk, publicJwk, err = cryptox.GenerateJwkKeyPairFromPrivateKey(privateKey)
if err != nil {
return "", fmt.Errorf("failed to generate JWK pair from private key: %v", err)
}
}
publicJwk.Set("kid", uuid.New().String())
publicJwk.Set("use", "sig")
if err := publicJwk.Validate(); err != nil {
return "", fmt.Errorf("failed to validate public JWK: %v", err)
}
trustedIssuer.PublicKey = publicJwk
// 3.b ...and then, add opaal's server host as a trusted issuer with JWK
if verbose {
fmt.Printf("Attempting to add issuer to authorization server...\n")
}
res, err := client.AddTrustedIssuer(
eps.TrustedIssuers,
trustedIssuer,
)
if err != nil {
return "", fmt.Errorf("failed to add trusted issuer: %v", err)
}
fmt.Printf("trusted issuer: %v\n", string(res))
// TODO: add trusted issuer to cache if successful
// 4. create a new JWT based on the claims from the identity provider and sign
parsedIdToken, err := jwt.ParseString(idToken, jwt.WithKeySet(idp.KeySet))
if err != nil {
return "", fmt.Errorf("failed to parse ID token: %v", err)
}
payload := parsedIdToken.PrivateClaims()
payload["iss"] = trustedIssuer.Issuer
payload["aud"] = []string{eps.Token}
payload["iat"] = time.Now().Unix()
payload["nbf"] = time.Now().Unix()
payload["exp"] = time.Now().Add(time.Second * 3600).Unix()
payload["sub"] = "opaal"
payloadJson, err := json.Marshal(payload)
if err != nil {
return "", fmt.Errorf("failed to marshal payload: %v", err)
}
fmt.Printf("payload: %v\n", string(payloadJson))
newJwt, err := jws.Sign(payloadJson, jws.WithJSON(), jws.WithKey(jwa.RS256, privateJwk))
if err != nil {
return "", fmt.Errorf("failed to sign token: %v", err)
}
// 5. dynamically register new OAuth client and authorize it to make jwt_bearer request
fmt.Printf("Registering new OAuth2 client with authorization server...\n")
res, err = client.RegisterOAuthClient(eps.Register)
if err != nil {
return "", fmt.Errorf("failed to register client: %v", err)
}
fmt.Printf("%v\n", string(res))
// extract the client info from response
var clientData map[string]any
err = json.Unmarshal(res, &clientData)
if err != nil {
return "", fmt.Errorf("failed to unmarshal client data: %v", err)
} else {
// check for error first
errJson := clientData["error"]
if errJson == nil {
client.Id = clientData["client_id"].(string)
client.Secret = clientData["client_secret"].(string)
} else {
// delete client and try to create again
fmt.Printf("Attempting to delete client...\n")
err := client.DeleteOAuthClient(eps.Clients)
if err != nil {
return "", fmt.Errorf("failed to delete OAuth client: %v", err)
}
fmt.Printf("Attempting to re-create client...\n")
res, err := client.CreateOAuthClient(eps.Clients)
if err != nil {
return "", fmt.Errorf("failed to register client: %v", err)
}
fmt.Printf("%v\n", string(res))
}
}
// TODO: add OAuth client to cache if successfully
// authorize the client
// fmt.Printf("Attempting to authorize client...\n")
// res, err = client.AuthorizeOAuthClient(config.Authorization.RequestUrls.Authorize)
// if err != nil {
// return fmt.Errorf("failed to authorize client: %v", err)
// }
// fmt.Printf("%v\n", string(res))
// 6. send JWT to authorization server and receive a access token
if eps.Token != "" {
fmt.Printf("Fetching access token from authorization server...\n")
res, err := client.PerformTokenGrant(eps.Token, string(newJwt))
if err != nil {
return "", fmt.Errorf("failed to fetch access token: %v", err)
}
// extract token from response if there are no errors
var data map[string]any
err = json.Unmarshal(res, &data)
if err != nil {
return "", fmt.Errorf("failed to unmarshal response: %v", err)
}
if data["error"] != nil {
return "", fmt.Errorf("the authorization server returned an error (%v): %v", data["error"], data["error_description"])
}
fmt.Printf("%s\n", res)
err = json.Unmarshal(res, &data)
if err != nil {
return "", fmt.Errorf("failed to unmarshal access token: %v", err)
}
return data["access_token"].(string), nil
} else {
return "", fmt.Errorf("token endpoint not set")
}
return string(res), nil
}
func ForwardToken(eps JwtBearerEndpoints, params JwtBearerFlowParams) error {
var (
client = params.Client
idToken = params.IdToken
idp = params.IdentityProvider
verbose = params.Verbose
)
// fetch JWKS and add issuer to authentication server to submit ID token
if verbose {
fmt.Printf("Fetching JWKS from authentication server for verification...\n")
}
err := idp.FetchJwks()
if err != nil {
return fmt.Errorf("failed to fetch JWK: %v", err)
} else {
if verbose {
fmt.Printf("Successfully retrieved JWK from authentication server.\n\n")
fmt.Printf("Attempting to add issuer to authorization server...\n")
}
ti := &oauth.TrustedIssuer{
Issuer: idp.Issuer,
Subject: "1",
ExpiresAt: time.Now().Add(time.Second * 3600),
}
res, err := client.AddTrustedIssuer(
eps.TrustedIssuers,
ti,
)
if err != nil {
return fmt.Errorf("failed to add trusted issuer: %v", err)
}
if verbose {
fmt.Printf("%v\n", string(res))
}
}
// try and register a new client with authorization server
if verbose {
fmt.Printf("Registering new OAuth2 client with authorization server...\n")
}
res, err := client.RegisterOAuthClient(eps.Register)
if err != nil {
return fmt.Errorf("failed to register client: %v", err)
}
if verbose {
fmt.Printf("%v\n", string(res))
}
// extract the client info from response
var clientData map[string]any
err = json.Unmarshal(res, &clientData)
if err != nil {
return fmt.Errorf("failed to unmarshal client data: %v", err)
} else {
// check for error first
errJson := clientData["error"]
if errJson == nil {
client.Id = clientData["client_id"].(string)
client.Secret = clientData["client_secret"].(string)
} else {
// delete client and create again
fmt.Printf("Attempting to delete client...\n")
err := client.DeleteOAuthClient(eps.Clients)
if err != nil {
return fmt.Errorf("failed to delete OAuth client: %v", err)
}
fmt.Printf("Attempting to re-create client...\n")
res, err := client.CreateOAuthClient(eps.Clients)
if err != nil {
return fmt.Errorf("failed to register client: %v", err)
}
fmt.Printf("%v\n", string(res))
}
}
// authorize the client
// fmt.Printf("Attempting to authorize client...\n")
// res, err = client.AuthorizeOAuthClient(config.Authorization.RequestUrls.Authorize)
// if err != nil {
// return fmt.Errorf("failed to authorize client: %v", err)
// }
// fmt.Printf("%v\n", string(res))
// use ID token/user info to fetch access token from authentication server
if eps.Token != "" {
if verbose {
fmt.Printf("Fetching access token from authorization server...\n")
}
res, err := client.PerformTokenGrant(eps.Token, idToken)
if err != nil {
return fmt.Errorf("failed to fetch access token: %v", err)
}
if verbose {
fmt.Printf("%s\n", res)
}
} else {
return fmt.Errorf("token endpoint is not set")
}
return nil
}

View file

@ -1,4 +1,4 @@
package opaal
package oauth
import (
"bytes"

View file

@ -1,65 +1,41 @@
package opaal
package oauth
import (
"bytes"
"davidallendj/opaal/internal/oidc"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/cookiejar"
"net/url"
"slices"
"strings"
"time"
"github.com/davidallendj/go-utils/httpx"
"github.com/davidallendj/go-utils/util"
"github.com/lestrrat-go/jwx/v2/jwk"
"golang.org/x/net/publicsuffix"
)
func (client *Client) AddTrustedIssuer(url string, issuer string, key jwk.Key, subject string, expires time.Duration) ([]byte, error) {
// hydra endpoint: POST /admin/trust/grants/jwt-bearer/issuers
quotedScopes := make([]string, len(client.Scope))
for i, s := range client.Scope {
quotedScopes[i] = fmt.Sprintf("\"%s\"", s)
}
jwkstr, err := json.Marshal(key)
if err != nil {
return nil, fmt.Errorf("failed to marshal JWK: %v", err)
}
// NOTE: Can also include "jwks_uri" instead
data := []byte(fmt.Sprintf("{"+
"\"allow_any_subject\": false,"+
"\"issuer\": \"%s\","+
"\"subject\": \"%s\","+
"\"expires_at\": \"%v\","+
"\"jwk\": %v,"+
"\"scope\": [ %s ]"+
"}", issuer, subject, time.Now().Add(expires).Format(time.RFC3339), string(jwkstr), strings.Join(quotedScopes, ",")))
req, err := http.NewRequest("POST", url, bytes.NewBuffer(data))
// req.Header.Add("X-CSRF-Token", client.CsrfToken.Value)
if err != nil {
return nil, fmt.Errorf("failed to make request: %v", err)
}
req.Header.Add("Content-Type", "application/json")
// req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", idToken))
res, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to do request: %v", err)
}
defer res.Body.Close()
return io.ReadAll(res.Body)
type Client struct {
http.Client
Id string `db:"id" yaml:"id"`
Secret string `db:"secret" yaml:"secret"`
Name string `db:"name" yaml:"name"`
Description string `db:"description" yaml:"description"`
Issuer string `db:"issuer" yaml:"issuer"`
RegistrationAccessToken string `db:"registration_access_token" yaml:"registration-access-token"`
RedirectUris []string `db:"redirect_uris" yaml:"redirect-uris"`
Scope []string `db:"scope" yaml:"scope"`
Audience []string `db:"audience" yaml:"audience"`
FlowId string
CsrfToken string
}
func (client *Client) AddTrustedIssuerWithIdentityProvider(url string, idp *oidc.IdentityProvider, subject string, expires time.Duration) ([]byte, error) {
// hydra endpoint: POST /admin/trust/grants/jwt-bearer/issuers
key, ok := idp.Jwks.Key(0)
if !ok {
return nil, fmt.Errorf("no keys found in key set")
}
return client.AddTrustedIssuer(url, idp.Issuer, key, subject, expires)
func NewClient() *Client {
return &Client{}
}
func (client *Client) ClearCookies() {
jar, _ := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List})
client.Jar = jar
}
func (client *Client) IsOAuthClientRegistered(clientUrl string) (bool, error) {
@ -107,9 +83,9 @@ func (client *Client) GetOAuthClient(clientUrl string) error {
return nil
}
func (client *Client) CreateOAuthClient(registerUrl string, audience []string) ([]byte, error) {
func (client *Client) CreateOAuthClient(registerUrl string) ([]byte, error) {
// hydra endpoint: POST /clients
audience = util.QuoteArrayStrings(audience)
audience := util.QuoteArrayStrings(client.Audience)
body := httpx.Body(fmt.Sprintf(`{
"client_id": "%s",
"client_name": "%s",
@ -151,9 +127,12 @@ func (client *Client) CreateOAuthClient(registerUrl string, audience []string) (
return b, err
}
func (client *Client) RegisterOAuthClient(registerUrl string, audience []string) ([]byte, error) {
func (client *Client) RegisterOAuthClient(registerUrl string) ([]byte, error) {
// hydra endpoint: POST /oauth2/register
audience = util.QuoteArrayStrings(audience)
if registerUrl == "" {
return nil, fmt.Errorf("no URL provided")
}
audience := util.QuoteArrayStrings(client.Audience)
body := httpx.Body(fmt.Sprintf(`{
"client_name": "opaal",
"token_endpoint_auth_method": "client_secret_post",

View file

@ -1,4 +1,4 @@
package opaal
package oauth
import (
"fmt"

99
internal/oauth/trusted.go Normal file
View file

@ -0,0 +1,99 @@
package oauth
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"github.com/davidallendj/go-utils/httpx"
"github.com/lestrrat-go/jwx/v2/jwk"
)
type TrustedIssuer struct {
Id string `db:"id" yaml:"id"`
AllowAnySubject bool `db:"allow_any_subject" yaml:"allow-any-subject"`
ExpiresAt time.Time `db:"expires_at" yaml:"expires-at"`
Issuer string `db:"issuer" yaml:"issuer"`
PublicKey jwk.Key `db:"public_key" yaml:"public-key"`
Scope []string `db:"scope" yaml:"scope"`
Subject string `db:"subject" yaml:"subject"`
}
func NewTrustedIssuer() *TrustedIssuer {
return &TrustedIssuer{
AllowAnySubject: false,
ExpiresAt: time.Now().Add(time.Hour),
Scope: []string{"openid"},
Subject: "1",
}
}
func (ti *TrustedIssuer) IsTrustedIssuerValid() bool {
err := ti.PublicKey.Validate()
return ti.Issuer != "" && err == nil && ti.Subject != ""
}
func ParseString(b []byte) (*TrustedIssuer, error) {
// take data from JSON to populate fields
ti := &TrustedIssuer{}
data := map[string]any{}
json.Unmarshal(b, &data)
return ti, nil
}
func (client *Client) ListTrustedIssuers(url string) ([]TrustedIssuer, error) {
// hydra endpoint: GET /admin/trust/grants/jwt-bearer/issuers
_, b, err := httpx.MakeHttpRequest(url, http.MethodGet, nil, nil)
if err != nil {
return nil, fmt.Errorf("failed to make request: %v", err)
}
// unmarshal results into TrustedIssuers objects
trustedIssuers := []TrustedIssuer{}
err = json.Unmarshal(b, &trustedIssuers)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal JSON: %v", err)
}
return trustedIssuers, nil
}
func (client *Client) AddTrustedIssuer(url string, ti *TrustedIssuer) ([]byte, error) {
// hydra endpoint: POST /admin/trust/grants/jwt-bearer/issuers
quotedScopes := make([]string, len(client.Scope))
for i, s := range client.Scope {
quotedScopes[i] = fmt.Sprintf("\"%s\"", s)
}
// NOTE: Can also include "jwks_uri" instead of "jwk"
body := map[string]any{
"allow_any_subject": ti.AllowAnySubject,
"issuer": ti.Issuer,
"expires_at": ti.ExpiresAt,
"jwk": ti.PublicKey,
"scope": client.Scope,
}
if !ti.AllowAnySubject {
body["subject"] = ti.Subject
}
b, err := json.Marshal(body)
fmt.Printf("request: %v\n", string(b))
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %v", err)
}
req, err := http.NewRequest("POST", url, bytes.NewBuffer(b))
if err != nil {
return nil, fmt.Errorf("failed to make request: %v", err)
}
req.Header.Add("Content-Type", "application/json")
res, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to do request: %v", err)
}
defer res.Body.Close()
return io.ReadAll(res.Body)
}

View file

@ -1,121 +0,0 @@
package opaal
import (
"fmt"
"net/http"
"os"
"strings"
"github.com/go-chi/chi/middleware"
"github.com/go-chi/chi/v5"
"github.com/nikolalohinski/gonja/v2"
"github.com/nikolalohinski/gonja/v2/exec"
)
type Server struct {
*http.Server
Host string `yaml:"host"`
Port int `yaml:"port"`
Callback string `yaml:"callback"`
}
func NewServerWithConfig(conf *Config) *Server {
host := conf.Server.Host
port := conf.Server.Port
server := &Server{
Server: &http.Server{
Addr: fmt.Sprintf("%s:%d", host, port),
},
Host: host,
Port: port,
}
return server
}
func (s *Server) SetListenAddr(host string, port int) {
s.Addr = s.GetListenAddr()
}
func (s *Server) GetListenAddr() string {
return fmt.Sprintf("%s:%d", s.Host, s.Port)
}
func (s *Server) WaitForAuthorizationCode(loginUrl string, callback string) (string, error) {
// check if callback is set
if callback == "" {
callback = "/oidc/callback"
}
var code string
r := chi.NewRouter()
r.Use(middleware.RedirectSlashes)
r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/login", http.StatusSeeOther)
})
r.HandleFunc("/login", func(w http.ResponseWriter, r *http.Request) {
// show login page with notice to redirect
template, err := gonja.FromFile("pages/index.html")
if err != nil {
panic(err)
}
data := exec.NewContext(map[string]interface{}{
"loginUrl": loginUrl,
})
if err = template.Execute(w, data); err != nil { // Prints: Hello Bob!
panic(err)
}
})
r.HandleFunc(callback, func(w http.ResponseWriter, r *http.Request) {
// get the code from the OIDC provider
if r != nil {
code = r.URL.Query().Get("code")
fmt.Printf("Authorization code: %v\n", code)
}
http.Redirect(w, r, "/redirect", http.StatusSeeOther)
})
r.HandleFunc("/redirect", func(w http.ResponseWriter, r *http.Request) {
err := s.Close()
if err != nil {
fmt.Printf("failed to close server: %v\n", err)
}
})
s.Handler = r
return code, s.ListenAndServe()
}
func (s *Server) Serve(data chan []byte) error {
output, ok := <-data
if !ok {
return fmt.Errorf("failed to receive data")
}
fmt.Printf("Received data: %v\n", string(output))
// http.HandleFunc("/redirect", func(w http.ResponseWriter, r *http.Request) {
// })
r := chi.NewRouter()
r.HandleFunc("/success", func(w http.ResponseWriter, r *http.Request) {
fmt.Printf("Serving success page.")
successPage, err := os.ReadFile("pages/success.html")
if err != nil {
fmt.Printf("failed to load success page: %v\n", err)
}
successPage = []byte(strings.ReplaceAll(string(successPage), "{{access_token}}", string(output)))
w.Write(successPage)
})
r.HandleFunc("/error", func(w http.ResponseWriter, r *http.Request) {
fmt.Printf("Serving error page.")
errorPage, err := os.ReadFile("pages/success.html")
if err != nil {
fmt.Printf("failed to load success page: %v\n", err)
}
// errorPage = []byte(strings.ReplaceAll(string(errorPage), "{{access_token}}", output))
w.Write(errorPage)
})
s.Handler = r
return s.ListenAndServe()
}